# REINFORCE Training Notebook

This notebook trains a REINFORCE agent to play Snake.

**Algorithm**: REINFORCE (Monte Carlo Policy Gradient) with:
- Policy-only network (no critic/value function)
- Episode-based updates using complete returns
- High variance but unbiased gradient estimates
- Entropy bonus for exploration

In [None]:
# Cell 1: Imports
import sys
from pathlib import Path

project_root = Path.cwd().parent.parent
sys.path.insert(0, str(project_root))

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

from core.environment_vectorized import VectorizedSnakeEnv
from core.networks import PPO_Actor_MLP, PPO_Actor_CNN
from core.utils import MetricsTracker, set_seed, get_device

print(f"Project root: {project_root}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
# Cell 2: Configuration (papermill parameters)
# ============== CONFIGURATION ==============

# Environment
GRID_SIZE = 10
NUM_ENVS = 256
MAX_STEPS = 1000
ACTION_SPACE_TYPE = 'relative'  # 'absolute' or 'relative'
STATE_REPRESENTATION = 'feature'  # 'feature' or 'grid'
USE_FLOOD_FILL = False

# Training
NUM_EPISODES = 500  # Short for testing
LEARNING_RATE = 0.001

# REINFORCE hyperparameters
GAMMA = 0.99  # Discount factor
ENTROPY_COEF = 0.01  # Entropy bonus coefficient
MAX_GRAD_NORM = 0.5  # Gradient clipping

# Network
HIDDEN_DIMS = (128, 128)

# Output
SAVE_DIR = '../../results/weights/reinforce'
SEED = 42
LOG_INTERVAL = 50

In [None]:
# Cell 3: Environment & Model Setup
set_seed(SEED)
device = get_device()
print(f"Using device: {device}")

save_dir = Path(SAVE_DIR)
save_dir.mkdir(parents=True, exist_ok=True)

# Create environment
env = VectorizedSnakeEnv(
    num_envs=NUM_ENVS,
    grid_size=GRID_SIZE,
    action_space_type=ACTION_SPACE_TYPE,
    state_representation=STATE_REPRESENTATION,
    max_steps=MAX_STEPS,
    use_flood_fill=USE_FLOOD_FILL,
    device=device
)

# Determine dimensions
if STATE_REPRESENTATION == 'feature':
    input_dim = 14 if USE_FLOOD_FILL else 11
else:
    input_dim = GRID_SIZE

output_dim = 3 if ACTION_SPACE_TYPE == 'relative' else 4

# Create policy network (REINFORCE has no critic)
if STATE_REPRESENTATION == 'feature':
    policy = PPO_Actor_MLP(input_dim, output_dim, HIDDEN_DIMS).to(device)
else:
    policy = PPO_Actor_CNN(GRID_SIZE, 3, output_dim).to(device)

# Optimizer
optimizer = torch.optim.Adam(policy.parameters(), lr=LEARNING_RATE)

# Metrics
metrics = MetricsTracker(window_size=100)

print(f"Environment: {GRID_SIZE}x{GRID_SIZE} grid, {NUM_ENVS} parallel envs")
print(f"Policy parameters: {sum(p.numel() for p in policy.parameters())}")
print(f"Note: REINFORCE is policy-only (no critic network)")

In [None]:
# Cell 4: REINFORCE Helper Functions

class REINFORCEBuffer:
    """Episode buffer for Monte Carlo returns"""
    def __init__(self, device):
        self.device = device
        self.clear()
    
    def clear(self):
        self.states = []
        self.actions = []
        self.rewards = []
        self.log_probs = []
    
    def add(self, state, action, reward, log_prob):
        self.states.append(state.cpu())
        self.actions.append(action.cpu())
        self.rewards.append(reward.cpu())
        self.log_probs.append(log_prob.cpu())
    
    def get(self):
        states = torch.cat(self.states, dim=0).to(self.device)
        actions = torch.cat(self.actions, dim=0).to(self.device)
        rewards = torch.cat(self.rewards, dim=0).to(self.device)
        log_probs = torch.cat(self.log_probs, dim=0).to(self.device)
        return states, actions, rewards, log_probs
    
    def size(self):
        return sum(s.shape[0] for s in self.states) if self.states else 0

def select_actions(states):
    """Select actions using current policy"""
    logits = policy(states)
    probs = F.softmax(logits, dim=-1)
    dist = torch.distributions.Categorical(probs)
    actions = dist.sample()
    log_probs = dist.log_prob(actions)
    return actions, log_probs

def compute_returns(rewards):
    """Compute discounted Monte Carlo returns"""
    returns = []
    G = 0
    for r in reversed(rewards):
        G = r + GAMMA * G
        returns.insert(0, G)
    return torch.tensor(returns, dtype=torch.float32, device=device)

def reinforce_update(states, actions, log_probs, returns):
    """REINFORCE policy gradient update"""
    # Normalize returns for stability
    returns = (returns - returns.mean()) / (returns.std() + 1e-8)
    
    # Recompute log probs and entropy
    logits = policy(states)
    probs = F.softmax(logits, dim=-1)
    dist = torch.distributions.Categorical(probs)
    new_log_probs = dist.log_prob(actions)
    entropy = dist.entropy().mean()
    
    # Policy loss: -log_prob * return (maximize expected return)
    policy_loss = -(new_log_probs * returns).mean()
    
    # Total loss with entropy bonus
    loss = policy_loss - ENTROPY_COEF * entropy
    
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(policy.parameters(), MAX_GRAD_NORM)
    optimizer.step()
    
    return policy_loss.item(), entropy.item()

print("REINFORCE helper functions defined.")

In [None]:
# Cell 5: Training Loop
import time

print("Starting REINFORCE Training...")
print(f"Device: {device}")
print(f"Episodes: {NUM_EPISODES}")
print()

buffer = REINFORCEBuffer(device)
states = env.reset(seed=SEED)
start_time = time.time()

episode_rewards = torch.zeros(NUM_ENVS, device=device)
episode_lengths = torch.zeros(NUM_ENVS, device=device)

episode = 0
total_steps = 0

# Tracking for plots
all_rewards = []
all_scores = []
all_losses = []
all_entropies = []

# Per-environment episode buffers
env_rewards = [[] for _ in range(NUM_ENVS)]
env_log_probs = [[] for _ in range(NUM_ENVS)]
env_states = [[] for _ in range(NUM_ENVS)]
env_actions = [[] for _ in range(NUM_ENVS)]

while episode < NUM_EPISODES:
    with torch.no_grad():
        actions, log_probs = select_actions(states)
    
    next_states, rewards, dones, info = env.step(actions)
    
    # Store per-environment
    for i in range(NUM_ENVS):
        env_states[i].append(states[i])
        env_actions[i].append(actions[i])
        env_rewards[i].append(rewards[i].item())
        env_log_probs[i].append(log_probs[i])
    
    episode_rewards += rewards
    episode_lengths += 1
    total_steps += NUM_ENVS
    
    if dones.any():
        done_indices = torch.where(dones)[0]
        for idx in done_indices:
            idx = idx.item()
            reward = episode_rewards[idx].item()
            length = episode_lengths[idx].item()
            score = info['scores'][idx].item()
            
            # Compute returns for this episode
            ep_returns = compute_returns(env_rewards[idx])
            ep_states = torch.stack(env_states[idx])
            ep_actions = torch.stack(env_actions[idx])
            ep_log_probs = torch.stack(env_log_probs[idx])
            
            # Update policy
            loss, entropy = reinforce_update(ep_states, ep_actions, ep_log_probs, ep_returns)
            all_losses.append(loss)
            all_entropies.append(entropy)
            
            # Record metrics
            metrics.add_episode(reward, length, score)
            all_rewards.append(reward)
            all_scores.append(score)
            
            episode += 1
            
            # Clear episode buffer
            env_rewards[idx] = []
            env_log_probs[idx] = []
            env_states[idx] = []
            env_actions[idx] = []
            episode_rewards[idx] = 0
            episode_lengths[idx] = 0
            
            if episode % LOG_INTERVAL == 0:
                stats = metrics.get_recent_stats()
                elapsed = time.time() - start_time
                fps = total_steps / elapsed if elapsed > 0 else 0
                print(f"Episode {episode}/{NUM_EPISODES} | "
                      f"Score: {stats['avg_score']:.2f} | "
                      f"Reward: {stats['avg_reward']:.2f} | "
                      f"FPS: {fps:.0f}")
            
            if episode >= NUM_EPISODES:
                break
    
    states = next_states

training_time = time.time() - start_time
print(f"\nTraining complete!")
print(f"Total time: {training_time:.1f}s")
print(f"Total steps: {total_steps:,}")

In [None]:
# Cell 6: Save Weights
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"reinforce_{GRID_SIZE}x{GRID_SIZE}_{NUM_EPISODES}ep_{timestamp}.pt"
filepath = save_dir / filename

torch.save({
    'policy': policy.state_dict(),
    'optimizer': optimizer.state_dict(),
    'episode': episode,
    'total_steps': total_steps,
    'config': {
        'grid_size': GRID_SIZE,
        'num_envs': NUM_ENVS,
        'hidden_dims': HIDDEN_DIMS,
        'learning_rate': LEARNING_RATE,
        'gamma': GAMMA,
        'entropy_coef': ENTROPY_COEF,
        'state_representation': STATE_REPRESENTATION
    }
}, filepath)

print(f"Model saved to: {filepath}")

In [None]:
# Cell 7: Visualization
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

def smooth(data, window=50):
    if len(data) < window:
        return data
    return np.convolve(data, np.ones(window)/window, mode='valid')

# Plot 1: Episode Rewards
ax1 = axes[0, 0]
ax1.plot(all_rewards, alpha=0.3, label='Raw')
if len(all_rewards) > 50:
    ax1.plot(range(49, len(all_rewards)), smooth(all_rewards), label='Smoothed (50)')
ax1.set_xlabel('Episode')
ax1.set_ylabel('Reward')
ax1.set_title('Episode Rewards')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot 2: Episode Scores
ax2 = axes[0, 1]
ax2.plot(all_scores, alpha=0.3, label='Raw')
if len(all_scores) > 50:
    ax2.plot(range(49, len(all_scores)), smooth(all_scores), label='Smoothed (50)')
ax2.set_xlabel('Episode')
ax2.set_ylabel('Score (Food Eaten)')
ax2.set_title('Episode Scores')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Plot 3: Policy Loss
ax3 = axes[1, 0]
ax3.plot(all_losses, alpha=0.5)
ax3.set_xlabel('Episode')
ax3.set_ylabel('Policy Loss')
ax3.set_title('REINFORCE Policy Loss')
ax3.grid(True, alpha=0.3)

# Plot 4: Entropy
ax4 = axes[1, 1]
ax4.plot(all_entropies)
ax4.set_xlabel('Episode')
ax4.set_ylabel('Entropy')
ax4.set_title('Policy Entropy')
ax4.grid(True, alpha=0.3)

plt.suptitle(f'REINFORCE Training Results - {NUM_EPISODES} Episodes', fontsize=14)
plt.tight_layout()

fig_dir = project_root / 'results' / 'figures'
fig_dir.mkdir(parents=True, exist_ok=True)
fig_path = fig_dir / f'reinforce_training_{timestamp}.png'
plt.savefig(fig_path, dpi=150, bbox_inches='tight')
print(f"Figure saved to: {fig_path}")

plt.show()

In [None]:
# Cell 8: Results Summary
stats = metrics.get_recent_stats()

print("=" * 50)
print("TRAINING SUMMARY")
print("=" * 50)
print(f"Algorithm: REINFORCE (Monte Carlo Policy Gradient)")
print(f"Episodes: {episode}")
print(f"Total Steps: {total_steps:,}")
print(f"Training Time: {training_time:.1f}s")
print()
print("Note: REINFORCE uses complete episode returns (high variance).")
print("      No critic/baseline - gradients can be noisy.")
print()
print("Final Performance (last 100 episodes):")
print(f"  Average Score: {stats['avg_score']:.2f}")
print(f"  Average Reward: {stats['avg_reward']:.2f}")
print(f"  Average Length: {stats['avg_length']:.2f}")
print(f"  Max Score: {stats['max_score']}")
print()
print(f"Overall Statistics:")
print(f"  Mean Score: {np.mean(all_scores):.2f} +/- {np.std(all_scores):.2f}")
print(f"  Max Score: {max(all_scores)}")
print(f"  Mean Reward: {np.mean(all_rewards):.2f} +/- {np.std(all_rewards):.2f}")
print("=" * 50)