In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import numpy as np
import sys
import os
from tqdm import tqdm
import pickle

# Add the project root directory to sys.path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../..')))

from src.ParObsSnakeEnv import ParObsSnakeEnv
from src.FullObsSnakeEnv import FullObsSnakeEnv
from src.utils import compute_metrics
import matplotlib.pyplot as plt

In [2]:
class DQN(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=128):
        super(DQN, self).__init__()
        self.hidden_dim = hidden_dim
        self.fc1 = nn.Linear(input_dim, self.hidden_dim)
        self.fc2 = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.fc3 = nn.Linear(self.hidden_dim, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
# Define the DQN agent
class DQNAgent:
    def __init__(self, state_dim, action_dim, lr=0.001, gamma=0.99, epsilon=1.0, epsilon_decay=0.995, epsilon_min=0.01, memory_size=10000, batch_size=64):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        self.batch_size = batch_size

        self.memory = deque(maxlen=memory_size)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = DQN(state_dim, action_dim).to(self.device)
        self.target_model = DQN(state_dim, action_dim).to(self.device)
        self.target_model.load_state_dict(self.model.state_dict())
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.criterion = nn.MSELoss()

        # Metrics for plotting
        self.episode_rewards = []
        self.episode_losses = []
        self.epsilon_values = []
        self.average_q_values = []

    def choose_action(self, state):
        if random.random() < self.epsilon:
            return random.randint(0, self.action_dim - 1), None
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        state = state.reshape((1, -1))
        with torch.no_grad():
            # print(">>>>", state.shape)
            q_values = self.model(state)
        return torch.argmax(q_values).item(), None

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def replay(self):
        if len(self.memory) < self.batch_size:
            return

        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.FloatTensor(states).to(self.device)
        actions = torch.LongTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)
        next_states = torch.FloatTensor(next_states).to(self.device)
        dones = torch.FloatTensor(dones).to(self.device)

        states = states.reshape((self.batch_size, -1))
        next_states = next_states.reshape((self.batch_size, -1))

        # Compute current Q-values
        q_values = self.model(states)
        q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

        # Compute target Q-values
        next_q_values = self.target_model(next_states).max(1)[0]
        target_q_values = rewards + (1 - dones) * self.gamma * next_q_values

        # Update the Q-network
        loss = self.criterion(q_values, target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Log the loss
        self.episode_losses.append(loss.item())

        # Track average Q-value
        self.average_q_values.append(q_values.mean().item())

        # Decay epsilon
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

    def update_target_model(self):
        self.target_model.load_state_dict(self.model.state_dict())
        
    def save(self, filename):
        """Saves the entire agent to a file with all tensors moved to CPU."""
        # Save the current device
        current_device = next(self.model.parameters()).device

        # Move models to CPU
        self.model.to('cpu')
        self.target_model.to('cpu')

        state = {
            'model_state_dict': self.model.state_dict(),
            'target_model_state_dict': self.target_model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'hyperparameters': {
                'state_dim': self.state_dim,
                'action_dim': self.action_dim,
                'gamma': self.gamma,
                'epsilon': self.epsilon,
                'epsilon_decay': self.epsilon_decay,
                'epsilon_min': self.epsilon_min,
                'batch_size': self.batch_size,
            },
            'memory': list(self.memory),  # Convert deque to list
        }

        with open(filename, 'wb') as f:
            pickle.dump(state, f)
        print(f"Agent saved to {filename}")

        # Move the models back to their original device
        self.model.to(current_device)
        self.target_model.to(current_device)
    
    def load(cls, filename, lr=0.001, device=None):
        """Loads the agent from a file."""
        with open(filename, 'rb') as f:
            state = pickle.load(f)

        # Determine the device to load the model onto
        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Recreate the agent
        agent = cls(
            state['hyperparameters']['state_dim'],
            state['hyperparameters']['action_dim'],
            lr=lr,
            gamma=state['hyperparameters']['gamma'],
            epsilon=state['hyperparameters']['epsilon'],
            epsilon_decay=state['hyperparameters']['epsilon_decay'],
            epsilon_min=state['hyperparameters']['epsilon_min'],
            batch_size=state['hyperparameters']['batch_size'],
        )

        # Restore the agent's state
        agent.model.load_state_dict(state['model_state_dict'])
        agent.target_model.load_state_dict(state['target_model_state_dict'])
        agent.optimizer.load_state_dict(state['optimizer_state_dict'])
        agent.memory = deque(state['memory'], maxlen=len(state['memory']))

        # Move the model to the specified device
        agent.model = agent.model.to(device)
        agent.target_model = agent.target_model.to(device)

        print(f"Agent loaded from {filename} onto device: {device}")
        return agent

    # Train the agent
    @classmethod
    def train(self, env, episodes=1000, update_target_every=10, save_plots=False, plots_path='dqn_training_plots.png'):
        for episode in tqdm(range(episodes), desc="Training", unit='episode'):
            state = env.reset()
            total_reward = 0
            done = False

            while not done:
                action, _ = self.choose_action(state)
                next_state, reward, done, _ = env.step(action)
                self.remember(state, action, reward, next_state, done)
                self.replay()
                state = next_state
                total_reward += reward

            self.episode_rewards.append(total_reward)
            self.epsilon_values.append(self.epsilon)

            if (episode + 1) % update_target_every == 0:
                self.update_target_model()

        if save_plots:
            self.save_plots(plots_path)
    
    def save_plots(self, plots_path):
        plots_dir = os.path.dirname(plots_path)
        os.makedirs(plots_dir, exist_ok=True)

        fig, axs = plt.subplots(2, 2, figsize=(15, 10))

        # Rewards per episode
        axs[0, 0].plot(self.episode_rewards)
        axs[0, 0].set_title("Episode Rewards")
        axs[0, 0].set_xlabel("Episode")
        axs[0, 0].set_ylabel("Total Reward")

        # Loss per episode
        axs[0, 1].plot(self.episode_losses)
        axs[0, 1].set_title("Loss Over Training")
        axs[0, 1].set_xlabel("Episode")
        axs[0, 1].set_ylabel("Loss")

        # Epsilon decay
        axs[1, 0].plot(self.epsilon_values)
        axs[1, 0].set_title("Epsilon Decay")
        axs[1, 0].set_xlabel("Episode")
        axs[1, 0].set_ylabel("Epsilon Value")

        # Average Q-values
        axs[1, 1].plot(self.average_q_values)
        axs[1, 1].set_title("Average Q-Values")
        axs[1, 1].set_xlabel("Episode")
        axs[1, 1].set_ylabel("Average Q-Value")

        plt.tight_layout()
        plt.savefig(plots_path)
        plt.close()

In [4]:
grid_size = 10
# env = FullObsSnakeEnv(grid_size=grid_size, interact=False)
env = ParObsSnakeEnv(grid_size=grid_size, interact=False)
if isinstance(env, FullObsSnakeEnv):
    state_dim = env.observation_space.shape[0] * env.observation_space.shape[1] * env.observation_space.shape[2]
else:
    state_dim = env.observation_space.shape[0]

action_dim = env.action_space.n
agent = DQNAgent(state_dim, action_dim)

In [5]:
num_episodes = 300
agent.train(env, episodes=num_episodes)

  states = torch.FloatTensor(states).to(self.device)
Training: 100%|██████████| 300/300 [01:03<00:00,  4.70episode/s]


In [None]:
environment = 'full 'if isinstance(env, FullObsSnakeEnv) else 'par'

agent_name = f'dqn_agent_{environment}_{num_episodes}_{grid_size}.pkl'
model_weights_dir = os.path.join('../..', 'models', 'dqn')
os.makedirs(model_weights_dir, exist_ok=True)
agent_path = os.path.join(model_weights_dir, agent_name)

agent.save(agent_path)
agent = DQNAgent.load(agent_path)

Agent saved to ../../models/dqn/dqn_agent_par_300_10.pkl


In [None]:
if isinstance(env, ParObsSnakeEnv):
    env = ParObsSnakeEnv(grid_size=2*grid_size, interact=False)

model_metrics_dir = os.path.join('../..', 'artifacts', 'models_stats', 'dqn')
os.makedirs(model_metrics_dir, exist_ok=True)

train_metrics_name = f'dqn_train_metrics_{environment}_{num_episodes}_{grid_size}.png'
train_metrics_path = os.path.join(model_metrics_dir, train_metrics_name)
agent.save_plots(train_metrics_path)

num_simulations = 100
sim_metrics_name = f'dqn_sim_metrics_{environment}_{num_episodes}_{env.grid_size}_{num_simulations}.json'
sim_metrics_path = os.path.join(model_metrics_dir, sim_metrics_name)
compute_metrics(agent, env, sim_metrics_path, num_simulations=num_simulations)

  2%|▏         | 2/100 [00:00<00:07, 12.76it/s]

Snake length: 32, Episode reward: 2698
Snake length: 24, Episode reward: 1979
Snake length: 22, Episode reward: 1726
Snake length: 19, Episode reward: 1542


  7%|▋         | 7/100 [00:00<00:06, 15.24it/s]

Snake length: 34, Episode reward: 2845
Snake length: 20, Episode reward: 1592
Snake length: 26, Episode reward: 2120
Snake length: 27, Episode reward: 2172


 11%|█         | 11/100 [00:00<00:06, 13.31it/s]

Snake length: 35, Episode reward: 2914
Snake length: 30, Episode reward: 2443
Snake length: 38, Episode reward: 3284


 13%|█▎        | 13/100 [00:01<00:08, 10.78it/s]

Snake length: 38, Episode reward: 3188
Snake length: 42, Episode reward: 3535


 15%|█▌        | 15/100 [00:01<00:07, 11.40it/s]

Snake length: 29, Episode reward: 2410
Snake length: 22, Episode reward: 1759
Snake length: 43, Episode reward: 3618


 19%|█▉        | 19/100 [00:01<00:06, 12.33it/s]

Snake length: 31, Episode reward: 2629
Snake length: 21, Episode reward: 1661
Snake length: 23, Episode reward: 1869
Snake length: 23, Episode reward: 1878


 23%|██▎       | 23/100 [00:01<00:06, 12.65it/s]

Snake length: 30, Episode reward: 2434
Snake length: 13, Episode reward: 991
Snake length: 42, Episode reward: 3622


 25%|██▌       | 25/100 [00:01<00:05, 12.61it/s]

Snake length: 27, Episode reward: 2177
Snake length: 35, Episode reward: 2941
Snake length: 33, Episode reward: 2722


 27%|██▋       | 27/100 [00:02<00:06, 11.66it/s]

Snake length: 38, Episode reward: 3187
Snake length: 39, Episode reward: 3283


 31%|███       | 31/100 [00:02<00:06, 10.60it/s]

Snake length: 33, Episode reward: 2754
Snake length: 29, Episode reward: 2399
Snake length: 39, Episode reward: 3297


 33%|███▎      | 33/100 [00:02<00:05, 11.68it/s]

Snake length: 28, Episode reward: 2271
Snake length: 25, Episode reward: 2047
Snake length: 18, Episode reward: 1460


 37%|███▋      | 37/100 [00:03<00:05, 11.90it/s]

Snake length: 41, Episode reward: 3483
Snake length: 22, Episode reward: 1818
Snake length: 32, Episode reward: 2680


 41%|████      | 41/100 [00:03<00:04, 13.65it/s]

Snake length: 30, Episode reward: 2466
Snake length: 27, Episode reward: 2231
Snake length: 25, Episode reward: 2008
Snake length: 15, Episode reward: 1216


 43%|████▎     | 43/100 [00:03<00:04, 14.02it/s]

Snake length: 29, Episode reward: 2358
Snake length: 23, Episode reward: 1831
Snake length: 41, Episode reward: 3466


 47%|████▋     | 47/100 [00:03<00:04, 11.99it/s]

Snake length: 33, Episode reward: 2804
Snake length: 31, Episode reward: 2540
Snake length: 26, Episode reward: 2200


 50%|█████     | 50/100 [00:04<00:03, 13.45it/s]

Snake length: 25, Episode reward: 2028
Snake length: 10, Episode reward: 710
Snake length: 31, Episode reward: 2587
Snake length: 28, Episode reward: 2317


 54%|█████▍    | 54/100 [00:04<00:03, 13.89it/s]

Snake length: 18, Episode reward: 1399
Snake length: 21, Episode reward: 1668
Snake length: 36, Episode reward: 2972


 56%|█████▌    | 56/100 [00:04<00:03, 12.81it/s]

Snake length: 33, Episode reward: 2752
Snake length: 29, Episode reward: 2456
Snake length: 33, Episode reward: 2808


 60%|██████    | 60/100 [00:04<00:02, 13.60it/s]

Snake length: 13, Episode reward: 957
Snake length: 29, Episode reward: 2454
Snake length: 22, Episode reward: 1795
Snake length: 20, Episode reward: 1554


 63%|██████▎   | 63/100 [00:04<00:02, 14.10it/s]

Snake length: 17, Episode reward: 1306
Snake length: 38, Episode reward: 3228
Snake length: 20, Episode reward: 1616


 67%|██████▋   | 67/100 [00:05<00:02, 13.12it/s]

Snake length: 30, Episode reward: 2511
Snake length: 39, Episode reward: 3354
Snake length: 28, Episode reward: 2282


 71%|███████   | 71/100 [00:05<00:01, 17.50it/s]

Snake length: 5, Episode reward: 313
Snake length: 26, Episode reward: 2089
Snake length: 3, Episode reward: 101
Snake length: 17, Episode reward: 1317
Snake length: 16, Episode reward: 1209
Snake length: 22, Episode reward: 1758


 76%|███████▌  | 76/100 [00:05<00:01, 14.63it/s]

Snake length: 36, Episode reward: 3003
Snake length: 37, Episode reward: 3180
Snake length: 26, Episode reward: 2140


 78%|███████▊  | 78/100 [00:05<00:01, 14.09it/s]

Snake length: 25, Episode reward: 2114
Snake length: 30, Episode reward: 2480
Snake length: 28, Episode reward: 2350


 82%|████████▏ | 82/100 [00:06<00:01, 13.26it/s]

Snake length: 29, Episode reward: 2419
Snake length: 40, Episode reward: 3408
Snake length: 16, Episode reward: 1284


 86%|████████▌ | 86/100 [00:06<00:00, 17.27it/s]

Snake length: 21, Episode reward: 1682
Snake length: 3, Episode reward: 122
Snake length: 7, Episode reward: 460
Snake length: 20, Episode reward: 1620
Snake length: 38, Episode reward: 3161


 90%|█████████ | 90/100 [00:06<00:00, 14.81it/s]

Snake length: 15, Episode reward: 1164
Snake length: 38, Episode reward: 3229
Snake length: 23, Episode reward: 1898


 92%|█████████▏| 92/100 [00:06<00:00, 14.20it/s]

Snake length: 23, Episode reward: 1874
Snake length: 34, Episode reward: 2876
Snake length: 25, Episode reward: 2032
Snake length: 8, Episode reward: 546


 97%|█████████▋| 97/100 [00:07<00:00, 14.15it/s]

Snake length: 22, Episode reward: 1779
Snake length: 36, Episode reward: 3017
Snake length: 24, Episode reward: 1967


 99%|█████████▉| 99/100 [00:07<00:00, 13.70it/s]

Snake length: 34, Episode reward: 2871
Snake length: 20, Episode reward: 1600


100%|██████████| 100/100 [00:07<00:00, 13.22it/s]

Snake length: 46, Episode reward: 3921





{'snake_lengths': [32,
  24,
  22,
  19,
  34,
  20,
  26,
  27,
  35,
  30,
  38,
  38,
  42,
  29,
  22,
  43,
  31,
  21,
  23,
  23,
  30,
  13,
  42,
  27,
  35,
  33,
  38,
  39,
  33,
  29,
  39,
  28,
  25,
  18,
  41,
  22,
  32,
  30,
  27,
  25,
  15,
  29,
  23,
  41,
  33,
  31,
  26,
  25,
  10,
  31,
  28,
  18,
  21,
  36,
  33,
  29,
  33,
  13,
  29,
  22,
  20,
  17,
  38,
  20,
  30,
  39,
  28,
  5,
  26,
  3,
  17,
  16,
  22,
  36,
  37,
  26,
  25,
  30,
  28,
  29,
  40,
  16,
  21,
  3,
  7,
  20,
  38,
  15,
  38,
  23,
  23,
  34,
  25,
  8,
  22,
  36,
  24,
  34,
  20,
  46],
 'episode_rewards': [2698,
  1979,
  1726,
  1542,
  2845,
  1592,
  2120,
  2172,
  2914,
  2443,
  3284,
  3188,
  3535,
  2410,
  1759,
  3618,
  2629,
  1661,
  1869,
  1878,
  2434,
  991,
  3622,
  2177,
  2941,
  2722,
  3187,
  3283,
  2754,
  2399,
  3297,
  2271,
  2047,
  1460,
  3483,
  1818,
  2680,
  2466,
  2231,
  2008,
  1216,
  2358,
  1831,
  3466,
  2804,
  2540,
 

In [None]:
if isinstance(env, FullObsSnakeEnv):
    env.interact = True
else:
    env = ParObsSnakeEnv(grid_size=2*grid_size)
    
state = env.reset()
done = False
while not done:
    action, _ = agent.choose_action(state)
    state, reward, done, _ = env.step(action)
    env.render()
    print(f"Reward: {reward}")

Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Rew

KeyboardInterrupt: 

In [9]:
agent = DQNAgent(state_dim, action_dim)
agent.load("../../models/dqn/dqn_agent_par_300_10.pkl")

TypeError: 'DQNAgent' object is not callable