<h2>1. Building the Environment for the Classic Snake game <h2>

In [None]:
import pygame
import time
import random

# Initialize pygame for some retro gaming fun!
pygame.init()

# Define our color palette
white = (255, 255, 255)
yellow = (255, 255, 102)
black = (0, 0, 0)
red = (213, 50, 80)
green = (0, 255, 0)
blue = (50, 153, 213)
purple = (138, 43, 226)

# Display settings for the game window
dis_width = 600
dis_height = 400

# Create the game display window
dis = pygame.display.set_mode((dis_width, dis_height))
pygame.display.set_caption('Snake Game')

# Game settings for speed and size
clock = pygame.time.Clock()
snake_block = 10
snake_speed = 15

# Load fonts for displaying text
font_style = pygame.font.SysFont('bahnschrift', 25)
score_font = pygame.font.SysFont('comicsansms', 35)
title_font = pygame.font.SysFont('comicsansms', 50)

# Function to draw the snake on the screen
def our_snake(snake_block, snake_list):
    for x in snake_list:
        pygame.draw.rect(dis, black, [x[0], x[1], snake_block, snake_block])

# Function to display the player's score
def your_score(score):
    value = score_font.render('Your Score: ' + str(score), True, yellow)
    dis.blit(value, [10, 10])

# Function to display messages on the screen
def message(msg, color, pos):
    mesg = font_style.render(msg, True, color)
    dis.blit(mesg, pos)

# Fun title animation for the start screen
def title_animation():
    for i in range(255, -1, -5):
        dis.fill(blue)
        title = title_font.render('Snake Game', True, (255, i, 0))
        dis.blit(title, [dis_width / 2 - 130, dis_height / 4])
        pygame.display.update()
        clock.tick(20)
    time.sleep(1)

# Start screen with instructions
def start_screen():
    title_animation()
    dis.fill(blue)
    message('Press any key to Start', purple, [dis_width / 2 - 120, dis_height / 2])
    pygame.display.update()
    waiting = True
    while waiting:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                quit()
            if event.type == pygame.KEYDOWN:
                waiting = False

# The main game loop where the magic happens!
def gameLoop():
    start_screen()

    game_over = False
    game_close = False

    # Initial position of the snake
    x1 = dis_width / 2
    y1 = dis_height / 2

    # Initial direction (no movement)
    x1_change = 0
    y1_change = 0

    snake_list = []
    length_of_snake = 1

    # Spawn the first piece of food
    foodx = round(random.randrange(0, dis_width - snake_block) / 10.0) * 10.0
    foody = round(random.randrange(0, dis_height - snake_block) / 10.0) * 10.0

    while not game_over:

        while game_close:
            dis.fill(blue)
            message('Game Over!', red, [dis_width / 2 - 100, dis_height / 4])
            your_score(length_of_snake - 1)
            message('Press C-Play Again or Q-Quit', white, [dis_width / 6, dis_height / 1.5])
            pygame.display.update()

            for event in pygame.event.get():
                if event.type == pygame.KEYDOWN:
                    if event.key == pygame.K_q:
                        game_over = True
                        game_close = False
                    if event.key == pygame.K_c:
                        # Reset game variables to start a new game
                        x1 = dis_width / 2
                        y1 = dis_height / 2
                        x1_change = 0
                        y1_change = 0
                        snake_list = []
                        length_of_snake = 1
                        foodx = round(random.randrange(0, dis_width - snake_block) / 10.0) * 10.0
                        foody = round(random.randrange(0, dis_height - snake_block) / 10.0) * 10.0
                        game_close = False

        # Handling arrow key presses for snake movement
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                game_over = True
            if event.type == pygame.KEYDOWN:
                if event.key == pygame.K_LEFT:
                    x1_change = -snake_block
                    y1_change = 0
                elif event.key == pygame.K_RIGHT:
                    x1_change = snake_block
                    y1_change = 0
                elif event.key == pygame.K_UP:
                    y1_change = -snake_block
                    x1_change = 0
                elif event.key == pygame.K_DOWN:
                    y1_change = snake_block
                    x1_change = 0

        # Check if the snake hits the wall or itself (game over)
        if x1 >= dis_width or x1 < 0 or y1 >= dis_height or y1 < 0:
            game_close = True
        x1 += x1_change
        y1 += y1_change
        dis.fill(blue)
        pygame.draw.rect(dis, green, [foodx, foody, snake_block, snake_block])  # Draw the food
        snake_head = [x1, y1]
        snake_list.append(snake_head)
        if len(snake_list) > length_of_snake:
            del snake_list[0]

        for x in snake_list[:-1]:
            if x == snake_head:
                game_close = True

        our_snake(snake_block, snake_list)
        your_score(length_of_snake - 1)

        pygame.display.update()

        # Check if the snake eats the food
        if x1 == foodx and y1 == foody:
            foodx = round(random.randrange(0, dis_width - snake_block) / 10.0) * 10.0
            foody = round(random.randrange(0, dis_height - snake_block) / 10.0) * 10.0
            length_of_snake += 1  # Grow the snake!

        clock.tick(snake_speed)

    pygame.quit()
    quit()

