**How to Train a Snake AI**

![](./img/snek.jpg)

Here is a guide on how to train an AI to play snake for you by using Deep Q Learning. There are three components: the snake game, the agent that plays the snake game, and the model.

After the snake game has been made, we can start creating an agent and train it to play snake using reinforcement learning or deep q learning. In other words, we will reward the agent depending on how well it is doing and it will try to find the best next action based on the reward. Based on the game state, we calculate our next action based on the model prediction. We train our model using a feed-forward neural network with an input layer, hidden layer, and an output layer.

First, create your own snake game. Create a file for your game and import all necessary modules. Initialize pygame and create a class for the different directions (up, donw, left, right). For the speed, it is recommended to set it to a higher number in order to speed up training. The agent typically starts to improve after 80-100 iterations.

In [1]:
import pygame
import random
from enum import Enum
from collections import namedtuple
import numpy as np

pygame.init()
font = pygame.font.Font('arial.ttf', 25)

class Direction(Enum):
    RIGHT = 1
    LEFT = 2
    UP = 3
    DOWN = 4

Point = namedtuple('Point', 'x, y')

# rgb colors
WHITE = (255, 255, 255)
RED = (200,0,0)
GREEN1 = (0, 155, 0)
GREEN2 = (0, 255, 0)
BLACK = (0,0,0)

BLOCK_SIZE = 20
SPEED = 140

pygame 2.5.2 (SDL 2.28.3, Python 3.11.5)
Hello from the pygame community. https://www.pygame.org/contribute.html


Next, create a class for the snake game. Initialize the display to your desired window size and create methods to reset the game state, randomly place food, to check if the game is over, and to reward the agent depending on if it has gotten food or if it collided/ended the game. 

In [2]:

