# Proximal Policy Optimization (PPO) Training Notebook

This notebook trains a **PPO** agent on the Snake game.

**Grid Size:** 20√ó20 (optimal for PPO)

**Expected Performance:**
- Episodes to convergence: 4000-6000
- Final average score: 30-50 apples
- Training time: 25-35 minutes (CPU) / 10-15 minutes (GPU)

---

## 1. Setup and Imports

In [None]:
# For Google Colab
# !git clone https://github.com/MarinCervinschi/rl-snake.git
# %cd rl-snake

In [None]:
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))

import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical

from game.config import GameConfig
from game.engine import SnakeGameEngine
from game.entities import Action, State
from utils.metrics import TrainingMetrics

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üñ•Ô∏è  Using device: {device}")

## 2. PPO Components

In [None]:
class ActorCriticNetwork(nn.Module):
    """Actor-Critic network with shared CNN backbone."""

    def __init__(self, grid_size: int, num_actions: int = 4):
        super().__init__()
        
        # Shared CNN
        self.shared_conv = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.Tanh(),
            nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.Tanh(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.Tanh(),
        )
        
        self.flat_size = 64 * grid_size * grid_size
        self.shared_fc = nn.Sequential(nn.Linear(self.flat_size, 512), nn.Tanh())
        
        # Actor head
        self.actor = nn.Sequential(
            nn.Linear(512, 256), nn.Tanh(),
            nn.Linear(256, num_actions),
        )
        
        # Critic head
        self.critic = nn.Sequential(
            nn.Linear(512, 256), nn.Tanh(),
            nn.Linear(256, 1),
        )
        
        self._initialize_weights()

    def forward(self, x):
        conv_features = self.shared_conv(x)
        flat = conv_features.view(conv_features.size(0), -1)
        shared = self.shared_fc(flat)
        
        action_logits = self.actor(shared)
        action_probs = torch.softmax(action_logits, dim=-1)
        state_values = self.critic(shared)
        
        return action_probs, state_values

    def get_action_and_value(self, x, action=None):
        action_probs, value = self.forward(x)
        dist = Categorical(action_probs)
        
        if action is None:
            action = dist.sample()
        
        return action, dist.log_prob(action), dist.entropy(), value

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)


class RolloutBuffer:
    """Buffer for storing rollout experience."""

    def __init__(self):
        self.states = []
        self.actions = []
        self.log_probs = []
        self.rewards = []
        self.values = []
        self.dones = []

    def push(self, state, action, log_prob, reward, value, done):
        self.states.append(state)
        self.actions.append(action)
        self.log_probs.append(log_prob)
        self.rewards.append(reward)
        self.values.append(value)
        self.dones.append(done)

    def get(self):
        return (
            np.array(self.states, dtype=np.float32),
            np.array(self.actions, dtype=np.int64),
            np.array(self.log_probs, dtype=np.float32),
            np.array(self.rewards, dtype=np.float32),
            np.array(self.values, dtype=np.float32),
            np.array(self.dones, dtype=np.float32),
        )

    def clear(self):
        self.states.clear()
        self.actions.clear()
        self.log_probs.clear()
        self.rewards.clear()
        self.values.clear()
        self.dones.clear()

    def __len__(self):
        return len(self.states)

## 3. PPO Agent

