# PPO Two-Snake MLP Training Notebook

This notebook trains two PPO agents to compete in a two-snake environment.

**Algorithm**: PPO Two-Snake with:
- Vectorized environment (128 parallel games)
- 35-dimensional feature vector per snake
- Big snake: 256x256 network, Small snake: 128x128 network
- Direct co-evolution (both agents learn simultaneously)
- Rollout-based training with GAE

In [None]:
# Cell 1: Imports
import sys
from pathlib import Path

project_root = Path.cwd().parent.parent
sys.path.insert(0, str(project_root))

import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
import time

from core.environment_two_snake_vectorized import VectorizedTwoSnakeEnv
from core.networks import PPO_Actor_MLP, PPO_Critic_MLP
from core.utils import set_seed, get_device

print(f"Project root: {project_root}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
# Cell 2: Configuration (papermill parameters)
# ============== CONFIGURATION ==============

# Environment
NUM_ENVS = 128
GRID_SIZE = 20
TARGET_FOOD = 10
MAX_STEPS = 1000

# Network sizes
BIG_HIDDEN_DIMS = (256, 256)  # Agent 1
SMALL_HIDDEN_DIMS = (128, 128)  # Agent 2

# PPO hyperparameters
ACTOR_LR = 0.0003
CRITIC_LR = 0.0003
ROLLOUT_STEPS = 2048
BATCH_SIZE = 64
EPOCHS_PER_ROLLOUT = 4
GAMMA = 0.99
GAE_LAMBDA = 0.95
CLIP_EPSILON = 0.2
VALUE_LOSS_COEF = 0.5
ENTROPY_COEF = 0.1
MAX_GRAD_NORM = 0.5

# Training
TOTAL_STEPS = 10000  # Short for testing (full: 250000)
LOG_INTERVAL = 100

# Output
SAVE_DIR = '../../results/weights/ppo_two_snake_mlp'
SEED = 42

In [None]:
# Cell 3: Environment & Model Setup
set_seed(SEED)
device = get_device()
print(f"Using device: {device}")

save_dir = Path(SAVE_DIR)
save_dir.mkdir(parents=True, exist_ok=True)

# Create environment
env = VectorizedTwoSnakeEnv(
    num_envs=NUM_ENVS,
    grid_size=GRID_SIZE,
    target_food=TARGET_FOOD,
    max_steps=MAX_STEPS,
    device=device
)

# State and action dimensions
state_dim = 35  # 35-dimensional feature vector per snake
action_dim = 3  # relative actions

# Create Agent 1 (Big) networks
actor1 = PPO_Actor_MLP(state_dim, action_dim, BIG_HIDDEN_DIMS).to(device)
critic1 = PPO_Critic_MLP(state_dim, BIG_HIDDEN_DIMS).to(device)

# Create Agent 2 (Small) networks  
actor2 = PPO_Actor_MLP(state_dim, action_dim, SMALL_HIDDEN_DIMS).to(device)
critic2 = PPO_Critic_MLP(state_dim, SMALL_HIDDEN_DIMS).to(device)

# Optimizers
actor1_optim = optim.Adam(actor1.parameters(), lr=ACTOR_LR)
critic1_optim = optim.Adam(critic1.parameters(), lr=CRITIC_LR)
actor2_optim = optim.Adam(actor2.parameters(), lr=ACTOR_LR)
critic2_optim = optim.Adam(critic2.parameters(), lr=CRITIC_LR)

print(f"Environment: {GRID_SIZE}x{GRID_SIZE} grid, {NUM_ENVS} parallel games")
print(f"Agent 1 (Big): {BIG_HIDDEN_DIMS[0]}x{BIG_HIDDEN_DIMS[1]} hidden")
print(f"Agent 2 (Small): {SMALL_HIDDEN_DIMS[0]}x{SMALL_HIDDEN_DIMS[1]} hidden")
print(f"State dim: {state_dim}, Action dim: {action_dim}")

In [None]:
# Cell 4: PPO Helper Functions

class PPOBuffer:
    """Rollout buffer for PPO"""
    def __init__(self, device):
        self.device = device
        self.clear()
    
    def clear(self):
        self.states = []
        self.actions = []
        self.rewards = []
        self.dones = []
        self.log_probs = []
        self.values = []
    
    def add(self, state, action, reward, done, log_prob, value):
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        self.dones.append(done)
        self.log_probs.append(log_prob)
        self.values.append(value)
    
    def get(self):
        states = torch.stack(self.states)
        actions = torch.stack(self.actions)
        rewards = torch.stack(self.rewards)
        dones = torch.stack(self.dones)
        log_probs = torch.stack(self.log_probs)
        values = torch.stack(self.values)
        return states, actions, rewards, dones, log_probs, values

def select_action(actor, critic, state):
    """Select action using policy"""
    with torch.no_grad():
        logits = actor(state)
        value = critic(state).squeeze(-1)
        probs = F.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)
    return action, log_prob, value

