In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from game import Snake
import os
import wandb
import random
from collections import deque

### Hyperparameters

In [17]:
EPSILON = 1 
EPSILON_DECAY = 0.095
ACTION_SPACE = [0,1,2,3]
GAMMA = 0.99
LR = 1e-3
INPUT_SIZE = 8
OUTPUT_SIZE = 5 
# HIDDEN_SIZE = 64 
EPOCHS = 1000
FREQ_OF_TARGET_NN_UPDATE = 20
REPLAY_MEMORY_SIZE = 500000

In [3]:
device = "mps" if torch.backends.mps.is_available() else "cpu"

### Policy

In [4]:
def policy(epsilon, action_space, highest_estimated_action):
    if random.random() < epsilon:
        return random.randint(0, len(action_space) - 1)
    else: 
        return highest_estimated_action

### Loss function

In [5]:
def loss_function(active_prediction, gamma, immediate_reward, action, done, final_state, target_nn):
    active_prediction = active_prediction[action]
    if done == True: 
        q_target = immediate_reward
    else:
        next_state_value_target = target_nn(torch.tensor(final_state, dtype=torch.float32).to(device))[action]
        q_target = immediate_reward + (gamma * next_state_value_target) 
        
    loss = (q_target - active_prediction) ** 2
    return loss

### Neural Network

In [6]:
#[head_x, head_y, direction, food_x, food_y, danger_straight, danger_right, danger_left]
class DQN_base_ff(nn.Module):
    def __init__(self, input_size, output_size, hidden_size) -> None:
        super().__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
        return x
    
    def save(self, file_name="model.pth"):
        model_folder_path = "./model"
        if not os.path.exists(model_folder_path):
            os.makedirs(model_folder_path)
        
        file_name = os.path.join(model_folder_path, file_name)
        torch.save(self.state_dict(), file_name)

In [35]:
#[head_x, head_y, direction, food_x, food_y, danger_straight, danger_right, danger_left]
class DQN_scaled_ff(nn.Module):
    def __init__(self, input_size, output_size, hidden_size) -> None:
        super().__init__()
        hidden_size1 = int(hidden_size * 2)
        hidden_size2 = int(hidden_size / 2)
        
        self.input_linear = nn.Linear(input_size, hidden_size)
        self.linear_sequence1 = nn.Sequential(
            nn.Linear(hidden_size, hidden_size1),
            nn.ReLU(), 
            nn.Linear(hidden_size1, hidden_size1),
            nn.Linear(hidden_size1, hidden_size1),
            nn.ReLU(),
            nn.Linear(hidden_size1, hidden_size1),
            nn.Linear(hidden_size1, hidden_size1),
            nn.ReLU(),
            nn.Linear(hidden_size1, hidden_size1)
        )
        self.linear_sequence2 = nn.Sequential(
            nn.Linear(hidden_size1, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size2),
            nn.Linear(hidden_size2, hidden_size2),
            nn.ReLU(), 
            nn.Linear(hidden_size2, hidden_size2),
            nn.Linear(hidden_size2, hidden_size2),
            nn.ReLU(), 
            nn.Linear(hidden_size2, hidden_size2)
        )
        self.output_linear = nn.Linear(hidden_size2, output_size)

    def forward(self, x):
        x = self.input_linear(x) 
        x = self.linear_sequence1(x)
        x = self.linear_sequence2(x)
        x = self.output_linear(x)
        return x
    
    def save(self, file_name="model.pth"):
        model_folder_path = "./model"
        if not os.path.exists(model_folder_path):
            os.makedirs(model_folder_path)
        
        file_name = os.path.join(model_folder_path, file_name)
        torch.save(self.state_dict(), file_name)

In [7]:
def get_batch(size, state_action_pairs):
    indices = random.sample(range(len(state_action_pairs)), size)
    batch = [state_action_pairs[i] for i in indices]
    return batch

