## This code provides a neural network (Supervised learning / Q-Network) solution of the snake game


Importing useful functions

In [None]:
from snake_functions import game,draw,next

Useful for the supervised learning part

In [None]:
from snake_resol import find_shortest_path

Game dimensions

In [None]:
n,m=8,8

Redefining some functions from 'snake_functions' to gain info for the qnet

In [None]:
# Game class representing the Snake game state
class Game:
    # Constructor method that initializes the game attributes
    def __init__(self, tab, dir, snake_list, score):
        self.tab = tab  # The game grid (table) where the snake and food are placed
        self.dir = dir  # The current direction of the snake (d1, d2)
        self.snake_list = snake_list  # List of coordinates representing the snake's body
        self.score = score  # The player's current score

    # Method that updates the game state (snake movement, food generation, collision detection)
    def update(self):
        d1, d2 = self.dir  # The direction of the snake (change in x and y)
        current = list.copy(self.snake_list)  # Make a copy of the snake to track its previous state
        reward = -10  # Default reward (negative for game over, increased when food is eaten)

        # Check if there's food on the grid, and if not, place new food randomly
        if not any("F" in self.tab[i] for i in range(n)):
            f1, f2 = random.randint(0, n-1), random.randint(0, m-1)
            # Ensure that the food is not placed on the snake
            while (f1, f2) in self.snake_list:
                f1, f2 = random.randint(0, n-1), random.randint(0, m-1)
            self.tab[f1, f2] = "F"  # Place food ('F') on the grid

        # Loop through each segment of the snake
        for i in range(len(self.snake_list)):
            x, y = self.snake_list[i]  # Get current snake segment's coordinates

            if i == 0:  # Head of the snake (first segment)
                xf, yf = next_position(x, y, d1, d2)  # Get the next position of the head based on the direction

                # Check for collision with snake body or wall
                if self.tab[xf, yf] == "S" or (x + d1 not in range(n) or y + d2 not in range(m)):
                    print("Game Over! Score: " + str(self.score))  # Game over if collision occurs
                    return True, -10  # Return game over flag and penalty reward

                elif self.tab[xf, yf] == "F":  # Check if the head eats food
                    a, b = self.snake_list[-1]  # Get the last segment of the snake (tail)
                    # Extend the snake by adding a new segment at the tail's previous position
                    self.snake_list.append(next_position(a, b, -d1, -d2))
                    # Place new food randomly after eating
                    f1, f2 = random.randint(0, n-1), random.randint(0, m-1)
                    while (f1, f2) in self.snake_list:
                        f1, f2 = random.randint(0, n-1), random.randint(0, m-1)
                    self.tab[f1, f2] = "F"
                    self.score += 10  # Increase the score for eating food
                    reward = 10  # Reward for eating food

                # Update the snake's head position
                xn, yn = next_position(x, y, d1, d2)
                self.snake_list[i] = xn, yn  # Move the head to the new position
                self.tab[xn, yn] = "S"  # Mark the new head position on the grid
                self.tab[x, y] = "X"  # Mark the previous head position as visited

                # Loop through the grid to find the food and assign a slight penalty (-0.1) for not eating it yet
                for k in range(n):
                    for l in range(m):
                        if self.tab[k, l] == "F":
                            f1, f2 = k, l
                            reward = -0.1

            else:  # Body of the snake (all other segments)
                xprev, yprev = current[i - 1]  # Move each body segment to the position of the segment ahead
                self.snake_list[i] = xprev, yprev
                self.tab[xprev, yprev] = "S"  # Mark the new body segment position
                self.tab[x, y] = "X"  # Mark the previous body segment position as visited

        return False, reward  # Return no game over and the current reward

# Function to calculate the next position of the snake based on the direction
def next_position(x, y, d1, d2):
    xf, yf = x, y  # Current coordinates
    # Handle grid wrapping if the snake goes off the edge (teleport to opposite side)
    if x + d1 == n:
        xf = 0
    elif x + d1 == -1:
        xf = n - 1
    if y + d2 == m:
        yf = 0
    elif y + d2 == -1:
        yf = m - 1
    # Update the coordinates based on direction
    if xf == x:
        xf += d1
    if yf == y:
        yf += d2
    return xf, yf  # Return the next position of the snake