def compute_gae(rewards, values, dones, next_value, gamma=GAMMA, gae_lambda=GAE_LAMBDA):
    """Compute GAE advantages"""
    advantages = []
    gae = 0
    
    for t in reversed(range(len(rewards))):
        if t == len(rewards) - 1:
            next_val = next_value
        else:
            next_val = values[t + 1]
        
        mask = 1.0 - dones[t].float()
        delta = rewards[t] + gamma * next_val * mask - values[t]
        gae = delta + gamma * gae_lambda * mask * gae
        advantages.insert(0, gae)
    
    advantages = torch.stack(advantages)
    returns = advantages + values
    return advantages, returns

def ppo_update(actor, critic, actor_optim, critic_optim, states, actions, old_log_probs, advantages, returns):
    """PPO update for one agent"""
    # Flatten
    states_flat = states.view(-1, states.shape[-1])
    actions_flat = actions.view(-1)
    old_log_probs_flat = old_log_probs.view(-1)
    advantages_flat = advantages.view(-1)
    returns_flat = returns.view(-1)
    
    # Normalize advantages
    advantages_flat = (advantages_flat - advantages_flat.mean()) / (advantages_flat.std() + 1e-8)
    
    total_actor_loss = 0
    total_critic_loss = 0
    n_updates = 0
    
    for _ in range(EPOCHS_PER_ROLLOUT):
        indices = torch.randperm(states_flat.size(0))
        
        for start in range(0, states_flat.size(0), BATCH_SIZE):
            end = start + BATCH_SIZE
            batch_idx = indices[start:end]
            
            # Get batch
            batch_states = states_flat[batch_idx]
            batch_actions = actions_flat[batch_idx]
            batch_old_log_probs = old_log_probs_flat[batch_idx]
            batch_advantages = advantages_flat[batch_idx]
            batch_returns = returns_flat[batch_idx]
            
            # Forward pass
            logits = actor(batch_states)
            values = critic(batch_states).squeeze(-1)
            probs = F.softmax(logits, dim=-1)
            dist = torch.distributions.Categorical(probs)
            log_probs = dist.log_prob(batch_actions)
            entropy = dist.entropy().mean()
            
            # PPO loss
            ratio = torch.exp(log_probs - batch_old_log_probs)
            surr1 = ratio * batch_advantages
            surr2 = torch.clamp(ratio, 1 - CLIP_EPSILON, 1 + CLIP_EPSILON) * batch_advantages
            actor_loss = -torch.min(surr1, surr2).mean() - ENTROPY_COEF * entropy
            
            critic_loss = F.mse_loss(values, batch_returns)
            
            # Update
            actor_optim.zero_grad()
            actor_loss.backward()
            torch.nn.utils.clip_grad_norm_(actor.parameters(), MAX_GRAD_NORM)
            actor_optim.step()
            
            critic_optim.zero_grad()
            critic_loss.backward()
            torch.nn.utils.clip_grad_norm_(critic.parameters(), MAX_GRAD_NORM)
            critic_optim.step()
            
            total_actor_loss += actor_loss.item()
            total_critic_loss += critic_loss.item()
            n_updates += 1
    
    return total_actor_loss / n_updates, total_critic_loss / n_updates