class SnakeGameAI:
    '''Initializing the game:'''
    def __init__(self, w=240, h=240): #12x12 grid
        self.w = w
        self.h = h
        # init display
        self.display = pygame.display.set_mode((self.w, self.h))
        pygame.display.set_caption('Snake')
        self.clock = pygame.time.Clock()
        self.reset()
    def reset(self):
        # init game state
        self.direction = Direction.RIGHT

        self.head = Point(self.w/2, self.h/2)
        self.snake = [self.head,
                      Point(self.head.x-BLOCK_SIZE, self.head.y),
                      Point(self.head.x-(2*BLOCK_SIZE), self.head.y)]

        self.score = 0
        self.food = None
        self._place_food()
        self.frame_iteration = 0

    '''Randomly place food'''
    def _place_food(self):
        x = random.randint(0, (self.w-BLOCK_SIZE )//BLOCK_SIZE )*BLOCK_SIZE
        y = random.randint(0, (self.h-BLOCK_SIZE )//BLOCK_SIZE )*BLOCK_SIZE
        self.food = Point(x, y)
        if self.food in self.snake:
            self._place_food()

    '''Check for collisions:'''
    def is_collision(self, pt=None):
        if pt is None:
            pt = self.head
        # hits boundary
        if pt.x > self.w - BLOCK_SIZE or pt.x < 0 or pt.y > self.h - BLOCK_SIZE or pt.y < 0:
            return True
        # hits itself
        if pt in self.snake[1:]:
            return True
        return False
    
    '''Game rules and rewards:'''
    def play_step(self, action):
        self.frame_iteration += 1
        # 1. collect user input
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                quit()
        
        # 2. move
        self._move(action) # update the head
        self.snake.insert(0, self.head)
        
        # 3. check if game over
        reward = 0
        game_over = False
        if self.is_collision() or self.frame_iteration > 100*len(self.snake):
            game_over = True
            reward = -10
            return reward, game_over, self.score

        # 4. place new food or just move
        if self.head == self.food:
            self.score += 1
            reward = 10
            self._place_food()
        else:
            self.snake.pop()
        
        # 5. update ui and clock
        self._update_ui()
        self.clock.tick(SPEED)
        # 6. return game over and score
        return reward, game_over, self.score

    def _update_ui(self):
        self.display.fill(BLACK)

        for pt in self.snake:
            pygame.draw.rect(self.display, GREEN1, pygame.Rect(pt.x, pt.y, BLOCK_SIZE, BLOCK_SIZE))
            pygame.draw.rect(self.display, GREEN2, pygame.Rect(pt.x+4, pt.y+4, 12, 12))

        pygame.draw.rect(self.display, RED, pygame.Rect(self.food.x, self.food.y, BLOCK_SIZE, BLOCK_SIZE))

        text = font.render("Score: " + str(self.score), True, WHITE)
        self.display.blit(text, [0, 0])
        pygame.display.flip()
        
    '''Move the snake:'''
    def _move(self, action):
        # [straight, right, left]

        clock_wise = [Direction.RIGHT, Direction.DOWN, Direction.LEFT, Direction.UP]
        idx = clock_wise.index(self.direction)

        if np.array_equal(action, [1, 0, 0]):
            new_dir = clock_wise[idx] # no change
        elif np.array_equal(action, [0, 1, 0]):
            next_idx = (idx + 1) % 4
            new_dir = clock_wise[next_idx] # right turn r -> d -> l -> u
        else: # [0, 0, 1]
            next_idx = (idx - 1) % 4
            new_dir = clock_wise[next_idx] # left turn r -> u -> l -> d

        self.direction = new_dir

        x = self.head.x
        y = self.head.y
        if self.direction == Direction.RIGHT:
            x += BLOCK_SIZE
        elif self.direction == Direction.LEFT:
            x -= BLOCK_SIZE
        elif self.direction == Direction.DOWN:
            y += BLOCK_SIZE
        elif self.direction == Direction.UP:
            y -= BLOCK_SIZE

        self.head = Point(x, y)

Now that the snake game has been made, we can start creating an agent and train it to play snake using reinforcement learning/deep Q learning. A brief explanation for how deep Q learning works is it is basically like trial and error, where the agent learns over time which decisions are better or worse depending on the rewards. For example, the agent is rewarded for eating the apple, so it will learn to go towards the apple. It is also punished (aka given a negative reward) for colliding into itself or the walls, so it will learn to avoid doing this.

In deep Q learning, each Q value should improve the snake performance. We first initialize our model, then we either choose an action predicted as the best move based on observations from the game environment and prior knowledge, or a random move if we don't have enough information yet. 

A state is a representation of the environment observable by the agent. In this case, the state is a vector of 11 binary values (3 danger directions, 4 directions for the current direction the snake is going in, 4 food location directions). For the danger states, the states are set to 0 if there is no danger nearby (walls or snake body) but if it is near danger, the value turns 1). For the snake's direction, it can either be up, down, left, or right. The food location can be a combination of up, down, left, and/or right. 

There are 3 outputs for us to predict which the best action to take is (the reason why this is 3 instead of 4 is because the snake can't go behind itself; it can only go straight, left, or right)

Deep Neural Network Architecture:

![](./img/DNN.png)

In the agent, we want to create these functions:
- a function where we get the state of the game
- a function where we remember the reward and calculate the next best action
- a long term memory
- short term member
- get action based on the state

In [4]:
import torch
import random
import numpy as np
from collections import deque
from game import SnakeGameAI, Direction, Point
from model import Linear_QNet, QTrainer
from helper import plot
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

#can change these if you want to experiment:
MAX_MEMORY = 100_000
BATCH_SIZE = 1000
LR = 0.001 #learning rate

Now we can make a function using all of the functions we made above that trains the agent. If the game is over, it resets the game, increments the game count, trains the agent with long term memory, and updates the record if necessary. It also keeps track of scores and mean scores for plotting.

In [6]:
def train():
    plot_scores = []
    plot_mean_scores = []
    total_score = 0
    record = 0
    agent = Agent()
    game = SnakeGameAI()
    while True:
        # get curr state
        state_curr = agent.get_state(game)

        # get next move based on the state
        next_move = agent.get_action(state_curr)

        # perform move and get new state
        reward, game_over, score = game.play_step(next_move)
        state_new = agent.get_state(game)

        # train short memory
        agent.train_short_memory(state_curr, next_move, reward, state_new, game_over)

        # remember
        agent.remember(state_curr, next_move, reward, state_new, game_over)

        if game_over:
            # train long memory aka replay memory or experience memory
            game.reset()
            agent.n_games += 1
            agent.train_long_memory()

            if score > record:
                record = score
                agent.model.save()

            print('Game', agent.n_games, 'Score', score, 'Record:', record)
            
            #plotting
            plot_scores.append(score) #append score to plot list
            total_score += score #update total score
            mean_score = total_score / agent.n_games
            plot_mean_scores.append(mean_score) #append mean score to plot list
            plot(plot_scores, plot_mean_scores) #plot scores and mean scores

Before we can run the program, we will need to create a class for the model and a class for the trainer.

The model uses a simple forward function that:
    
    1. Applies first linear layer (self.linear1) to the input tensor x.
    
    2. Applies activation function to the output of the first linear layer.
    
    3. Applies second linear layer (self.linear2) to the output of the activation function.
    
    4. Returns the final output tensor.

In [None]:
class Linear_QNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        #x = F.relu(self.linear1(x)) #also good performer
        #x = F.gelu(self.linear1(x))
        x = F.leaky_relu(self.linear1(x)) #best performer
        #x = torch.tanh(self.linear1(x))
        #x = F.sigmoid(self.linear1(x))
        #x = F.softmax(self.linear1(x), dim = 0)
    
        #x = F.elu(self.linear1(x))

        x = self.linear2(x)
        return x

    #saves the model's state dictionary
    def save(self, file_name='model.pth'):
        model_folder_path = './model'
        if not os.path.exists(model_folder_path):
            os.makedirs(model_folder_path)

        file_name = os.path.join(model_folder_path, file_name)
        torch.save(self.state_dict(), file_name)

In our trainer, we use the Bellman equation, which can be simplified to 

![](./img/trainer.png)

What this basically means is if the game is over, our target value Q new is set to equal the immediate reward before the game ended. Else our new Q value is set to the the immediate reward plus the discounted maximum predicted Q value for the next state (aka discount factor (gamma) * predicted Q value of the next state)

We will also use the loss function:

![](./img/loss.png)

Which is basically the mean squared error between the target Q values and the predicted Q values


In [None]:
class QTrainer:
    def __init__(self, model, lr, gamma):
        self.lr = lr
        self.gamma = gamma
        self.model = model
        self.optimizer = optim.Adam(model.parameters(), lr=self.lr) #optimization algo

    def train_step(self, state, action, reward, next_state, gameover):
        state = torch.tensor(state, dtype=torch.float)
        next_state = torch.tensor(next_state, dtype=torch.float)
        action = torch.tensor(action, dtype=torch.long)
        reward = torch.tensor(reward, dtype=torch.float)

        #unsqueeze to match dim
        if len(state.shape) == 1:
            state = torch.unsqueeze(state, 0)
            next_state = torch.unsqueeze(next_state, 0)
            action = torch.unsqueeze(action, 0)
            reward = torch.unsqueeze(reward, 0)
            gameover = (gameover, )

        # 1: predicted Q values with current state
        predQ = self.model(state)

        targetQ = predQ.clone()
        for idx in range(len(gameover)):
            Q_new = reward[idx]
            if not gameover[idx]:
                # Bellman equation:
                Q_new = reward[idx] + self.gamma * torch.max(self.model(next_state[idx]))

            targetQ[idx][torch.argmax(action[idx]).item()] = Q_new

        # Apply the loss function:
        # 2: Q_new = r + y * max(next_predicted Q value) -> only do this if game not over
        self.optimizer.zero_grad()
        loss = nn.MSELoss(targetQ, predQ)
        loss.backward() #back propagation to compute gradients

        self.optimizer.step() #update model parameters based on the computed gradients

Extra potential experiments to improve performance:
- changing the reward values (i.e. giving higher reward for eating apple or higher negative reward for colliding into something)
- testing different activation functions

I tested out more activation functions. A variety of built-in activation functions can be found in torch.nn already. Let's test out GELU first:

![](./img/geluform.png)

In [1]:
#gelu: (manual implementation below, but this is also in torch.nn)
import math
def gelu(x):
  return 0.5 * x * (1 + torch.tanh(math.sqrt(math.pi / 2) * (x + 0.044715 * x ** 3)))

But it looks like GELU is not great at snake game:

![](./img/GELU.png)

And the same goes for many other activation functions other than RELU and LeakyRELU:

Sigmoid:

![](./img/sigmoid.png)

Softmax:

![](./img/softmax.png)

Tanh:

![](./img/tanh.png)

RELU:

![](./img/RELU.png)

LeakyRELU:
![](./img/LeakyRELU.png)

Papers referenced:

A. Sebastianelli, M. Tipaldi, S. L. Ullo and L. Glielmo, "A Deep Q-Learning based approach applied to the Snake game," 2021 29th Mediterranean Conference on Control and Automation (MED), PUGLIA, Italy, 2021, pp. 348-353, doi: 10.1109/MED51440.2021.9480232.

Hendrycks, Dan, and Kevin Gimpel. "Gaussian error linear units (gelus)." arXiv preprint arXiv:1606.08415 (2016).