# Deep Q-Network (DQN) Training Notebook

This notebook trains a **DQN with CNN** agent on the Snake game.

**Grid Size:** 10√ó10 (optimal for DQN)

**Expected Performance:**
- Episodes to convergence: 2000-3000
- Final average score: 15-25 apples
- Training time: 15-20 minutes (CPU) / 5-10 minutes (GPU)

---

## 1. Setup and Imports

In [None]:
# Add parent directory to path
import sys
from pathlib import Path

sys.path.append(str(Path.cwd().parent))

import random
from typing import List
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.optim as optim

from game.config import GameConfig
from game.engine import SnakeGameEngine
from game.entities import Action, State
from utils.metrics import TrainingMetrics

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üñ•Ô∏è  Using device: {device}")

# For Google Colab
# !git clone https://github.com/YOUR_REPO/rl-snake.git
# %cd rl-snake

## 2. DQN Components

In [None]:
class ConvQNetwork(nn.Module):
    """
    Convolutional Q-Network optimized for Snake.

    Architecture choices:
    - ReLU: Standard, fast, no saturation issues
    - Batch Normalization: Stabilizes training
    - Residual connections: Better gradient flow
    - Adaptive architecture: Scales with grid size
    """

    def __init__(self, grid_size: int, num_actions: int = 4):
        super().__init__()
        self.grid_size = grid_size
        self.num_actions = num_actions

        if grid_size <= 10:
            channels = [32, 64, 64]
            hidden_size = 256
        elif grid_size <= 15:
            channels = [32, 64, 128]
            hidden_size = 512
        else:
            channels = [32, 64, 128, 128]
            hidden_size = 512

        conv_layers = []
        in_channels = 3

        for out_channels in channels:
            conv_layers.extend(
                [
                    nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                    nn.BatchNorm2d(out_channels),  # Stabilizes training
                    nn.ReLU(inplace=True),
                ]
            )
            in_channels = out_channels

        self.conv = nn.Sequential(*conv_layers)

        self.flat_size = channels[-1] * grid_size * grid_size

        self.fc = nn.Sequential(
            nn.Linear(self.flat_size, hidden_size),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),  # Prevent overfitting
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_size // 2, num_actions),
        )

        self._initialize_weights()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = self.conv(x)
        flat = features.view(features.size(0), -1)
        return self.fc(flat)

    def _initialize_weights(self):
        """Kaiming initialization for ReLU networks."""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
                nn.init.constant_(m.bias, 0)


class ReplayBuffer:
    """
    Prioritized Experience Replay.

    Samples important transitions more frequently.
    Improves learning efficiency significantly.
    """

    def __init__(self, capacity: int = 100_000, alpha: float = 0.6, beta: float = 0.4):
        self.capacity = capacity
        self.alpha = alpha  # Prioritization exponent
        self.beta = beta  # Importance sampling exponent
        self.beta_increment = 0.001

        self.buffer = []
        self.priorities = np.zeros(capacity, dtype=np.float32)
        self.position = 0
        self.size = 0

    def push(self, state, action, reward, next_state, done):
        """Store transition with max priority."""
        max_priority = self.priorities.max() if self.size > 0 else 1.0

        if len(self.buffer) < self.capacity:
            self.buffer.append((state, action, reward, next_state, done))
        else:
            self.buffer[self.position] = (state, action, reward, next_state, done)

        self.priorities[self.position] = max_priority
        self.position = (self.position + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)

    def sample(self, batch_size: int):
        """Sample batch with prioritization."""
        if self.size < batch_size:
            return None

        # Calculate sampling probabilities
        priorities = self.priorities[: self.size]
        probs = priorities**self.alpha
        probs /= probs.sum()

        # Sample indices
        indices = np.random.choice(self.size, batch_size, p=probs, replace=False)

        # Calculate importance sampling weights
        weights = (self.size * probs[indices]) ** (-self.beta)
        weights /= weights.max()  # Normalize

        # Get samples
        batch = [self.buffer[idx] for idx in indices]
        states, actions, rewards, next_states, dones = zip(*batch)

        # Increment beta (anneal importance sampling)
        self.beta = min(1.0, self.beta + self.beta_increment)

        return (
            np.array(states, dtype=np.float32),
            np.array(actions, dtype=np.int64),
            np.array(rewards, dtype=np.float32),
            np.array(next_states, dtype=np.float32),
            np.array(dones, dtype=np.float32),
            indices,
            weights,
        )

    def update_priorities(self, indices, td_errors):
        """Update priorities based on TD errors."""
        for idx, error in zip(indices, td_errors):
            self.priorities[idx] = abs(error) + 1e-6  # Small constant for stability

    def __len__(self):
        return self.size

