In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import sys
from tqdm import tqdm
import pickle
import os
# Add the project root directory to sys.path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../..')))

from src.ParObsSnakeEnv import ParObsSnakeEnv
from src.FullObsSnakeEnv import FullObsSnakeEnv
from src.utils import compute_metrics
import matplotlib.pyplot as plt
from torch.distributions import Categorical

In [2]:
# Define the Policy Network
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=64):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return torch.softmax(x, dim=-1)

In [3]:
class REINFORCEAgent:
    def __init__(self, state_dim, action_dim, lr=0.001, gamma=0.99):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.gamma = gamma

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.policy = PolicyNetwork(state_dim, action_dim).to(self.device)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)

        # Metrics for plotting
        self.episode_rewards = []
        self.episode_losses = []

    def choose_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)  # Add batch dimension
        action_probs = self.policy(state)
        dist = Categorical(action_probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        return action.item(), log_prob

    def remember(self, log_prob, reward):
        if not hasattr(self, 'log_probs'):
            self.log_probs = []
        if not hasattr(self, 'rewards'):
            self.rewards = []
        self.log_probs.append(log_prob)
        self.rewards.append(reward)

    def update_policy(self):
        """Update the policy using stored rewards and log probabilities."""
        if not hasattr(self, 'log_probs') or len(self.log_probs) == 0:
            return

        returns = self.compute_returns(self.rewards)
        loss = -torch.sum(torch.stack(self.log_probs) * returns)  # Negative log-prob * return

        # Update the policy
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Log the loss
        self.episode_losses.append(loss.item())

        # Clear memory
        self.log_probs.clear()
        self.rewards.clear()

    def compute_returns(self, rewards):
        """Compute discounted returns for an episode."""
        returns = []
        G = 0
        for reward in reversed(rewards):
            G = reward + self.gamma * G
            returns.insert(0, G)
        returns = torch.FloatTensor(returns).to(self.device)
        # Normalize returns to improve training stability
        if len(returns) > 1 and returns.std() > 1e-5:
            returns = (returns - returns.mean()) / (returns.std() + 1e-5)
        return returns

    def save(self, filename):
        """Saves the entire agent to a file."""
        state = {
            'policy_state_dict': self.policy.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'hyperparameters': {
                'state_dim': self.state_dim,
                'action_dim': self.action_dim,
                'gamma': self.gamma,
            },
        }
        with open(filename, 'wb') as f:
            pickle.dump(state, f)
        print(f"Agent saved to {filename}")

    @classmethod
    def load(cls, filename, lr=0.001):
        """Loads the agent from a file."""
        with open(filename, 'rb') as f:
            state = pickle.load(f)

        # Recreate the agent
        agent = cls(
            state['hyperparameters']['state_dim'],
            state['hyperparameters']['action_dim'],
            lr=lr,
            gamma=state['hyperparameters']['gamma'],
        )
        # Restore the agent's state
        agent.policy.load_state_dict(state['policy_state_dict'])
        agent.optimizer.load_state_dict(state['optimizer_state_dict'])
        print(f"Agent loaded from {filename}")
        return agent

    def train(self, env, episodes=1000, save_plots=False, plots_path='reinforce_training_plots.png'):
        for episode in tqdm(range(episodes), desc="Training", unit="episode"):
            state = env.reset()
            total_reward = 0
            done = False

            while not done:
                action, log_prob = self.choose_action(state)
                next_state, reward, done, _ = env.step(action)
                self.remember(log_prob, reward)
                state = next_state
                total_reward += reward

            self.update_policy()
            self.episode_rewards.append(total_reward)

        if save_plots:
            self.save_plots(plots_path)

    def save_plots(self, plots_path):
        plots_dir = os.path.dirname(plots_path)
        os.makedirs(plots_dir, exist_ok=True)

        fig, axs = plt.subplots(1, 2, figsize=(15, 5))

        # Rewards per episode
        axs[0].plot(self.episode_rewards)
        axs[0].set_title("Episode Rewards")
        axs[0].set_xlabel("Episode")
        axs[0].set_ylabel("Total Reward")

        # Loss per episode
        axs[1].plot(self.episode_losses)
        axs[1].set_title("Loss Over Training")
        axs[1].set_xlabel("Episode")
        axs[1].set_ylabel("Loss")

        plt.tight_layout()
        plt.savefig(plots_path)
        plt.close()


In [4]:
grid_size = 10
# env = FullObsSnakeEnv(grid_size=grid_size, interact=False)
env = ParObsSnakeEnv(grid_size=grid_size, interact=False)
if isinstance(env, FullObsSnakeEnv):
    state_dim = env.observation_space.shape[0] * env.observation_space.shape[1] * env.observation_space.shape[2]
else:
    state_dim = env.observation_space.shape[0]

action_dim = env.action_space.n
agent = REINFORCEAgent(state_dim, action_dim)

In [5]:
num_episodes = 15000
agent.train(env, episodes=num_episodes)

Training:  42%|████▏     | 6353/15000 [02:39<03:37, 39.72episode/s] 


KeyboardInterrupt: 

In [None]:
environment = 'full 'if isinstance(env, FullObsSnakeEnv) else 'par'

agent_name = f'polNet_agent_{environment}_{num_episodes}_{grid_size}.pkl'
model_weights_dir = os.path.join('../..', 'models', 'polNet')
os.makedirs(model_weights_dir, exist_ok=True)
agent_path = os.path.join(model_weights_dir, agent_name)

agent.save(agent_path)
# agent = REINFORCEAgent.load(agent_path)

Agent saved to ../../models/polNet/polNet_agent_par_15000_10.pkl


In [None]:
if isinstance(env, ParObsSnakeEnv):
    env = ParObsSnakeEnv(grid_size=2*grid_size, interact=False)

model_metrics_dir = os.path.join('../..', 'artifacts', 'models_stats', 'polNet')
os.makedirs(model_metrics_dir, exist_ok=True)

train_metrics_name = f'polNet_train_metrics_{environment}_{num_episodes}_{grid_size}.png'
train_metrics_path = os.path.join(model_metrics_dir, train_metrics_name)
agent.save_plots(train_metrics_path)

num_simulations = 100
sim_metrics_name = f'polNet_sim_metrics_{environment}_{num_episodes}_{env.grid_size}_{num_simulations}.jsn'
sim_metrics_path = os.path.join(model_metrics_dir, sim_metrics_name)
compute_metrics(agent, env, sim_metrics_path, num_simulations=num_simulations)

  1%|          | 1/100 [00:00<00:11,  8.86it/s]

Snake length: 2, Episode reward: 12
Snake length: 1, Episode reward: -75


  6%|▌         | 6/100 [00:00<00:07, 12.61it/s]

Snake length: 1, Episode reward: -77
Snake length: 1, Episode reward: -73
Snake length: 1, Episode reward: -75
Snake length: 3, Episode reward: 74


  8%|▊         | 8/100 [00:00<00:08, 10.24it/s]

Snake length: 1, Episode reward: -80
Snake length: 3, Episode reward: 118


 10%|█         | 10/100 [00:01<00:13,  6.61it/s]

Snake length: 2, Episode reward: 20
Snake length: 1, Episode reward: -75


 13%|█▎        | 13/100 [00:01<00:11,  7.63it/s]

Snake length: 2, Episode reward: 0
Snake length: 1, Episode reward: -75
Snake length: 2, Episode reward: 3


 14%|█▍        | 14/100 [00:01<00:11,  7.23it/s]

Snake length: 4, Episode reward: 156
Snake length: 2, Episode reward: -5
Snake length: 1, Episode reward: -80
Snake length: 1, Episode reward: -76


 21%|██        | 21/100 [00:02<00:06, 12.10it/s]

Snake length: 2, Episode reward: 17
Snake length: 1, Episode reward: -71
Snake length: 1, Episode reward: -77
Snake length: 1, Episode reward: -78


 24%|██▍       | 24/100 [00:02<00:05, 12.91it/s]

Snake length: 1, Episode reward: -66
Snake length: 1, Episode reward: -67
Snake length: 1, Episode reward: -71
Snake length: 1, Episode reward: -80
Snake length: 1, Episode reward: -67


 27%|██▋       | 27/100 [00:02<00:04, 15.09it/s]

Snake length: 1, Episode reward: -73
Snake length: 1, Episode reward: -68
Snake length: 1, Episode reward: -79
Snake length: 1, Episode reward: -76
Snake length: 1, Episode reward: -78


 36%|███▌      | 36/100 [00:03<00:03, 18.62it/s]

Snake length: 2, Episode reward: 9
Snake length: 1, Episode reward: -81
Snake length: 3, Episode reward: 111
Snake length: 1, Episode reward: -80
Snake length: 2, Episode reward: 11


 39%|███▉      | 39/100 [00:03<00:05, 11.72it/s]

Snake length: 1, Episode reward: -68
Snake length: 1, Episode reward: -93
Snake length: 1, Episode reward: -78
Snake length: 1, Episode reward: -64


 44%|████▍     | 44/100 [00:03<00:03, 14.01it/s]

Snake length: 1, Episode reward: -78
Snake length: 1, Episode reward: -66
Snake length: 1, Episode reward: -77
Snake length: 3, Episode reward: 94
Snake length: 1, Episode reward: -88
Snake length: 1, Episode reward: -63


 49%|████▉     | 49/100 [00:04<00:05,  8.68it/s]

Snake length: 4, Episode reward: 189
Snake length: 2, Episode reward: 15
Snake length: 1, Episode reward: -74
Snake length: 1, Episode reward: -73


 51%|█████     | 51/100 [00:04<00:04,  9.94it/s]

Snake length: 1, Episode reward: -75
Snake length: 1, Episode reward: -87
Snake length: 1, Episode reward: -65
Snake length: 1, Episode reward: -70


 62%|██████▏   | 62/100 [00:05<00:01, 20.24it/s]

Snake length: 2, Episode reward: 1
Snake length: 1, Episode reward: -64
Snake length: 1, Episode reward: -52
Snake length: 1, Episode reward: -77
Snake length: 1, Episode reward: -69
Snake length: 1, Episode reward: -78
Snake length: 1, Episode reward: -62
Snake length: 1, Episode reward: -84


 65%|██████▌   | 65/100 [00:05<00:02, 16.53it/s]

Snake length: 3, Episode reward: 102
Snake length: 1, Episode reward: -58
Snake length: 1, Episode reward: -71
Snake length: 1, Episode reward: -62


 68%|██████▊   | 68/100 [00:06<00:02, 11.53it/s]

Snake length: 3, Episode reward: 110
Snake length: 2, Episode reward: 10
Snake length: 1, Episode reward: -62
Snake length: 2, Episode reward: 12
Snake length: 1, Episode reward: -77
Snake length: 1, Episode reward: -78
Snake length: 1, Episode reward: -73
Snake length: 1, Episode reward: -64


 78%|███████▊  | 78/100 [00:06<00:01, 14.92it/s]

Snake length: 2, Episode reward: -4
Snake length: 1, Episode reward: -72
Snake length: 1, Episode reward: -67
Snake length: 1, Episode reward: -80


 82%|████████▏ | 82/100 [00:06<00:01, 15.07it/s]

Snake length: 2, Episode reward: -10
Snake length: 1, Episode reward: -78
Snake length: 1, Episode reward: -72
Snake length: 1, Episode reward: -78
Snake length: 1, Episode reward: -73


 89%|████████▉ | 89/100 [00:07<00:00, 11.35it/s]

Snake length: 3, Episode reward: 102
Snake length: 1, Episode reward: -64
Snake length: 1, Episode reward: -75
Snake length: 1, Episode reward: -71
Snake length: 1, Episode reward: -77
Snake length: 2, Episode reward: 17
Snake length: 2, Episode reward: 8
Snake length: 1, Episode reward: -73


 92%|█████████▏| 92/100 [00:08<00:00,  9.73it/s]

Snake length: 2, Episode reward: 25


 94%|█████████▍| 94/100 [00:08<00:00,  8.72it/s]

Snake length: 3, Episode reward: 93
Snake length: 3, Episode reward: 89
Snake length: 1, Episode reward: -74
Snake length: 1, Episode reward: -78
Snake length: 1, Episode reward: -62
Snake length: 1, Episode reward: -80


100%|██████████| 100/100 [00:08<00:00, 11.58it/s]

Snake length: 1, Episode reward: -58
Snake length: 2, Episode reward: 20





{'snake_lengths': [2,
  1,
  1,
  1,
  1,
  3,
  1,
  3,
  2,
  1,
  2,
  1,
  2,
  4,
  2,
  1,
  1,
  2,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  2,
  1,
  3,
  1,
  2,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  3,
  1,
  1,
  4,
  2,
  1,
  1,
  1,
  1,
  1,
  1,
  2,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  3,
  1,
  1,
  1,
  3,
  2,
  1,
  2,
  1,
  1,
  1,
  1,
  2,
  1,
  1,
  1,
  2,
  1,
  1,
  1,
  1,
  3,
  1,
  1,
  1,
  1,
  2,
  2,
  1,
  2,
  3,
  3,
  1,
  1,
  1,
  1,
  1,
  2],
 'episode_rewards': [12,
  -75,
  -77,
  -73,
  -75,
  74,
  -80,
  118,
  20,
  -75,
  0,
  -75,
  3,
  156,
  -5,
  -80,
  -76,
  17,
  -71,
  -77,
  -78,
  -66,
  -67,
  -71,
  -80,
  -67,
  -73,
  -68,
  -79,
  -76,
  -78,
  9,
  -81,
  111,
  -80,
  11,
  -68,
  -93,
  -78,
  -64,
  -78,
  -66,
  -77,
  94,
  -88,
  -63,
  189,
  15,
  -74,
  -73,
  -75,
  -87,
  -65,
  -70,
  1,
  -64,
  -52,
  -77,
  -69,
  -78,
  -62,
  -84,
  102,
  -58,
  -71,
  -62,
  110,
  10,
  -6

In [7]:
if isinstance(env, FullObsSnakeEnv):
    env.interact = True
else:
    env = ParObsSnakeEnv(grid_size=2*grid_size)
    
state = env.reset()
done = False
while not done:
    action, _ = agent.choose_action(state)
    state, reward, done, _ = env.step(action)
    env.render()
    print(f"Reward: {reward}")

Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: -1
Reward: 1
Reward: 1
Reward: -1
Reward: -1
Reward: -1
Reward: -1
Reward: -1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: -1
Reward: 1
Reward: 1
Reward: 1
Reward: -1
Reward: -1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: -1
Reward: 1
Reward: 1
Reward: -1
Reward: 1
Reward: 76
Reward: -1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: -1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: -1
Reward: 1
Reward: 76
Reward: -75
