<a href="https://colab.research.google.com/github/IVANTAKE/Assignment1/blob/main/Assignment1%E2%80%94Wang_Yuxuan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install matplotlib



In [None]:
import numpy as np
import random
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
from IPython import display
import time

In [None]:
# parameters
GRID_SIZE = 20
TILE_COUNT = 20
LEARNING_RATE = 0.5  # learning rate
DISCOUNT = 0.95  # discount factor
INITIAL_EPSILON = 1.0  # Initial exploration rate
MIN_EPSILON = 0.01  # Minimum exploration rate
EPSILON_DECAY = 0.995  # Exploration rate attenuation
EPISODE_COUNT = 1000
MAX_STEPS_PER_EPISODE = 200

# Direction Vector
DIRECTIONS = {
    'UP': (0, -1),
    'DOWN': (0, 1),
    'LEFT': (-1, 0),
    'RIGHT': (1, 0)
}

In [None]:
# snake game environment
class SnakeEnv:
    def __init__(self):
        self.reset()

    def reset(self):
        self.snake_body = [(10, 10)]  # The initial position of the snake
        self.direction = random.choice(list(DIRECTIONS.values()))  # Initial direction
        self.food_pos = self._generate_food()  # Fruit location
        self.score = 0
        self.is_done = False
        self.step_count = 0

    def _generate_food(self):
        while True:
            pos = (random.randint(0, TILE_COUNT - 1), random.randint(0, TILE_COUNT - 1))
            if pos not in self.snake_body:
                return pos

    def get_state(self):
        head_x, head_y = self.snake_body[0]
        food_x, food_y = self.food_pos
        dir_x, dir_y = self.direction

        # State characteristics
        state = [
            # Determine if there is any danger ahead
            (dir_x == 1 and (head_x + 1 >= TILE_COUNT or (head_x + 1, head_y) in self.snake_body)) or
            (dir_x == -1 and (head_x - 1 < 0 or (head_x - 1, head_y) in self.snake_body)) or
            (dir_y == 1 and (head_y + 1 >= TILE_COUNT or (head_x, head_y + 1) in self.snake_body)) or
            (dir_y == -1 and (head_y - 1 < 0 or (head_x, head_y - 1) in self.snake_body)),

            # Determine if there is any danger on the right side
            (dir_y == 1 and (head_x + 1 >= TILE_COUNT or (head_x + 1, head_y) in self.snake_body)) or
            (dir_y == -1 and (head_x - 1 < 0 or (head_x - 1, head_y) in self.snake_body)) or
            (dir_x == -1 and (head_y + 1 >= TILE_COUNT or (head_x, head_y + 1) in self.snake_body)) or
            (dir_x == 1 and (head_y - 1 < 0 or (head_x, head_y - 1) in self.snake_body)),

            # Determine if there is any danger on the left side
            (dir_y == -1 and (head_x + 1 >= TILE_COUNT or (head_x + 1, head_y) in self.snake_body)) or
            (dir_y == 1 and (head_x - 1 < 0 or (head_x - 1, head_y) in self.snake_body)) or
            (dir_x == 1 and (head_y + 1 >= TILE_COUNT or (head_x, head_y + 1) in self.snake_body)) or
            (dir_x == -1 and (head_y - 1 < 0 or (head_x, head_y - 1) in self.snake_body)),

            # Direction information
            dir_x == 1, dir_x == -1, dir_y == 1, dir_y == -1,

            # Fruit location information
            food_x > head_x, food_x < head_x, food_y > head_y, food_y < head_y
        ]
        return tuple(map(int, state))

    def step(self, action):
      # straight
        if action == 0:
            pass
           # turn right
        elif action == 1:
            self.direction = (-self.direction[1], self.direction[0])
          # turn left
        elif action == 2:
            self.direction = (self.direction[1], -self.direction[0])

        new_head = (self.snake_body[0][0] + self.direction[0], self.snake_body[0][1] + self.direction[1])

        # Check collision
        if (new_head[0] < 0 or new_head[0] >= TILE_COUNT or
            new_head[1] < 0 or new_head[1] >= TILE_COUNT or
            new_head in self.snake_body):
            self.is_done = True
            reward = -10  # Collision penalty
        else:
            self.snake_body.insert(0, new_head)
            if new_head == self.food_pos:
                self.score += 1
                self.food_pos = self._generate_food()
                reward = 100  # Reward for eating fruit
            else:
                self.snake_body.pop()
                reward = -0.1  # Mobile punishment

        self.step_count += 1
        if self.step_count > MAX_STEPS_PER_EPISODE:
            self.is_done = True

        return self.get_state(), reward, self.is_done

    def render(self):
        plt.figure(figsize=(5, 5))
        plt.xlim(0, TILE_COUNT)
        plt.ylim(0, TILE_COUNT)
        plt.xticks([])
        plt.yticks([])

        # Draw the body of snake in circle
        for i, (x, y) in enumerate(self.snake_body):
            color = 'limegreen' if i == 0 else 'green'
            circle = Circle((x + 0.5, y + 0.5), 0.5, color=color)
            plt.gca().add_patch(circle)

        # Draw fruit
        food_circle = Circle((self.food_pos[0] + 0.5, self.food_pos[1] + 0.5), 0.5, color='red')
        plt.gca().add_patch(food_circle)

        plt.title(f"Score: {self.score}")
        display.display(plt.gcf())
        display.clear_output(wait=True)
        time.sleep(0.1)

# Q-learning agent
class QAgent:
    def __init__(self):
        self.q_values = {}
        self.epsilon = INITIAL_EPSILON

    def get_q(self, state, action):
        return self.q_values.get((state, action), 0.0)

    def choose_action(self, state):
        if random.random() < self.epsilon:
            return random.randint(0, 2)
        q_values = [self.get_q(state, a) for a in range(3)]
        return np.argmax(q_values)

    def update_q(self, state, action, reward, next_state):
        old_q = self.get_q(state, action)
        next_max_q = max([self.get_q(next_state, a) for a in range(3)])
        new_q = old_q + LEARNING_RATE * (reward + DISCOUNT * next_max_q - old_q)
        self.q_values[(state, action)] = new_q

    def decay_epsilon(self):
        self.epsilon = max(MIN_EPSILON, self.epsilon * EPSILON_DECAY)

# Training function
def training():
    env = SnakeEnv()
    agent = QAgent()

    for episode in range(EPISODE_COUNT):
        env.reset()
        state = env.get_state()
        total_reward = 0

        while not env.is_done:
            action = agent.choose_action(state)
            next_state, reward, done = env.step(action)
            agent.update_q(state, action, reward, next_state)
            state = next_state
            total_reward += reward

            env.render()  # game render

        agent.decay_epsilon()
        print(f"Episode {episode + 1}, Score: {env.score}, Total Reward: {total_reward}, Epsilon: {agent.epsilon:.4f}")

    print("THE TRAINGING IS FINISHED.")

# start training
if __name__ == "__main__":
    training()