# PPO Training Notebook

This notebook trains a PPO (Proximal Policy Optimization) agent to play Snake.

**Algorithm**: PPO with:
- Actor-Critic architecture (separate networks)
- Clipped surrogate objective
- Generalized Advantage Estimation (GAE)
- Multiple epochs per rollout
- Entropy bonus for exploration

In [None]:
# Cell 1: Imports
import sys
from pathlib import Path

project_root = Path.cwd().parent.parent
sys.path.insert(0, str(project_root))

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

from core.environment_vectorized import VectorizedSnakeEnv
from core.networks import PPO_Actor_MLP, PPO_Critic_MLP, PPO_Actor_CNN, PPO_Critic_CNN
from core.utils import MetricsTracker, set_seed, get_device

print(f"Project root: {project_root}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
# Cell 2: Configuration (papermill parameters)
# ============== CONFIGURATION ==============

# Environment
GRID_SIZE = 10
NUM_ENVS = 256
MAX_STEPS = 1000
ACTION_SPACE_TYPE = 'relative'  # 'absolute' or 'relative'
STATE_REPRESENTATION = 'feature'  # 'feature' or 'grid'
USE_FLOOD_FILL = False

# Training
NUM_EPISODES = 500  # Short for testing
ROLLOUT_STEPS = 2048  # Steps before each PPO update
BATCH_SIZE = 64
EPOCHS_PER_ROLLOUT = 4  # PPO epochs per collected rollout

# Learning rates
ACTOR_LR = 0.0003
CRITIC_LR = 0.001

# PPO hyperparameters
GAMMA = 0.99  # Discount factor
GAE_LAMBDA = 0.95  # GAE parameter
CLIP_EPSILON = 0.2  # PPO clipping range
VALUE_LOSS_COEF = 0.5  # Value loss coefficient
ENTROPY_COEF = 0.01  # Entropy bonus coefficient
MAX_GRAD_NORM = 0.5  # Gradient clipping

# Network
HIDDEN_DIMS = (128, 128)

# Output
SAVE_DIR = '../../results/weights/ppo'
SEED = 42
LOG_INTERVAL = 50

In [None]:
# Cell 3: Environment & Model Setup
set_seed(SEED)
device = get_device()
print(f"Using device: {device}")

save_dir = Path(SAVE_DIR)
save_dir.mkdir(parents=True, exist_ok=True)

# Create environment
env = VectorizedSnakeEnv(
    num_envs=NUM_ENVS,
    grid_size=GRID_SIZE,
    action_space_type=ACTION_SPACE_TYPE,
    state_representation=STATE_REPRESENTATION,
    max_steps=MAX_STEPS,
    use_flood_fill=USE_FLOOD_FILL,
    device=device
)

# Determine dimensions
if STATE_REPRESENTATION == 'feature':
    input_dim = 14 if USE_FLOOD_FILL else 11
else:
    input_dim = GRID_SIZE

output_dim = 3 if ACTION_SPACE_TYPE == 'relative' else 4

# Create Actor-Critic networks
if STATE_REPRESENTATION == 'feature':
    actor = PPO_Actor_MLP(input_dim, output_dim, HIDDEN_DIMS).to(device)
    critic = PPO_Critic_MLP(input_dim, HIDDEN_DIMS).to(device)
else:
    actor = PPO_Actor_CNN(GRID_SIZE, 3, output_dim).to(device)
    critic = PPO_Critic_CNN(GRID_SIZE, 3).to(device)

# Optimizers
actor_optimizer = torch.optim.Adam(actor.parameters(), lr=ACTOR_LR)
critic_optimizer = torch.optim.Adam(critic.parameters(), lr=CRITIC_LR)

# Metrics
metrics = MetricsTracker(window_size=100)

print(f"Environment: {GRID_SIZE}x{GRID_SIZE} grid, {NUM_ENVS} parallel envs")
print(f"Actor parameters: {sum(p.numel() for p in actor.parameters())}")
print(f"Critic parameters: {sum(p.numel() for p in critic.parameters())}")

In [None]:
# Cell 4: PPO Helper Functions

class PPOBuffer:
    """Rollout buffer for on-policy PPO training"""
    def __init__(self, capacity, device):
        self.capacity = capacity
        self.device = device
        self.clear()
    
    def clear(self):
        self.states = []
        self.actions = []
        self.rewards = []
        self.dones = []
        self.log_probs = []
        self.values = []
        self.size = 0
    
    def add(self, state, action, reward, done, log_prob, value):
        self.states.append(state.cpu())
        self.actions.append(action.cpu())
        self.rewards.append(reward.cpu())
        self.dones.append(done.cpu())
        self.log_probs.append(log_prob.cpu())
        self.values.append(value.cpu())
        self.size += state.shape[0]
    
    def get(self):
        states = torch.cat(self.states, dim=0).to(self.device)
        actions = torch.cat(self.actions, dim=0).to(self.device)
        rewards = torch.cat(self.rewards, dim=0).to(self.device)
        dones = torch.cat(self.dones, dim=0).to(self.device)
        log_probs = torch.cat(self.log_probs, dim=0).to(self.device)
        values = torch.cat(self.values, dim=0).to(self.device)
        return states, actions, rewards, dones, log_probs, values

def select_actions(states):
    """Select actions using current policy"""
    with torch.no_grad():
        logits = actor(states)
        values = critic(states)
        probs = F.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(probs)
        actions = dist.sample()
        log_probs = dist.log_prob(actions)
    return actions, log_probs, values

def compute_gae(rewards, values, dones, next_value):
    """Compute Generalized Advantage Estimation"""
    dones = dones.float()
    advantages = []
    gae = 0
    
    for t in reversed(range(len(rewards))):
        if t == len(rewards) - 1:
            next_non_terminal = 1.0 - dones[t]
            next_value_t = next_value
        else:
            next_non_terminal = 1.0 - dones[t]
            next_value_t = values[t + 1]
        
        delta = rewards[t] + GAMMA * next_value_t * next_non_terminal - values[t]
        gae = delta + GAMMA * GAE_LAMBDA * next_non_terminal * gae
        advantages.insert(0, gae)
    
    advantages = torch.stack(advantages)
    returns = advantages + values
    return advantages, returns

def ppo_update(states, actions, old_log_probs, advantages, returns):
    """PPO update step"""
    # Normalize advantages
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
    
    total_actor_loss = 0
    total_critic_loss = 0
    total_entropy = 0
    n_updates = 0
    
    for _ in range(EPOCHS_PER_ROLLOUT):
        indices = torch.randperm(states.size(0))
        
        for start in range(0, states.size(0), BATCH_SIZE):
            end = start + BATCH_SIZE
            batch_idx = indices[start:end]
            
            batch_states = states[batch_idx]
            batch_actions = actions[batch_idx]
            batch_old_log_probs = old_log_probs[batch_idx]
            batch_advantages = advantages[batch_idx]
            batch_returns = returns[batch_idx]
            
            # Evaluate actions
            logits = actor(batch_states)
            values = critic(batch_states)
            probs = F.softmax(logits, dim=-1)
            dist = torch.distributions.Categorical(probs)
            log_probs = dist.log_prob(batch_actions)
            entropy = dist.entropy().mean()
            
            # PPO clipped objective
            ratio = torch.exp(log_probs - batch_old_log_probs)
            surr1 = ratio * batch_advantages
            surr2 = torch.clamp(ratio, 1 - CLIP_EPSILON, 1 + CLIP_EPSILON) * batch_advantages
            actor_loss = -torch.min(surr1, surr2).mean()
            
            # Critic loss
            critic_loss = F.mse_loss(values.squeeze(), batch_returns)
            
            # Update actor
            actor_optimizer.zero_grad()
            actor_loss.backward(retain_graph=True)
            torch.nn.utils.clip_grad_norm_(actor.parameters(), MAX_GRAD_NORM)
            actor_optimizer.step()
            
            # Update critic
            critic_optimizer.zero_grad()
            critic_loss.backward()
            torch.nn.utils.clip_grad_norm_(critic.parameters(), MAX_GRAD_NORM)
            critic_optimizer.step()
            
            total_actor_loss += actor_loss.item()
            total_critic_loss += critic_loss.item()
            total_entropy += entropy.item()
            n_updates += 1
    
    return total_actor_loss / n_updates, total_critic_loss / n_updates, total_entropy / n_updates

print("PPO helper functions defined.")

In [None]:
# Cell 5: Training Loop
import time

print("Starting PPO Training...")
print(f"Device: {device}")
print(f"Episodes: {NUM_EPISODES}")
print(f"Rollout steps: {ROLLOUT_STEPS}")
print()

buffer = PPOBuffer(ROLLOUT_STEPS, device)
states = env.reset(seed=SEED)
start_time = time.time()

episode_rewards = torch.zeros(NUM_ENVS, device=device)
episode_lengths = torch.zeros(NUM_ENVS, device=device)

episode = 0
total_steps = 0

# Tracking for plots
all_rewards = []
all_scores = []
all_actor_losses = []
all_critic_losses = []
all_entropies = []

while episode < NUM_EPISODES:
    # Collect rollout
    buffer.clear()
    
    for _ in range(max(1, ROLLOUT_STEPS // NUM_ENVS)):
        actions, log_probs, values = select_actions(states)
        next_states, rewards, dones, info = env.step(actions)
        
        buffer.add(states, actions, rewards, dones, log_probs, values.squeeze())
        
        episode_rewards += rewards
        episode_lengths += 1
        total_steps += NUM_ENVS
        
        if dones.any():
            done_indices = torch.where(dones)[0]
            for idx in done_indices:
                reward = episode_rewards[idx].item()
                length = episode_lengths[idx].item()
                score = info['scores'][idx].item()
                
                metrics.add_episode(reward, length, score)
                all_rewards.append(reward)
                all_scores.append(score)
                
                episode += 1
                episode_rewards[idx] = 0
                episode_lengths[idx] = 0
                
                if episode % LOG_INTERVAL == 0:
                    stats = metrics.get_recent_stats()
                    elapsed = time.time() - start_time
                    fps = total_steps / elapsed if elapsed > 0 else 0
                    print(f"Episode {episode}/{NUM_EPISODES} | "
                          f"Score: {stats['avg_score']:.2f} | "
                          f"Reward: {stats['avg_reward']:.2f} | "
                          f"FPS: {fps:.0f}")
                
                if episode >= NUM_EPISODES:
                    break
        
        states = next_states
        if episode >= NUM_EPISODES:
            break
    
    if episode >= NUM_EPISODES:
        break
    
    # Compute final value for bootstrapping
    with torch.no_grad():
        next_value = critic(states).squeeze()
    
    # Get buffer data
    buf_states, buf_actions, buf_rewards, buf_dones, buf_log_probs, buf_values = buffer.get()
    
    # Reshape for GAE
    steps_per_env = len(buffer.states)
    buf_rewards = buf_rewards.view(steps_per_env, NUM_ENVS)
    buf_values = buf_values.view(steps_per_env, NUM_ENVS)
    buf_dones = buf_dones.view(steps_per_env, NUM_ENVS)
    
    # Compute advantages for each env
    all_advantages = []
    all_returns = []
    for env_idx in range(NUM_ENVS):
        adv, ret = compute_gae(
            buf_rewards[:, env_idx],
            buf_values[:, env_idx],
            buf_dones[:, env_idx],
            next_value[env_idx]
        )
        all_advantages.append(adv)
        all_returns.append(ret)
    
    advantages = torch.stack(all_advantages, dim=1).view(-1)
    returns = torch.stack(all_returns, dim=1).view(-1)
    
    # PPO update
    actor_loss, critic_loss, entropy = ppo_update(
        buf_states, buf_actions, buf_log_probs, advantages, returns
    )
    all_actor_losses.append(actor_loss)
    all_critic_losses.append(critic_loss)
    all_entropies.append(entropy)

training_time = time.time() - start_time
print(f"\nTraining complete!")
print(f"Total time: {training_time:.1f}s")
print(f"Total steps: {total_steps:,}")

In [None]:
# Cell 6: Save Weights
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"ppo_{GRID_SIZE}x{GRID_SIZE}_{NUM_EPISODES}ep_{timestamp}.pt"
filepath = save_dir / filename

torch.save({
    'actor': actor.state_dict(),
    'critic': critic.state_dict(),
    'actor_optimizer': actor_optimizer.state_dict(),
    'critic_optimizer': critic_optimizer.state_dict(),
    'episode': episode,
    'total_steps': total_steps,
    'config': {
        'grid_size': GRID_SIZE,
        'num_envs': NUM_ENVS,
        'hidden_dims': HIDDEN_DIMS,
        'actor_lr': ACTOR_LR,
        'critic_lr': CRITIC_LR,
        'gamma': GAMMA,
        'gae_lambda': GAE_LAMBDA,
        'clip_epsilon': CLIP_EPSILON,
        'state_representation': STATE_REPRESENTATION
    }
}, filepath)

print(f"Model saved to: {filepath}")

In [None]:
# Cell 7: Visualization
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

def smooth(data, window=50):
    if len(data) < window:
        return data
    return np.convolve(data, np.ones(window)/window, mode='valid')

# Plot 1: Episode Rewards
ax1 = axes[0, 0]
ax1.plot(all_rewards, alpha=0.3, label='Raw')
if len(all_rewards) > 50:
    ax1.plot(range(49, len(all_rewards)), smooth(all_rewards), label='Smoothed (50)')
ax1.set_xlabel('Episode')
ax1.set_ylabel('Reward')
ax1.set_title('Episode Rewards')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot 2: Episode Scores
ax2 = axes[0, 1]
ax2.plot(all_scores, alpha=0.3, label='Raw')
if len(all_scores) > 50:
    ax2.plot(range(49, len(all_scores)), smooth(all_scores), label='Smoothed (50)')
ax2.set_xlabel('Episode')
ax2.set_ylabel('Score (Food Eaten)')
ax2.set_title('Episode Scores')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Plot 3: Actor and Critic Loss
ax3 = axes[1, 0]
ax3.plot(all_actor_losses, label='Actor Loss', alpha=0.7)
ax3.plot(all_critic_losses, label='Critic Loss', alpha=0.7)
ax3.set_xlabel('PPO Update')
ax3.set_ylabel('Loss')
ax3.set_title('PPO Losses')
ax3.legend()
ax3.grid(True, alpha=0.3)

# Plot 4: Entropy
ax4 = axes[1, 1]
ax4.plot(all_entropies)
ax4.set_xlabel('PPO Update')
ax4.set_ylabel('Entropy')
ax4.set_title('Policy Entropy')
ax4.grid(True, alpha=0.3)

plt.suptitle(f'PPO Training Results - {NUM_EPISODES} Episodes', fontsize=14)
plt.tight_layout()

fig_dir = project_root / 'results' / 'figures'
fig_dir.mkdir(parents=True, exist_ok=True)
fig_path = fig_dir / f'ppo_training_{timestamp}.png'
plt.savefig(fig_path, dpi=150, bbox_inches='tight')
print(f"Figure saved to: {fig_path}")

plt.show()

In [None]:
# Cell 8: Results Summary
stats = metrics.get_recent_stats()

print("=" * 50)
print("TRAINING SUMMARY")
print("=" * 50)
print(f"Algorithm: PPO (Proximal Policy Optimization)")
print(f"Episodes: {episode}")
print(f"Total Steps: {total_steps:,}")
print(f"Training Time: {training_time:.1f}s")
print()
print("Hyperparameters:")
print(f"  Clip Epsilon: {CLIP_EPSILON}")
print(f"  GAE Lambda: {GAE_LAMBDA}")
print(f"  Entropy Coef: {ENTROPY_COEF}")
print()
print("Final Performance (last 100 episodes):")
print(f"  Average Score: {stats['avg_score']:.2f}")
print(f"  Average Reward: {stats['avg_reward']:.2f}")
print(f"  Average Length: {stats['avg_length']:.2f}")
print(f"  Max Score: {stats['max_score']}")
print()
print(f"Overall Statistics:")
print(f"  Mean Score: {np.mean(all_scores):.2f} +/- {np.std(all_scores):.2f}")
print(f"  Max Score: {max(all_scores)}")
print(f"  Mean Reward: {np.mean(all_rewards):.2f} +/- {np.std(all_rewards):.2f}")
print("=" * 50)