# Function to draw the game on the screen using pygame
def draw(game):
    SURF.fill(gris_clair)  # Fill the screen with a light gray background

    # Loop through the grid to draw snake, food, and empty cells
    for j in range(m):
        for i in range(n):
            if game.tab[i, j] == "S":  # Draw the snake
                pg.draw.rect(SURF, noir, [x0 + i * (x1 - x0) / (n - 1), y0 + (m - 1 - j) * (y1 - y0) / (m - 1), (x1 - x0) / n, (y1 - y0) / m])
            elif game.tab[i, j] == "F":  # Draw the food
                pg.draw.rect(SURF, rouge, [x0 + i * (x1 - x0) / (n - 1), y0 + (m - 1 - j) * (y1 - y0) / (m - 1), (x1 - x0) / n, (y1 - y0) / m])
            else:  # Draw empty cells
                pg.draw.rect(SURF, gris_fonce, [x0 + i * (x1 - x0) / (n - 1), y0 + (m - 1 - j) * (y1 - y0) / (m - 1), (x1 - x0) / n, (y1 - y0) / m])

    # Render and display the score on the screen
    img = font.render("Score: " + str(game.score), True, noir)
    SURF.blit(img, (900, 500))  # Display the score at the specified position
    pg.display.update()  # Update the display with the new drawing


Network definition

In [None]:
# Import necessary libraries
import pygame as pg  # Library used for creating the game visuals and handling input
import random  # Standard library for generating random numbers
import numpy as np  # Library for numerical computations (arrays, matrices)
import torch  # PyTorch library for building neural networks and handling tensors
import torch.nn as nn  # Submodule for creating neural network layers and models
import torch.optim as optim  # Submodule for optimization algorithms like Adam
from collections import deque  # Data structure for efficient memory storage (double-ended queue)
import sys  # Standard library used for system-specific parameters and functions

# Game constants
#n, m = 10, 10  # Size of the game grid (n x m cells)

# Initialize Pygame
pg.init()  # Initialize all the imported pygame modules
SURF = pg.display.set_mode((1450, 1000))  # Set up the game window with resolution 1450x1000
font = pg.font.SysFont(None, 30)  # Initialize the font to display the score

# Colors used in the game
gris_clair = (220, 220, 220)  # Light gray
gris_fonce = (150, 150, 150)  # Dark gray
noir = (0, 0, 0)  # Black
rouge = (255, 0, 0)  # Red

# Initial coordinates for the game area
x0, y0 = 200, 150  # Top-left corner of the game grid
x1, y1 = 800, 750  # Bottom-right corner of the game grid

# Reinforcement Learning (RL) constants
GAMMA = 0.95  # Discount factor for future rewards
EPSILON = 1.0  # Exploration rate (start with 100% exploration)
EPSILON_DECAY = 0.99  # Decay rate of epsilon to gradually favor exploitation
MIN_EPSILON = 0.01  # Minimum exploration rate
LEARNING_RATE = 0.001  # Learning rate for the optimizer
BATCH_SIZE = 1000  # Size of the batch used for training the neural network
MEMORY_SIZE = 100_000  # Maximum size of the memory (replay buffer)

# DQN (Deep Q-Network) model definition
class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        x = 512  # Hidden layer size
        self.fc1 = nn.Linear(input_dim, x)  # Input layer connected to the hidden layer
        self.fc2 = nn.Linear(x, output_dim)  # Hidden layer connected to the output layer

    def forward(self, x):
        x = torch.relu(self.fc1(x))  # Apply ReLU activation function to the hidden layer
        return self.fc2(x)  # Output layer (no activation, as Q-values are raw scores)

