In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import sys
from tqdm import tqdm
import pickle
import os
# Add the project root directory to sys.path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../..')))

from src.ParObsSnakeEnv import ParObsSnakeEnv
from src.FullObsSnakeEnv import FullObsSnakeEnv
from src.utils import compute_metrics
import matplotlib.pyplot as plt
from torch.distributions import Categorical

In [13]:
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=128):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return torch.softmax(x, dim=-1)

In [14]:
class REINFORCEAgent:
    def __init__(self, state_dim, action_dim, lr=0.001, gamma=0.99, device=None):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.gamma = gamma

        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.policy = PolicyNetwork(state_dim, action_dim).to(self.device)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)

        # Metrics for plotting
        self.episode_rewards = []
        self.episode_losses = []

    def choose_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)  # Add batch dimension
        action_probs = self.policy(state)
        dist = Categorical(action_probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        return action.item(), log_prob

    def remember(self, log_prob, reward):
        if not hasattr(self, 'log_probs'):
            self.log_probs = []
        if not hasattr(self, 'rewards'):
            self.rewards = []
        self.log_probs.append(log_prob)
        self.rewards.append(reward)

    def update_policy(self):
        """Update the policy using stored rewards and log probabilities."""
        if not hasattr(self, 'log_probs') or len(self.log_probs) == 0:
            return

        returns = self.compute_returns(self.rewards)
        loss = -torch.sum(torch.stack(self.log_probs) * returns)  # Negative log-prob * return

        # Update the policy
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Log the loss
        self.episode_losses.append(loss.item())

        # Clear memory
        self.log_probs.clear()
        self.rewards.clear()

    def compute_returns(self, rewards):
        """Compute discounted returns for an episode."""
        returns = []
        G = 0
        for reward in reversed(rewards):
            G = reward + self.gamma * G
            returns.insert(0, G)
        returns = torch.FloatTensor(returns).to(self.device)
        # Normalize returns to improve training stability
        if len(returns) > 1 and returns.std() > 1e-5:
            returns = (returns - returns.mean()) / (returns.std() + 1e-5)
        return returns

    def save(self, filename):
        """Saves the entire agent to a file."""
        state = {
            'policy_state_dict': self.policy.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'hyperparameters': {
                'state_dim': self.state_dim,
                'action_dim': self.action_dim,
                'gamma': self.gamma,
            },
        }
        with open(filename, 'wb') as f:
            pickle.dump(state, f)
        print(f"Agent saved to {filename}")

    @classmethod
    def load(cls, filename, lr=0.001):
        """Loads the agent from a file."""
        with open(filename, 'rb') as f:
            state = pickle.load(f)

        # Recreate the agent
        agent = cls(
            state['hyperparameters']['state_dim'],
            state['hyperparameters']['action_dim'],
            lr=lr,
            gamma=state['hyperparameters']['gamma'],
        )
        # Restore the agent's state
        agent.policy.load_state_dict(state['policy_state_dict'])
        agent.optimizer.load_state_dict(state['optimizer_state_dict'])
        print(f"Agent loaded from {filename}")
        return agent

    def train(self, env, episodes=1000, save_plots=False, plots_path='reinforce_training_plots.png'):
        for episode in tqdm(range(episodes), desc="Training", unit="episode"):
            state = env.reset()
            total_reward = 0
            done = False
            self.policy.train()

            while not done:
                action, log_prob = self.choose_action(state)
                next_state, reward, done, _ = env.step(action)
                self.remember(log_prob, reward)
                state = next_state
                total_reward += reward

            self.update_policy()
            self.episode_rewards.append(total_reward)

        if save_plots:
            self.save_plots(plots_path)
        self.policy.eval()


    def save_plots(self, plots_path):
        plots_dir = os.path.dirname(plots_path)
        os.makedirs(plots_dir, exist_ok=True)

        fig, axs = plt.subplots(1, 2, figsize=(15, 5))

        # Rewards per episode
        axs[0].plot(self.episode_rewards)
        axs[0].set_title("Episode Rewards")
        axs[0].set_xlabel("Episode")
        axs[0].set_ylabel("Total Reward")

        # Loss per episode
        axs[1].plot(self.episode_losses)
        axs[1].set_title("Loss Over Training")
        axs[1].set_xlabel("Episode")
        axs[1].set_ylabel("Loss")

        plt.tight_layout()
        plt.savefig(plots_path)
        plt.close()


In [15]:
grid_size = 10
# env = FullObsSnakeEnv(grid_size=grid_size, interact=False)
env = ParObsSnakeEnv(grid_size=grid_size, interact=False)
if isinstance(env, FullObsSnakeEnv):
    state_dim = env.observation_space.shape[0] * env.observation_space.shape[1] * env.observation_space.shape[2]
else:
    state_dim = env.observation_space.shape[0]

action_dim = env.action_space.n
agent = REINFORCEAgent(state_dim, action_dim)

In [16]:
num_episodes = 10000
agent.train(env, episodes=num_episodes)

Training: 100%|██████████| 10000/10000 [06:07<00:00, 27.20episode/s]


In [17]:
environment = 'full 'if isinstance(env, FullObsSnakeEnv) else 'par'

agent_name = f'polNet_agent_{environment}_{num_episodes}_{grid_size}.pkl'
model_weights_dir = os.path.join('../..', 'models', 'polNet')
os.makedirs(model_weights_dir, exist_ok=True)
agent_path = os.path.join(model_weights_dir, agent_name)

agent.save(agent_path)
# agent = REINFORCEAgent.load(agent_path)

Agent saved to ../../models/polNet/polNet_agent_par_10000_10.pkl


In [18]:
if isinstance(env, ParObsSnakeEnv):
    env = ParObsSnakeEnv(grid_size=2*grid_size, interact=False)

model_metrics_dir = os.path.join('../..', 'artifacts', 'models_stats', 'polNet')
os.makedirs(model_metrics_dir, exist_ok=True)

train_metrics_name = f'polNet_train_metrics_{environment}_{num_episodes}_{grid_size}.png'
train_metrics_path = os.path.join(model_metrics_dir, train_metrics_name)
agent.save_plots(train_metrics_path)

num_simulations = 100
sim_metrics_name = f'polNet_sim_metrics_{environment}_{num_episodes}_{env.grid_size}_{num_simulations}.json'
sim_metrics_path = os.path.join(model_metrics_dir, sim_metrics_name)
compute_metrics(agent, env, sim_metrics_path, num_simulations=num_simulations)

  3%|▎         | 3/100 [00:00<00:06, 14.04it/s]

Snake length: 1, Episode reward: -74
Snake length: 4, Episode reward: 198
Snake length: 8, Episode reward: 552
Snake length: 4, Episode reward: 228


  7%|▋         | 7/100 [00:00<00:05, 15.92it/s]

Snake length: 6, Episode reward: 352
Snake length: 7, Episode reward: 467
Snake length: 5, Episode reward: 304
Snake length: 9, Episode reward: 619


 13%|█▎        | 13/100 [00:00<00:04, 19.64it/s]

Snake length: 7, Episode reward: 451
Snake length: 4, Episode reward: 182
Snake length: 4, Episode reward: 184
Snake length: 2, Episode reward: 21
Snake length: 7, Episode reward: 450
Snake length: 2, Episode reward: 15


 17%|█▋        | 17/100 [00:01<00:04, 17.44it/s]

Snake length: 12, Episode reward: 891
Snake length: 10, Episode reward: 724
Snake length: 5, Episode reward: 251
Snake length: 4, Episode reward: 192
Snake length: 4, Episode reward: 191


 22%|██▏       | 22/100 [00:01<00:04, 17.16it/s]

Snake length: 7, Episode reward: 442
Snake length: 6, Episode reward: 385
Snake length: 9, Episode reward: 625


 25%|██▌       | 25/100 [00:01<00:04, 16.50it/s]

Snake length: 6, Episode reward: 372
Snake length: 2, Episode reward: 41
Snake length: 8, Episode reward: 573
Snake length: 5, Episode reward: 273


 28%|██▊       | 28/100 [00:01<00:03, 18.71it/s]

Snake length: 3, Episode reward: 108
Snake length: 7, Episode reward: 426
Snake length: 4, Episode reward: 180


 32%|███▏      | 32/100 [00:01<00:04, 14.92it/s]

Snake length: 17, Episode reward: 1336
Snake length: 9, Episode reward: 632
Snake length: 4, Episode reward: 215
Snake length: 5, Episode reward: 263


 37%|███▋      | 37/100 [00:02<00:04, 15.69it/s]

Snake length: 8, Episode reward: 530
Snake length: 5, Episode reward: 274
Snake length: 8, Episode reward: 546
Snake length: 12, Episode reward: 872


 39%|███▉      | 39/100 [00:02<00:04, 13.02it/s]

Snake length: 12, Episode reward: 937
Snake length: 7, Episode reward: 494
Snake length: 6, Episode reward: 393


 44%|████▍     | 44/100 [00:02<00:03, 16.96it/s]

Snake length: 7, Episode reward: 446
Snake length: 3, Episode reward: 113
Snake length: 5, Episode reward: 255
Snake length: 5, Episode reward: 275


 48%|████▊     | 48/100 [00:02<00:03, 15.71it/s]

Snake length: 12, Episode reward: 887
Snake length: 7, Episode reward: 446
Snake length: 7, Episode reward: 454
Snake length: 10, Episode reward: 684


 53%|█████▎    | 53/100 [00:03<00:02, 18.59it/s]

Snake length: 4, Episode reward: 189
Snake length: 8, Episode reward: 548
Snake length: 6, Episode reward: 350
Snake length: 5, Episode reward: 259
Snake length: 8, Episode reward: 517
Snake length: 5, Episode reward: 263


 58%|█████▊    | 58/100 [00:03<00:02, 16.48it/s]

Snake length: 21, Episode reward: 1705
Snake length: 8, Episode reward: 534
Snake length: 4, Episode reward: 193
Snake length: 4, Episode reward: 180
Snake length: 6, Episode reward: 379


 62%|██████▏   | 62/100 [00:03<00:02, 15.03it/s]

Snake length: 10, Episode reward: 694
Snake length: 9, Episode reward: 614
Snake length: 6, Episode reward: 397
Snake length: 4, Episode reward: 199


 66%|██████▌   | 66/100 [00:04<00:02, 15.38it/s]

Snake length: 9, Episode reward: 635
Snake length: 11, Episode reward: 764
Snake length: 5, Episode reward: 282
Snake length: 4, Episode reward: 197


 71%|███████   | 71/100 [00:04<00:01, 15.26it/s]

Snake length: 15, Episode reward: 1141
Snake length: 7, Episode reward: 441
Snake length: 4, Episode reward: 189
Snake length: 6, Episode reward: 384


 74%|███████▍  | 74/100 [00:04<00:01, 16.16it/s]

Snake length: 6, Episode reward: 342
Snake length: 3, Episode reward: 99
Snake length: 10, Episode reward: 727
Snake length: 1, Episode reward: -73
Snake length: 7, Episode reward: 467


 79%|███████▉  | 79/100 [00:04<00:01, 15.02it/s]

Snake length: 9, Episode reward: 645
Snake length: 14, Episode reward: 1061
Snake length: 6, Episode reward: 376


 81%|████████  | 81/100 [00:05<00:01, 14.86it/s]

Snake length: 6, Episode reward: 367
Snake length: 6, Episode reward: 395
Snake length: 4, Episode reward: 201


 86%|████████▌ | 86/100 [00:05<00:00, 15.64it/s]

Snake length: 8, Episode reward: 577
Snake length: 6, Episode reward: 375
Snake length: 4, Episode reward: 184
Snake length: 7, Episode reward: 461


 90%|█████████ | 90/100 [00:05<00:00, 15.96it/s]

Snake length: 7, Episode reward: 444
Snake length: 6, Episode reward: 364
Snake length: 7, Episode reward: 450
Snake length: 6, Episode reward: 373


 92%|█████████▏| 92/100 [00:05<00:00, 15.92it/s]

Snake length: 4, Episode reward: 186
Snake length: 12, Episode reward: 866
Snake length: 1, Episode reward: -71
Snake length: 8, Episode reward: 548


 98%|█████████▊| 98/100 [00:06<00:00, 17.79it/s]

Snake length: 11, Episode reward: 801
Snake length: 5, Episode reward: 299
Snake length: 2, Episode reward: 10
Snake length: 8, Episode reward: 524
Snake length: 7, Episode reward: 446


100%|██████████| 100/100 [00:06<00:00, 16.07it/s]

Snake length: 5, Episode reward: 279





{'snake_lengths': [1,
  4,
  8,
  4,
  6,
  7,
  5,
  9,
  7,
  4,
  4,
  2,
  7,
  2,
  12,
  10,
  5,
  4,
  4,
  7,
  6,
  9,
  6,
  2,
  8,
  5,
  3,
  7,
  4,
  17,
  9,
  4,
  5,
  8,
  5,
  8,
  12,
  12,
  7,
  6,
  7,
  3,
  5,
  5,
  12,
  7,
  7,
  10,
  4,
  8,
  6,
  5,
  8,
  5,
  21,
  8,
  4,
  4,
  6,
  10,
  9,
  6,
  4,
  9,
  11,
  5,
  4,
  15,
  7,
  4,
  6,
  6,
  3,
  10,
  1,
  7,
  9,
  14,
  6,
  6,
  6,
  4,
  8,
  6,
  4,
  7,
  7,
  6,
  7,
  6,
  4,
  12,
  1,
  8,
  11,
  5,
  2,
  8,
  7,
  5],
 'episode_rewards': [-74,
  198,
  552,
  228,
  352,
  467,
  304,
  619,
  451,
  182,
  184,
  21,
  450,
  15,
  891,
  724,
  251,
  192,
  191,
  442,
  385,
  625,
  372,
  41,
  573,
  273,
  108,
  426,
  180,
  1336,
  632,
  215,
  263,
  530,
  274,
  546,
  872,
  937,
  494,
  393,
  446,
  113,
  255,
  275,
  887,
  446,
  454,
  684,
  189,
  548,
  350,
  259,
  517,
  263,
  1705,
  534,
  193,
  180,
  379,
  694,
  614,
  397,
  199,
  635,
 

In [21]:
if isinstance(env, FullObsSnakeEnv):
    env.interact = True
else:
    env = ParObsSnakeEnv(grid_size=2*grid_size)
    
state = env.reset()
done = False
with torch.no_grad():
    while not done:
        action, _ = agent.choose_action(state)
        state, reward, done, _ = env.step(action)
        env.render()
        print(f"Reward: {reward}")

Reward: 1
Reward: 1
Reward: 1
Reward: -1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: -1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: -1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
