# PPO Two-Snake Curriculum Training Notebook (MLP)

This notebook trains PPO agents using curriculum learning for two-snake competitive play.

**Algorithm**: PPO with 5-stage curriculum:
1. Stage 0: vs StaticAgent (learn basic movement, 70% win threshold)
2. Stage 1: vs RandomAgent (handle unpredictability, 60% win threshold)
3. Stage 2: vs GreedyFoodAgent (compete for food, 55% win threshold)
4. Stage 3: vs Frozen small network (50% win threshold)
5. Stage 4: Co-evolution (both learning, no threshold)

In [None]:
# Cell 1: Imports
import sys
from pathlib import Path

project_root = Path.cwd().parent.parent
sys.path.insert(0, str(project_root))

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
import time
from dataclasses import dataclass, field
from typing import Optional, List

from core.environment_two_snake_vectorized import VectorizedTwoSnakeEnv
from core.networks import PPO_Actor_MLP, PPO_Critic_MLP
from core.utils import set_seed, get_device
from scripts.baselines.scripted_opponents import get_scripted_agent

print(f"Project root: {project_root}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
# Cell 2: Configuration (papermill parameters)
# ============== CONFIGURATION ==============

# Environment
GRID_SIZE = 20
NUM_ENVS = 128
MAX_STEPS = 1000

# Network sizes
BIG_HIDDEN_DIMS = (256, 256)  # Agent 1
SMALL_HIDDEN_DIMS = (128, 128)  # Agent 2

# PPO hyperparameters
ACTOR_LR = 0.0003
CRITIC_LR = 0.0003
GAMMA = 0.99
GAE_LAMBDA = 0.95
CLIP_EPS = 0.2
ENTROPY_COEF = 0.01
VALUE_COEF = 0.5
MAX_GRAD_NORM = 0.5
PPO_EPOCHS = 4
BATCH_SIZE = 64
ROLLOUT_STEPS = 256  # Steps per rollout

# Curriculum Stage Settings (min_steps for each)
STAGE0_MIN_STEPS = 5000  # Static opponent
STAGE1_MIN_STEPS = 7500  # Random opponent
STAGE2_MIN_STEPS = 10000  # Greedy opponent
STAGE3_MIN_STEPS = 10000  # Frozen policy
STAGE4_MIN_STEPS = 20000  # Co-evolution

# Curriculum thresholds
STAGE0_WIN_THRESHOLD = 0.70
STAGE1_WIN_THRESHOLD = 0.60
STAGE2_WIN_THRESHOLD = 0.55
STAGE3_WIN_THRESHOLD = 0.50

# Target food per stage (progressive difficulty)
STAGE0_TARGET_FOOD = 10
STAGE1_TARGET_FOOD = 10
STAGE2_TARGET_FOOD = 4
STAGE3_TARGET_FOOD = 6
STAGE4_TARGET_FOOD = 8

# Output
SAVE_DIR = '../../results/weights/ppo_two_snake_mlp_curriculum'
SEED = 42
LOG_INTERVAL = 50

In [None]:
# Cell 3: Environment & Agent Setup
set_seed(SEED)
device = get_device()
print(f"Using device: {device}")

save_dir = Path(SAVE_DIR)
save_dir.mkdir(parents=True, exist_ok=True)

# Create environment
env = VectorizedTwoSnakeEnv(
    num_envs=NUM_ENVS,
    grid_size=GRID_SIZE,
    max_steps=MAX_STEPS,
    target_food=STAGE0_TARGET_FOOD,
    device=device
)

# Get input dimensions from environment
obs1, obs2 = env.reset()
input_dim = obs1.shape[1]  # Should be 35 for competitive features
output_dim = 3  # relative actions

print(f"Input dim: {input_dim}, Output dim: {output_dim}")

# Create Agent 1 (Big)
actor1 = PPO_Actor_MLP(input_dim, output_dim, BIG_HIDDEN_DIMS).to(device)
critic1 = PPO_Critic_MLP(input_dim, BIG_HIDDEN_DIMS).to(device)
actor1_optimizer = torch.optim.Adam(actor1.parameters(), lr=ACTOR_LR)
critic1_optimizer = torch.optim.Adam(critic1.parameters(), lr=CRITIC_LR)

# Create Agent 2 (Small)
actor2 = PPO_Actor_MLP(input_dim, output_dim, SMALL_HIDDEN_DIMS).to(device)
critic2 = PPO_Critic_MLP(input_dim, SMALL_HIDDEN_DIMS).to(device)
actor2_optimizer = torch.optim.Adam(actor2.parameters(), lr=ACTOR_LR)
critic2_optimizer = torch.optim.Adam(critic2.parameters(), lr=CRITIC_LR)

# Load scripted opponents
scripted_agents = {}
for agent_type in ['static', 'random', 'greedy']:
    try:
        scripted_agents[agent_type] = get_scripted_agent(agent_type, device=device)
        print(f"Loaded {agent_type} opponent")
    except Exception as e:
        print(f"Warning: Could not load {agent_type} agent: {e}")

print(f"\nEnvironment: {GRID_SIZE}x{GRID_SIZE} grid, {NUM_ENVS} parallel envs")
print(f"Agent 1 (Big): {BIG_HIDDEN_DIMS}")
print(f"Agent 2 (Small): {SMALL_HIDDEN_DIMS}")

In [None]:
# Cell 4: PPO Helper Functions

class PPOBuffer:
    """Buffer for storing PPO rollout data"""
    def __init__(self, device):
        self.device = device
        self.clear()
    
    def clear(self):
        self.states = []
        self.actions = []
        self.rewards = []
        self.dones = []
        self.log_probs = []
        self.values = []
    
    def add(self, states, actions, rewards, dones, log_probs, values):
        self.states.append(states)
        self.actions.append(actions)
        self.rewards.append(rewards)
        self.dones.append(dones)
        self.log_probs.append(log_probs)
        self.values.append(values)
    
    def get(self):
        states = torch.cat(self.states, dim=0)
        actions = torch.cat(self.actions, dim=0)
        rewards = torch.stack(self.rewards)
        dones = torch.stack(self.dones)
        log_probs = torch.cat(self.log_probs, dim=0)
        values = torch.cat(self.values, dim=0)
        return states, actions, rewards, dones, log_probs, values

def select_actions(actor, critic, states):
    """Select actions using current policy"""
    with torch.no_grad():
        logits = actor(states)
        values = critic(states).squeeze(-1)
        probs = F.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(probs)
        actions = dist.sample()
        log_probs = dist.log_prob(actions)
    return actions, log_probs, values

def select_greedy_actions(actor, states):
    """Select greedy actions (for frozen policy)"""
    with torch.no_grad():
        logits = actor(states)
        actions = logits.argmax(dim=-1)
    return actions

def compute_gae(rewards, values, dones, next_value, gamma=GAMMA, gae_lambda=GAE_LAMBDA):
    """Compute Generalized Advantage Estimation"""
    steps, num_envs = rewards.shape
    advantages = torch.zeros_like(rewards)
    gae = torch.zeros(num_envs, device=rewards.device)
    
    for t in reversed(range(steps)):
        if t == steps - 1:
            next_val = next_value
        else:
            next_val = values[t + 1]
        
        mask = 1.0 - dones[t].float()
        delta = rewards[t] + gamma * next_val * mask - values[t]
        gae = delta + gamma * gae_lambda * mask * gae
        advantages[t] = gae
    
    returns = advantages + values.view(steps, num_envs)
    return advantages, returns

def ppo_update(actor, critic, actor_optimizer, critic_optimizer, states, actions, old_log_probs, advantages, returns):
    """PPO update step"""
    # Normalize advantages
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
    
    total_actor_loss = 0
    total_critic_loss = 0
    total_entropy = 0
    
    # Mini-batch updates
    num_samples = states.shape[0]
    indices = torch.randperm(num_samples, device=states.device)
    
    for epoch in range(PPO_EPOCHS):
        for start in range(0, num_samples, BATCH_SIZE):
            end = min(start + BATCH_SIZE, num_samples)
            batch_indices = indices[start:end]
            
            batch_states = states[batch_indices]
            batch_actions = actions[batch_indices]
            batch_old_log_probs = old_log_probs[batch_indices]
            batch_advantages = advantages[batch_indices]
            batch_returns = returns[batch_indices]
            
            # Actor loss
            logits = actor(batch_states)
            probs = F.softmax(logits, dim=-1)
            dist = torch.distributions.Categorical(probs)
            new_log_probs = dist.log_prob(batch_actions)
            entropy = dist.entropy().mean()
            
            ratio = torch.exp(new_log_probs - batch_old_log_probs)
            surr1 = ratio * batch_advantages
            surr2 = torch.clamp(ratio, 1 - CLIP_EPS, 1 + CLIP_EPS) * batch_advantages
            actor_loss = -torch.min(surr1, surr2).mean() - ENTROPY_COEF * entropy
            
            # Critic loss
            values = critic(batch_states).squeeze(-1)
            critic_loss = F.mse_loss(values, batch_returns)
            
            # Update actor
            actor_optimizer.zero_grad()
            actor_loss.backward()
            torch.nn.utils.clip_grad_norm_(actor.parameters(), MAX_GRAD_NORM)
            actor_optimizer.step()
            
            # Update critic
            critic_optimizer.zero_grad()
            critic_loss.backward()
            torch.nn.utils.clip_grad_norm_(critic.parameters(), MAX_GRAD_NORM)
            critic_optimizer.step()
            
            total_actor_loss += actor_loss.item()
            total_critic_loss += critic_loss.item()
            total_entropy += entropy.item()
    
    num_updates = PPO_EPOCHS * ((num_samples + BATCH_SIZE - 1) // BATCH_SIZE)
    return total_actor_loss / num_updates, total_critic_loss / num_updates, total_entropy / num_updates

# Global tracking
total_steps = 0
total_rounds = 0
round_winners = []
all_scores1 = []
all_scores2 = []
all_losses = []
all_win_rates = []
stage_results = []

buffer1 = PPOBuffer(device)
buffer2 = PPOBuffer(device)

def calculate_win_rate(window=100):
    if len(round_winners) < window:
        window = len(round_winners)
    if window == 0:
        return 0.0
    recent = round_winners[-window:]
    return sum(1 for w in recent if w == 1) / window

print("PPO helper functions defined.")

In [None]:
# Cell 5: Stage 0 - Static Opponent
print("="*70)
print("STAGE 0: Static Opponent")
print("="*70)
print(f"Target food: {STAGE0_TARGET_FOOD}")
print(f"Min steps: {STAGE0_MIN_STEPS}")
print(f"Win rate threshold: {STAGE0_WIN_THRESHOLD}")
print("="*70 + "\n")

stage0_start = time.time()
stage_steps = 0
env.set_target_food(STAGE0_TARGET_FOOD)
obs1, obs2 = env.reset()

while True:
    buffer1.clear()
    
    # Collect rollout
    for _ in range(max(1, ROLLOUT_STEPS // NUM_ENVS)):
        # Agent 1 uses PPO policy
        actions1, log_probs1, values1 = select_actions(actor1, critic1, obs1)
        # Agent 2 uses static scripted agent
        actions2 = scripted_agents['static'].select_action(env)
        
        next_obs1, next_obs2, r1, r2, dones, info = env.step(actions1, actions2)
        
        buffer1.add(obs1, actions1, r1, dones, log_probs1, values1)
        
        if dones.any():
            num_done = len(info['done_envs'])
            for i in range(num_done):
                round_winners.append(int(info['winners'][i]))
                total_rounds += 1
                all_scores1.append(info['food_counts1'][i])
                all_scores2.append(info['food_counts2'][i])
        
        obs1 = next_obs1
        obs2 = next_obs2
        stage_steps += NUM_ENVS
        total_steps += NUM_ENVS
    
    # PPO update for agent 1
    next_value1 = critic1(obs1).squeeze(-1).detach()
    states1, actions_b1, rewards1, dones1, log_probs_b1, values_b1 = buffer1.get()
    
    steps_per_env = len(buffer1.states)
    rewards1 = rewards1.view(steps_per_env, NUM_ENVS)
    values_v1 = values_b1.view(steps_per_env, NUM_ENVS)
    dones1 = dones1.view(steps_per_env, NUM_ENVS)
    
    advantages1, returns1 = compute_gae(rewards1, values_v1, dones1, next_value1)
    advantages1 = advantages1.view(-1)
    returns1 = returns1.view(-1)
    
    actor_loss, critic_loss, entropy = ppo_update(
        actor1, critic1, actor1_optimizer, critic1_optimizer,
        states1, actions_b1, log_probs_b1, advantages1, returns1
    )
    all_losses.append(actor_loss + critic_loss)
    
    # Logging
    if total_steps % (LOG_INTERVAL * NUM_ENVS) < NUM_ENVS * ROLLOUT_STEPS // NUM_ENVS:
        win_rate = calculate_win_rate()
        all_win_rates.append(win_rate)
        avg_score1 = np.mean(all_scores1[-100:]) if all_scores1 else 0
        print(f"[Step {total_steps:>6}] Win Rate: {win_rate:.2%} | Score: {avg_score1:.1f} | Loss: {actor_loss + critic_loss:.4f}")
    
    # Check stage completion
    if stage_steps >= STAGE0_MIN_STEPS and calculate_win_rate() >= STAGE0_WIN_THRESHOLD:
        break
    if stage_steps >= STAGE0_MIN_STEPS * 3:  # Safety limit
        print("Warning: Stage 0 reached max steps without meeting threshold")
        break

stage0_time = time.time() - stage0_start
stage0_win_rate = calculate_win_rate()
stage_results.append({
    'stage': 0,
    'name': 'Static',
    'time': stage0_time,
    'steps': stage_steps,
    'win_rate': stage0_win_rate
})

print(f"\n{'='*70}")
print(f"STAGE 0 COMPLETE")
print(f"Time: {stage0_time:.1f}s | Steps: {stage_steps:,} | Win Rate: {stage0_win_rate:.2%}")
print("="*70)

In [None]:
# Cell 6: Stage 1 - Random Opponent
print("="*70)
print("STAGE 1: Random Opponent")
print("="*70)
print(f"Target food: {STAGE1_TARGET_FOOD}")
print(f"Min steps: {STAGE1_MIN_STEPS}")
print(f"Win rate threshold: {STAGE1_WIN_THRESHOLD}")
print("="*70 + "\n")

stage1_start = time.time()
stage_steps = 0
env.set_target_food(STAGE1_TARGET_FOOD)
obs1, obs2 = env.reset()

while True:
    buffer1.clear()
    
    for _ in range(max(1, ROLLOUT_STEPS // NUM_ENVS)):
        actions1, log_probs1, values1 = select_actions(actor1, critic1, obs1)
        actions2 = scripted_agents['random'].select_action(env)
        
        next_obs1, next_obs2, r1, r2, dones, info = env.step(actions1, actions2)
        
        buffer1.add(obs1, actions1, r1, dones, log_probs1, values1)
        
        if dones.any():
            num_done = len(info['done_envs'])
            for i in range(num_done):
                round_winners.append(int(info['winners'][i]))
                total_rounds += 1
                all_scores1.append(info['food_counts1'][i])
                all_scores2.append(info['food_counts2'][i])
        
        obs1 = next_obs1
        obs2 = next_obs2
        stage_steps += NUM_ENVS
        total_steps += NUM_ENVS
    
    # PPO update
    next_value1 = critic1(obs1).squeeze(-1).detach()
    states1, actions_b1, rewards1, dones1, log_probs_b1, values_b1 = buffer1.get()
    
    steps_per_env = len(buffer1.states)
    rewards1 = rewards1.view(steps_per_env, NUM_ENVS)
    values_v1 = values_b1.view(steps_per_env, NUM_ENVS)
    dones1 = dones1.view(steps_per_env, NUM_ENVS)
    
    advantages1, returns1 = compute_gae(rewards1, values_v1, dones1, next_value1)
    advantages1 = advantages1.view(-1)
    returns1 = returns1.view(-1)
    
    actor_loss, critic_loss, entropy = ppo_update(
        actor1, critic1, actor1_optimizer, critic1_optimizer,
        states1, actions_b1, log_probs_b1, advantages1, returns1
    )
    all_losses.append(actor_loss + critic_loss)
    
    if total_steps % (LOG_INTERVAL * NUM_ENVS) < NUM_ENVS * ROLLOUT_STEPS // NUM_ENVS:
        win_rate = calculate_win_rate()
        all_win_rates.append(win_rate)
        avg_score1 = np.mean(all_scores1[-100:]) if all_scores1 else 0
        print(f"[Step {total_steps:>6}] Win Rate: {win_rate:.2%} | Score: {avg_score1:.1f} | Loss: {actor_loss + critic_loss:.4f}")
    
    if stage_steps >= STAGE1_MIN_STEPS and calculate_win_rate() >= STAGE1_WIN_THRESHOLD:
        break
    if stage_steps >= STAGE1_MIN_STEPS * 3:
        print("Warning: Stage 1 reached max steps without meeting threshold")
        break

stage1_time = time.time() - stage1_start
stage1_win_rate = calculate_win_rate()
stage_results.append({
    'stage': 1,
    'name': 'Random',
    'time': stage1_time,
    'steps': stage_steps,
    'win_rate': stage1_win_rate
})

print(f"\n{'='*70}")
print(f"STAGE 1 COMPLETE")
print(f"Time: {stage1_time:.1f}s | Steps: {stage_steps:,} | Win Rate: {stage1_win_rate:.2%}")
print("="*70)

In [None]:
# Cell 7: Stage 2 - Greedy Opponent
print("="*70)
print("STAGE 2: Greedy Opponent")
print("="*70)
print(f"Target food: {STAGE2_TARGET_FOOD}")
print(f"Min steps: {STAGE2_MIN_STEPS}")
print(f"Win rate threshold: {STAGE2_WIN_THRESHOLD}")
print("="*70 + "\n")

stage2_start = time.time()
stage_steps = 0
env.set_target_food(STAGE2_TARGET_FOOD)
obs1, obs2 = env.reset()

while True:
    buffer1.clear()
    
    for _ in range(max(1, ROLLOUT_STEPS // NUM_ENVS)):
        actions1, log_probs1, values1 = select_actions(actor1, critic1, obs1)
        actions2 = scripted_agents['greedy'].select_action(env)
        
        next_obs1, next_obs2, r1, r2, dones, info = env.step(actions1, actions2)
        
        buffer1.add(obs1, actions1, r1, dones, log_probs1, values1)
        
        if dones.any():
            num_done = len(info['done_envs'])
            for i in range(num_done):
                round_winners.append(int(info['winners'][i]))
                total_rounds += 1
                all_scores1.append(info['food_counts1'][i])
                all_scores2.append(info['food_counts2'][i])
        
        obs1 = next_obs1
        obs2 = next_obs2
        stage_steps += NUM_ENVS
        total_steps += NUM_ENVS
    
    # PPO update
    next_value1 = critic1(obs1).squeeze(-1).detach()
    states1, actions_b1, rewards1, dones1, log_probs_b1, values_b1 = buffer1.get()
    
    steps_per_env = len(buffer1.states)
    rewards1 = rewards1.view(steps_per_env, NUM_ENVS)
    values_v1 = values_b1.view(steps_per_env, NUM_ENVS)
    dones1 = dones1.view(steps_per_env, NUM_ENVS)
    
    advantages1, returns1 = compute_gae(rewards1, values_v1, dones1, next_value1)
    advantages1 = advantages1.view(-1)
    returns1 = returns1.view(-1)
    
    actor_loss, critic_loss, entropy = ppo_update(
        actor1, critic1, actor1_optimizer, critic1_optimizer,
        states1, actions_b1, log_probs_b1, advantages1, returns1
    )
    all_losses.append(actor_loss + critic_loss)
    
    if total_steps % (LOG_INTERVAL * NUM_ENVS) < NUM_ENVS * ROLLOUT_STEPS // NUM_ENVS:
        win_rate = calculate_win_rate()
        all_win_rates.append(win_rate)
        avg_score1 = np.mean(all_scores1[-100:]) if all_scores1 else 0
        print(f"[Step {total_steps:>6}] Win Rate: {win_rate:.2%} | Score: {avg_score1:.1f} | Loss: {actor_loss + critic_loss:.4f}")
    
    if stage_steps >= STAGE2_MIN_STEPS and calculate_win_rate() >= STAGE2_WIN_THRESHOLD:
        break
    if stage_steps >= STAGE2_MIN_STEPS * 3:
        print("Warning: Stage 2 reached max steps without meeting threshold")
        break

stage2_time = time.time() - stage2_start
stage2_win_rate = calculate_win_rate()
stage_results.append({
    'stage': 2,
    'name': 'Greedy',
    'time': stage2_time,
    'steps': stage_steps,
    'win_rate': stage2_win_rate
})

print(f"\n{'='*70}")
print(f"STAGE 2 COMPLETE")
print(f"Time: {stage2_time:.1f}s | Steps: {stage_steps:,} | Win Rate: {stage2_win_rate:.2%}")
print("="*70)

In [None]:
# Cell 8: Stage 3 - Frozen Policy Opponent
print("="*70)
print("STAGE 3: Frozen Policy Opponent")
print("="*70)
print(f"Target food: {STAGE3_TARGET_FOOD}")
print(f"Min steps: {STAGE3_MIN_STEPS}")
print(f"Win rate threshold: {STAGE3_WIN_THRESHOLD}")
print("="*70 + "\n")

stage3_start = time.time()
stage_steps = 0
env.set_target_food(STAGE3_TARGET_FOOD)
obs1, obs2 = env.reset()

while True:
    buffer1.clear()
    
    for _ in range(max(1, ROLLOUT_STEPS // NUM_ENVS)):
        actions1, log_probs1, values1 = select_actions(actor1, critic1, obs1)
        # Frozen policy: use agent2's actor with greedy selection
        actions2 = select_greedy_actions(actor2, obs2)
        
        next_obs1, next_obs2, r1, r2, dones, info = env.step(actions1, actions2)
        
        buffer1.add(obs1, actions1, r1, dones, log_probs1, values1)
        
        if dones.any():
            num_done = len(info['done_envs'])
            for i in range(num_done):
                round_winners.append(int(info['winners'][i]))
                total_rounds += 1
                all_scores1.append(info['food_counts1'][i])
                all_scores2.append(info['food_counts2'][i])
        
        obs1 = next_obs1
        obs2 = next_obs2
        stage_steps += NUM_ENVS
        total_steps += NUM_ENVS
    
    # PPO update (only agent 1)
    next_value1 = critic1(obs1).squeeze(-1).detach()
    states1, actions_b1, rewards1, dones1, log_probs_b1, values_b1 = buffer1.get()
    
    steps_per_env = len(buffer1.states)
    rewards1 = rewards1.view(steps_per_env, NUM_ENVS)
    values_v1 = values_b1.view(steps_per_env, NUM_ENVS)
    dones1 = dones1.view(steps_per_env, NUM_ENVS)
    
    advantages1, returns1 = compute_gae(rewards1, values_v1, dones1, next_value1)
    advantages1 = advantages1.view(-1)
    returns1 = returns1.view(-1)
    
    actor_loss, critic_loss, entropy = ppo_update(
        actor1, critic1, actor1_optimizer, critic1_optimizer,
        states1, actions_b1, log_probs_b1, advantages1, returns1
    )
    all_losses.append(actor_loss + critic_loss)
    
    if total_steps % (LOG_INTERVAL * NUM_ENVS) < NUM_ENVS * ROLLOUT_STEPS // NUM_ENVS:
        win_rate = calculate_win_rate()
        all_win_rates.append(win_rate)
        avg_score1 = np.mean(all_scores1[-100:]) if all_scores1 else 0
        print(f"[Step {total_steps:>6}] Win Rate: {win_rate:.2%} | Score: {avg_score1:.1f} | Loss: {actor_loss + critic_loss:.4f}")
    
    if stage_steps >= STAGE3_MIN_STEPS and calculate_win_rate() >= STAGE3_WIN_THRESHOLD:
        break
    if stage_steps >= STAGE3_MIN_STEPS * 3:
        print("Warning: Stage 3 reached max steps without meeting threshold")
        break

stage3_time = time.time() - stage3_start
stage3_win_rate = calculate_win_rate()
stage_results.append({
    'stage': 3,
    'name': 'Frozen',
    'time': stage3_time,
    'steps': stage_steps,
    'win_rate': stage3_win_rate
})

print(f"\n{'='*70}")
print(f"STAGE 3 COMPLETE")
print(f"Time: {stage3_time:.1f}s | Steps: {stage_steps:,} | Win Rate: {stage3_win_rate:.2%}")
print("="*70)

In [None]:
# Cell 9: Stage 4 - Co-Evolution (Both Learning)
print("="*70)
print("STAGE 4: Co-Evolution (Both Learning)")
print("="*70)
print(f"Target food: {STAGE4_TARGET_FOOD}")
print(f"Min steps: {STAGE4_MIN_STEPS}")
print("Win rate threshold: None (train for full duration)")
print("="*70 + "\n")

stage4_start = time.time()
stage_steps = 0
env.set_target_food(STAGE4_TARGET_FOOD)
obs1, obs2 = env.reset()

while stage_steps < STAGE4_MIN_STEPS:
    buffer1.clear()
    buffer2.clear()
    
    for _ in range(max(1, ROLLOUT_STEPS // NUM_ENVS)):
        # Both agents learning
        actions1, log_probs1, values1 = select_actions(actor1, critic1, obs1)
        actions2, log_probs2, values2 = select_actions(actor2, critic2, obs2)
        
        next_obs1, next_obs2, r1, r2, dones, info = env.step(actions1, actions2)
        
        buffer1.add(obs1, actions1, r1, dones, log_probs1, values1)
        buffer2.add(obs2, actions2, r2, dones, log_probs2, values2)
        
        if dones.any():
            num_done = len(info['done_envs'])
            for i in range(num_done):
                round_winners.append(int(info['winners'][i]))
                total_rounds += 1
                all_scores1.append(info['food_counts1'][i])
                all_scores2.append(info['food_counts2'][i])
        
        obs1 = next_obs1
        obs2 = next_obs2
        stage_steps += NUM_ENVS
        total_steps += NUM_ENVS
        
        if stage_steps >= STAGE4_MIN_STEPS:
            break
    
    if stage_steps >= STAGE4_MIN_STEPS:
        break
    
    # PPO update for both agents
    steps_per_env = len(buffer1.states)
    
    # Agent 1 update
    next_value1 = critic1(obs1).squeeze(-1).detach()
    states1, actions_b1, rewards1, dones1, log_probs_b1, values_b1 = buffer1.get()
    rewards1 = rewards1.view(steps_per_env, NUM_ENVS)
    values_v1 = values_b1.view(steps_per_env, NUM_ENVS)
    dones1 = dones1.view(steps_per_env, NUM_ENVS)
    advantages1, returns1 = compute_gae(rewards1, values_v1, dones1, next_value1)
    advantages1 = advantages1.view(-1)
    returns1 = returns1.view(-1)
    actor_loss1, critic_loss1, _ = ppo_update(
        actor1, critic1, actor1_optimizer, critic1_optimizer,
        states1, actions_b1, log_probs_b1, advantages1, returns1
    )
    
    # Agent 2 update
    next_value2 = critic2(obs2).squeeze(-1).detach()
    states2, actions_b2, rewards2, dones2, log_probs_b2, values_b2 = buffer2.get()
    rewards2 = rewards2.view(steps_per_env, NUM_ENVS)
    values_v2 = values_b2.view(steps_per_env, NUM_ENVS)
    dones2 = dones2.view(steps_per_env, NUM_ENVS)
    advantages2, returns2 = compute_gae(rewards2, values_v2, dones2, next_value2)
    advantages2 = advantages2.view(-1)
    returns2 = returns2.view(-1)
    actor_loss2, critic_loss2, _ = ppo_update(
        actor2, critic2, actor2_optimizer, critic2_optimizer,
        states2, actions_b2, log_probs_b2, advantages2, returns2
    )
    
    all_losses.append(actor_loss1 + critic_loss1)
    
    if total_steps % (LOG_INTERVAL * NUM_ENVS) < NUM_ENVS * ROLLOUT_STEPS // NUM_ENVS:
        win_rate = calculate_win_rate()
        all_win_rates.append(win_rate)
        avg_score1 = np.mean(all_scores1[-100:]) if all_scores1 else 0
        avg_score2 = np.mean(all_scores2[-100:]) if all_scores2 else 0
        print(f"[Step {total_steps:>6}] Win Rate: {win_rate:.2%} | Scores: {avg_score1:.1f} vs {avg_score2:.1f}")

stage4_time = time.time() - stage4_start
stage4_win_rate = calculate_win_rate()
stage_results.append({
    'stage': 4,
    'name': 'Co-Evolution',
    'time': stage4_time,
    'steps': stage_steps,
    'win_rate': stage4_win_rate
})

total_training_time = sum(r['time'] for r in stage_results)

print(f"\n{'='*70}")
print(f"STAGE 4 COMPLETE")
print(f"Time: {stage4_time:.1f}s | Steps: {stage_steps:,} | Win Rate: {stage4_win_rate:.2%}")
print("="*70)
print(f"\nTOTAL CURRICULUM TRAINING COMPLETE!")
print(f"Total time: {total_training_time:.1f}s ({total_training_time/60:.1f} min)")
print(f"Total steps: {total_steps:,}")

In [None]:
# Cell 10: Save Weights
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# Save Agent 1 (Big)
agent1_path = save_dir / f"big_256x256_curriculum_{timestamp}.pt"
torch.save({
    'actor': actor1.state_dict(),
    'critic': critic1.state_dict(),
    'actor_optimizer': actor1_optimizer.state_dict(),
    'critic_optimizer': critic1_optimizer.state_dict(),
    'total_steps': total_steps,
    'config': {
        'hidden_dims': BIG_HIDDEN_DIMS,
        'grid_size': GRID_SIZE,
        'actor_lr': ACTOR_LR,
        'critic_lr': CRITIC_LR
    }
}, agent1_path)
print(f"Agent 1 (Big) saved to: {agent1_path}")

# Save Agent 2 (Small)
agent2_path = save_dir / f"small_128x128_curriculum_{timestamp}.pt"
torch.save({
    'actor': actor2.state_dict(),
    'critic': critic2.state_dict(),
    'actor_optimizer': actor2_optimizer.state_dict(),
    'critic_optimizer': critic2_optimizer.state_dict(),
    'total_steps': total_steps,
    'config': {
        'hidden_dims': SMALL_HIDDEN_DIMS,
        'grid_size': GRID_SIZE,
        'actor_lr': ACTOR_LR,
        'critic_lr': CRITIC_LR
    }
}, agent2_path)
print(f"Agent 2 (Small) saved to: {agent2_path}")

In [None]:
# Cell 11: Visualization
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

def smooth(data, window=50):
    if len(data) < window:
        return data
    return np.convolve(data, np.ones(window)/window, mode='valid')

# Plot 1: Agent Scores
ax1 = axes[0, 0]
ax1.plot(all_scores1, alpha=0.3, label='Agent 1 (Big)', color='blue')
ax1.plot(all_scores2, alpha=0.3, label='Agent 2 (Small)', color='red')
if len(all_scores1) > 50:
    ax1.plot(range(49, len(all_scores1)), smooth(all_scores1), color='blue', linewidth=2)
    ax1.plot(range(49, len(all_scores2)), smooth(all_scores2), color='red', linewidth=2)
ax1.set_xlabel('Round')
ax1.set_ylabel('Score (Food Eaten)')
ax1.set_title('Agent Scores')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot 2: Win Rate Over Time
ax2 = axes[0, 1]
ax2.plot(all_win_rates, color='green')
ax2.axhline(y=0.5, color='gray', linestyle='--', label='50% (balanced)')
ax2.set_xlabel('Training Progress')
ax2.set_ylabel('Agent 1 Win Rate')
ax2.set_title('Win Rate Over Time')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Plot 3: Training Loss
ax3 = axes[1, 0]
ax3.plot(all_losses, alpha=0.5)
if len(all_losses) > 50:
    ax3.plot(range(49, len(all_losses)), smooth(all_losses), color='blue', linewidth=2)
ax3.set_xlabel('Update')
ax3.set_ylabel('Loss')
ax3.set_title('PPO Loss')
ax3.grid(True, alpha=0.3)

# Plot 4: Stage Results
ax4 = axes[1, 1]
stage_names = [r['name'] for r in stage_results]
stage_win_rates = [r['win_rate'] * 100 for r in stage_results]
colors = ['#2ecc71', '#3498db', '#e74c3c', '#9b59b6', '#f39c12']
bars = ax4.bar(stage_names, stage_win_rates, color=colors[:len(stage_results)])
ax4.set_xlabel('Stage')
ax4.set_ylabel('Final Win Rate (%)')
ax4.set_title('Win Rate by Curriculum Stage')
ax4.set_ylim(0, 100)
for bar, rate in zip(bars, stage_win_rates):
    ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, 
             f'{rate:.1f}%', ha='center', va='bottom', fontsize=10)
ax4.grid(True, alpha=0.3, axis='y')

plt.suptitle(f'PPO Two-Snake Curriculum Training Results', fontsize=14)
plt.tight_layout()

fig_dir = project_root / 'results' / 'figures'
fig_dir.mkdir(parents=True, exist_ok=True)
fig_path = fig_dir / f'ppo_two_snake_curriculum_{timestamp}.png'
plt.savefig(fig_path, dpi=150, bbox_inches='tight')
print(f"Figure saved to: {fig_path}")

plt.show()

In [None]:
# Cell 12: Results Summary
print("=" * 70)
print("CURRICULUM TRAINING SUMMARY")
print("=" * 70)
print(f"Algorithm: PPO Two-Snake Curriculum (MLP)")
print(f"Grid Size: {GRID_SIZE}x{GRID_SIZE}")
print(f"Total Steps: {total_steps:,}")
print(f"Total Rounds: {total_rounds:,}")
print(f"Total Training Time: {total_training_time:.1f}s ({total_training_time/60:.1f} min)")
print()
print("Stage Results:")
print("-" * 70)
print(f"{'Stage':<15} {'Opponent':<15} {'Steps':>10} {'Time':>10} {'Win Rate':>12}")
print("-" * 70)
for r in stage_results:
    print(f"Stage {r['stage']:<8} {r['name']:<15} {r['steps']:>10,} {r['time']:>9.1f}s {r['win_rate']:>11.2%}")
print("-" * 70)
print()
print("Final Performance (last 100 rounds):")
print(f"  Agent 1 Avg Score: {np.mean(all_scores1[-100:]):.2f}")
print(f"  Agent 2 Avg Score: {np.mean(all_scores2[-100:]):.2f}")
print(f"  Final Win Rate: {calculate_win_rate():.2%}")
print("=" * 70)