gameLoop()
# Use arrow keys to play and guide your snake!


<h2>2. Training the Model to play the snake game<h2>


In [None]:
import pygame
import random
import numpy as np
from collections import deque
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Initialize pygame
pygame.init()

# Display settings
dis_width = 600
dis_height = 400

# Colors
white = (255, 255, 255)
red = (200, 0, 0)
green = (0, 200, 0)
black = (0, 0, 0)

# Create display
dis = pygame.display.set_mode((dis_width, dis_height))
pygame.display.set_caption('Snake RL Agent')

clock = pygame.time.Clock()

# Snake block size and speed
snake_block = 20
snake_speed = 100  # High speed for training

# Agent parameters
MAX_MEMORY = 100_000
BATCH_SIZE = 1000
LR = 0.001

# Direction enums
class Direction:
    LEFT = 0
    RIGHT = 1
    UP = 2
    DOWN = 3

# Point dataclass for positions
@dataclass
class Point:
    x: int
    y: int

# Neural Network Model
class Linear_QNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Linear_QNet, self).__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
        return x
    
    def save(self, file_name='model.pth'):
        torch.save(self.state_dict(), file_name)

# QTrainer for training the model
class QTrainer:
    def __init__(self, model, lr, gamma):
        self.lr = lr
        self.gamma = gamma
        self.model = model
        self.optimizer = optim.Adam(model.parameters(), lr=self.lr)
        self.criterion = nn.MSELoss()
    
    def train_step(self, state, action, reward, next_state, done):
        state = torch.tensor(np.array(state), dtype=torch.float)
        next_state = torch.tensor(np.array(next_state), dtype=torch.float)
        action = torch.tensor(np.array(action), dtype=torch.long)
        reward = torch.tensor(np.array(reward), dtype=torch.float)

        if len(state.shape) == 1:
            # reshape to (1, x)
            state = torch.unsqueeze(state, 0)
            next_state = torch.unsqueeze(next_state, 0)
            action = torch.unsqueeze(action, 0)
            reward = torch.unsqueeze(reward, 0)
            done = (done, )
        
        # predicted Q values with current state
        pred = self.model(state)
        
        target = pred.clone()
        for idx in range(len(done)):
            Q_new = reward[idx]
            if not done[idx]:
                Q_new = reward[idx] + self.gamma * torch.max(self.model(next_state[idx]))
            
            target[idx][torch.argmax(action[idx]).item()] = Q_new
        
        self.optimizer.zero_grad()
        loss = self.criterion(target, pred)
        loss.backward()
        
        self.optimizer.step()

