# DQN Two-Snake Curriculum Training Notebook

This notebook trains DQN agents using curriculum learning for two-snake competitive play.

**Algorithm**: DQN with 5-stage curriculum:
1. Stage 0: vs StaticAgent (learn basic movement, 70% win threshold)
2. Stage 1: vs RandomAgent (handle unpredictability, 60% win threshold)
3. Stage 2: vs GreedyFoodAgent (compete for food, 55% win threshold)
4. Stage 3: vs Frozen small network (50% win threshold)
5. Stage 4: Co-evolution (both learning, no threshold)

In [None]:
# Cell 1: Imports
import sys
from pathlib import Path

project_root = Path.cwd().parent.parent
sys.path.insert(0, str(project_root))

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
import time

from core.environment_two_snake_vectorized import VectorizedTwoSnakeEnv
from core.networks import DQN_MLP
from core.utils import ReplayBuffer, EpsilonScheduler, set_seed, get_device
from scripts.baselines.scripted_opponents import get_scripted_agent

print(f"Project root: {project_root}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
# Cell 2: Configuration (papermill parameters)
# ============== CONFIGURATION ==============

# Environment
GRID_SIZE = 20
NUM_ENVS = 128
MAX_STEPS = 1000

# Network sizes
BIG_HIDDEN_DIMS = (256, 256)  # Agent 1
SMALL_HIDDEN_DIMS = (128, 128)  # Agent 2

# DQN hyperparameters
LEARNING_RATE = 0.001
GAMMA = 0.99
BUFFER_SIZE = 50000
BATCH_SIZE = 64
TARGET_UPDATE_FREQ = 1000
TRAIN_STEPS_RATIO = 0.125  # Train every 8th step

# Epsilon parameters
EPSILON_START = 0.8
EPSILON_END = 0.01

# Curriculum Stage Settings (min_steps for each)
STAGE0_MIN_STEPS = 5000  # Static opponent
STAGE1_MIN_STEPS = 7500  # Random opponent
STAGE2_MIN_STEPS = 10000  # Greedy opponent
STAGE3_MIN_STEPS = 10000  # Frozen policy
STAGE4_MIN_STEPS = 20000  # Co-evolution

# Curriculum thresholds
STAGE0_WIN_THRESHOLD = 0.70
STAGE1_WIN_THRESHOLD = 0.60
STAGE2_WIN_THRESHOLD = 0.55
STAGE3_WIN_THRESHOLD = 0.50

# Target food per stage (progressive difficulty)
STAGE0_TARGET_FOOD = 10
STAGE1_TARGET_FOOD = 10
STAGE2_TARGET_FOOD = 4
STAGE3_TARGET_FOOD = 6
STAGE4_TARGET_FOOD = 8

# Output
SAVE_DIR = '../../results/weights/dqn_two_snake_curriculum'
SEED = 42
LOG_INTERVAL = 100

In [None]:
# Cell 3: Environment & Agent Setup
set_seed(SEED)
device = get_device()
print(f"Using device: {device}")

save_dir = Path(SAVE_DIR)
save_dir.mkdir(parents=True, exist_ok=True)

# Create environment
env = VectorizedTwoSnakeEnv(
    num_envs=NUM_ENVS,
    grid_size=GRID_SIZE,
    max_steps=MAX_STEPS,
    target_food=STAGE0_TARGET_FOOD,
    device=device
)

# Get input dimensions from environment
obs1, obs2 = env.reset()
input_dim = obs1.shape[1]  # Should be 35 for competitive features
output_dim = 3  # relative actions

print(f"Input dim: {input_dim}, Output dim: {output_dim}")

# Create Agent 1 (Big) networks
policy_net1 = DQN_MLP(input_dim, output_dim, BIG_HIDDEN_DIMS).to(device)
target_net1 = DQN_MLP(input_dim, output_dim, BIG_HIDDEN_DIMS).to(device)
target_net1.load_state_dict(policy_net1.state_dict())
target_net1.eval()
optimizer1 = optim.Adam(policy_net1.parameters(), lr=LEARNING_RATE)
buffer1 = ReplayBuffer(capacity=BUFFER_SIZE, seed=SEED)

# Create Agent 2 (Small) networks
policy_net2 = DQN_MLP(input_dim, output_dim, SMALL_HIDDEN_DIMS).to(device)
target_net2 = DQN_MLP(input_dim, output_dim, SMALL_HIDDEN_DIMS).to(device)
target_net2.load_state_dict(policy_net2.state_dict())
target_net2.eval()
optimizer2 = optim.Adam(policy_net2.parameters(), lr=LEARNING_RATE)
buffer2 = ReplayBuffer(capacity=BUFFER_SIZE, seed=SEED)

# Load scripted opponents
scripted_agents = {}
for agent_type in ['static', 'random', 'greedy']:
    try:
        scripted_agents[agent_type] = get_scripted_agent(agent_type, device=device)
        print(f"Loaded {agent_type} opponent")
    except Exception as e:
        print(f"Warning: Could not load {agent_type} agent: {e}")

print(f"\nEnvironment: {GRID_SIZE}x{GRID_SIZE} grid, {NUM_ENVS} parallel envs")
print(f"Agent 1 (Big): {BIG_HIDDEN_DIMS}")
print(f"Agent 2 (Small): {SMALL_HIDDEN_DIMS}")

In [None]:
# Cell 4: DQN Helper Functions

def select_action(policy_net, states, epsilon):
    """Select actions using epsilon-greedy policy"""
    with torch.no_grad():
        q_values = policy_net(states)
        greedy_actions = q_values.argmax(dim=1)
    
    num_envs = states.shape[0]
    random_mask = torch.rand(num_envs, device=device) < epsilon
    random_actions = torch.randint(0, 3, (num_envs,), device=device)
    actions = torch.where(random_mask, random_actions, greedy_actions)
    return actions

def select_greedy_action(policy_net, states):
    """Select greedy actions (for frozen policy)"""
    with torch.no_grad():
        q_values = policy_net(states)
        return q_values.argmax(dim=1)

def train_step(policy_net, target_net, optimizer, replay_buffer):
    """Perform one DQN training step"""
    if not replay_buffer.is_ready(BATCH_SIZE):
        return None
    
    states, actions, rewards, next_states, dones = replay_buffer.sample(BATCH_SIZE)
    
    states = states.to(device)
    actions = actions.to(device)
    rewards = rewards.to(device)
    next_states = next_states.to(device)
    dones = dones.to(device)
    
    # Compute Q values
    q_values = policy_net(states)
    q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
    
    # Compute target Q values
    with torch.no_grad():
        next_q_values = target_net(next_states).max(dim=1)[0]
        target_q_values = rewards + GAMMA * next_q_values * (1 - dones)
    
    # Compute loss
    loss = nn.functional.mse_loss(q_values, target_q_values)
    
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(policy_net.parameters(), 10.0)
    optimizer.step()
    
    return loss.item()

def store_transitions(buffer, states, actions, rewards, next_states, dones):
    """Store transitions in replay buffer"""
    states_np = states.cpu().numpy()
    actions_np = actions.cpu().numpy()
    rewards_np = rewards.cpu().numpy()
    next_states_np = next_states.cpu().numpy()
    dones_np = dones.cpu().numpy()
    
    for i in range(len(states_np)):
        buffer.push(states_np[i], actions_np[i], rewards_np[i], next_states_np[i], dones_np[i])

# Global tracking
total_steps = 0
total_rounds = 0
round_winners = []
all_scores1 = []
all_scores2 = []
all_losses = []
all_win_rates = []
stage_results = []

# Epsilon schedulers (will be reset per stage)
epsilon1 = EPSILON_START
epsilon2 = EPSILON_START

def calculate_win_rate(window=100):
    if len(round_winners) < window:
        window = len(round_winners)
    if window == 0:
        return 0.0
    recent = round_winners[-window:]
    return sum(1 for w in recent if w == 1) / window

print("DQN helper functions defined.")

In [None]:
# Cell 5: Stage 0 - Static Opponent
print("="*70)
print("STAGE 0: Static Opponent")
print("="*70)
print(f"Target food: {STAGE0_TARGET_FOOD}")
print(f"Min steps: {STAGE0_MIN_STEPS}")
print(f"Win rate threshold: {STAGE0_WIN_THRESHOLD}")
print("="*70 + "\n")

stage0_start = time.time()
stage_steps = 0
env.target_food = STAGE0_TARGET_FOOD
obs1, obs2 = env.reset()

# Reset epsilon for this stage
epsilon1 = EPSILON_START
epsilon_decay1 = (EPSILON_START - EPSILON_END) / STAGE0_MIN_STEPS

while True:
    # Agent 1 uses epsilon-greedy
    actions1 = select_action(policy_net1, obs1, epsilon1)
    # Agent 2 uses static scripted agent
    actions2 = scripted_agents['static'].select_action(env)
    
    next_obs1, next_obs2, r1, r2, dones, info = env.step(actions1, actions2)
    
    # Store transitions for agent 1
    store_transitions(buffer1, obs1, actions1, r1, next_obs1, dones)
    
    # Track completed rounds
    if dones.any():
        num_done = len(info['done_envs'])
        for i in range(num_done):
            round_winners.append(int(info['winners'][i]))
            total_rounds += 1
            all_scores1.append(info['food_counts1'][i])
            all_scores2.append(info['food_counts2'][i])
    
    obs1 = next_obs1
    obs2 = next_obs2
    stage_steps += NUM_ENVS
    total_steps += NUM_ENVS
    
    # Training
    if stage_steps % max(1, int(1 / TRAIN_STEPS_RATIO)) == 0:
        loss = train_step(policy_net1, target_net1, optimizer1, buffer1)
        if loss is not None:
            all_losses.append(loss)
    
    # Update target network
    if total_steps % TARGET_UPDATE_FREQ == 0:
        target_net1.load_state_dict(policy_net1.state_dict())
    
    # Decay epsilon
    epsilon1 = max(EPSILON_END, epsilon1 - epsilon_decay1 * NUM_ENVS)
    
    # Logging
    if total_steps % LOG_INTERVAL == 0:
        win_rate = calculate_win_rate()
        all_win_rates.append(win_rate)
        avg_score1 = np.mean(all_scores1[-100:]) if all_scores1 else 0
        avg_loss = np.mean(all_losses[-10:]) if all_losses else 0
        print(f"[Step {total_steps:>6}] Win Rate: {win_rate:.2%} | Score: {avg_score1:.1f} | Loss: {avg_loss:.4f} | Eps: {epsilon1:.3f}")
    
    # Check stage completion
    if stage_steps >= STAGE0_MIN_STEPS and calculate_win_rate() >= STAGE0_WIN_THRESHOLD:
        break
    if stage_steps >= STAGE0_MIN_STEPS * 3:  # Safety limit
        print("Warning: Stage 0 reached max steps without meeting threshold")
        break

stage0_time = time.time() - stage0_start
stage0_win_rate = calculate_win_rate()
stage_results.append({
    'stage': 0,
    'name': 'Static',
    'time': stage0_time,
    'steps': stage_steps,
    'win_rate': stage0_win_rate
})

print(f"\n{'='*70}")
print(f"STAGE 0 COMPLETE")
print(f"Time: {stage0_time:.1f}s | Steps: {stage_steps:,} | Win Rate: {stage0_win_rate:.2%}")
print("="*70)

In [None]:
# Cell 6: Stage 1 - Random Opponent
print("="*70)
print("STAGE 1: Random Opponent")
print("="*70)
print(f"Target food: {STAGE1_TARGET_FOOD}")
print(f"Min steps: {STAGE1_MIN_STEPS}")
print(f"Win rate threshold: {STAGE1_WIN_THRESHOLD}")
print("="*70 + "\n")

stage1_start = time.time()
stage_steps = 0
env.target_food = STAGE1_TARGET_FOOD
obs1, obs2 = env.reset()

# Reset epsilon for this stage
epsilon1 = EPSILON_START
epsilon_decay1 = (EPSILON_START - EPSILON_END) / STAGE1_MIN_STEPS

while True:
    actions1 = select_action(policy_net1, obs1, epsilon1)
    actions2 = scripted_agents['random'].select_action(env)
    
    next_obs1, next_obs2, r1, r2, dones, info = env.step(actions1, actions2)
    
    store_transitions(buffer1, obs1, actions1, r1, next_obs1, dones)
    
    if dones.any():
        num_done = len(info['done_envs'])
        for i in range(num_done):
            round_winners.append(int(info['winners'][i]))
            total_rounds += 1
            all_scores1.append(info['food_counts1'][i])
            all_scores2.append(info['food_counts2'][i])
    
    obs1 = next_obs1
    obs2 = next_obs2
    stage_steps += NUM_ENVS
    total_steps += NUM_ENVS
    
    if stage_steps % max(1, int(1 / TRAIN_STEPS_RATIO)) == 0:
        loss = train_step(policy_net1, target_net1, optimizer1, buffer1)
        if loss is not None:
            all_losses.append(loss)
    
    if total_steps % TARGET_UPDATE_FREQ == 0:
        target_net1.load_state_dict(policy_net1.state_dict())
    
    epsilon1 = max(EPSILON_END, epsilon1 - epsilon_decay1 * NUM_ENVS)
    
    if total_steps % LOG_INTERVAL == 0:
        win_rate = calculate_win_rate()
        all_win_rates.append(win_rate)
        avg_score1 = np.mean(all_scores1[-100:]) if all_scores1 else 0
        avg_loss = np.mean(all_losses[-10:]) if all_losses else 0
        print(f"[Step {total_steps:>6}] Win Rate: {win_rate:.2%} | Score: {avg_score1:.1f} | Loss: {avg_loss:.4f} | Eps: {epsilon1:.3f}")
    
    if stage_steps >= STAGE1_MIN_STEPS and calculate_win_rate() >= STAGE1_WIN_THRESHOLD:
        break
    if stage_steps >= STAGE1_MIN_STEPS * 3:
        print("Warning: Stage 1 reached max steps without meeting threshold")
        break

stage1_time = time.time() - stage1_start
stage1_win_rate = calculate_win_rate()
stage_results.append({
    'stage': 1,
    'name': 'Random',
    'time': stage1_time,
    'steps': stage_steps,
    'win_rate': stage1_win_rate
})

print(f"\n{'='*70}")
print(f"STAGE 1 COMPLETE")
print(f"Time: {stage1_time:.1f}s | Steps: {stage_steps:,} | Win Rate: {stage1_win_rate:.2%}")
print("="*70)

In [None]:
# Cell 7: Stage 2 - Greedy Opponent
print("="*70)
print("STAGE 2: Greedy Opponent")
print("="*70)
print(f"Target food: {STAGE2_TARGET_FOOD}")
print(f"Min steps: {STAGE2_MIN_STEPS}")
print(f"Win rate threshold: {STAGE2_WIN_THRESHOLD}")
print("="*70 + "\n")

stage2_start = time.time()
stage_steps = 0
env.target_food = STAGE2_TARGET_FOOD
obs1, obs2 = env.reset()

# Reset epsilon for this stage
epsilon1 = EPSILON_START
epsilon_decay1 = (EPSILON_START - EPSILON_END) / STAGE2_MIN_STEPS

while True:
    actions1 = select_action(policy_net1, obs1, epsilon1)
    actions2 = scripted_agents['greedy'].select_action(env)
    
    next_obs1, next_obs2, r1, r2, dones, info = env.step(actions1, actions2)
    
    store_transitions(buffer1, obs1, actions1, r1, next_obs1, dones)
    
    if dones.any():
        num_done = len(info['done_envs'])
        for i in range(num_done):
            round_winners.append(int(info['winners'][i]))
            total_rounds += 1
            all_scores1.append(info['food_counts1'][i])
            all_scores2.append(info['food_counts2'][i])
    
    obs1 = next_obs1
    obs2 = next_obs2
    stage_steps += NUM_ENVS
    total_steps += NUM_ENVS
    
    if stage_steps % max(1, int(1 / TRAIN_STEPS_RATIO)) == 0:
        loss = train_step(policy_net1, target_net1, optimizer1, buffer1)
        if loss is not None:
            all_losses.append(loss)
    
    if total_steps % TARGET_UPDATE_FREQ == 0:
        target_net1.load_state_dict(policy_net1.state_dict())
    
    epsilon1 = max(EPSILON_END, epsilon1 - epsilon_decay1 * NUM_ENVS)
    
    if total_steps % LOG_INTERVAL == 0:
        win_rate = calculate_win_rate()
        all_win_rates.append(win_rate)
        avg_score1 = np.mean(all_scores1[-100:]) if all_scores1 else 0
        avg_loss = np.mean(all_losses[-10:]) if all_losses else 0
        print(f"[Step {total_steps:>6}] Win Rate: {win_rate:.2%} | Score: {avg_score1:.1f} | Loss: {avg_loss:.4f} | Eps: {epsilon1:.3f}")
    
    if stage_steps >= STAGE2_MIN_STEPS and calculate_win_rate() >= STAGE2_WIN_THRESHOLD:
        break
    if stage_steps >= STAGE2_MIN_STEPS * 3:
        print("Warning: Stage 2 reached max steps without meeting threshold")
        break

stage2_time = time.time() - stage2_start
stage2_win_rate = calculate_win_rate()
stage_results.append({
    'stage': 2,
    'name': 'Greedy',
    'time': stage2_time,
    'steps': stage_steps,
    'win_rate': stage2_win_rate
})

print(f"\n{'='*70}")
print(f"STAGE 2 COMPLETE")
print(f"Time: {stage2_time:.1f}s | Steps: {stage_steps:,} | Win Rate: {stage2_win_rate:.2%}")
print("="*70)

In [None]:
# Cell 8: Stage 3 - Frozen Policy Opponent
print("="*70)
print("STAGE 3: Frozen Policy Opponent")
print("="*70)
print(f"Target food: {STAGE3_TARGET_FOOD}")
print(f"Min steps: {STAGE3_MIN_STEPS}")
print(f"Win rate threshold: {STAGE3_WIN_THRESHOLD}")
print("="*70 + "\n")

stage3_start = time.time()
stage_steps = 0
env.target_food = STAGE3_TARGET_FOOD
obs1, obs2 = env.reset()

# Reset epsilon for this stage
epsilon1 = EPSILON_START
epsilon_decay1 = (EPSILON_START - EPSILON_END) / STAGE3_MIN_STEPS

while True:
    actions1 = select_action(policy_net1, obs1, epsilon1)
    # Frozen policy: use agent2's policy with greedy selection
    actions2 = select_greedy_action(policy_net2, obs2)
    
    next_obs1, next_obs2, r1, r2, dones, info = env.step(actions1, actions2)
    
    store_transitions(buffer1, obs1, actions1, r1, next_obs1, dones)
    
    if dones.any():
        num_done = len(info['done_envs'])
        for i in range(num_done):
            round_winners.append(int(info['winners'][i]))
            total_rounds += 1
            all_scores1.append(info['food_counts1'][i])
            all_scores2.append(info['food_counts2'][i])
    
    obs1 = next_obs1
    obs2 = next_obs2
    stage_steps += NUM_ENVS
    total_steps += NUM_ENVS
    
    if stage_steps % max(1, int(1 / TRAIN_STEPS_RATIO)) == 0:
        loss = train_step(policy_net1, target_net1, optimizer1, buffer1)
        if loss is not None:
            all_losses.append(loss)
    
    if total_steps % TARGET_UPDATE_FREQ == 0:
        target_net1.load_state_dict(policy_net1.state_dict())
    
    epsilon1 = max(EPSILON_END, epsilon1 - epsilon_decay1 * NUM_ENVS)
    
    if total_steps % LOG_INTERVAL == 0:
        win_rate = calculate_win_rate()
        all_win_rates.append(win_rate)
        avg_score1 = np.mean(all_scores1[-100:]) if all_scores1 else 0
        avg_loss = np.mean(all_losses[-10:]) if all_losses else 0
        print(f"[Step {total_steps:>6}] Win Rate: {win_rate:.2%} | Score: {avg_score1:.1f} | Loss: {avg_loss:.4f} | Eps: {epsilon1:.3f}")
    
    if stage_steps >= STAGE3_MIN_STEPS and calculate_win_rate() >= STAGE3_WIN_THRESHOLD:
        break
    if stage_steps >= STAGE3_MIN_STEPS * 3:
        print("Warning: Stage 3 reached max steps without meeting threshold")
        break

stage3_time = time.time() - stage3_start
stage3_win_rate = calculate_win_rate()
stage_results.append({
    'stage': 3,
    'name': 'Frozen',
    'time': stage3_time,
    'steps': stage_steps,
    'win_rate': stage3_win_rate
})

print(f"\n{'='*70}")
print(f"STAGE 3 COMPLETE")
print(f"Time: {stage3_time:.1f}s | Steps: {stage_steps:,} | Win Rate: {stage3_win_rate:.2%}")
print("="*70)

In [None]:
# Cell 9: Stage 4 - Co-Evolution (Both Learning)
print("="*70)
print("STAGE 4: Co-Evolution (Both Learning)")
print("="*70)
print(f"Target food: {STAGE4_TARGET_FOOD}")
print(f"Min steps: {STAGE4_MIN_STEPS}")
print("Win rate threshold: None (train for full duration)")
print("="*70 + "\n")

stage4_start = time.time()
stage_steps = 0
env.target_food = STAGE4_TARGET_FOOD
obs1, obs2 = env.reset()

# Reset epsilon for both agents
epsilon1 = EPSILON_START
epsilon2 = EPSILON_START
epsilon_decay1 = (EPSILON_START - EPSILON_END) / STAGE4_MIN_STEPS
epsilon_decay2 = (EPSILON_START - EPSILON_END) / STAGE4_MIN_STEPS

while stage_steps < STAGE4_MIN_STEPS:
    # Both agents learning with epsilon-greedy
    actions1 = select_action(policy_net1, obs1, epsilon1)
    actions2 = select_action(policy_net2, obs2, epsilon2)
    
    next_obs1, next_obs2, r1, r2, dones, info = env.step(actions1, actions2)
    
    # Store transitions for both agents
    store_transitions(buffer1, obs1, actions1, r1, next_obs1, dones)
    store_transitions(buffer2, obs2, actions2, r2, next_obs2, dones)
    
    if dones.any():
        num_done = len(info['done_envs'])
        for i in range(num_done):
            round_winners.append(int(info['winners'][i]))
            total_rounds += 1
            all_scores1.append(info['food_counts1'][i])
            all_scores2.append(info['food_counts2'][i])
    
    obs1 = next_obs1
    obs2 = next_obs2
    stage_steps += NUM_ENVS
    total_steps += NUM_ENVS
    
    # Train both agents
    if stage_steps % max(1, int(1 / TRAIN_STEPS_RATIO)) == 0:
        loss1 = train_step(policy_net1, target_net1, optimizer1, buffer1)
        loss2 = train_step(policy_net2, target_net2, optimizer2, buffer2)
        if loss1 is not None:
            all_losses.append(loss1)
    
    # Update target networks
    if total_steps % TARGET_UPDATE_FREQ == 0:
        target_net1.load_state_dict(policy_net1.state_dict())
        target_net2.load_state_dict(policy_net2.state_dict())
    
    # Decay epsilon for both
    epsilon1 = max(EPSILON_END, epsilon1 - epsilon_decay1 * NUM_ENVS)
    epsilon2 = max(EPSILON_END, epsilon2 - epsilon_decay2 * NUM_ENVS)
    
    if total_steps % LOG_INTERVAL == 0:
        win_rate = calculate_win_rate()
        all_win_rates.append(win_rate)
        avg_score1 = np.mean(all_scores1[-100:]) if all_scores1 else 0
        avg_score2 = np.mean(all_scores2[-100:]) if all_scores2 else 0
        print(f"[Step {total_steps:>6}] Win Rate: {win_rate:.2%} | Scores: {avg_score1:.1f} vs {avg_score2:.1f}")

stage4_time = time.time() - stage4_start
stage4_win_rate = calculate_win_rate()
stage_results.append({
    'stage': 4,
    'name': 'Co-Evolution',
    'time': stage4_time,
    'steps': stage_steps,
    'win_rate': stage4_win_rate
})

total_training_time = sum(r['time'] for r in stage_results)

print(f"\n{'='*70}")
print(f"STAGE 4 COMPLETE")
print(f"Time: {stage4_time:.1f}s | Steps: {stage_steps:,} | Win Rate: {stage4_win_rate:.2%}")
print("="*70)
print(f"\nTOTAL CURRICULUM TRAINING COMPLETE!")
print(f"Total time: {total_training_time:.1f}s ({total_training_time/60:.1f} min)")
print(f"Total steps: {total_steps:,}")

In [None]:
# Cell 10: Save Weights
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# Save Agent 1 (Big)
agent1_path = save_dir / f"big_256x256_curriculum_{timestamp}.pt"
torch.save({
    'policy_net': policy_net1.state_dict(),
    'target_net': target_net1.state_dict(),
    'optimizer': optimizer1.state_dict(),
    'total_steps': total_steps,
    'config': {
        'hidden_dims': BIG_HIDDEN_DIMS,
        'grid_size': GRID_SIZE,
        'learning_rate': LEARNING_RATE
    }
}, agent1_path)
print(f"Agent 1 (Big) saved to: {agent1_path}")

# Save Agent 2 (Small)
agent2_path = save_dir / f"small_128x128_curriculum_{timestamp}.pt"
torch.save({
    'policy_net': policy_net2.state_dict(),
    'target_net': target_net2.state_dict(),
    'optimizer': optimizer2.state_dict(),
    'total_steps': total_steps,
    'config': {
        'hidden_dims': SMALL_HIDDEN_DIMS,
        'grid_size': GRID_SIZE,
        'learning_rate': LEARNING_RATE
    }
}, agent2_path)
print(f"Agent 2 (Small) saved to: {agent2_path}")

In [None]:
# Cell 11: Visualization
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

def smooth(data, window=50):
    if len(data) < window:
        return data
    return np.convolve(data, np.ones(window)/window, mode='valid')

# Plot 1: Agent Scores
ax1 = axes[0, 0]
ax1.plot(all_scores1, alpha=0.3, label='Agent 1 (Big)', color='blue')
ax1.plot(all_scores2, alpha=0.3, label='Agent 2 (Small)', color='red')
if len(all_scores1) > 50:
    ax1.plot(range(49, len(all_scores1)), smooth(all_scores1), color='blue', linewidth=2)
    ax1.plot(range(49, len(all_scores2)), smooth(all_scores2), color='red', linewidth=2)
ax1.set_xlabel('Round')
ax1.set_ylabel('Score (Food Eaten)')
ax1.set_title('Agent Scores')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot 2: Win Rate Over Time
ax2 = axes[0, 1]
ax2.plot(all_win_rates, color='green')
ax2.axhline(y=0.5, color='gray', linestyle='--', label='50% (balanced)')
ax2.set_xlabel('Training Progress')
ax2.set_ylabel('Agent 1 Win Rate')
ax2.set_title('Win Rate Over Time')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Plot 3: Training Loss
ax3 = axes[1, 0]
ax3.plot(all_losses, alpha=0.5)
if len(all_losses) > 50:
    ax3.plot(range(49, len(all_losses)), smooth(all_losses), color='blue', linewidth=2)
ax3.set_xlabel('Update')
ax3.set_ylabel('Loss')
ax3.set_title('DQN Loss')
ax3.grid(True, alpha=0.3)

# Plot 4: Stage Results
ax4 = axes[1, 1]
stage_names = [r['name'] for r in stage_results]
stage_win_rates = [r['win_rate'] * 100 for r in stage_results]
colors = ['#2ecc71', '#3498db', '#e74c3c', '#9b59b6', '#f39c12']
bars = ax4.bar(stage_names, stage_win_rates, color=colors[:len(stage_results)])
ax4.set_xlabel('Stage')
ax4.set_ylabel('Final Win Rate (%)')
ax4.set_title('Win Rate by Curriculum Stage')
ax4.set_ylim(0, 100)
for bar, rate in zip(bars, stage_win_rates):
    ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, 
             f'{rate:.1f}%', ha='center', va='bottom', fontsize=10)