## 3. DQN Agent

In [None]:
class DQNAgent:

    def __init__(
        self,
        grid_size: int,
        learning_rate: float = 0.0001,
        discount_factor: float = 0.99,
        # Exploration parameters (adaptive)
        epsilon_start: float = 1.0,
        epsilon_end: float = 0.05,  # Keep some exploration
        epsilon_frames: int = 500_000,  # Decay over more frames
        # Training parameters
        batch_size: int = 64,
        buffer_size: int = 100_000,
        target_update_freq: int = 2_000,  # Less frequent updates
        learning_starts: int = 10_000,  # Warmup period
    ):
        self.grid_size = grid_size
        self.discount_factor = discount_factor
        self.batch_size = batch_size
        self.target_update_freq = target_update_freq
        self.learning_starts = learning_starts

        # Adaptive epsilon decay
        self.epsilon = epsilon_start
        self.epsilon_start = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_frames = epsilon_frames
        self.epsilon_decay = (epsilon_start - epsilon_end) / epsilon_frames

        # Networks
        self.q_network = ConvQNetwork(grid_size).to(device)
        self.target_network = ConvQNetwork(grid_size).to(device)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.target_network.eval()

        # Optimizer with weight decay (L2 regularization)
        self.optimizer = optim.AdamW(
            self.q_network.parameters(), lr=learning_rate, weight_decay=1e-5
        )

        # Huber loss (robust to outliers)
        self.criterion = nn.SmoothL1Loss(reduction="none")  # Per-sample loss for PER

        # Replay buffer
        self.memory = ReplayBuffer(buffer_size)

        # Training statistics
        self.steps = 0
        self.episodes_trained = 0
        self.losses: List[float] = []
        self.episode_rewards: List[float] = []

    def get_action(self, state: State, training: bool = True) -> Action:
        """
        Select action with adaptive epsilon-greedy.

        Args:
            state: Current game state
            training: If False, uses greedy policy (no exploration)
        """
        # During warmup: pure exploration
        if training and self.steps < self.learning_starts:
            return random.choice(list(Action))

        # Epsilon-greedy (only during training)
        if training and random.random() < self.epsilon:
            return random.choice(list(Action))

        # Greedy action from Q-network
        state_tensor = self._state_to_tensor(state)
        with torch.no_grad():
            q_values = self.q_network(state_tensor)
        return Action(q_values.argmax().item())

    def train(
        self, state: State, action: Action, reward: float, next_state: State, done: bool
    ):
        """Single training step."""
        state_array = state.to_tensor()
        next_state_array = next_state.to_tensor()

        # Store in replay buffer
        self.memory.push(state_array, action.value, reward, next_state_array, done)

        self.steps += 1

        # Update epsilon (linear decay)
        if self.steps > self.learning_starts:
            self.epsilon = max(self.epsilon_end, self.epsilon - self.epsilon_decay)

        # Don't train until warmup complete
        if self.steps < self.learning_starts:
            return

        # Don't train if buffer too small
        if len(self.memory) < self.batch_size:
            return

        # Perform training update
        loss = self._update_network()
        if loss is not None:
            self.losses.append(loss)

        # Update target network
        if self.steps % self.target_update_freq == 0:
            self.target_network.load_state_dict(self.q_network.state_dict())
            print(f"üîÑ Updated target network at step {self.steps:,}")

        # Track episodes
        if done:
            self.episodes_trained += 1

    def _update_network(self):
        """Update Q-network using batch from replay buffer."""
        # Sample batch
        batch = self.memory.sample(self.batch_size)
        if batch is None:
            return None
        states, actions, rewards, next_states, dones, indices, weights = batch
        weights_t = torch.FloatTensor(weights).to(device)

        # Convert to tensors
        states_t = torch.FloatTensor(states).to(device)
        actions_t = torch.LongTensor(actions).to(device)
        rewards_t = torch.FloatTensor(rewards).to(device)
        next_states_t = torch.FloatTensor(next_states).to(device)
        dones_t = torch.FloatTensor(dones).to(device)

        # Current Q-values
        current_q = (
            self.q_network(states_t).gather(1, actions_t.unsqueeze(1)).squeeze(1)
        )

        # Target Q-values
        with torch.no_grad():
            # Double DQN: use online network to select, target to evaluate
            next_actions = self.q_network(next_states_t).argmax(1)
            next_q = (
                self.target_network(next_states_t)
                .gather(1, next_actions.unsqueeze(1))
                .squeeze(1)
            )

            target_q = rewards_t + (1 - dones_t) * self.discount_factor * next_q

        # Compute loss (element-wise for PER)
        td_errors = current_q - target_q
        losses = self.criterion(current_q, target_q)

        # Apply importance sampling weights
        loss = (losses * weights_t).mean()

        # Optimize
        self.optimizer.zero_grad()
        loss.backward()

        self.optimizer.step()

        # Update priorities
        self.memory.update_priorities(indices, td_errors.detach().cpu().numpy())

        return loss.item()

    def _state_to_tensor(self, state: State) -> torch.Tensor:
        """Convert State to tensor."""
        state_array = state.to_tensor()
        return torch.FloatTensor(state_array).unsqueeze(0).to(device)

    def save(self, filepath: str):
        Path(filepath).parent.mkdir(parents=True, exist_ok=True)
        torch.save(
            {
                "q_network_state": self.q_network.state_dict(),
                "target_network_state": self.target_network.state_dict(),
                "optimizer_state": self.optimizer.state_dict(),
                "grid_size": self.grid_size,
                "epsilon": self.epsilon,
                "steps": self.steps,
                "episodes_trained": self.episodes_trained,
            },
            filepath,
        )
        print(f"üíæ Model saved to {filepath}")
        print(f"   Episodes: {self.episodes_trained}")
        print(f"   Steps: {self.steps:,}")
        print(f"   Epsilon: {self.epsilon:.4f}")

    def load(self, filepath: str):
        """Load model and training state."""
        if not Path(filepath).exists():
            print(f"‚ö†Ô∏è  No saved model found at {filepath}")
            return

        checkpoint = torch.load(filepath, map_location=device)

        self.q_network.load_state_dict(checkpoint["q_network_state"])
        self.target_network.load_state_dict(checkpoint["target_network_state"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state"])
        self.grid_size = checkpoint["grid_size"]
        self.epsilon = checkpoint["epsilon"]
        self.steps = checkpoint["steps"]
        self.episodes_trained = checkpoint["episodes_trained"]

        print(f"‚úÖ Model loaded from {filepath}")
        print(f"   Episodes: {self.episodes_trained}")
        print(f"   Steps: {self.steps:,}")

## 4. Configuration

In [None]:
# Training configuration
GRID_SIZE = 10
EPISODES = 3_000

# Agent hyperparameters
LEARNING_RATE = 0.0001
DISCOUNT_FACTOR = 0.99
EPSILON_START = 1.0
EPSILON_DECAY = 0.995
EPSILON_MIN = 0.01
BATCH_SIZE = 64
BUFFER_SIZE = 100_000
TARGET_UPDATE = 1_000

# Save locations
MODEL_PATH = "models/dqn_cnn.pkl"
RESULTS_DIR = f"results/dqn_{GRID_SIZE}x{GRID_SIZE}"

print(f"üéÆ Training DQN on {GRID_SIZE}√ó{GRID_SIZE} grid")
print(f"üìà Episodes: {EPISODES:,}")
print(f"üñ•Ô∏è  Device: {device}")

## 5. Initialize Environment and Agent

In [None]:
# Environment
config = GameConfig(grid_size=GRID_SIZE)
game = SnakeGameEngine(config)

# Agent
agent = DQNAgent(grid_size=GRID_SIZE)

# Metrics
metrics = TrainingMetrics(save_dir=RESULTS_DIR)

print("‚úÖ Initialized")

## 6. Training Loop

In [None]:
record_score = 0
pbar = tqdm(range(1, EPISODES + 1), desc="Training")

try:
    for episode in pbar:
        state = game.reset()
        done = False
        episode_reward = 0
        steps = 0

        while not done:
            action = agent.get_action(state)
            reward, done, score = game.step(action)
            next_state = game.get_state()

            agent.train(state, action, reward, next_state, done)

            state = next_state
            episode_reward += reward
            steps += 1

        metrics.record_episode(episode, score, steps, episode_reward)

        if score > record_score:
            record_score = score

        # Update progress
        pbar.set_postfix(
            {
                "Avg": f"{metrics.get_recent_average_score():.2f}",
                "Best": record_score,
                "Œµ": f"{agent.epsilon:.3f}",
                "Buffer": len(agent.memory),
            }
        )
except KeyboardInterrupt:
    print("\n\n‚ö†Ô∏è  Training interrupted by user")

print("\n‚úÖ Training complete!")

## 7. Save Model

In [None]:
agent.save(MODEL_PATH)

## 8. Results and Visualization

In [None]:
metrics.print_summary()

In [None]:
metrics.plot(show=True, save=False)
plt.show()

## 9. Training Loss Analysis

In [None]:
if agent.losses:
    plt.figure(figsize=(12, 4))
    
    # Raw loss
    plt.subplot(1, 2, 1)
    plt.plot(agent.losses, alpha=0.3)
    if len(agent.losses) > 1000:
        window = 1000
        smoothed = np.convolve(agent.losses, np.ones(window)/window, mode='valid')
        plt.plot(smoothed, linewidth=2, label=f'Moving Avg ({window})')
    plt.xlabel('Training Step')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Loss distribution
    plt.subplot(1, 2, 2)
    plt.hist(agent.losses, bins=50, alpha=0.7)
    plt.xlabel('Loss')
    plt.ylabel('Frequency')
    plt.title('Loss Distribution')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"Loss Statistics:")
    print(f"  Mean: {np.mean(agent.losses):.4f}")
    print(f"  Std: {np.std(agent.losses):.4f}")
    print(f"  Min: {np.min(agent.losses):.4f}")
    print(f"  Max: {np.max(agent.losses):.4f}")

## 10. Test the Trained Agent

In [None]:
agent.epsilon = 0.0  # Pure exploitation

test_episodes = 10
test_scores = []

print(f"Testing for {test_episodes} episodes...")

for ep in range(test_episodes):
    state = game.reset()
    done = False
    
    while not done:
        action = agent.get_action(state)
        reward, done, score = game.step(action)
        state = game.get_state()
    
    test_scores.append(score)
    print(f"  Episode {ep+1}: Score = {score}")

print(f"\nTest Results:")
print(f"  Average: {np.mean(test_scores):.2f}")
print(f"  Best: {max(test_scores)}")
print(f"  Worst: {min(test_scores)}")

## 11. Export for Google Colab (Optional)

If running on Google Colab, you can download the trained model:

In [None]:
# Uncomment to download model in Colab
# from google.colab import files
# files.download(MODEL_PATH)