# Noisy DQN Training Notebook

This notebook trains a Noisy DQN agent to play Snake.

**Algorithm**: Noisy DQN with:
- NoisyLinear layers with parametric noise
- No epsilon-greedy exploration needed
- Network learns when to explore via noise parameters
- Experience replay + Target networks

In [None]:
# Cell 1: Imports
import sys
from pathlib import Path

project_root = Path.cwd().parent.parent
sys.path.insert(0, str(project_root))

import torch
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

from core.environment_vectorized import VectorizedSnakeEnv
from core.networks import NoisyDQN_MLP
from core.utils import ReplayBuffer, MetricsTracker, set_seed, get_device

print(f"Project root: {project_root}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
# Cell 2: Configuration (papermill parameters)
# ============== CONFIGURATION ==============

# Environment
GRID_SIZE = 10
NUM_ENVS = 256
MAX_STEPS = 1000
ACTION_SPACE_TYPE = 'relative'
STATE_REPRESENTATION = 'feature'  # Only 'feature' supported for Noisy DQN
USE_FLOOD_FILL = False
USE_ENHANCED_FEATURES = False
USE_SELECTIVE_FEATURES = False

# Training
NUM_EPISODES = 500
LEARNING_RATE = 0.001
BATCH_SIZE = 64
BUFFER_SIZE = 100_000
GAMMA = 0.99
TARGET_UPDATE_FREQ = 1000
MIN_BUFFER_SIZE = 1000
TRAIN_STEPS_RATIO = 0.03125

# Network
HIDDEN_DIMS = (128, 128)

# Noisy DQN - ENABLED
USE_NOISY = True
NOISY_SIGMA = 0.5  # Initial noise standard deviation

# Output
SAVE_DIR = '../../results/weights/noisy_dqn'
SEED = 42
LOG_INTERVAL = 50

In [None]:
# Cell 3: Environment & Model Setup
set_seed(SEED)
device = get_device()
print(f"Using device: {device}")
print(f"Noisy Networks: {USE_NOISY}")
print(f"Noise Sigma: {NOISY_SIGMA}")

save_dir = Path(SAVE_DIR)
save_dir.mkdir(parents=True, exist_ok=True)

env = VectorizedSnakeEnv(
    num_envs=NUM_ENVS,
    grid_size=GRID_SIZE,
    action_space_type=ACTION_SPACE_TYPE,
    state_representation=STATE_REPRESENTATION,
    max_steps=MAX_STEPS,
    use_flood_fill=USE_FLOOD_FILL,
    use_enhanced_features=USE_ENHANCED_FEATURES,
    use_selective_features=USE_SELECTIVE_FEATURES,
    device=device
)

# Determine input dimension
input_dim = 11
if USE_FLOOD_FILL:
    input_dim = 14
if USE_SELECTIVE_FEATURES:
    input_dim = 19
if USE_ENHANCED_FEATURES:
    input_dim = 24

# Noisy DQN only supports MLP (feature-based state)
policy_net = NoisyDQN_MLP(
    input_dim=input_dim,
    output_dim=env.action_space.n,
    hidden_dims=HIDDEN_DIMS,
    sigma_init=NOISY_SIGMA
).to(device)

target_net = NoisyDQN_MLP(
    input_dim=input_dim,
    output_dim=env.action_space.n,
    hidden_dims=HIDDEN_DIMS,
    sigma_init=NOISY_SIGMA
).to(device)

target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = torch.optim.Adam(policy_net.parameters(), lr=LEARNING_RATE)
replay_buffer = ReplayBuffer(capacity=BUFFER_SIZE, seed=SEED)
metrics = MetricsTracker(window_size=100)

print(f"Environment: {GRID_SIZE}x{GRID_SIZE} grid, {NUM_ENVS} parallel envs")
print(f"Policy network: {sum(p.numel() for p in policy_net.parameters())} parameters")
print(f"Note: Noisy DQN uses parametric noise - no epsilon-greedy needed!")

In [None]:
# Cell 4: Training Loop
import time
import torch.nn.functional as F

def select_actions(states):
    """Select actions using noisy network (no epsilon needed)"""
    with torch.no_grad():
        policy_net.reset_noise()  # Sample new noise
        q_values = policy_net(states)
        return q_values.argmax(dim=1)

def train_step():
    if not replay_buffer.is_ready(BATCH_SIZE):
        return None
    
    # Reset noise for both networks
    policy_net.reset_noise()
    target_net.reset_noise()
    
    states, actions, rewards, next_states, dones = replay_buffer.sample(BATCH_SIZE)
    states = states.to(device)
    actions = actions.to(device)
    rewards = rewards.to(device)
    next_states = next_states.to(device)
    dones = dones.to(device)
    
    current_q_values = policy_net(states).gather(1, actions.unsqueeze(1))
    
    with torch.no_grad():
        next_q_values = target_net(next_states).max(1)[0]
        target_q_values = rewards + GAMMA * next_q_values * (1 - dones)
    
    loss = F.smooth_l1_loss(current_q_values.squeeze(), target_q_values)
    
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(policy_net.parameters(), 1.0)
    optimizer.step()
    
    return loss.item()

print("Starting Noisy DQN Training...")
print(f"Device: {device}")
print(f"Episodes: {NUM_EPISODES}")
print("Note: Using parametric noise for exploration (no epsilon decay)")
print()

states = env.reset(seed=SEED)
start_time = time.time()

episode_rewards = torch.zeros(NUM_ENVS, device=device)
episode_lengths = torch.zeros(NUM_ENVS, dtype=torch.long, device=device)

episode = 0
total_steps = 0

all_rewards = []
all_scores = []
all_losses = []

while episode < NUM_EPISODES:
    actions = select_actions(states)
    next_states, rewards, dones, info = env.step(actions)
    
    episode_rewards += rewards
    episode_lengths += 1
    
    for i in range(NUM_ENVS):
        replay_buffer.push(
            states[i].cpu().numpy(),
            actions[i].item(),
            rewards[i].item(),
            next_states[i].cpu().numpy(),
            dones[i].item()
        )
    
    if replay_buffer.is_ready(MIN_BUFFER_SIZE):
        num_train_steps = max(1, int(NUM_ENVS * TRAIN_STEPS_RATIO))
        for _ in range(num_train_steps):
            loss = train_step()
            if loss is not None:
                metrics.add_loss(loss)
                all_losses.append(loss)
    
    if total_steps % TARGET_UPDATE_FREQ == 0:
        target_net.load_state_dict(policy_net.state_dict())
    
    if dones.any():
        done_indices = torch.where(dones)[0]
        for idx in done_indices:
            reward = episode_rewards[idx].item()
            length = episode_lengths[idx].item()
            score = info['scores'][idx].item()
            
            metrics.add_episode(reward, length, score)
            all_rewards.append(reward)
            all_scores.append(score)
            
            episode += 1
            episode_rewards[idx] = 0
            episode_lengths[idx] = 0
            
            if episode % LOG_INTERVAL == 0:
                stats = metrics.get_recent_stats()
                elapsed = time.time() - start_time
                fps = total_steps / elapsed if elapsed > 0 else 0
                print(f"Episode {episode}/{NUM_EPISODES} | "
                      f"Score: {stats['avg_score']:.2f} | "
                      f"Reward: {stats['avg_reward']:.2f} | "
                      f"FPS: {fps:.0f}")
            
            if episode >= NUM_EPISODES:
                break
    
    states = next_states
    total_steps += NUM_ENVS

training_time = time.time() - start_time
print(f"\nTraining complete!")
print(f"Total time: {training_time:.1f}s")
print(f"Total steps: {total_steps:,}")

In [None]:
# Cell 5: Save Weights
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"noisy_dqn_{GRID_SIZE}x{GRID_SIZE}_{NUM_EPISODES}ep_{timestamp}.pt"
filepath = save_dir / filename

torch.save({
    'policy_net': policy_net.state_dict(),
    'target_net': target_net.state_dict(),
    'optimizer': optimizer.state_dict(),
    'episode': episode,
    'total_steps': total_steps,
    'config': {
        'grid_size': GRID_SIZE,
        'num_envs': NUM_ENVS,
        'hidden_dims': HIDDEN_DIMS,
        'learning_rate': LEARNING_RATE,
        'gamma': GAMMA,
        'use_noisy': USE_NOISY,
        'noisy_sigma': NOISY_SIGMA,
        'state_representation': STATE_REPRESENTATION
    }
}, filepath)

print(f"Model saved to: {filepath}")

In [None]:
# Cell 6: Visualization
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

def smooth(data, window=50):
    if len(data) < window:
        return data
    return np.convolve(data, np.ones(window)/window, mode='valid')

ax1 = axes[0, 0]
ax1.plot(all_rewards, alpha=0.3, label='Raw')
if len(all_rewards) > 50:
    ax1.plot(range(49, len(all_rewards)), smooth(all_rewards), label='Smoothed (50)')
ax1.set_xlabel('Episode')
ax1.set_ylabel('Reward')
ax1.set_title('Episode Rewards')
ax1.legend()
ax1.grid(True, alpha=0.3)

ax2 = axes[0, 1]
ax2.plot(all_scores, alpha=0.3, label='Raw')
if len(all_scores) > 50:
    ax2.plot(range(49, len(all_scores)), smooth(all_scores), label='Smoothed (50)')
ax2.set_xlabel('Episode')
ax2.set_ylabel('Score (Food Eaten)')
ax2.set_title('Episode Scores')
ax2.legend()
ax2.grid(True, alpha=0.3)

ax3 = axes[1, 0]
if all_losses:
    step = max(1, len(all_losses) // 1000)
    ax3.plot(all_losses[::step], alpha=0.5)
    ax3.set_xlabel('Training Step (subsampled)')
    ax3.set_ylabel('Loss')
    ax3.set_title('Training Loss')
    ax3.grid(True, alpha=0.3)

# Plot 4: Score distribution instead of epsilon (no epsilon in noisy DQN)
ax4 = axes[1, 1]
ax4.hist(all_scores, bins=30, edgecolor='black', alpha=0.7)
ax4.axvline(np.mean(all_scores), color='red', linestyle='--', label=f'Mean: {np.mean(all_scores):.2f}')
ax4.set_xlabel('Score')
ax4.set_ylabel('Frequency')
ax4.set_title('Score Distribution')
ax4.legend()
ax4.grid(True, alpha=0.3)

plt.suptitle(f'Noisy DQN Training Results - {NUM_EPISODES} Episodes', fontsize=14)
plt.tight_layout()

fig_dir = project_root / 'results' / 'figures'
fig_dir.mkdir(parents=True, exist_ok=True)
fig_path = fig_dir / f'noisy_dqn_training_{timestamp}.png'
plt.savefig(fig_path, dpi=150, bbox_inches='tight')
print(f"Figure saved to: {fig_path}")

plt.show()

In [None]:
# Cell 7: Results Summary
stats = metrics.get_recent_stats()

print("=" * 50)
print("TRAINING SUMMARY")
print("=" * 50)
print(f"Algorithm: Noisy DQN")
print(f"Exploration: Parametric noise (sigma={NOISY_SIGMA})")
print(f"Episodes: {episode}")
print(f"Total Steps: {total_steps:,}")
print(f"Training Time: {training_time:.1f}s")
print()
print("Final Performance (last 100 episodes):")
print(f"  Average Score: {stats['avg_score']:.2f}")
print(f"  Average Reward: {stats['avg_reward']:.2f}")
print(f"  Average Length: {stats['avg_length']:.2f}")
print(f"  Max Score: {stats['max_score']}")
print()
print(f"Overall Statistics:")
print(f"  Mean Score: {np.mean(all_scores):.2f} +/- {np.std(all_scores):.2f}")
print(f"  Max Score: {max(all_scores)}")
print(f"  Mean Reward: {np.mean(all_rewards):.2f} +/- {np.std(all_rewards):.2f}")
print("=" * 50)