# Agent class
class Agent:
    def __init__(self):
        self.n_games = 0
        self.epsilon = 0  # randomness
        self.gamma = 0.9  # discount rate
        self.memory = deque(maxlen=MAX_MEMORY)  # popleft()
        self.model = Linear_QNet(11, 256, 3)
        self.trainer = QTrainer(self.model, lr=LR, gamma=self.gamma)
    
    def get_state(self, game):
        head = game.snake[0]
        point_l = Point(head.x - snake_block, head.y)
        point_r = Point(head.x + snake_block, head.y)
        point_u = Point(head.x, head.y - snake_block)
        point_d = Point(head.x, head.y + snake_block)
        
        dir_l = game.direction == Direction.LEFT
        dir_r = game.direction == Direction.RIGHT
        dir_u = game.direction == Direction.UP
        dir_d = game.direction == Direction.DOWN
        
        state = [
            # Danger straight
            (dir_r and game.is_collision(point_r)) or 
            (dir_l and game.is_collision(point_l)) or 
            (dir_u and game.is_collision(point_u)) or 
            (dir_d and game.is_collision(point_d)),
            
            # Danger right
            (dir_u and game.is_collision(point_r)) or 
            (dir_d and game.is_collision(point_l)) or 
            (dir_l and game.is_collision(point_u)) or 
            (dir_r and game.is_collision(point_d)),
            
            # Danger left
            (dir_d and game.is_collision(point_r)) or 
            (dir_u and game.is_collision(point_l)) or 
            (dir_r and game.is_collision(point_u)) or 
            (dir_l and game.is_collision(point_d)),
            
            # Move direction
            dir_l,
            dir_r,
            dir_u,
            dir_d,
            
            # Food location
            game.food.x < game.head.x,  # food left
            game.food.x > game.head.x,  # food right
            game.food.y < game.head.y,  # food up
            game.food.y > game.head.y  # food down
        ]
        
        return np.array(state, dtype=int)
    
    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))  # popleft if MAX_MEMORY is reached
    
    def train_long_memory(self):
        if len(self.memory) > BATCH_SIZE:
            mini_sample = random.sample(self.memory, BATCH_SIZE)  # list of tuples
        else:
            mini_sample = self.memory
        
        states, actions, rewards, next_states, dones = zip(*mini_sample)
        self.trainer.train_step(states, actions, rewards, next_states, dones)
    
    def train_short_memory(self, state, action, reward, next_state, done):
        self.trainer.train_step(state, action, reward, next_state, done)
    
    def get_action(self, state):
        # random moves: tradeoff exploration / exploitation
        self.epsilon = 80 - self.n_games  # adjust epsilon
        final_move = [0, 0, 0]
        if random.randint(0, 200) < self.epsilon:
            move = random.randint(0, 2)
            final_move[move] = 1
        else:
            state0 = torch.tensor(state, dtype=torch.float)
            prediction = self.model(state0)
            move = torch.argmax(prediction).item()
            final_move[move] = 1
        
        return final_move