ax4.grid(True, alpha=0.3, axis='y')

plt.suptitle(f'DQN Two-Snake Curriculum Training Results', fontsize=14)
plt.tight_layout()

fig_dir = project_root / 'results' / 'figures'
fig_dir.mkdir(parents=True, exist_ok=True)
fig_path = fig_dir / f'dqn_two_snake_curriculum_{timestamp}.png'
plt.savefig(fig_path, dpi=150, bbox_inches='tight')
print(f"Figure saved to: {fig_path}")

plt.show()

In [None]:
# Cell 12: Results Summary
print("=" * 70)
print("CURRICULUM TRAINING SUMMARY")
print("=" * 70)
print(f"Algorithm: DQN Two-Snake Curriculum")
print(f"Grid Size: {GRID_SIZE}x{GRID_SIZE}")
print(f"Total Steps: {total_steps:,}")
print(f"Total Rounds: {total_rounds:,}")
print(f"Total Training Time: {total_training_time:.1f}s ({total_training_time/60:.1f} min)")
print()
print("Stage Results:")
print("-" * 70)
print(f"{'Stage':<15} {'Opponent':<15} {'Steps':>10} {'Time':>10} {'Win Rate':>12}")
print("-" * 70)
for r in stage_results:
    print(f"Stage {r['stage']:<8} {r['name']:<15} {r['steps']:>10,} {r['time']:>9.1f}s {r['win_rate']:>11.2%}")
print("-" * 70)
print()
print("Final Performance (last 100 rounds):")
print(f"  Agent 1 Avg Score: {np.mean(all_scores1[-100:]):.2f}")
print(f"  Agent 2 Avg Score: {np.mean(all_scores2[-100:]):.2f}")
print(f"  Final Win Rate: {calculate_win_rate():.2%}")
print("=" * 70)