## This code provides a neural network (CNN) solution of the snake game


Importing useful functions

In [2]:
from snake_functions import game,draw,next

In [61]:
from snake_resol import find_shortest_path,find_food

Game dimensions

In [3]:
n,m=8,8

Package installation

In [None]:
!pip install tensorflow keras


CNN definition

In [2]:
import tensorflow as tf
import tensorflow.keras.layers as layers
import tensorflow.keras.models as models

# Example with n=8, m=8 (input_shape can be adjusted to (5, 5, 3) if needed)
input_shape = (n, m, 3)

model = models.Sequential()

# First convolutional layer
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape, padding='same'))
model.add(layers.MaxPooling2D((2, 2)))

# Second convolutional layer
model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same'))

# Global average pooling (instead of MaxPooling to adapt to smaller input sizes)
model.add(layers.GlobalAveragePooling2D())

# Fully connected layer
model.add(layers.Dense(64, activation='relu'))

# Output layer (4 neurons for 4 possible directions)
model.add(layers.Dense(4, activation='softmax'))

# Compile the model
model.compile(optimizer='adam', 
              loss='categorical_crossentropy', 
              metrics=['accuracy'])

# Summary of the model
model.summary()


  super().__init__(


Redefining Game class to get a reward feedback (useful to avoid infinite movement loops) and highlighting snake head in draw function 

In [40]:
# Game class representing the Snake game state
class Game:
    def __init__(self, tab, dir, snake_list, score):
        self.tab = tab  # The game grid (table) where the snake and food are placed
        self.dir = dir  # The current direction of the snake (d1, d2)
        self.snake_list = snake_list  # List of coordinates representing the snake's body
        self.score = score  # The player's current score

    
    def update(self):
        "Method that updates the game state (snake movement, food generation, collision detection)"
        d1, d2 = self.dir  # The direction of the snake (change in x and y)
        current = list.copy(self.snake_list)  # Make a copy of the snake to track its previous state
        reward = -10  # Default reward (negative for game over, increased when food is eaten)

        # Check if there's food on the grid, and if not, place new food randomly
        if not any("F" in self.tab[i] for i in range(n)):
            f1, f2 = random.randint(0, n-1), random.randint(0, m-1)
            # Ensure that the food is not placed on the snake
            while (f1, f2) in self.snake_list:
                f1, f2 = random.randint(0, n-1), random.randint(0, m-1)
            self.tab[f1, f2] = "F"  # Place food ('F') on the grid

        # Loop through each segment of the snake
        for i in range(len(self.snake_list)):
            x, y = self.snake_list[i]  # Get current snake segment's coordinates

            if i == 0:  # Head of the snake (first segment)
                d1,d2=self.dir
                xf, yf = x+d1,y+d2  # Get the next position of the head based on the direction

                # Check for collision with snake body or wall
                if xf not in range(n) or yf not in range(m) or self.tab[xf, yf] == "S"  :
                    #print("Game Over! Score: " + str(self.score))  # Game over if collision occurs
                    return True, -10  # Return game over flag and penalty reward

                elif self.tab[xf, yf] == "F":  # Check if the head eats food
                    a, b = self.snake_list[-1]  # Get the last segment of the snake (tail)
                    # Extend the snake by adding a new segment at the tail's previous position
                    if a+d1 in range(n) and b+d2 in range(m):
                        self.snake_list.append((a+d1,b+d2))
                    # Place new food randomly after eating
                    f1, f2 = random.randint(0, n-1), random.randint(0, m-1)
                    while (f1, f2) in self.snake_list:
                        f1, f2 = random.randint(0, n-1), random.randint(0, m-1)
                    self.tab[f1, f2] = "F"
                    self.score += 10  # Increase the score for eating food
                    reward = 10  # Reward for eating food

                # Update the snake's head position
                self.snake_list[i] = xf, yf  # Move the head to the new position
                self.tab[xf, yf] = "S"  # Mark the new head position on the grid
                self.tab[x, y] = "X"  # Mark the previous head position as visited

                # Loop through the grid to find the food and assign a slight penalty (-0.1) for not eating it yet
                for k in range(n):
                    for l in range(m):
                        if self.tab[k, l] == "F":
                            f1, f2 = k, l
                            reward = -0.1

            else:  # Body of the snake (all other segments)
                xprev, yprev = current[i - 1]  # Move each body segment to the position of the segment ahead
                self.snake_list[i] = xprev, yprev
                self.tab[xprev, yprev] = "S"  # Mark the new body segment position
                self.tab[x, y] = "X"  # Mark the previous body segment position as visited

        # Check if there's food on the grid, and if not, place new food randomly
        if not any("F" in self.tab[i] for i in range(n)):
            f1, f2 = random.randint(0, n-1), random.randint(0, m-1)
            # Ensure that the food is not placed on the snake
            while (f1, f2) in self.snake_list:
                f1, f2 = random.randint(0, n-1), random.randint(0, m-1)
            self.tab[f1, f2] = "F"  # Place food ('F') on the grid
        
        return False, reward  # Return no game over and the current reward



def draw(game):
    "Function to draw the game on the screen using pygame"
    SURF.fill(gris_clair)  # Fill the screen with a light gray background

    # Loop through the grid to draw snake, food, and empty cells
    for j in range(m):
        for i in range(n):
            if game.tab[i, j] == "S":  # Draw the snake
                
                if (i,j)!=game.snake_list[0]:
                    clr=noir
                else: 
                    clr=(0,255,0)
                pg.draw.rect(SURF, clr, [x0 + i * (x1 - x0) / (n - 1), y0 + (m - 1 - j) * (y1 - y0) / (m - 1), (x1 - x0) / n, (y1 - y0) / m])
            elif game.tab[i, j] == "F":  # Draw the food
                pg.draw.rect(SURF, rouge, [x0 + i * (x1 - x0) / (n - 1), y0 + (m - 1 - j) * (y1 - y0) / (m - 1), (x1 - x0) / n, (y1 - y0) / m])
            else:  # Draw empty cells
                pg.draw.rect(SURF, gris_fonce, [x0 + i * (x1 - x0) / (n - 1), y0 + (m - 1 - j) * (y1 - y0) / (m - 1), (x1 - x0) / n, (y1 - y0) / m])

    # Render and display the score on the screen
    img = font.render("Score: " + str(game.score), True, noir)
    SURF.blit(img, (900, 500))  # Display the score at the specified position
    pg.display.update()  # Update the display with the new drawing


### Main training loop 

In [None]:

import numpy as np
import random

num_episodes=1000000


#Generating random game states and labeling 'good' move (here we take the move that reduces the distance to the food)

# Game State
X=np.zeros((num_episodes,n,m,3))

#Labels
Y=np.zeros((num_episodes,4))


for i in range(num_episodes):
    
    #Random snake head
    o1,o2=np.random.randint(n),np.random.randint(m)
    X[i,o1,o2,0]=1
    
    # Extending the snake by 1 in some random direction
    d1,d2=random.choice([(0,1),(1,0),(-1,0),(0,-1)])
    while(o1+d1 not in range(n) or o2+d2 not in range(m)):
        d1,d2=random.choice([(0,1),(1,0),(-1,0),(0,-1)])
    X[i,o1+d1,o2+d2,1]=1
    l=[(o1,o2),(o1+d1,o2+d2)]
    
    #Finding empty spot to put the food cell in
    f1,f2=np.random.randint(n),np.random.randint(m)
    while((f1,f2) in l):
        f1,f2=np.random.randint(n),np.random.randint(m)
    X[i,f1,f2,2]=1
    
    #Finding the heuristically best direction and labelling the good direction as this one
    possible_directions=[(0,1),(1,0),(-1,0),(0,-1)]
    possible_directions=[(f1-(o1+x))**2+ (f2-(o2+y))**2 if (o1+x) in range(n) and (o2+y) in range(m) and X[i,o1+x,o2+y,1]!=1 else 1000 for (x,y) in possible_directions]
    good_action=np.argmin(possible_directions)
    
    
    Y[i,good_action]=1
    


#Fitting the randomly generated data
model.fit(X, Y, epochs=1, batch_size=1)

    
    
    

In [23]:
def fit_fun(games,num_episodes):
    "X of dim : num_episodes n x m x 4"
    "Function used to fit the model over the games in games"
    labels=np.zeros((num_episodes,4))
    # Over all the games in games
    for k in range(num_episodes):
        
        #Finding the food item 
        for i in range(n):
            for j in range(m):
                if games[k,i,j,2]==1:
                    f1,f2=i,j
                    
        # Finding the heuristically best option for current gamestate
        possible_directions=[(0,1),(1,0),(-1,0),(0,-1)]
        possible_directions=[(f1-(o1+x))**2+ (f2-(o2+y))**2 if (o1+x) in range(n) and (o2+y) in range(m) and games[k,o1+x,o2+y,1]!=1 else 1000 for (x,y) in possible_directions]
        good_action=np.argmin(possible_directions)
        labels[k,good_action]=1
        
    #Fitting 
    model.fit(games,labels)

In [18]:
import pygame as pg
# Initialize Pygame
pg.init()  # Initialize all the imported pygame modules
SURF = pg.display.set_mode((1450, 1000))  # Set up the game window with resolution 1450x1000
font = pg.font.SysFont(None, 30)  # Initialize the font to display the score

# Colors used in the game
gris_clair = (220, 220, 220)  # Light gray
gris_fonce = (150, 150, 150)  # Dark gray
noir = (0, 0, 0)  # Black
rouge = (255, 0, 0)  # Red

# Initial coordinates for the game area
x0, y0 = 200, 150  # Top-left corner of the game grid
x1, y1 = 800, 750  # Bottom-right corner of the game grid


pygame 2.5.2 (SDL 2.28.3, Python 3.12.3)
Hello from the pygame community. https://www.pygame.org/contribute.html




Visualising Snake and fine tuning network

In [None]:
import tqdm  #For loading bar


num_games=10000

# To transform string game matrix to cnn input
def gametab_to_input(snake_list,food):
    
    # First dimension is to be able to fit into CNN as it requires four dimensions (first one corresponds to the number of games)
    res=np.zeros((1,n,m,3))
    f1,f2=food
    x,y=snake_list[0]
    res[0,x,y,0]=1
    for i,j in snake_list[1:]:
        res[0,i,j,1]=1
    res[0,f1,f2,2]=1
    return res


# States where the net predicted the wrong ouput
correction_list=[]

#Keeping track of scores
score_list=[]
for i in tqdm.tqdm(range(num_games)):
    
    # Declaring new random game
    new_tab=np.full((n,m),'X')
    
    #Random snake head and body
    x,y=np.random.randint(n),np.random.randint(m)
    poss_dir=[(0,1),(1,0),(-1,0),(0,-1)]
    d1,d2=random.choice(poss_dir)
    while(x-d1 not in range(n) or y-d2 not in range(m)):
        d1,d2=random.choice(poss_dir)
    new_snake=[(x,y),(x-d1,y-d2)]
    new_tab[x,y]='S'
    new_tab[x-d1,y-d2]='S'
    
    #Random food cell
    f1,f2=np.random.randint(n),np.random.randint(m)
    while(new_tab[f1,f2]=='S'):
        f1,f2=np.random.randint(n),np.random.randint(m)
    new_tab[f1,f2]='F'
    
    #New gamestate
    new_game=Game(new_tab,(d1,d2),new_snake,0)
    
    done=False
    timer=0
    
    while(not done):
        j=len(correction_list)
        #10 errors found (or 10 cases where game ended due to network missclassification)
        if j==10:
            
            #Fitting to adjust network output
            X=np.zeros((j,n,m,3))
            for l in range(j):
                X[l,:,:,:]=correction_list[l]
            fit_fun(X,j)
            correction_list=[]
        timer+=1
        
        #snake head
        x,y=new_game.snake_list[0]
        
        # If we get an 'infinite loop', we choose some random other action to get out of it
        if timer>100:
            
            #Choosing a random direction
            good_dir=[(o1,o2) for (o1,o2) in poss_dir if (o1+x) in range(n) and (o2+y) in range(m) and new_game.tab[x+o1,y+o2]!='S']
            some_dir=random.choice(good_dir)
            action=poss_dir.index(some_dir)
            done=True
            
        else:    
            
            #Finding food cell
            for i in range(n):
                for j in range(m):
                    if new_game.tab[i,j]=='F':
                        f1,f2=i,j
                        
            #Finding predicted action for current gamestate
            cnn_input=gametab_to_input(new_game.snake_list,(f1,f2))
            action=model.predict(gametab_to_input(new_game.snake_list,(f1,f2)),verbose=0)
        action=np.argmax(action)
        new_dir=poss_dir[action]
        new_game.dir=new_dir
        
        #Drawing gamestate
        draw(new_game)
        
        
        done,reward=new_game.update()
        
        # Game Over 
        if done==True:    
            correction_list.append(cnn_input)
            score_list.append(new_game.score)
        
        #This means that the snake found food ; we reset the timer because we're certain that
        #we're not in an infinite loop
        if reward>0:
            
            timer=0
            
        #Uncomment this to get slower game
        #pg.time.wait(1)
        
        
        
        
    

Saving the model

In [None]:
import torch
model.save_weights('/Users/Riyad/Documents/Perso_projects/Snake/Snake/model.weights.h5'.format(epoch=0))