In [18]:
def train(active_nn, target_nn, state_action_pairs, replay_memory_size, device, loss_function, optimizer, gamma, epochs, freq_of_target_nn_update, epsilon, action_space, policy, model_name, model_type, learning_rate, run_name, save_model=False):
    wandb.init(
        project="snake_game_rl",
        name=run_name,
        config={
            "model_name": model_name,
            "model_type": model_type,
            "epsilon": epsilon,
            "gamma": gamma,
            "learning_rate": learning_rate,
            "replay_memory_size": replay_memory_size,
            "epochs": epochs,
            "freq_of_target_nn_update": freq_of_target_nn_update,
        }
    )
    
    training_table = wandb.Table(columns=["epoch", "loss"])
    evaluation_table = wandb.Table(columns=["epoch", "steps", "score"])
    
    for epoch in range(epochs):
        train_sample_batch = state_action_pairs[random.randint(0, replay_memory_size - 1)]
        optimizer.zero_grad() 
        input = torch.tensor(train_sample_batch[0], dtype=torch.float32).to(device)
        action_pred = active_nn(input)
        loss = loss_function(action_pred, gamma, train_sample_batch[2], train_sample_batch[1], train_sample_batch[4], train_sample_batch[3], target_nn)
    
        loss.backward()
        optimizer.step()
        
        loss_value = loss.item()
        
        wandb.log({
            "epoch": epoch,
            "training_loss": loss_value
        })
        
        training_table.add_data(str(epoch), str(loss_value))
        
        print(f"Epoch {epoch}, Loss: {loss_value}")
    
        if epoch % freq_of_target_nn_update == 0:
            target_nn.load_state_dict(active_nn.state_dict())
            target_nn.requires_grad_ = False

            print("Evaluating active network")
            steps = 0
            score = 0
            evaluation_game = Snake(evaluation=True)
            
            while True:
                initial_state = evaluation_game.get_state()
                action_pred = active_nn(torch.tensor(initial_state, dtype=torch.float32).to(device))
                action = policy(0, action_space, torch.argmax(action_pred))
                done = evaluation_game.move_with_action(action=action)
                steps += 1
                
                if evaluation_game.get_immediate_reward() == 1:
                    score += 1
                
                if done == True:
                    break
            
            wandb.log({
                "epoch": epoch,
                "evaluation_steps": steps,
                "evaluation_score": score
            })
            
            evaluation_table.add_data(str(epoch), str(steps), str(score))
            
            print(f"Evaluation - Steps: {steps}, Score: {score}")
    
    wandb.log({"training_metrics_table": training_table})
    wandb.log({"evaluation_metrics_table": evaluation_table})
    
    summary_data = [
        ["Model Name", model_name],
        ["Model Type", model_type],
        ["Epsilon", str(epsilon)],
        ["Gamma", str(gamma)],
        ["Learning Rate", str(learning_rate)],
        ["Replay Memory Size", str(replay_memory_size)],
        ["Epochs", str(epochs)],
        ["Target NN Update Frequency", str(freq_of_target_nn_update)],
        ["Final Loss", str(loss_value)],
        ["Final Evaluation Steps", str(steps) if 'steps' in locals() else 'N/A'],
        ["Final Evaluation Score", str(score) if 'score' in locals() else 'N/A']
    ]
    
    summary_table = wandb.Table(columns=["Parameter", "Value"], data=summary_data)
    wandb.log({"training_summary_table": summary_table})
    
    if save_model:
        active_nn.save(f"{model_name}.pth")
    
    wandb.finish()
    
    return active_nn, target_nn


### Training Base Model

In [8]:
# [initial_state, action, final_state, dead, reward]
state_action_pairs = deque(maxlen=REPLAY_MEMORY_SIZE) 

In [None]:
HIDDEN_SIZE = 16
active_nn = DQN_scaled_ff(INPUT_SIZE, OUTPUT_SIZE, HIDDEN_SIZE).to(device=device)

game = Snake()
for episode in range(REPLAY_MEMORY_SIZE):
    initial_state = game.get_state()
    action_pred = torch.argmax(active_nn(torch.tensor(initial_state, dtype=torch.float32).to(device)))
    action = policy(EPSILON, ACTION_SPACE, action_pred)
    done = game.move_with_action(action=action)
    final_state = game.get_state()
    reward = game.get_immediate_reward()
    
    state_action_pairs.append((initial_state, action, reward, final_state, done))
    EPSILON *= EPSILON_DECAY

    if done == 1:
        game = Snake()


In [16]:
active_nn = DQN_base_ff(INPUT_SIZE, OUTPUT_SIZE, HIDDEN_SIZE).to(device=device)
target_nn = DQN_base_ff(INPUT_SIZE, OUTPUT_SIZE, HIDDEN_SIZE).to(device=device)
target_nn.load_state_dict(active_nn.state_dict())
target_nn.requires_grad_ = False
optimizer = optim.AdamW(active_nn.parameters(), LR)