In [None]:
class PPOAgent:
    """Proximal Policy Optimization agent."""

    def __init__(
        self,
        grid_size: int = 20,
        learning_rate: float = 0.0003,
        discount_factor: float = 0.99,
        gae_lambda: float = 0.95,
        clip_epsilon: float = 0.2,
        value_coef: float = 0.5,
        entropy_coef: float = 0.01,
        rollout_length: int = 2048,
        batch_size: int = 64,
        epochs: int = 4,
        max_grad_norm: float = 0.5,
    ):
        self.grid_size = grid_size
        self.discount_factor = discount_factor
        self.gae_lambda = gae_lambda
        self.clip_epsilon = clip_epsilon
        self.value_coef = value_coef
        self.entropy_coef = entropy_coef
        self.rollout_length = rollout_length
        self.batch_size = batch_size
        self.epochs = epochs
        self.max_grad_norm = max_grad_norm

        self.device = device
        self.policy = ActorCriticNetwork(grid_size).to(self.device)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)

        self.buffer = RolloutBuffer()
        self.steps = 0
        self.updates = 0
        self.policy_losses = []
        self.value_losses = []
        self.entropy_losses = []

    def get_action(self, state: State) -> Action:
        state_tensor = self._state_to_tensor(state)
        with torch.no_grad():
            action_probs, _ = self.policy(state_tensor)
        dist = Categorical(action_probs)
        return Action(dist.sample().item())

    def train(self, state: State, action: Action, reward: float, next_state: State, done: bool):
        state_array = state.to_tensor()
        state_tensor = self._state_to_tensor(state)
        action_tensor = torch.tensor([action.value], device=self.device)

        with torch.no_grad():
            _, log_prob, _, value = self.policy.get_action_and_value(state_tensor, action_tensor)

        self.buffer.push(state_array, action.value, log_prob.item(), reward, value.item(), done)
        self.steps += 1

        if len(self.buffer) >= self.rollout_length:
            self._update()
            self.buffer.clear()

    def _update(self):
        states, actions, old_log_probs, rewards, values, dones = self.buffer.get()

        # Compute advantages using GAE
        advantages = self._compute_gae(rewards, values, dones)
        returns = advantages + values
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # Convert to tensors
        states_t = torch.FloatTensor(states).to(self.device)
        actions_t = torch.LongTensor(actions).to(self.device)
        old_log_probs_t = torch.FloatTensor(old_log_probs).to(self.device)
        advantages_t = torch.FloatTensor(advantages).to(self.device)
        returns_t = torch.FloatTensor(returns).to(self.device)

        # Train for multiple epochs
        for _ in range(self.epochs):
            indices = np.arange(len(states))
            np.random.shuffle(indices)

            for start in range(0, len(states), self.batch_size):
                end = start + self.batch_size
                batch_idx = indices[start:end]

                batch_states = states_t[batch_idx]
                batch_actions = actions_t[batch_idx]
                batch_old_log_probs = old_log_probs_t[batch_idx]
                batch_advantages = advantages_t[batch_idx]
                batch_returns = returns_t[batch_idx]

                _, new_log_probs, entropy, values_pred = self.policy.get_action_and_value(
                    batch_states, batch_actions
                )

                # PPO clipped objective
                ratio = torch.exp(new_log_probs - batch_old_log_probs)
                surr1 = ratio * batch_advantages
                surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * batch_advantages
                policy_loss = -torch.min(surr1, surr2).mean()

                # Value loss
                values_pred = values_pred.squeeze()
                value_loss = ((values_pred - batch_returns) ** 2).mean()

                # Entropy loss
                entropy_loss = -entropy.mean()

                # Total loss
                loss = policy_loss + self.value_coef * value_loss + self.entropy_coef * entropy_loss

                # Optimize
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
                self.optimizer.step()

                self.policy_losses.append(policy_loss.item())
                self.value_losses.append(value_loss.item())
                self.entropy_losses.append(entropy_loss.item())

        self.updates += 1

    def _compute_gae(self, rewards, values, dones):
        advantages = np.zeros_like(rewards)
        last_gae = 0

        for t in reversed(range(len(rewards))):
            next_value = 0 if t == len(rewards) - 1 else values[t + 1]
            delta = rewards[t] + self.discount_factor * next_value * (1 - dones[t]) - values[t]
            last_gae = delta + self.discount_factor * self.gae_lambda * (1 - dones[t]) * last_gae
            advantages[t] = last_gae

        return advantages

    def _state_to_tensor(self, state: State):
        state_array = state.to_tensor()
        return torch.FloatTensor(state_array).unsqueeze(0).to(self.device)

    def save(self, filepath: str):
        Path(filepath).parent.mkdir(parents=True, exist_ok=True)
        torch.save(
            {
                "policy_state": self.policy.state_dict(),
                "optimizer_state": self.optimizer.state_dict(),
                "steps": self.steps,
                "updates": self.updates,
            },
            filepath,
        )
        print(f"üíæ Model saved to {filepath}")

    def load(self, filepath: str):
        if not Path(filepath).exists():
            print(f"‚ö†Ô∏è  No saved model found")
            return
        checkpoint = torch.load(filepath, map_location=self.device)
        self.policy.load_state_dict(checkpoint["policy_state"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state"])
        self.steps = checkpoint["steps"]
        self.updates = checkpoint["updates"]
        print(f"‚úÖ Model loaded")

## 4. Configuration

In [None]:
# Training configuration
GRID_SIZE = 20
EPISODES = 5_000

# Agent hyperparameters
LEARNING_RATE = 0.0003
DISCOUNT_FACTOR = 0.99
GAE_LAMBDA = 0.95
CLIP_EPSILON = 0.2
VALUE_COEF = 0.5
ENTROPY_COEF = 0.01
ROLLOUT_LENGTH = 2048
BATCH_SIZE = 64
EPOCHS = 4

# Save locations
MODEL_PATH = "models/ppo.pkl"
RESULTS_DIR = f"results/ppo_{GRID_SIZE}x{GRID_SIZE}"

print(f"üéÆ Training PPO on {GRID_SIZE}√ó{GRID_SIZE} grid")
print(f"üìà Episodes: {EPISODES:,}")
print(f"üñ•Ô∏è  Device: {device}")

## 5. Initialize

In [None]:
config = GameConfig(grid_size=GRID_SIZE)
game = SnakeGameEngine(config)

agent = PPOAgent(
    grid_size=GRID_SIZE,
    learning_rate=LEARNING_RATE,
    discount_factor=DISCOUNT_FACTOR,
    gae_lambda=GAE_LAMBDA,
    clip_epsilon=CLIP_EPSILON,
    value_coef=VALUE_COEF,
    entropy_coef=ENTROPY_COEF,
    rollout_length=ROLLOUT_LENGTH,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
)

metrics = TrainingMetrics(save_dir=RESULTS_DIR)
print("‚úÖ Initialized")

## 6. Training Loop

In [None]:
record_score = 0
pbar = tqdm(range(1, EPISODES + 1), desc="Training")

try:
    for episode in pbar:
        state = game.reset()
        done = False
        episode_reward = 0
        steps = 0

        while not done:
            action = agent.get_action(state)
            reward, done, score = game.step(action)
            next_state = game.get_state()

            agent.train(state, action, reward, next_state, done)

            state = next_state
            episode_reward += reward
            steps += 1

        metrics.record_episode(episode, score, steps, episode_reward)

        if score > record_score:
            record_score = score

        pbar.set_postfix(
            {
                "Avg": f"{metrics.get_recent_average_score():.2f}",
                "Best": record_score,
                "Updates": agent.updates,
            }
        )
except KeyboardInterrupt:
    print("\n\n‚ö†Ô∏è  Training interrupted by user")

print("\n‚úÖ Training complete!")

## 7. Save and Visualize

In [None]:
agent.save(MODEL_PATH)
metrics.print_summary()
metrics.plot(show=True, save=False)
plt.show()

## 8. Loss Analysis

In [None]:
if agent.policy_losses:
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    axes[0].plot(agent.policy_losses, alpha=0.3)
    axes[0].set_title('Policy Loss')
    axes[0].set_xlabel('Update Step')
    axes[0].grid(True, alpha=0.3)
    
    axes[1].plot(agent.value_losses, alpha=0.3, color='orange')
    axes[1].set_title('Value Loss')
    axes[1].set_xlabel('Update Step')
    axes[1].grid(True, alpha=0.3)
    
    axes[2].plot(agent.entropy_losses, alpha=0.3, color='green')
    axes[2].set_title('Entropy Loss')
    axes[2].set_xlabel('Update Step')
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

## 9. Test Agent

In [None]:
test_episodes = 10
test_scores = []

for ep in range(test_episodes):
    state = game.reset()
    done = False
    
    while not done:
        action = agent.get_action(state)
        reward, done, score = game.step(action)
        state = game.get_state()
    
    test_scores.append(score)
    print(f"Episode {ep+1}: {score}")

print(f"\nAverage: {np.mean(test_scores):.2f}")
print(f"Best: {max(test_scores)}")

## 10. Export for Google Colab (Optional)

If running on Google Colab, you can download the trained model:

In [None]:
# Uncomment to download model in Colab
# from google.colab import files
# files.download(MODEL_PATH)