# Game class
class SnakeGameAI:
    def __init__(self, w=dis_width, h=dis_height):
        self.w = w
        self.h = h
        self.reset()
    
    def reset(self):
        self.direction = Direction.RIGHT
        self.head = Point(self.w/2, self.h/2)
        self.snake = [self.head,
                      Point(self.head.x - snake_block, self.head.y),
                      Point(self.head.x - (2*snake_block), self.head.y)]
        
        self.score = 0
        self.food = None
        self._place_food()
        self.frame_iteration = 0
    
    def _place_food(self):
        x = random.randint(0, (self.w - snake_block) // snake_block) * snake_block
        y = random.randint(0, (self.h - snake_block) // snake_block) * snake_block
        self.food = Point(x, y)
        if self.food in self.snake:
            self._place_food()
    
    def play_step(self, action):
        self.frame_iteration += 1
        # Quit event
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                quit()
        
        # Move
        self._move(action)  # update the head
        self.snake.insert(0, self.head)
        
        # Check if game over
        reward = 0
        game_over = False
        if self.is_collision() or self.frame_iteration > 100*len(self.snake):
            game_over = True
            reward = -10
            return reward, game_over, self.score
        
        # Place new food or just move
        if self.head == self.food:
            self.score += 1
            reward = 10
            self._place_food()
        else:
            self.snake.pop()
        
        # Update UI and clock
        self._update_ui()
        clock.tick(snake_speed)
        
        return reward, game_over, self.score
    
    def is_collision(self, pt=None):
        if pt is None:
            pt = self.head
        # Hits boundary
        if pt.x > self.w - snake_block or pt.x < 0 or pt.y > self.h - snake_block or pt.y < 0:
            return True
        # Hits itself
        if pt in self.snake[1:]:
            return True
        return False
    
    def _update_ui(self):
        dis.fill(black)
        
        for pt in self.snake:
            pygame.draw.rect(dis, green, pygame.Rect(pt.x, pt.y, snake_block, snake_block))
        
        pygame.draw.rect(dis, red, pygame.Rect(self.food.x, self.food.y, snake_block, snake_block))
        
        text = font.render("Score: " + str(self.score), True, white)
        dis.blit(text, [0, 0])
        pygame.display.flip()
    
    def _move(self, action):
        # [straight, right, left]
        clock_wise = [Direction.RIGHT, Direction.DOWN, Direction.LEFT, Direction.UP]
        idx = clock_wise.index(self.direction)
        
        if np.array_equal(action, [1, 0, 0]):
            new_dir = clock_wise[idx]  # no change
        elif np.array_equal(action, [0, 1, 0]):
            next_idx = (idx + 1) % 4
            new_dir = clock_wise[next_idx]  # right turn r -> d -> l -> u
        else:  # [0, 0, 1]
            next_idx = (idx - 1) % 4
            new_dir = clock_wise[next_idx]  # left turn r -> u -> l -> d
        
        self.direction = new_dir
        
        x = self.head.x
        y = self.head.y
        if self.direction == Direction.RIGHT:
            x += snake_block
        elif self.direction == Direction.LEFT:
            x -= snake_block
        elif self.direction == Direction.DOWN:
            y += snake_block
        elif self.direction == Direction.UP:
            y -= snake_block
        
        self.head = Point(x, y)

# Main training loop
def train():
    plot_scores = []
    plot_mean_scores = []
    total_score = 0
    record = 0
    agent = Agent()
    game = SnakeGameAI()
    
    while True:
        # get old state
        state_old = agent.get_state(game)
        
        # get move
        final_move = agent.get_action(state_old)
        
        # perform move and get new state
        reward, done, score = game.play_step(final_move)
        state_new = agent.get_state(game)
        
        # train short memory
        agent.train_short_memory(state_old, final_move, reward, state_new, done)
        
        # remember
        agent.remember(state_old, final_move, reward, state_new, done)
        
        if done:
            # train long memory (experience replay), plot result
            game.reset()
            agent.n_games += 1
            agent.train_long_memory()
            
            if score > record:
                record = score
                agent.model.save()
            
            print('Game', agent.n_games, 'Score', score, 'Record:', record)

# Initialize font
font = pygame.font.SysFont('arial', 25)

if __name__ == '__main__':
    train()


<h2>3. Testing the Model <h2>

In [None]:
import pygame
import random
import numpy as np
from collections import deque
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

pygame.init()

dis_width = 600
dis_height = 400

white = (255, 255, 255)
red = (200, 0, 0)
green = (0, 200, 0)
black = (0, 0, 0)

dis = pygame.display.set_mode((dis_width, dis_height))
pygame.display.set_caption('Snake RL Agent')

clock = pygame.time.Clock()

snake_block = 20
snake_speed = 100  

font = pygame.font.SysFont('arial', 25)

MAX_MEMORY = 100_000
BATCH_SIZE = 1000
LR = 0.001

class Direction:
    LEFT = 0
    RIGHT = 1
    UP = 2
    DOWN = 3

@dataclass
class Point:
    x: int
    y: int

class Linear_QNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Linear_QNet, self).__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
        return x
    
    def save(self, file_name='model.pth'):
        torch.save(self.state_dict(), file_name)

class QTrainer:
    def __init__(self, model, lr, gamma):
        self.lr = lr
        self.gamma = gamma
        self.model = model
        self.optimizer = optim.Adam(model.parameters(), lr=self.lr)
        self.criterion = nn.MSELoss()
    
    def train_step(self, state, action, reward, next_state, done):
        state = torch.tensor(np.array(state), dtype=torch.float)
        next_state = torch.tensor(np.array(next_state), dtype=torch.float)
        action = torch.tensor(np.array(action), dtype=torch.long)
        reward = torch.tensor(np.array(reward), dtype=torch.float)

        if len(state.shape) == 1:
            state = torch.unsqueeze(state, 0)
            next_state = torch.unsqueeze(next_state, 0)
            action = torch.unsqueeze(action, 0)
            reward = torch.unsqueeze(reward, 0)
            done = (done, )
        
        pred = self.model(state)
        target = pred.clone()
        for idx in range(len(done)):
            Q_new = reward[idx]
            if not done[idx]:
                Q_new = reward[idx] + self.gamma * torch.max(self.model(next_state[idx]))
            
            target[idx][torch.argmax(action[idx]).item()] = Q_new
        
        self.optimizer.zero_grad()
        loss = self.criterion(target, pred)
        loss.backward()
        
        self.optimizer.step()
        
class Agent:
    def __init__(self):
        self.n_games = 0
        self.epsilon = 0  
        self.gamma = 0.9  
        self.memory = deque(maxlen=MAX_MEMORY)
        self.model = Linear_QNet(11, 256, 3)
        self.trainer = QTrainer(self.model, lr=LR, gamma=self.gamma)
    
    def get_state(self, game):
        head = game.snake[0]
        point_l = Point(head.x - snake_block, head.y)
        point_r = Point(head.x + snake_block, head.y)
        point_u = Point(head.x, head.y - snake_block)
        point_d = Point(head.x, head.y + snake_block)
        
        dir_l = game.direction == Direction.LEFT
        dir_r = game.direction == Direction.RIGHT
        dir_u = game.direction == Direction.UP
        dir_d = game.direction == Direction.DOWN
        
        state = [
            (dir_r and game.is_collision(point_r)) or 
            (dir_l and game.is_collision(point_l)) or 
            (dir_u and game.is_collision(point_u)) or 
            (dir_d and game.is_collision(point_d)),
            (dir_u and game.is_collision(point_r)) or 
            (dir_d and game.is_collision(point_l)) or 
            (dir_l and game.is_collision(point_u)) or 
            (dir_r and game.is_collision(point_d)),
            (dir_d and game.is_collision(point_r)) or 
            (dir_u and game.is_collision(point_l)) or 
            (dir_r and game.is_collision(point_u)) or 
            (dir_l and game.is_collision(point_d)),
            dir_l,
            dir_r,
            dir_u,
            dir_d,
            game.food.x < game.head.x, 
            game.food.x > game.head.x, 
            game.food.y < game.head.y, 
            game.food.y > game.head.y  
        ]
        
        return np.array(state, dtype=int)
    
    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
    
    def train_long_memory(self):
        if len(self.memory) > BATCH_SIZE:
            mini_sample = random.sample(self.memory, BATCH_SIZE)
        else:
            mini_sample = self.memory
        
        states, actions, rewards, next_states, dones = zip(*mini_sample)
        self.trainer.train_step(states, actions, rewards, next_states, dones)
    
    def train_short_memory(self, state, action, reward, next_state, done):
        self.trainer.train_step(state, action, reward, next_state, done)
    
    def get_action(self, state):
        self.epsilon = 80 - self.n_games
        final_move = [0, 0, 0]
        if random.randint(0, 200) < self.epsilon:
            move = random.randint(0, 2)
            final_move[move] = 1
        else:
            state0 = torch.tensor(state, dtype=torch.float)
            prediction = self.model(state0)
            move = torch.argmax(prediction).item()
            final_move[move] = 1
        
        return final_move


class SnakeGameAI:
    def __init__(self, w=dis_width, h=dis_height):
        self.w = w
        self.h = h
        self.reset()
    
    def reset(self):
        self.direction = Direction.RIGHT
        self.head = Point(self.w/2, self.h/2)
        self.snake = [self.head,
                      Point(self.head.x - snake_block, self.head.y),
                      Point(self.head.x - (2*snake_block), self.head.y)]
        
        self.score = 0
        self.food = None
        self._place_food()
        self.frame_iteration = 0
    
    def _place_food(self):
        x = random.randint(0, (self.w - snake_block) // snake_block) * snake_block
        y = random.randint(0, (self.h - snake_block) // snake_block) * snake_block
        self.food = Point(x, y)
        if self.food in self.snake:
            self._place_food()
    
    def play_step(self, action):
        self.frame_iteration += 1
        
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                quit()
        
        self._move(action)
        self.snake.insert(0, self.head)
        
        reward = 0
        game_over = False
        if self.is_collision() or self.frame_iteration > 100*len(self.snake):
            game_over = True
            reward = -10
            return reward, game_over, self.score
        
        if self.head == self.food:
            self.score += 1
            reward = 10
            self._place_food()
        else:
            self.snake.pop()
        
        self._update_ui()
        clock.tick(snake_speed)
        
        return reward, game_over, self.score
    
    def is_collision(self, pt=None):
        if pt is None:
            pt = self.head
        if pt.x > self.w - snake_block or pt.x < 0 or pt.y > self.h - snake_block or pt.y < 0:
            return True
        if pt in self.snake[1:]:
            return True
        return False
    
    def _update_ui(self):
        dis.fill(black)
        
        for pt in self.snake:
            pygame.draw.rect(dis, green, pygame.Rect(pt.x, pt.y, snake_block, snake_block))
        
        pygame.draw.rect(dis, red, pygame.Rect(self.food.x, self.food.y, snake_block, snake_block))
        
        text = font.render("Score: " + str(self.score), True, white)  # Use the global font
        dis.blit(text, [0, 0])
        pygame.display.flip()
    
    def _move(self, action):
        clock_wise = [Direction.RIGHT, Direction.DOWN, Direction.LEFT, Direction.UP]
        idx = clock_wise.index(self.direction)
        
        if np.array_equal(action, [1, 0, 0]):
            new_dir = clock_wise[idx] 
        elif np.array_equal(action, [0, 1, 0]):
            next_idx = (idx + 1) % 4
            new_dir = clock_wise[next_idx] 
        else:  
            next_idx = (idx - 1) % 4
            new_dir = clock_wise[next_idx] 
        
        self.direction = new_dir
        
        x = self.head.x
        y = self.head.y
        if self.direction == Direction.RIGHT:
            x += snake_block
        elif self.direction == Direction.LEFT:
            x -= snake_block
        elif self.direction == Direction.DOWN:
            y += snake_block
        elif self.direction == Direction.UP:
            y -= snake_block
        
        self.head = Point(x, y)


def test(num_games=10):
    agent = Agent()
    agent.model.load_state_dict(torch.load('model.pth'))  # Load the trained model
    game = SnakeGameAI()

    for _ in range(num_games):
        done = False
        game.reset()

        while not done:
            state_old = agent.get_state(game)
            state_tensor = torch.tensor(state_old, dtype=torch.float)
            with torch.no_grad():  
                prediction = agent.model(state_tensor)
            move = torch.argmax(prediction).item()

            final_move = [0, 0, 0]
            final_move[move] = 1
            reward, done, score = game.play_step(final_move)

        print(f'Score: {score}')

# Running the test function
if __name__ == '__main__':
    test(num_games=20)  # Test for 20 games