# RL Agent class that interacts with the environment
class Agent:
    def __init__(self, state_size, action_size, device='mps'):
        self.device = device  # Device (e.g., 'cpu' or 'cuda')
        self.state_size = state_size  # Size of the state representation (input to the network)
        self.action_size = action_size  # Number of possible actions (output of the network)
        self.memory = deque(maxlen=MEMORY_SIZE)  # Replay buffer to store experiences
        self.gamma = GAMMA  # Discount factor for future rewards
        self.epsilon = EPSILON  # Initial exploration rate
        self.epsilon_min = MIN_EPSILON  # Minimum exploration rate
        self.epsilon_decay = EPSILON_DECAY  # Rate at which exploration decays
        self.model = DQN(state_size, action_size)  # The Q-network model
        self.optimizer = optim.Adam(self.model.parameters(), lr=LEARNING_RATE)  # Optimizer (Adam)
        self.criterion = nn.MSELoss()  # Loss function (Mean Squared Error)

    # Method to store an experience in the replay buffer
    def memorize(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))  # Append experience to memory

    # Method to choose an action based on the current state (epsilon-greedy policy)
    def act(self, state):
        if np.random.rand() <= self.epsilon:  # Explore: choose random action
            return random.randrange(self.action_size)
        state = torch.FloatTensor(state).unsqueeze(0)  # Convert state to tensor and add batch dimension
        act_values = self.model(state)  # Get Q-values for each action
        return torch.argmax(act_values[0]).item()  # Choose action with the highest Q-value

    # Method to replay experiences and train the network
    def replay(self, method):
        if len(self.memory) < BATCH_SIZE:  # If there aren't enough experiences in memory, return
            return

        # Sample a random batch of experiences from memory
        minibatch = random.sample(self.memory, BATCH_SIZE)
        for state, action, reward, next_state, done in minibatch:
            
            # Q-net part
            if method == 'unsupervised':
                target = reward  # Start with the immediate reward
                if not done:  # If the game is not over, calculate future reward
                    next_state = torch.FloatTensor(next_state).unsqueeze(0)
                    target = reward + self.gamma * torch.max(self.model(next_state)[0]).item()
                
                state = torch.FloatTensor(state).unsqueeze(0)  # Convert state to tensor
                target_f = self.model(state)  # Predict Q-values
                target_f[0][action] = target  # Update the Q-value for the chosen action

                # Perform backpropagation to update the model
                self.optimizer.zero_grad()
                loss = self.criterion(target_f, self.model(state))
                loss.backward()
                self.optimizer.step()

            else: # Supervised learning part
                tab_l = state[:n * m]  # Extract the grid from the state
                tab_l = np.array(tab_l)
                tab = tab_l.reshape((n, m))  # Reshape the flattened grid into n x m

                snake_list = []  # List to store snake positions
                vals = {}
                
                # Finding the positions corresponding to the snake/food/snake head
                for i in range(n):
                    for j in range(m):
                        if tab[i, j] != 0:
                            if tab[i, j] not in vals.keys():
                                vals[tab[i, j]] = [(i, j)]
                            else:
                                vals[tab[i, j]].append((i, j))
                            
                vals = dict(sorted(vals.items()))  # Sort snake parts in order
                
                # Construct the snake based on the IDs
                for i, couple in enumerate(vals.items()):
                    k, l = couple
                    print(i,couple)
                    if i == 0:  # Body of the snake
                        for x, y in l:
                            snake_list.append((x, y))
                    elif i == 2:  # Head of the snake
                        snake_list.insert(0, l[0])
                    elif i == 1:  # Food
                        f1, f2 = l[0]

                # Placeholder for further logic related to game state and pathfinding
                self.optimizer.zero_grad()
                x, y = state[n * m:n * m + 3] / np.linalg.norm(state[n * m:n * m + 3])  # Normalize direction
                state = torch.FloatTensor(state).unsqueeze(0)

                possible_directions = [(1, 0), (0, 1), (-1, 0), (0, -1)]  # List of possible movement directions
                s1, s2 = snake_list[0]  # Get the head of the snake
                possible_spots = [(s1 + i, s2 + j) if (s1 + i) in range(n) and (s2 + j) in range(m) and (s1 + i, s2 + j) not in snake_list else (1000, 1000) for (i, j) in possible_directions]  # Determine possible next moves
                
                # Select the best action based on the distance to the food
                action = np.argmin([np.sqrt((i - f1) ** 2 + (j - f2) ** 2) for (i, j) in possible_spots])

                actionl = [0 if i != action else 1 for i in range(4)]  # Create one-hot encoding for action
                actionl = torch.FloatTensor(actionl)
                loss = self.criterion(actionl, self.model(state)[0])  # Calculate loss
                loss.backward()  # Perform backpropagation
                self.optimizer.step()  # Update model parameters

        # Decay epsilon to reduce exploration over time
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay  # Multiply epsilon by decay factor


Training loop (a pygame window pops up where you can see training)

In [None]:
def state_to_numeric(state, snakelist):
    """Convert the game state (with string elements) to a numerical representation.
    
    Args:
        state: A 2D array representing the game grid with 'S' for the snake, 'F' for food, and 'X' for empty spaces.
        snakelist: A list of tuples representing the snake's position on the grid.

    Returns:
        numeric_state: A 2D array where:
            - 0 represents empty space,
            - 1 represents the snake's body,
            - 2 represents the food,
            - 3 represents the snake's head.
    """
    numeric_state = np.zeros((n, m))  # Create an empty grid of zeros
    for i in range(n):
        for j in range(m):
            if state[i, j] == "S":  # Snake body is marked as 1
                numeric_state[i, j] = 1
            elif state[i, j] == "F":  # Food is marked as 2
                numeric_state[i, j] = 2
            elif state[i, j] == "X":  # Empty space remains 0
                numeric_state[i, j] = 0

    # Set the head of the snake to 3 (differentiate from the body)
    x, y = snakelist[0]  # Snake's head is the first element in the snake list
    numeric_state[x, y] = 3
    return numeric_state

