In [None]:
# Imports
import os
import random
from pathlib import Path
from collections import deque
import torch
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

In [None]:
# Change working directory to the root of the project
os.chdir(Path(os.getcwd()).parent)

# Print the current working directory
print("Current working directory: ", os.getcwd())

In [None]:
# Import the game
from game.game import SnakeGameAI, Direction, Point

In [None]:
# Define parameters
MAX_MEMORY = 1_00_000
BATCH_SIZE = 1_000
LR = 1e-3

In [None]:
# Pytorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Helper functions
from helper.helper import plot

In [None]:
# Class for the neural network
class Linear_QNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.linear1 = nn.Linear(input_size, hidden_size) # 3 inputs (left, straight, right), 256 hidden neurons
        self.linear2 = nn.Linear(hidden_size, output_size) # 256 hidden neurons, 3 outputs (left, straight, right)

    def forward(self, x):
        x = F.relu(self.linear1(x)) # ReLU activation function
        x = self.linear2(x) # No activation function
        return x

    def save(self, file_name='lienar_qnet.pth'):
        model_folder_path = './models'
        if not os.path.exists(model_folder_path):
            os.makedirs(model_folder_path)
        file_name = os.path.join(model_folder_path, file_name)
        torch.save(self.state_dict(), file_name)

In [None]:
# Class for trainer
class QTrainer:
    def __init__(self, model, lr, gamma):
        self.lr = lr
        self.gamma = gamma
        self.model = model
        self.optimizer = optim.Adam(model.parameters(), lr=self.lr) # Adam optimizer
        self.criterion = nn.MSELoss() # Mean squared error loss

    def train_step(self, state, action, reward, next_state, done):
        # Convert to tensors
        state = torch.tensor(state, dtype=torch.float)
        next_state = torch.tensor(next_state, dtype=torch.float)
        action = torch.tensor(action, dtype=torch.long) # Long tensor
        reward = torch.tensor(reward, dtype=torch.float)
        # (n, x) -> (n, 1)
        # (1, 3) -> (1, 1)
        if len(state.shape) == 1:
            # Unsqueeze the first dimension
            # (1, 3) -> (1, 1, 3)
            state = torch.unsqueeze(state, 0)
            next_state = torch.unsqueeze(next_state, 0)
            action = torch.unsqueeze(action, 0)
            reward = torch.unsqueeze(reward, 0)
            done = (done, ) # Tuple

        # 1: predicted Q values with current state
        pred = self.model(state)
        
        # 2: Q_new = r + y * max(next_predicted Q value) -> only do this if not done
        target = pred.clone() # -> Q value
        for idx in range(len(done)):
            Q_new = reward[idx]
            if not done[idx]:
                Q_new = reward[idx] + self.gamma * torch.max(self.model(next_state[idx])) # r + y * max(next_predicted Q value) -> only do this if not done
                
            target[idx][torch.argmax(action).item()] = Q_new
        
        # Set gradients to zero before backpropagation
        self.optimizer.zero_grad()
        
        # 3: loss = (Q_new - Q_old)^2
        loss = self.criterion(target, pred)
        loss.backward()
        
        # 4. Update the weights
        self.optimizer.step()

In [None]:
# Class for the agent
class Agent:
    # Constructor
    def __init__(self) -> None:
        self.n_games = 0
        self.epsilon = 0 # randomness
        self.gamma = 0.9 # discount rate
        self.memory = deque(maxlen=MAX_MEMORY) # popleft()
        self.model = Linear_QNet(11, 256, 3) # neural network
        self.trainer = QTrainer(model=self.model, lr=LR, gamma=self.gamma) # optimizer

    # Function to get the state of the game
    def get_state(self, game):
        head = game.snake[0]
        point_l = Point(head.x - 20, head.y)
        point_r = Point(head.x + 20, head.y)
        point_u = Point(head.x, head.y - 20)
        point_d = Point(head.x, head.y + 20)
        
        dir_l = game.direction == Direction.LEFT
        dir_r = game.direction == Direction.RIGHT
        dir_u = game.direction == Direction.UP
        dir_d = game.direction == Direction.DOWN
        
        state = [
            # Danger straight
            (dir_r and game.is_collision(point_r)) or
            (dir_l and game.is_collision(point_l)) or
            (dir_u and game.is_collision(point_u)) or
            (dir_d and game.is_collision(point_d)),
            
            # Danger right
            (dir_u and game.is_collision(point_r)) or
            (dir_d and game.is_collision(point_l)) or
            (dir_l and game.is_collision(point_u)) or
            (dir_r and game.is_collision(point_d)),
            
            # Danger left
            (dir_d and game.is_collision(point_r)) or
            (dir_u and game.is_collision(point_l)) or
            (dir_r and game.is_collision(point_u)) or
            (dir_l and game.is_collision(point_d)),
            
            # Move direction
            dir_l,
            dir_r,
            dir_u,
            dir_d,
            
            # Food location
            game.food.x < game.head.x, # food left
            game.food.x > game.head.x, # food right
            game.food.y < game.head.y, # food up
            game.food.y > game.head.y # food down
        ]
        
        return np.array(state, dtype=int)
    
    # Function to remember the state of the game
    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
    
    # Function to train the agent
    def train_long_memory(self):
        if len(self.memory) > BATCH_SIZE:
            mini_sample = random.sample(self.memory, BATCH_SIZE) # Randomly sample from memory
        else:
            mini_sample = self.memory
            
        states, actions, rewards, next_states, dones = zip(*mini_sample)
        
        self.trainer.train_step(states, actions, rewards, next_states, dones)
    
    # Function to train the agent
    def train_short_memory(self, state, action, reward, next_state, done):
        self.trainer.train_step(state, action, reward, next_state, done)
    
    # Function to get the action
    def get_action(self, state):
        self.epsilon = 80 - self.n_games
        
        final_move = [0, 0, 0]
        
        if np.random.randint(0, 200) < self.epsilon:
            move = np.random.randint(0, 2)
            final_move[move] = 1
        else:
            state0 = torch.tensor(state, dtype=torch.float)
            prediction = self.model(state0)
            move = torch.argmax(prediction).item()
            final_move[move] = 1
            
        return final_move

In [None]:
# Global function to train the model
def train():
    # Store the scores
    plot_scores = []
    plot_mean_scores = []
    total_score = 0
    record = 0
    
    # Initialize the agent
    agent = Agent()
    
    # Initialize the game
    game = SnakeGameAI()
    
    # Training loop
    while True:
        # Get old state
        state_old = agent.get_state(game)
        
        # Get move
        final_move = agent.get_action(state_old)
        
        # Perform move and get new state
        reward, done, score = game.play_step(final_move)
        state_new = agent.get_state(game)
        
        # Train short memory
        agent.train_short_memory(state_old, final_move, reward, state_new, done)
        
        # Remember
        agent.remember(state_old, final_move, reward, state_new, done)
        
        # If game is over
        if done:
            # Train long memory, plot result
            game.reset()
            agent.n_games += 1
            agent.train_long_memory()
            
            # Update the record score score
            if score > record:
                record = score
                
            # Print results
            print('Game', agent.n_games, 'Score', score, 'Record', record)
            
            # Save the model
            agent.model.save()
            
            # Plot the results
            plot_scores.append(score)
            total_score += score
            mean_score = total_score / agent.n_games
            plot_mean_scores.append(mean_score)
            plot(plot_scores, plot_mean_scores)

In [None]:
# Call the train function
train()