active_nn, target_nn = train(
    active_nn=active_nn,
    target_nn=target_nn,
    state_action_pairs=state_action_pairs,
    replay_memory_size=REPLAY_MEMORY_SIZE,
    device=device,
    loss_function=loss_function,
    optimizer=optimizer,
    gamma=GAMMA,
    epochs=EPOCHS,
    freq_of_target_nn_update=FREQ_OF_TARGET_NN_UPDATE,
    epsilon=EPSILON,
    action_space=ACTION_SPACE,
    policy=policy,
    model_name="initial_base_model",
    model_type="linear",
    learning_rate=LR,
    run_name="initial_base_model_run"
)

Epoch 0, Loss: 0.29449227452278137
Evaluating active network
Evaluation - Steps: 20, Score: 1
Epoch 1, Loss: 0.00926232896745205
Epoch 2, Loss: 0.012865720316767693
Epoch 3, Loss: 9.183608926832676e-05
Epoch 4, Loss: 0.1681126058101654
Epoch 5, Loss: 0.017012910917401314
Epoch 6, Loss: 1.627926576475147e-05
Epoch 7, Loss: 0.0009350040927529335
Epoch 8, Loss: 13.79163646697998
Epoch 9, Loss: 0.008541463874280453
Epoch 10, Loss: 0.016976360231637955
Epoch 11, Loss: 0.04102009907364845
Epoch 12, Loss: 0.006098002661019564
Epoch 13, Loss: 0.0025676083751022816
Epoch 14, Loss: 0.00011074024223489687
Epoch 15, Loss: 0.04613468423485756
Epoch 16, Loss: 0.0014864255208522081
Epoch 17, Loss: 0.002015576232224703
Epoch 18, Loss: 0.0122377909719944
Epoch 19, Loss: 0.20847086608409882
Epoch 20, Loss: 0.0025835803244262934
Evaluating active network
Evaluation - Steps: 20, Score: 0
Epoch 21, Loss: 0.04392904415726662
Epoch 22, Loss: 0.027787817642092705
Epoch 23, Loss: 0.019414285197854042
Epoch 24,

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇█████
evaluation_score,█▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█
evaluation_steps,▂▂▁▂▂▂▁▁▁▁▂▁▁▁▃▂▄▂▂▂▃▂▇▂▇▂▂▃▂▂▃▁▂▃▂▄▂▂▂█
training_loss,▁▁▁█▁▁▃▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▁▁▁▁▁▃▁▁▁▁▁▁▁▁

0,1
epoch,999.0
evaluation_score,1.0
evaluation_steps,64.0
training_loss,0.00058


### Training Scaled FF Model

In [39]:
REPLAY_MEMORY_SIZE = 50000000
HIDDEN_SIZE = 64 
EPOCHS = 1000000

In [40]:
# [initial_state, action, final_state, dead, reward]
state_action_pairs = deque(maxlen=REPLAY_MEMORY_SIZE) 

In [41]:
active_nn = DQN_scaled_ff(INPUT_SIZE, OUTPUT_SIZE, HIDDEN_SIZE).to(device=device)

game = Snake()
for episode in range(REPLAY_MEMORY_SIZE):
    initial_state = game.get_state()
    action_pred = torch.argmax(active_nn(torch.tensor(initial_state, dtype=torch.float32).to(device)))
    action = policy(EPSILON, ACTION_SPACE, action_pred)
    done = game.move_with_action(action=action)
    final_state = game.get_state()
    reward = game.get_immediate_reward()
    
    state_action_pairs.append((initial_state, action, reward, final_state, done))
    EPSILON *= EPSILON_DECAY

    if done == 1:
        game = Snake()


KeyboardInterrupt: 

In [None]:
active_nn = DQN_scaled_ff(INPUT_SIZE, OUTPUT_SIZE, HIDDEN_SIZE).to(device=device)
target_nn = DQN_scaled_ff(INPUT_SIZE, OUTPUT_SIZE, HIDDEN_SIZE).to(device=device)
target_nn.load_state_dict(active_nn.state_dict())
target_nn.requires_grad_ = False
optimizer = optim.AdamW(active_nn.parameters(), LR)

active_nn, target_nn = train(
    active_nn=active_nn,
    target_nn=target_nn,
    state_action_pairs=state_action_pairs,
    replay_memory_size=REPLAY_MEMORY_SIZE,
    device=device,
    loss_function=loss_function,
    optimizer=optimizer,
    gamma=GAMMA,
    epochs=EPOCHS,
    freq_of_target_nn_update=FREQ_OF_TARGET_NN_UPDATE,
    epsilon=EPSILON,
    action_space=ACTION_SPACE,
    policy=policy,
    model_name="scaled_ff_model",
    model_type="scaled_ff",
    learning_rate=LR,
    run_name="scaled_ff_model_run"
)