print("Helper functions defined.")

In [None]:
# Cell 5: Training Loop
print("Starting PPO Two-Snake Training...")
print(f"Total steps: {TOTAL_STEPS:,}")
print()

start_time = time.time()
total_steps = 0
round_count = 0

# Tracking
all_win_rates = []
all_scores1 = []
all_scores2 = []
round_winners = []

# Buffers for both agents
buffer1 = PPOBuffer(device)
buffer2 = PPOBuffer(device)

state1, state2 = env.reset(seed=SEED)

while total_steps < TOTAL_STEPS:
    buffer1.clear()
    buffer2.clear()
    
    # Collect rollout
    for _ in range(ROLLOUT_STEPS // NUM_ENVS):
        action1, log_prob1, value1 = select_action(actor1, critic1, state1)
        action2, log_prob2, value2 = select_action(actor2, critic2, state2)
        
        # VectorizedTwoSnakeEnv returns 6 values
        next_state1, next_state2, reward1, reward2, dones, info = env.step(action1, action2)
        
        buffer1.add(state1, action1, reward1, dones, log_prob1, value1)
        buffer2.add(state2, action2, reward2, dones, log_prob2, value2)
        
        # Track wins on episode ends
        if dones.any():
            num_done = len(info['done_envs'])
            for i in range(num_done):
                winner = info['winners'][i]
                round_winners.append(int(winner))
                all_scores1.append(info['food_counts1'][i])
                all_scores2.append(info['food_counts2'][i])
                round_count += 1
        
        state1, state2 = next_state1, next_state2
        total_steps += NUM_ENVS
    
    # Compute advantages and update both agents
    with torch.no_grad():
        next_value1 = critic1(state1).squeeze(-1)
        next_value2 = critic2(state2).squeeze(-1)
    
    states1, actions1, rewards1, dones1, log_probs1, values1 = buffer1.get()
    states2, actions2, rewards2, dones2, log_probs2, values2 = buffer2.get()
    
    advantages1, returns1 = compute_gae(rewards1, values1, dones1, next_value1)
    advantages2, returns2 = compute_gae(rewards2, values2, dones2, next_value2)
    
    ppo_update(actor1, critic1, actor1_optim, critic1_optim, states1, actions1, log_probs1, advantages1, returns1)
    ppo_update(actor2, critic2, actor2_optim, critic2_optim, states2, actions2, log_probs2, advantages2, returns2)
    
    # Logging
    if total_steps % (LOG_INTERVAL * NUM_ENVS) < NUM_ENVS:
        win_rate = sum(1 for w in round_winners[-100:] if w == 1) / max(1, len(round_winners[-100:]))
        all_win_rates.append(win_rate)
        
        elapsed = time.time() - start_time
        sps = total_steps / elapsed if elapsed > 0 else 0
        
        recent_s1 = np.mean(all_scores1[-100:]) if all_scores1 else 0
        recent_s2 = np.mean(all_scores2[-100:]) if all_scores2 else 0
        
        print(f"Step {total_steps:,}/{TOTAL_STEPS:,} | "
              f"Rounds: {round_count} | "
              f"A1 Score: {recent_s1:.2f} | "
              f"A2 Score: {recent_s2:.2f} | "
              f"Win Rate: {win_rate:.2%} | "
              f"SPS: {sps:.0f}")

training_time = time.time() - start_time
print(f"\nTraining complete!")
print(f"Total time: {training_time:.1f}s")
print(f"Total rounds: {round_count}")

In [None]:
# Cell 6: Save Weights
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# Save Agent 1
agent1_path = save_dir / f"big_256x256_step{total_steps}_{timestamp}.pt"
torch.save({
    'actor': actor1.state_dict(),
    'critic': critic1.state_dict(),
    'actor_optimizer': actor1_optim.state_dict(),
    'critic_optimizer': critic1_optim.state_dict(),
    'total_steps': total_steps
}, agent1_path)
print(f"Agent 1 saved to: {agent1_path}")

# Save Agent 2
agent2_path = save_dir / f"small_128x128_step{total_steps}_{timestamp}.pt"
torch.save({
    'actor': actor2.state_dict(),
    'critic': critic2.state_dict(),
    'actor_optimizer': actor2_optim.state_dict(),
    'critic_optimizer': critic2_optim.state_dict(),
    'total_steps': total_steps
}, agent2_path)
print(f"Agent 2 saved to: {agent2_path}")

In [None]:
# Cell 7: Visualization
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

def smooth(data, window=50):
    if len(data) < window:
        return data
    return np.convolve(data, np.ones(window)/window, mode='valid')

# Plot 1: Agent Scores
ax1 = axes[0, 0]
ax1.plot(all_scores1, alpha=0.3, label='Agent 1 (raw)', color='blue')
ax1.plot(all_scores2, alpha=0.3, label='Agent 2 (raw)', color='red')
if len(all_scores1) > 50:
    ax1.plot(range(49, len(all_scores1)), smooth(all_scores1), label='Agent 1 (smooth)', color='blue')
    ax1.plot(range(49, len(all_scores2)), smooth(all_scores2), label='Agent 2 (smooth)', color='red')
ax1.set_xlabel('Round')
ax1.set_ylabel('Score')
ax1.set_title('Agent Scores')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot 2: Win Rate
ax2 = axes[0, 1]
ax2.plot(all_win_rates, color='green')
ax2.axhline(y=0.5, color='gray', linestyle='--')
ax2.set_xlabel('Log Step')
ax2.set_ylabel('Agent 1 Win Rate')
ax2.set_title('Win Rate Over Time')
ax2.grid(True, alpha=0.3)

# Plot 3: Win Distribution
ax3 = axes[1, 0]
a1_wins = sum(1 for w in round_winners if w == 1)
a2_wins = sum(1 for w in round_winners if w == 2)
draws = sum(1 for w in round_winners if w == 0)
labels = ['Agent 1', 'Agent 2', 'Draw']
sizes = [a1_wins, a2_wins, draws]
colors = ['blue', 'red', 'gray']
ax3.pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%')
ax3.set_title('Win Distribution')

# Plot 4: Score Distribution
ax4 = axes[1, 1]
ax4.hist(all_scores1, bins=20, alpha=0.5, label='Agent 1', color='blue')
ax4.hist(all_scores2, bins=20, alpha=0.5, label='Agent 2', color='red')
ax4.set_xlabel('Score')
ax4.set_ylabel('Frequency')
ax4.set_title('Score Distribution')
ax4.legend()

plt.suptitle(f'PPO Two-Snake Training - {total_steps:,} Steps', fontsize=14)
plt.tight_layout()

fig_dir = project_root / 'results' / 'figures'
fig_dir.mkdir(parents=True, exist_ok=True)
fig_path = fig_dir / f'ppo_two_snake_mlp_training_{timestamp}.png'
plt.savefig(fig_path, dpi=150, bbox_inches='tight')
print(f"Figure saved to: {fig_path}")

plt.show()

In [None]:
# Cell 8: Results Summary
print("=" * 50)
print("TRAINING SUMMARY")
print("=" * 50)
print(f"Algorithm: PPO Two-Snake (Co-Evolution)")
print(f"Total Steps: {total_steps:,}")
print(f"Total Rounds: {round_count}")
print(f"Training Time: {training_time:.1f}s")
print()
print("Match Results:")
print(f"  Agent 1 Wins: {a1_wins} ({a1_wins/round_count*100:.1f}%)")
print(f"  Agent 2 Wins: {a2_wins} ({a2_wins/round_count*100:.1f}%)")
print(f"  Draws: {draws} ({draws/round_count*100:.1f}%)")
print()
print("Final Performance (last 100 rounds):")
print(f"  Agent 1 Avg Score: {np.mean(all_scores1[-100:]):.2f}")
print(f"  Agent 2 Avg Score: {np.mean(all_scores2[-100:]):.2f}")
print(f"  Agent 1 Win Rate: {sum(1 for w in round_winners[-100:] if w == 1)/100:.2%}")
print("=" * 50)