# A2C Training Notebook

This notebook trains an A2C (Advantage Actor-Critic) agent to play Snake.

**Algorithm**: A2C (Synchronous Advantage Actor-Critic) with:
- Actor-Critic architecture
- N-step advantage estimation
- Short rollouts with value bootstrapping
- Lower variance than REINFORCE

In [None]:
# Cell 1: Imports
import sys
from pathlib import Path

project_root = Path.cwd().parent.parent
sys.path.insert(0, str(project_root))

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

from core.environment_vectorized import VectorizedSnakeEnv
from core.networks import PPO_Actor_MLP, PPO_Critic_MLP, PPO_Actor_CNN, PPO_Critic_CNN
from core.utils import MetricsTracker, set_seed, get_device

print(f"Project root: {project_root}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
# Cell 2: Configuration (papermill parameters)
# ============== CONFIGURATION ==============

# Environment
GRID_SIZE = 10
NUM_ENVS = 256
MAX_STEPS = 1000
ACTION_SPACE_TYPE = 'relative'  # 'absolute' or 'relative'
STATE_REPRESENTATION = 'feature'  # 'feature' or 'grid'
USE_FLOOD_FILL = False

# Training
NUM_EPISODES = 500  # Short for testing
ROLLOUT_STEPS = 5  # N-step for advantage estimation

# Learning rates
ACTOR_LR = 0.0003
CRITIC_LR = 0.001

# A2C hyperparameters
GAMMA = 0.99  # Discount factor
ENTROPY_COEF = 0.01  # Entropy bonus coefficient
VALUE_COEF = 0.5  # Value loss coefficient
MAX_GRAD_NORM = 0.5  # Gradient clipping

# Network
HIDDEN_DIMS = (128, 128)

# Output
SAVE_DIR = '../../results/weights/a2c'
SEED = 42
LOG_INTERVAL = 50

In [None]:
# Cell 3: Environment & Model Setup
set_seed(SEED)
device = get_device()
print(f"Using device: {device}")

save_dir = Path(SAVE_DIR)
save_dir.mkdir(parents=True, exist_ok=True)

# Create environment
env = VectorizedSnakeEnv(
    num_envs=NUM_ENVS,
    grid_size=GRID_SIZE,
    action_space_type=ACTION_SPACE_TYPE,
    state_representation=STATE_REPRESENTATION,
    max_steps=MAX_STEPS,
    use_flood_fill=USE_FLOOD_FILL,
    device=device
)

# Determine dimensions
if STATE_REPRESENTATION == 'feature':
    input_dim = 14 if USE_FLOOD_FILL else 11
else:
    input_dim = GRID_SIZE

output_dim = 3 if ACTION_SPACE_TYPE == 'relative' else 4

# Create Actor-Critic networks
if STATE_REPRESENTATION == 'feature':
    actor = PPO_Actor_MLP(input_dim, output_dim, HIDDEN_DIMS).to(device)
    critic = PPO_Critic_MLP(input_dim, HIDDEN_DIMS).to(device)
else:
    actor = PPO_Actor_CNN(GRID_SIZE, 3, output_dim).to(device)
    critic = PPO_Critic_CNN(GRID_SIZE, 3).to(device)

# Optimizers
actor_optimizer = torch.optim.Adam(actor.parameters(), lr=ACTOR_LR)
critic_optimizer = torch.optim.Adam(critic.parameters(), lr=CRITIC_LR)

# Metrics
metrics = MetricsTracker(window_size=100)

print(f"Environment: {GRID_SIZE}x{GRID_SIZE} grid, {NUM_ENVS} parallel envs")
print(f"Actor parameters: {sum(p.numel() for p in actor.parameters())}")
print(f"Critic parameters: {sum(p.numel() for p in critic.parameters())}")
print(f"N-step rollout: {ROLLOUT_STEPS}")

In [None]:
# Cell 4: A2C Helper Functions

def select_actions(states):
    """Select actions and get values"""
    with torch.no_grad():
        logits = actor(states)
        values = critic(states)
        probs = F.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(probs)
        actions = dist.sample()
        log_probs = dist.log_prob(actions)
    return actions, log_probs, values.squeeze()

def compute_nstep_returns(rewards, values, dones, next_value):
    """Compute n-step returns with bootstrapping"""
    returns = []
    R = next_value
    
    for t in reversed(range(len(rewards))):
        mask = 1.0 - dones[t].float()
        R = rewards[t] + GAMMA * R * mask
        returns.insert(0, R)
    
    return torch.stack(returns)

def a2c_update(states, actions, log_probs, returns, values):
    """A2C policy and value update"""
    # Compute advantages
    advantages = returns - values
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
    
    # Recompute action distribution
    logits = actor(states)
    new_values = critic(states).squeeze()
    probs = F.softmax(logits, dim=-1)
    dist = torch.distributions.Categorical(probs)
    new_log_probs = dist.log_prob(actions)
    entropy = dist.entropy().mean()
    
    # Actor loss (policy gradient with advantage)
    actor_loss = -(new_log_probs * advantages.detach()).mean()
    
    # Critic loss (value function)
    critic_loss = F.mse_loss(new_values, returns.detach())
    
    # Update actor
    actor_optimizer.zero_grad()
    (actor_loss - ENTROPY_COEF * entropy).backward()
    torch.nn.utils.clip_grad_norm_(actor.parameters(), MAX_GRAD_NORM)
    actor_optimizer.step()
    
    # Update critic
    critic_optimizer.zero_grad()
    critic_loss.backward()
    torch.nn.utils.clip_grad_norm_(critic.parameters(), MAX_GRAD_NORM)
    critic_optimizer.step()
    
    return actor_loss.item(), critic_loss.item(), entropy.item()

print("A2C helper functions defined.")

In [None]:
# Cell 5: Training Loop
import time

print("Starting A2C Training...")
print(f"Device: {device}")
print(f"Episodes: {NUM_EPISODES}")
print(f"N-step: {ROLLOUT_STEPS}")
print()

states = env.reset(seed=SEED)
start_time = time.time()

episode_rewards = torch.zeros(NUM_ENVS, device=device)
episode_lengths = torch.zeros(NUM_ENVS, device=device)

episode = 0
total_steps = 0

# Tracking for plots
all_rewards = []
all_scores = []
all_actor_losses = []
all_critic_losses = []
all_entropies = []

while episode < NUM_EPISODES:
    # Collect n-step rollout
    rollout_states = []
    rollout_actions = []
    rollout_log_probs = []
    rollout_rewards = []
    rollout_values = []
    rollout_dones = []
    
    for _ in range(ROLLOUT_STEPS):
        actions, log_probs, values = select_actions(states)
        next_states, rewards, dones, info = env.step(actions)
        
        rollout_states.append(states)
        rollout_actions.append(actions)
        rollout_log_probs.append(log_probs)
        rollout_rewards.append(rewards)
        rollout_values.append(values)
        rollout_dones.append(dones)
        
        episode_rewards += rewards
        episode_lengths += 1
        total_steps += NUM_ENVS
        
        if dones.any():
            done_indices = torch.where(dones)[0]
            for idx in done_indices:
                reward = episode_rewards[idx].item()
                length = episode_lengths[idx].item()
                score = info['scores'][idx].item()
                
                metrics.add_episode(reward, length, score)
                all_rewards.append(reward)
                all_scores.append(score)
                
                episode += 1
                episode_rewards[idx] = 0
                episode_lengths[idx] = 0
                
                if episode % LOG_INTERVAL == 0:
                    stats = metrics.get_recent_stats()
                    elapsed = time.time() - start_time
                    fps = total_steps / elapsed if elapsed > 0 else 0
                    print(f"Episode {episode}/{NUM_EPISODES} | "
                          f"Score: {stats['avg_score']:.2f} | "
                          f"Reward: {stats['avg_reward']:.2f} | "
                          f"FPS: {fps:.0f}")
                
                if episode >= NUM_EPISODES:
                    break
        
        states = next_states
        if episode >= NUM_EPISODES:
            break
    
    if episode >= NUM_EPISODES:
        break
    
    # Compute bootstrap value
    with torch.no_grad():
        next_value = critic(states).squeeze()
    
    # Stack rollout
    rollout_states = torch.stack(rollout_states)  # (T, N, ...)
    rollout_actions = torch.stack(rollout_actions)  # (T, N)
    rollout_log_probs = torch.stack(rollout_log_probs)  # (T, N)
    rollout_rewards = torch.stack(rollout_rewards)  # (T, N)
    rollout_values = torch.stack(rollout_values)  # (T, N)
    rollout_dones = torch.stack(rollout_dones)  # (T, N)
    
    # Compute returns
    returns = compute_nstep_returns(rollout_rewards, rollout_values, rollout_dones, next_value)
    
    # Flatten for update
    flat_states = rollout_states.view(-1, rollout_states.shape[-1])
    flat_actions = rollout_actions.view(-1)
    flat_log_probs = rollout_log_probs.view(-1)
    flat_returns = returns.view(-1)
    flat_values = rollout_values.view(-1)
    
    # A2C update
    actor_loss, critic_loss, entropy = a2c_update(
        flat_states, flat_actions, flat_log_probs, flat_returns, flat_values
    )
    all_actor_losses.append(actor_loss)
    all_critic_losses.append(critic_loss)
    all_entropies.append(entropy)

training_time = time.time() - start_time
print(f"\nTraining complete!")
print(f"Total time: {training_time:.1f}s")
print(f"Total steps: {total_steps:,}")

In [None]:
# Cell 6: Save Weights
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"a2c_{GRID_SIZE}x{GRID_SIZE}_{NUM_EPISODES}ep_{timestamp}.pt"
filepath = save_dir / filename

torch.save({
    'actor': actor.state_dict(),
    'critic': critic.state_dict(),
    'actor_optimizer': actor_optimizer.state_dict(),
    'critic_optimizer': critic_optimizer.state_dict(),
    'episode': episode,
    'total_steps': total_steps,
    'config': {
        'grid_size': GRID_SIZE,
        'num_envs': NUM_ENVS,
        'hidden_dims': HIDDEN_DIMS,
        'actor_lr': ACTOR_LR,
        'critic_lr': CRITIC_LR,
        'gamma': GAMMA,
        'rollout_steps': ROLLOUT_STEPS,
        'state_representation': STATE_REPRESENTATION
    }
}, filepath)

print(f"Model saved to: {filepath}")

In [None]:
# Cell 7: Visualization
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

def smooth(data, window=50):
    if len(data) < window:
        return data
    return np.convolve(data, np.ones(window)/window, mode='valid')

# Plot 1: Episode Rewards
ax1 = axes[0, 0]
ax1.plot(all_rewards, alpha=0.3, label='Raw')
if len(all_rewards) > 50:
    ax1.plot(range(49, len(all_rewards)), smooth(all_rewards), label='Smoothed (50)')
ax1.set_xlabel('Episode')
ax1.set_ylabel('Reward')
ax1.set_title('Episode Rewards')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot 2: Episode Scores
ax2 = axes[0, 1]
ax2.plot(all_scores, alpha=0.3, label='Raw')
if len(all_scores) > 50:
    ax2.plot(range(49, len(all_scores)), smooth(all_scores), label='Smoothed (50)')
ax2.set_xlabel('Episode')
ax2.set_ylabel('Score (Food Eaten)')
ax2.set_title('Episode Scores')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Plot 3: Actor and Critic Loss
ax3 = axes[1, 0]
ax3.plot(all_actor_losses, label='Actor Loss', alpha=0.7)
ax3.plot(all_critic_losses, label='Critic Loss', alpha=0.7)
ax3.set_xlabel('Update')
ax3.set_ylabel('Loss')
ax3.set_title('A2C Losses')
ax3.legend()
ax3.grid(True, alpha=0.3)

# Plot 4: Entropy
ax4 = axes[1, 1]
ax4.plot(all_entropies)
ax4.set_xlabel('Update')
ax4.set_ylabel('Entropy')
ax4.set_title('Policy Entropy')
ax4.grid(True, alpha=0.3)

plt.suptitle(f'A2C Training Results - {NUM_EPISODES} Episodes', fontsize=14)
plt.tight_layout()

fig_dir = project_root / 'results' / 'figures'
fig_dir.mkdir(parents=True, exist_ok=True)
fig_path = fig_dir / f'a2c_training_{timestamp}.png'
plt.savefig(fig_path, dpi=150, bbox_inches='tight')
print(f"Figure saved to: {fig_path}")

plt.show()

In [None]:
# Cell 8: Results Summary
stats = metrics.get_recent_stats()

print("=" * 50)
print("TRAINING SUMMARY")
print("=" * 50)
print(f"Algorithm: A2C (Advantage Actor-Critic)")
print(f"N-step rollout: {ROLLOUT_STEPS}")
print(f"Episodes: {episode}")
print(f"Total Steps: {total_steps:,}")
print(f"Training Time: {training_time:.1f}s")
print()
print("Final Performance (last 100 episodes):")
print(f"  Average Score: {stats['avg_score']:.2f}")
print(f"  Average Reward: {stats['avg_reward']:.2f}")
print(f"  Average Length: {stats['avg_length']:.2f}")
print(f"  Max Score: {stats['max_score']}")
print()
print(f"Overall Statistics:")
print(f"  Mean Score: {np.mean(all_scores):.2f} +/- {np.std(all_scores):.2f}")
print(f"  Max Score: {max(all_scores)}")
print(f"  Mean Reward: {np.mean(all_rewards):.2f} +/- {np.std(all_rewards):.2f}")
print("=" * 50)