# üêç Reinforcement Learning for Snake: Zero to Hero

Welcome! This notebook is a **hands-on, runnable** tutorial on training an AI agent to play Snake.
We'll go from a random agent to a scaled-up PPO master ‚Äî and you can run every cell.

> **Companion material**: Read [blog.md](blog.md) for the full narrative, or watch the [Live Visualization](https://Saheb.github.io/rl-snake/snake_learning_journey.html).

---
## 0. A Quick Primer: What is Reinforcement Learning?

RL is **learning by trial and error**. An *agent* takes *actions* in an *environment*,
and receives *rewards* (or penalties). Over time, it learns a *policy* ‚Äî a strategy
that maximises its cumulative reward.

```
   Agent ‚îÄ‚îÄaction‚îÄ‚îÄ‚ñ∂ Environment
     ‚ñ≤                  ‚îÇ
     ‚îî‚îÄ‚îÄstate, reward‚óÄ‚îÄ‚îÄ‚îò
```

### Two Families of RL

| | Value-Based ("The Accountant") | Policy-Based ("The Athlete") |
|---|---|---|
| **How it thinks** | Calculates the *worth* of every move | Learns *instincts* directly |
| **Famous algorithm** | **DQN** (Deep Q-Network) | **PPO** (Proximal Policy Optimization) |
| **Analogy** | A map that tells you the gold at every corner | An athlete who just *knows* where to throw the ball |
| **Scales to big boards?** | ‚ùå Table explodes | ‚úÖ Generalises well |

---
## 1. Meet the Environment

Our game is a custom `SnakeGame` class. Let's create one and see what the agent "sees".

In [None]:
from snake_game import SnakeGame, GameState, EMPTY, SNAKE, FOOD
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

game = SnakeGame(board_size=5, seed=42)
state = game.reset(seed=42)

print(f"Board size : {game.board_size}x{game.board_size}")
print(f"Actions    : 0=Up, 1=Right, 2=Down, 3=Left")
print(f"Snake pos  : {list(game.snake_position)}")
print(f"Food pos   : {game.food_position}")
print()
game.print_board()

Let's **visualise** a board as a colour grid so we can see what's happening:

In [None]:
def plot_board(game, title="Snake Game"):
    """Render the board as a colour image."""
    cmap = mcolors.ListedColormap(['#1a1a2e', '#16c784', '#ff6b6b'])
    bounds = [-0.5, 0.5, 1.5, 2.5]
    norm = mcolors.BoundaryNorm(bounds, cmap.N)

    fig, ax = plt.subplots(figsize=(4, 4))
    ax.imshow(game.board, cmap=cmap, norm=norm)
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.set_xticks([]); ax.set_yticks([])
    for i in range(game.board_size + 1):
        ax.axhline(i - 0.5, color='white', linewidth=0.5, alpha=0.3)
        ax.axvline(i - 0.5, color='white', linewidth=0.5, alpha=0.3)
    plt.tight_layout()
    plt.show()

plot_board(game, "Initial Board (üü¢ Snake, üî¥ Food)")

---
## 2. A Random Agent (The Baseline)

Before any learning, let's see how a **random agent** performs.
This is our "Phase ?" ‚Äî the absolute floor.

In [None]:
import random

def run_random_agent(board_size=5, num_games=500, max_steps=200):
    """Run a random agent and collect scores."""
    game = SnakeGame(board_size=board_size)
    scores = []
    for _ in range(num_games):
        game.reset()
        for _ in range(max_steps):
            action = random.randint(0, 3)
            _, _, done, info = game.step(action)
            if done:
                break
        scores.append(info['score'])
    return scores

random_scores = run_random_agent()
print(f"Random Agent on 5x5:")
print(f"  Mean Score : {np.mean(random_scores):.2f}")
print(f"  Max Score  : {max(random_scores)}")
print(f"  Median     : {np.median(random_scores):.1f}")

As expected ‚Äî the random agent barely scores. It usually dies in a few steps.

---
## 3. Phase 0: Tabular Q-Learning (The Accountant)

Our first real algorithm. The agent maintains a **Q-table** ‚Äî a dictionary that maps
`(state, action)` ‚Üí expected future reward. After every step, it updates the table:

$$Q(s, a) \leftarrow Q(s, a) + \alpha \Big[ r + \gamma \max_{a'} Q(s', a') - Q(s, a) \Big]$$

Let's train one **live** and watch the Q-table grow!

In [None]:
from train_tabular_q import QLearningAgent, DoubleQLearningAgent

# Train a fresh Q-Learning agent (small run for demo)
game = SnakeGame(board_size=5)
agent = QLearningAgent(learning_rate=0.1, discount_factor=0.99, epsilon=1.0)

scores = []
q_table_sizes = []  # Track how many states the agent discovers

for ep in range(2000):
    game.reset()
    state = agent.get_state_key(game, 0)
    action = agent.choose_action(state, game)
    done = False

    while not done:
        _, reward, done, info = game.step(action)
        next_state = agent.get_state_key(game, action)
        agent.update_q_value(state, action, reward, next_state)
        state = next_state
        action = agent.choose_action(next_state, game)

    scores.append(info['score'])
    q_table_sizes.append(len(agent.q_table))
    agent.epsilon = max(0.01, agent.epsilon * 0.995)

print(f"Training complete! Q-table has {len(agent.q_table):,} unique states.")
print(f"Last 100 episodes ‚Äî Mean: {np.mean(scores[-100:]):.1f}, Max: {max(scores[-100:])}")

In [None]:
# Plot: Training Curve + Q-table Growth (side by side)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# Smoothed training curve
window = 50
smoothed = np.convolve(scores, np.ones(window)/window, mode='valid')
ax1.plot(smoothed, color='#16c784', linewidth=1.5)
ax1.set_title('Training Curve (Q-Learning)', fontweight='bold')
ax1.set_xlabel('Episode')
ax1.set_ylabel('Score')
ax1.set_facecolor('#1a1a2e')
ax1.grid(alpha=0.2)

# Q-table growth
ax2.plot(q_table_sizes, color='#ff6b6b', linewidth=1.5)
ax2.set_title('Q-Table Size (States Discovered)', fontweight='bold')
ax2.set_xlabel('Episode')
ax2.set_ylabel('Unique States')
ax2.set_facecolor('#1a1a2e')
ax2.grid(alpha=0.2)

plt.tight_layout()
plt.show()

print(f"üëÜ Notice: The Q-table keeps growing. On a 10x10 board, this would explode!")

### Phase 0+: Double Q-Learning

Standard Q-Learning tends to **overestimate** values. Double Q-Learning fixes this by
maintaining *two* tables and cross-checking:

- Table A picks the best action
- Table B evaluates it (and vice versa)

This gives more **stable** and **reliable** learning.

In [None]:
# Load the pre-trained models and compare
import pickle

def evaluate_agent(agent, board_size=5, num_games=200, max_steps=500):
    """Evaluate a trained agent."""
    game = SnakeGame(board_size=board_size)
    scores = []
    for _ in range(num_games):
        game.reset()
        action = 0
        for _ in range(max_steps):
            state_key = agent.get_state_key(game, action)
            old_eps = agent.epsilon
            agent.epsilon = 0  # Greedy during eval
            action = agent.choose_action(state_key, game)
            agent.epsilon = old_eps
            _, _, done, info = game.step(action)
            if done:
                break
        scores.append(info['score'])
    return scores

try:
    with open('tabular_q_5x5.pkl', 'rb') as f:
        q_agent = pickle.load(f)
    with open('tabular_double_q_5x5.pkl', 'rb') as f:
        dq_agent = pickle.load(f)

    q_scores = evaluate_agent(q_agent)
    dq_scores = evaluate_agent(dq_agent)

    print("=" * 40)
    print(f"{'Metric':<20} {'Q-Learning':>10} {'Double Q':>10}")
    print("-" * 40)
    print(f"{'Mean Score':<20} {np.mean(q_scores):>10.1f} {np.mean(dq_scores):>10.1f}")
    print(f"{'Max Score':<20} {max(q_scores):>10} {max(dq_scores):>10}")
    print(f"{'Std Dev':<20} {np.std(q_scores):>10.1f} {np.std(dq_scores):>10.1f}")
    print(f"{'Q-Table States':<20} {len(q_agent.q_table):>10,} {len(dq_agent.q_table_a):>10,}")
    print("=" * 40)
except FileNotFoundError:
    print("Pre-trained models not found. Run `python train_tabular_q.py` first!")
    print("  python train_tabular_q.py --type q --episodes 5000")
    print("  python train_tabular_q.py --type double_q --episodes 5000")

---
## 4. üöß The Wall: Why Tables Don't Scale

On a 5x5 board, tabular Q-Learning is perfect. But try a **10x10** board:

| Board | Possible States | Can Tabular Handle It? |
|---|---|---|
| 5√ó5 | ~2,000 | ‚úÖ Easy |
| 8√ó8 | ~50,000+ | ‚ö†Ô∏è Barely |
| 10√ó10 | ~500,000+ | ‚ùå Impossible |

The Q-table can't memorise that many states. And even if it could,
**rewards are sparse**: a random snake on a 10x10 board might wander for 1000 steps
before accidentally eating food. There's nothing to learn from.

We need **neural networks** to *generalise* across similar states,
and **clever training strategies** to overcome the sparse reward problem.

---
## 5. The Solution: A Triple Threat

To conquer the 10x10 board, we combined three techniques:

### üß† Technique 1: PPO (Proximal Policy Optimization)
Instead of memorising a table, PPO uses a **neural network** that outputs
action *probabilities*. It's a policy-gradient method, meaning it directly
optimises "what action should I take?" rather than "what is this state worth?"

The key idea: PPO clips its updates so it never changes the policy too drastically
in one step. This gives stable, reliable learning.

### üéì Technique 2: Imitation Learning (Behavioral Cloning)
Before exploring on its own, we let the agent **watch an expert** play.
The expert is a REINFORCE agent trained on 5x5. We record its games and
pre-train the PPO network to mimic those moves. This gives the agent
"instincts" before it even starts exploring.

### üìà Technique 3: Curriculum Learning
We don't jump straight to 10x10. Instead:
1. Master **5x5** (easy, dense rewards)
2. Transfer brain ‚Üí **8x8** (medium)
3. Transfer brain ‚Üí **10x10** (the goal)

Each stage builds on the skills from the previous one.

### The PPO Network Architecture

Let's peek inside the neural network that powers our final agent:

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical

class ActorCritic(nn.Module):
    """The brain of our PPO agent.
    
    - Actor: outputs action probabilities ("what should I do?")
    - Critic: estimates state value ("how good is my situation?")
    - Shared layers: both heads share a common feature extractor
    """
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU()
        )
        self.actor = nn.Linear(hidden_size // 2, output_size)   # ‚Üí 3 actions
        self.critic = nn.Linear(hidden_size // 2, 1)            # ‚Üí 1 value
    
    def forward(self, x):
        shared = self.shared(x)
        return F.softmax(self.actor(shared), dim=-1), self.critic(shared)

# Create the network
net = ActorCritic(input_size=14, hidden_size=256, output_size=3)
print(net)
print(f"\nTotal parameters: {sum(p.numel() for p in net.parameters()):,}")

# Test with a dummy state
dummy = torch.randn(1, 14)
probs, value = net(dummy)
print(f"\nAction probabilities: {probs.detach().numpy().round(3)}")
print(f"State value estimate: {value.item():.3f}")

Notice the architecture:
- **14 input features** (danger sensors, food direction, movement direction, etc.)
- **256 ‚Üí 128** shared hidden layers (the "common brain")
- **Actor head** ‚Üí 3 outputs (Straight / Turn Right / Turn Left)
- **Critic head** ‚Üí 1 output ("how good is this state?")

This is *much* more compact than a Q-table with thousands of entries!

---
## 6. What *Failed* (And Why It Matters)

Not everything worked. Here's what we tried and what we learned:

| Approach | Board | Result | Why It Failed |
|---|---|---|---|
| **A2C** (Actor-Critic) | 5x5 | Max 4, Mean 0.2 | Bootstrap trap: critic gives bad estimates early, actor trusts them anyway |
| **Vanilla PPO** | 5x5 | Max 4, Mean 0.35 | Same bootstrap trap ‚Äî no expert guidance |
| **PPO + DQN demos** | 8x8 | Max 42, Mean 6.4 | DQN demos had 96% accuracy but PPO couldn't replicate the value-based style |
| **PPO + REINFORCE demos** | 8x8 | Max 46, Mean 11.9 | ‚úÖ Policy‚ÜíPolicy transfer works! |

### üí° Key Insight: Algorithm Compatibility > Imitation Accuracy

DQN demos gave **96% behavioral cloning accuracy** (higher than REINFORCE's 82%),
but PPO performed *worse* with them! Why?

- **REINFORCE and PPO are both policy-gradient methods** ‚Äî they "think" the same way
- **DQN is value-based** ‚Äî it outputs Q-values, not probabilities
- Transferring from a policy method ‚Üí policy method preserves the decision structure

---
## 7. Existing Training Plots

Here are the training curves from our full runs:

In [None]:
from IPython.display import Image, display
import os

plots = ['tabular_q_training.png', 'tabular_double_q_training.png']
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for ax, plot_file in zip(axes, plots):
    if os.path.exists(plot_file):
        img = plt.imread(plot_file)
        ax.imshow(img)
        ax.set_title(plot_file.replace('.png', '').replace('_', ' ').title(), fontweight='bold')
    else:
        ax.text(0.5, 0.5, f'{plot_file}\nnot found', ha='center', va='center')
    ax.axis('off')

plt.suptitle('Full Training Runs (5000 episodes each)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

---
## 8. üèÜ Final Results

| Phase | Strategy | Board | Best Score | Mean Score |
|---|---|---|---|---|
| 0 | Random Agent | 5x5 | ~2 | ~0.3 |
| 0 | Tabular Q-Learning | 5x5 | 24 (Perfect) | ~11 |
| 0+ | Double Q-Learning | 5x5 | 24 (Stable) | ~14 |
| 1 | PPO + Imitation (5x5) | 5x5 | 24 | ~20 |
| 2 | Curriculum ‚Üí 8x8 | 8x8 | 46 | ~12 |
| **3** | **Curriculum ‚Üí 10x10** | **10x10** | **64** | **~18** |

The agent fills over **60%** of a 10x10 board ‚Äî starting from literally nothing!

---
## 9. Try It Yourself!

**Train your own agents:**
```bash
# Tabular (fast, runs in seconds)
python train_tabular_q.py --type q --episodes 5000
python train_tabular_q.py --type double_q --episodes 5000

# Full PPO curriculum (takes ~30 min)
python train_ppo_curriculum.py
```

**Watch the results:**
- Open `snake_learning_journey.html` in your browser to see the full evolution
- Open `snake_10x10_replay.html` to watch the best 10x10 games

**Explore the code:**
| File | What it does |
|---|---|
| `snake_game.py` | The game environment |
| `train_tabular_q.py` | Q-Learning and Double Q-Learning |
| `train_reinforce.py` | REINFORCE (expert for demos) |
| `train_ppo_curriculum.py` | PPO + Imitation + Curriculum |
| `visualize_journey.py` | Record games ‚Üí HTML visualization |

---
*Happy learning! üêç*