In [1]:
# CDS524 Assignment 1
# Snake Game
# Student Name: HUANG Xinghua
# Student ID: 3160617

In [2]:
!pip install pygame numpy matplotlib



In [3]:
import pygame
import numpy as np
import random
from collections import deque
from tqdm import tqdm  # Progress bar for training

pygame 2.6.1 (SDL 2.28.4, Python 3.12.4)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [4]:
# Define Actions
action_space = ["LEFT", "RIGHT", "STRAIGHT"]

In [5]:
# Initialize Pygame
pygame.init()

(5, 0)

In [6]:
# Snake Game Class
class SnakeGame:
    def __init__(self, grid_size=10):
        self.grid_size = grid_size
        self.cell_size = 40  # Size of each grid cell
        self.width = self.grid_size * self.cell_size
        self.height = self.grid_size * self.cell_size
        self.window = pygame.display.set_mode((self.width, self.height))
        pygame.display.set_caption("Q-Learning Snake Game")

        self.font = pygame.font.Font(None, 36)
        self.score = 0
        self.reset()

    def reset(self):
        # Resets the game state at the start of each episode.
        self.snake = deque([(self.grid_size // 2, self.grid_size // 2)])  # Start at center
        self.food = self.place_food()
        self.direction = (0, 1)  # Moving right initially
        self.done = False
        self.score = 0
        return self.get_state()

    def place_food(self):
        # Places food at a random location not occupied by the snake.
        while True:
            food_x = random.randint(0, self.grid_size - 1)
            food_y = random.randint(0, self.grid_size - 1)
            if (food_x, food_y) not in self.snake:
                return (food_x, food_y)

    def get_state(self):
        #Returns an optimized state representation.
        head_x, head_y = self.snake[0]
        food_x, food_y = self.food

        # Relative food direction
        food_left = 1 if food_x < head_x else 0
        food_right = 1 if food_x > head_x else 0
        food_up = 1 if food_y < head_y else 0
        food_down = 1 if food_y > head_y else 0

        # Danger detection (binary values)
        danger_left = 1 if (head_x - 1, head_y) in self.snake or head_x - 1 < 0 else 0
        danger_right = 1 if (head_x + 1, head_y) in self.snake or head_x + 1 >= self.grid_size else 0
        danger_up = 1 if (head_x, head_y - 1) in self.snake or head_y - 1 < 0 else 0
        danger_down = 1 if (head_x, head_y + 1) in self.snake or head_y + 1 >= self.grid_size else 0

        return (food_left, food_right, food_up, food_down, danger_left, danger_right, danger_up, danger_down)

    def step(self, action):
        # Moves the snake based on the action and returns (new_state, reward, done).
        if self.done:
            return self.get_state(), 0, True

        # Determine new direction
        if action == "LEFT":
            self.direction = (-self.direction[1], self.direction[0])  # Rotate left
        elif action == "RIGHT":
            self.direction = (self.direction[1], -self.direction[0])  # Rotate right

        # Move the snake
        new_head = (self.snake[0][0] + self.direction[0], self.snake[0][1] + self.direction[1])

        # Check for collisions
        if (new_head in self.snake) or (new_head[0] < 0 or new_head[0] >= self.grid_size or new_head[1] < 0 or new_head[1] >= self.grid_size):
            self.done = True
            return self.get_state(), -100, True  # Game over penalty

        # Add new head position
        self.snake.appendleft(new_head)

        # Reward system
        reward = -0.1  # Small penalty for each move
        if new_head == self.food:
            self.score += 10
            reward = 10  # Strong reward for eating food
            self.food = self.place_food()  # Generate new food
        else:
            self.snake.pop()  # Move forward

        return self.get_state(), reward, self.done

    def render(self, action=None, reward=0):
        # Renders the game using Pygame and displays UI elements.
        self.window.fill((0, 0, 0))  # Black background

        # Draw snake
        for segment in self.snake:
            pygame.draw.rect(self.window, (0, 255, 0), (segment[1] * self.cell_size, segment[0] * self.cell_size, self.cell_size, self.cell_size))

        # Draw food
        pygame.draw.rect(self.window, (255, 0, 0), (self.food[1] * self.cell_size, self.food[0] * self.cell_size, self.cell_size, self.cell_size))

        # Display score
        score_text = self.font.render(f"Score: {self.score}", True, (255, 255, 255))
        self.window.blit(score_text, (10, 10))

        # Display action taken
        if action:
            action_text = self.font.render(f"Action: {action}", True, (255, 255, 255))
            self.window.blit(action_text, (10, 40))

        # Display reward
        reward_text = self.font.render(f"Reward: {reward}", True, (255, 255, 255))
        self.window.blit(reward_text, (10, 70))

        pygame.display.flip()
        pygame.time.delay(100)  # Slow down the game for visibility

In [7]:
# Q-learning Training
env = SnakeGame(grid_size=10)

In [None]:
# Q-table & hyperparameters
q_table = np.zeros((2, 2, 2, 2, 2, 2, 2, 2, len(action_space)))
alpha = 0.5
gamma = 0.9
epsilon = 1.0
epsilon_decay = 0.997
epsilon_min = 0.05
episodes = 5000
max_steps_per_episode = 500

# Training loop
for episode in tqdm(range(episodes), desc="Training Progress"):
    state = env.reset()
    done = False
    steps = 0

    while not done and steps < max_steps_per_episode:
        pygame.event.pump()  # Prevents UI from freezing

        state_idx = tuple(state)  # Convert state to tuple for Q-table indexing

        # Epsilon-greedy action selection
        if random.uniform(0, 1) < epsilon:
            action_idx = random.choice(range(len(action_space)))
        else:
            action_idx = np.argmax(q_table[state_idx])

        action = action_space[action_idx]
        new_state, reward, done = env.step(action)
        env.render(action, reward)  # Keep UI responsive

        new_state_idx = tuple(new_state)
        best_next_action = np.argmax(q_table[new_state_idx])

        # Q-learning update
        q_table[state_idx][action_idx] += alpha * (reward + gamma * q_table[new_state_idx][best_next_action] - q_table[state_idx][action_idx])

        state = new_state
        steps += 1

    epsilon = max(epsilon * epsilon_decay, epsilon_min)  # Decay exploration

print("Training finished!")

Training Progress:   6%|▌         | 303/5000 [12:44<4:07:08,  3.16s/it] 

In [None]:
# Keep UI Open
running = True
while running:
    pygame.event.pump()
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            running = False

pygame.quit()  # Properly exit Pygame