def numeric_to_state(state):
    """Convert a numeric grid back to the original string representation.
    
    Args:
        state: A 2D array of numerical values representing the game state.

    Returns:
        real_state: A 2D array where:
            - 'S' represents the snake,
            - 'F' represents the food,
            - 'X' represents empty space.
    """
    real_state = np.full((n, m), 'X')  # Initialize the grid with 'X' for empty spaces
    vals = {}
    
    # Populate `vals` with the positions of non-zero values
    for i in range(n):
        for j in range(m):
            if state[i, j] != 0:
                if state[i, j] not in vals.keys():
                    vals[state[i, j]] = [(i, j)]  # Assign positions to each unique value
                else:
                    vals[state[i, j]].append((i, j))

    vals = dict(sorted(vals.items()))  # Sort values to maintain correct order for snake and food
    current = 0
    replace_list = ['S', 'F', 'S']  # Replacement order: Snake, Food, Snake
    
    # Replace numeric values with corresponding characters
    for k, l in vals.items():
        for i, j in l:
            real_state[i, j] = replace_list[current]
        current += 1

    return real_state

# List to store rewards from each episode
reward_list = []

def main():
    """Main function to initialize and run the game with reinforcement learning."""
    
    # Initialize the game state (using your `Game` class)
    state_size = n * m + 2  # The state size is the number of cells (n*m) + 2 (for direction vector)
    action_size = 4  # 4 possible actions: move up, down, left, or right
    agent = Agent(state_size, action_size)  # Create the RL agent
    episodes = 100000  # Number of episodes for training

    for e in range(episodes):
        # Create a new game environment
        env_tab = np.full((n, m), "X")  # Initialize the game grid with 'X' (empty spaces)
        env_dir = (0, 1)  # Snake starts moving to the right
        x, y = np.random.randint(0, n-1), np.random.randint(0, m-1)  # Random starting position for snake
        d1, d2 = env_dir  # Snake's direction
        env_snake = [(x, y), (x - d1, y - d2), (x - 2 * d1, y - 2 * d2)]  # Initialize snake with 3 segments
        
        # Place food at a random position
        f1, f2 = random.randint(0, n-1), random.randint(0, m-1)
        env_tab[f1, f2] = "F"  # Set food on the grid

        score = 0  # Initialize score
        game_state = Game(env_tab, env_dir, env_snake, score)  # Initialize the game state

        # Convert the current game state to a numerical representation
        numeric_state = state_to_numeric(game_state.tab, game_state.snake_list)
        state = np.concatenate((numeric_state.reshape(n*m), list(env_dir)))  # Flatten the state and append direction
        state = state / (1 if np.linalg.norm(state) == 0 else np.linalg.norm(state))  # Normalize the state vector
        done = False  # Flag to indicate if the game is over
        reward_epoch = 0  # Initialize the reward for the current episode

        while not done:
            action = agent.act(state)  # Choose an action using the agent's policy
            old_dir = game_state.dir  # Save the old direction

            # Update the direction based on the chosen action
            if action == 0:  # Turn left
                game_state.dir = (1, 0)
            elif action == 1:  # Turn right
                game_state.dir = (0, 1)
            elif action == 2:  # Move up
                game_state.dir = (-1, 0)
            else:  # Move down
                game_state.dir = (0, -1)

            # Update the game state and get the reward
            done, reward = game_state.update()
            reward_epoch += reward  # Accumulate reward for the episode
            
            # Convert the new game state to a numeric representation
            numeric_next_state = state_to_numeric(game_state.tab, game_state.snake_list)
            next_state = np.concatenate((numeric_next_state.reshape(n*m), list(game_state.dir)))  # Flatten state
            next_state = next_state / (1 if np.linalg.norm(next_state) == 0 else np.linalg.norm(next_state))  # Normalize
            
            # Store the experience in the agent's memory
            agent.memorize(state, action, reward, next_state, done)
            state = next_state  # Update current state

            draw(game_state)  # Visualize the game
            pg.time.wait(10)  # Add a small delay for smoother rendering

            if done:  # If the game is over, break the loop
                print(f"Game over! Final score: {game_state.score}")
                reward_list.append(reward_epoch)  # Record the reward for this episode
                break

            # Train the agent
            if len(agent.memory) < 50_000:  # Initially use supervised learning
                print('some', len(agent.memory))
                agent.replay('supervised')
            else:  # Switch to unsupervised learning after filling the memory
                agent.replay('unsupervised')

        print(f"Episode {e+1}/{episodes}, Score: {game_state.score}, Epsilon: {agent.epsilon:.2f}")

if __name__ == "__main__":
    main()


Results

In [None]:
import matplotlib.pyplot as plt
plt.plot(reward_list)