In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import sys
from tqdm import tqdm
import pickle
import os
# Add the project root directory to sys.path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../..')))

from src.ParObsSnakeEnv import ParObsSnakeEnv
from src.FullObsSnakeEnv import FullObsSnakeEnv
from src.utils import compute_metrics
import matplotlib.pyplot as plt
from torch.distributions import Categorical

In [2]:
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=128):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return torch.softmax(x, dim=-1)

In [3]:
class REINFORCEAgent:
    def __init__(self, state_dim, action_dim, lr=0.001, gamma=0.99, device=None):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.gamma = gamma

        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.policy = PolicyNetwork(state_dim, action_dim).to(self.device)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)

        # Metrics for plotting
        self.episode_rewards = []
        self.episode_losses = []

    def choose_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)  # Add batch dimension
        action_probs = self.policy(state)
        dist = Categorical(action_probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        return action.item(), log_prob

    def remember(self, log_prob, reward):
        if not hasattr(self, 'log_probs'):
            self.log_probs = []
        if not hasattr(self, 'rewards'):
            self.rewards = []
        self.log_probs.append(log_prob)
        self.rewards.append(reward)

    def update_policy(self):
        """Update the policy using stored rewards and log probabilities."""
        if not hasattr(self, 'log_probs') or len(self.log_probs) == 0:
            return

        returns = self.compute_returns(self.rewards)
        loss = -torch.sum(torch.stack(self.log_probs) * returns)  # Negative log-prob * return

        # Update the policy
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Log the loss
        self.episode_losses.append(loss.item())

        # Clear memory
        self.log_probs.clear()
        self.rewards.clear()

    def compute_returns(self, rewards):
        """Compute discounted returns for an episode."""
        returns = []
        G = 0
        for reward in reversed(rewards):
            G = reward + self.gamma * G
            returns.insert(0, G)
        returns = torch.FloatTensor(returns).to(self.device)
        # Normalize returns to improve training stability
        if len(returns) > 1 and returns.std() > 1e-5:
            returns = (returns - returns.mean()) / (returns.std() + 1e-5)
        return returns

    def save(self, filename):
        """Saves the entire agent to a file."""
        state = {
            'policy_state_dict': self.policy.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'hyperparameters': {
                'state_dim': self.state_dim,
                'action_dim': self.action_dim,
                'gamma': self.gamma,
            },
        }
        with open(filename, 'wb') as f:
            pickle.dump(state, f)
        print(f"Agent saved to {filename}")

    @classmethod
    def load(cls, filename, lr=0.001):
        """Loads the agent from a file."""
        with open(filename, 'rb') as f:
            state = pickle.load(f)

        # Recreate the agent
        agent = cls(
            state['hyperparameters']['state_dim'],
            state['hyperparameters']['action_dim'],
            lr=lr,
            gamma=state['hyperparameters']['gamma'],
        )
        # Restore the agent's state
        agent.policy.load_state_dict(state['policy_state_dict'])
        agent.optimizer.load_state_dict(state['optimizer_state_dict'])
        print(f"Agent loaded from {filename}")
        return agent

    def train(self, env, episodes=1000, save_plots=False, plots_path='reinforce_training_plots.png'):
        for episode in tqdm(range(episodes), desc="Training", unit="episode"):
            state = env.reset()
            total_reward = 0
            done = False
            self.policy.train()

            while not done:
                action, log_prob = self.choose_action(state)
                next_state, reward, done, _ = env.step(action)
                self.remember(log_prob, reward)
                state = next_state
                total_reward += reward

            self.update_policy()
            self.episode_rewards.append(total_reward)

        if save_plots:
            self.save_plots(plots_path)
        self.policy.eval()


    def save_plots(self, plots_path):
        plots_dir = os.path.dirname(plots_path)
        os.makedirs(plots_dir, exist_ok=True)

        fig, axs = plt.subplots(1, 2, figsize=(15, 5))

        # Rewards per episode
        axs[0].plot(self.episode_rewards)
        axs[0].set_title("Episode Rewards")
        axs[0].set_xlabel("Episode")
        axs[0].set_ylabel("Total Reward")

        # Loss per episode
        axs[1].plot(self.episode_losses)
        axs[1].set_title("Loss Over Training")
        axs[1].set_xlabel("Episode")
        axs[1].set_ylabel("Loss")

        plt.tight_layout()
        plt.savefig(plots_path)
        plt.close()


In [4]:
grid_size = 10
# env = FullObsSnakeEnv(grid_size=grid_size, interact=False)
env = ParObsSnakeEnv(grid_size=grid_size, interact=False)
if isinstance(env, FullObsSnakeEnv):
    state_dim = env.observation_space.shape[0] * env.observation_space.shape[1] * env.observation_space.shape[2]
else:
    state_dim = env.observation_space.shape[0]

action_dim = env.action_space.n
agent = REINFORCEAgent(state_dim, action_dim)

In [5]:
num_episodes = 10000
agent.train(env, episodes=num_episodes)

Training: 100%|██████████| 10000/10000 [05:57<00:00, 27.98episode/s]


In [6]:
environment = 'full 'if isinstance(env, FullObsSnakeEnv) else 'par'

agent_name = f'polNet_agent_{environment}_{num_episodes}_{grid_size}.pkl'
model_weights_dir = os.path.join('../..', 'models', 'polNet')
os.makedirs(model_weights_dir, exist_ok=True)
agent_path = os.path.join(model_weights_dir, agent_name)

agent.save(agent_path)
# agent = REINFORCEAgent.load(agent_path)

Agent saved to ../../models/polNet/polNet_agent_par_10000_10.pkl


In [7]:
if isinstance(env, ParObsSnakeEnv):
    env = ParObsSnakeEnv(grid_size=2*grid_size, interact=False)

model_metrics_dir = os.path.join('../..', 'artifacts', 'models_stats', 'polNet')
os.makedirs(model_metrics_dir, exist_ok=True)

train_metrics_name = f'polNet_train_metrics_{environment}_{num_episodes}_{grid_size}.png'
train_metrics_path = os.path.join(model_metrics_dir, train_metrics_name)
agent.save_plots(train_metrics_path)

num_simulations = 100
sim_metrics_name = f'polNet_sim_metrics_{environment}_{num_episodes}_{env.grid_size}_{num_simulations}.json'
sim_metrics_path = os.path.join(model_metrics_dir, sim_metrics_name)
compute_metrics(agent, env, sim_metrics_path, num_simulations=num_simulations)

  2%|▏         | 2/100 [00:00<00:04, 19.98it/s]

Snake length: 2, Episode reward: 20
Snake length: 11, Episode reward: 787


  9%|▉         | 9/100 [00:00<00:03, 23.33it/s]

Snake length: 16, Episode reward: 1239
Snake length: 8, Episode reward: 525
Snake length: 2, Episode reward: 23
Snake length: 2, Episode reward: 17
Snake length: 3, Episode reward: 107
Snake length: 2, Episode reward: 31
Snake length: 6, Episode reward: 370


 12%|█▏        | 12/100 [00:00<00:04, 21.24it/s]

Snake length: 9, Episode reward: 615
Snake length: 4, Episode reward: 193
Snake length: 7, Episode reward: 444
Snake length: 5, Episode reward: 294


 15%|█▌        | 15/100 [00:00<00:03, 21.60it/s]

Snake length: 7, Episode reward: 468
Snake length: 1, Episode reward: -64
Snake length: 8, Episode reward: 560
Snake length: 10, Episode reward: 745


 18%|█▊        | 18/100 [00:00<00:04, 17.48it/s]

Snake length: 5, Episode reward: 314
Snake length: 11, Episode reward: 817


 23%|██▎       | 23/100 [00:01<00:05, 13.09it/s]

Snake length: 15, Episode reward: 1217
Snake length: 8, Episode reward: 534
Snake length: 4, Episode reward: 191
Snake length: 8, Episode reward: 530
Snake length: 1, Episode reward: -72
Snake length: 2, Episode reward: 19


 28%|██▊       | 28/100 [00:01<00:04, 15.87it/s]

Snake length: 4, Episode reward: 210
Snake length: 5, Episode reward: 265
Snake length: 11, Episode reward: 855
Snake length: 2, Episode reward: 31


 31%|███       | 31/100 [00:01<00:03, 18.16it/s]

Snake length: 7, Episode reward: 429
Snake length: 5, Episode reward: 296
Snake length: 8, Episode reward: 536
Snake length: 8, Episode reward: 515


 36%|███▌      | 36/100 [00:02<00:03, 17.25it/s]

Snake length: 6, Episode reward: 356
Snake length: 6, Episode reward: 384
Snake length: 4, Episode reward: 245
Snake length: 4, Episode reward: 199


 39%|███▉      | 39/100 [00:02<00:03, 19.07it/s]

Snake length: 7, Episode reward: 444
Snake length: 3, Episode reward: 125


 42%|████▏     | 42/100 [00:02<00:03, 15.37it/s]

Snake length: 11, Episode reward: 770
Snake length: 4, Episode reward: 186
Snake length: 8, Episode reward: 559
Snake length: 6, Episode reward: 361


 44%|████▍     | 44/100 [00:02<00:03, 15.51it/s]

Snake length: 9, Episode reward: 631
Snake length: 12, Episode reward: 904


 48%|████▊     | 48/100 [00:03<00:04, 12.30it/s]

Snake length: 9, Episode reward: 615
Snake length: 4, Episode reward: 210
Snake length: 12, Episode reward: 877


 51%|█████     | 51/100 [00:03<00:03, 13.11it/s]

Snake length: 4, Episode reward: 181
Snake length: 8, Episode reward: 511
Snake length: 9, Episode reward: 632


 55%|█████▌    | 55/100 [00:03<00:03, 14.18it/s]

Snake length: 10, Episode reward: 711
Snake length: 3, Episode reward: 102
Snake length: 5, Episode reward: 285
Snake length: 7, Episode reward: 461
Snake length: 8, Episode reward: 530


 59%|█████▉    | 59/100 [00:03<00:02, 15.78it/s]

Snake length: 4, Episode reward: 218
Snake length: 7, Episode reward: 442
Snake length: 6, Episode reward: 369
Snake length: 8, Episode reward: 520


 63%|██████▎   | 63/100 [00:03<00:02, 16.58it/s]

Snake length: 7, Episode reward: 460
Snake length: 4, Episode reward: 219
Snake length: 8, Episode reward: 541


 65%|██████▌   | 65/100 [00:04<00:02, 14.52it/s]

Snake length: 9, Episode reward: 649
Snake length: 6, Episode reward: 358
Snake length: 15, Episode reward: 1205


 70%|███████   | 70/100 [00:04<00:02, 10.95it/s]

Snake length: 20, Episode reward: 1615
Snake length: 6, Episode reward: 363
Snake length: 4, Episode reward: 208
Snake length: 8, Episode reward: 562


 74%|███████▍  | 74/100 [00:05<00:02, 10.70it/s]

Snake length: 8, Episode reward: 558
Snake length: 3, Episode reward: 123
Snake length: 7, Episode reward: 445
Snake length: 7, Episode reward: 496


 77%|███████▋  | 77/100 [00:05<00:01, 13.45it/s]

Snake length: 4, Episode reward: 193
Snake length: 4, Episode reward: 216
Snake length: 4, Episode reward: 198
Snake length: 2, Episode reward: 15
Snake length: 4, Episode reward: 179


 83%|████████▎ | 83/100 [00:05<00:01, 15.22it/s]

Snake length: 16, Episode reward: 1253
Snake length: 6, Episode reward: 373
Snake length: 2, Episode reward: 36
Snake length: 6, Episode reward: 357
Snake length: 5, Episode reward: 270


 85%|████████▌ | 85/100 [00:05<00:01, 14.63it/s]

Snake length: 14, Episode reward: 1059
Snake length: 3, Episode reward: 114
Snake length: 7, Episode reward: 430


 90%|█████████ | 90/100 [00:06<00:00, 14.56it/s]

Snake length: 14, Episode reward: 1066
Snake length: 6, Episode reward: 380
Snake length: 9, Episode reward: 599
Snake length: 2, Episode reward: 33


 92%|█████████▏| 92/100 [00:06<00:00, 15.01it/s]

Snake length: 10, Episode reward: 719
Snake length: 9, Episode reward: 657


 96%|█████████▌| 96/100 [00:06<00:00, 12.92it/s]

Snake length: 14, Episode reward: 1067
Snake length: 7, Episode reward: 476
Snake length: 6, Episode reward: 349
Snake length: 11, Episode reward: 768
Snake length: 1, Episode reward: -70


100%|██████████| 100/100 [00:06<00:00, 14.58it/s]

Snake length: 3, Episode reward: 108
Snake length: 8, Episode reward: 518





{'snake_lengths': [2,
  11,
  16,
  8,
  2,
  2,
  3,
  2,
  6,
  9,
  4,
  7,
  5,
  7,
  1,
  8,
  10,
  5,
  11,
  15,
  8,
  4,
  8,
  1,
  2,
  4,
  5,
  11,
  2,
  7,
  5,
  8,
  8,
  6,
  6,
  4,
  4,
  7,
  3,
  11,
  4,
  8,
  6,
  9,
  12,
  9,
  4,
  12,
  4,
  8,
  9,
  10,
  3,
  5,
  7,
  8,
  4,
  7,
  6,
  8,
  7,
  4,
  8,
  9,
  6,
  15,
  20,
  6,
  4,
  8,
  8,
  3,
  7,
  7,
  4,
  4,
  4,
  2,
  4,
  16,
  6,
  2,
  6,
  5,
  14,
  3,
  7,
  14,
  6,
  9,
  2,
  10,
  9,
  14,
  7,
  6,
  11,
  1,
  3,
  8],
 'episode_rewards': [20,
  787,
  1239,
  525,
  23,
  17,
  107,
  31,
  370,
  615,
  193,
  444,
  294,
  468,
  -64,
  560,
  745,
  314,
  817,
  1217,
  534,
  191,
  530,
  -72,
  19,
  210,
  265,
  855,
  31,
  429,
  296,
  536,
  515,
  356,
  384,
  245,
  199,
  444,
  125,
  770,
  186,
  559,
  361,
  631,
  904,
  615,
  210,
  877,
  181,
  511,
  632,
  711,
  102,
  285,
  461,
  530,
  218,
  442,
  369,
  520,
  460,
  219,
  541,
  649,
 

In [9]:
if isinstance(env, FullObsSnakeEnv):
    env.interact = True
else:
    env = ParObsSnakeEnv(grid_size=2*grid_size)
    
state = env.reset()
done = False
with torch.no_grad():
    while not done:
        action, _ = agent.choose_action(state)
        state, reward, done, _ = env.step(action)
        env.render()
        print(f"Reward: {reward}")

2024-12-04 13:51:28.193 python[90124:2021557] +[IMKClient subclass]: chose IMKClient_Modern
2024-12-04 13:51:28.193 python[90124:2021557] +[IMKInputSession subclass]: chose IMKInputSession_Modern


Reward: 1
Reward: 1
Reward: 1
Reward: -1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: -1
Reward: 1
Reward: 1
Reward: -1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: -1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: -1
Reward: 1
Reward: 1
Reward: 1
Reward: -1
Reward: -1
Reward: 1
Reward: 1
Reward: 1
Reward