# Two-Snake Classic DQN Training Notebook

This notebook trains two competing DQN agents to play against each other.

**Algorithm**: Classic DQN Two-Snake Training with:
- Two independent DQN agents competing
- Episode-based training (not vectorized)
- Each agent has its own Q-network and replay buffer
- Target: reach target_score food items first

In [None]:
# Cell 1: Imports
import sys
from pathlib import Path

project_root = Path.cwd().parent.parent
sys.path.insert(0, str(project_root))

import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
import time

from core.environment_two_snake_classic import TwoSnakeCompetitiveEnv
from agents.vanilla_dqn import VanillaDQNAgent

print(f"Project root: {project_root}")

In [None]:
# Cell 2: Configuration (papermill parameters)
# ============== CONFIGURATION ==============

# Environment
GRID_SIZE = 10
TARGET_SCORE = 10  # Food needed to win
MAX_STEPS_PER_EPISODE = 500

# Training
NUM_EPISODES = 500  # Short for testing (full: 10000)
EVAL_FREQ = 50  # Logging frequency
SAVE_FREQ = 500  # Checkpoint frequency

# DQN Agent Config
HIDDEN_SIZE = 128
LEARNING_RATE = 0.0005
GAMMA = 0.99
EPSILON_START = 1.0
EPSILON_MIN = 0.01
EPSILON_DECAY = 0.995
BUFFER_SIZE = 50000
BATCH_SIZE = 64
TARGET_UPDATE_FREQ = 1000

# Output
SAVE_DIR = '../../results/weights/competitive/classic'
SEED = 42

In [None]:
# Cell 3: Environment & Agent Setup
np.random.seed(SEED)

save_dir = Path(SAVE_DIR)
save_dir.mkdir(parents=True, exist_ok=True)

# Create environment
env = TwoSnakeCompetitiveEnv(
    grid_size=GRID_SIZE,
    target_score=TARGET_SCORE,
    max_steps=MAX_STEPS_PER_EPISODE
)

state_size = env.observation_space.shape[0]  # 20 features
action_size = env.action_space.n  # 3 actions

print(f"State size: {state_size}, Action size: {action_size}")

# Create Agent 1
agent1 = VanillaDQNAgent(
    state_size=state_size,
    action_size=action_size,
    hidden_size=HIDDEN_SIZE,
    learning_rate=LEARNING_RATE,
    gamma=GAMMA,
    epsilon=EPSILON_START,
    epsilon_min=EPSILON_MIN,
    epsilon_decay=EPSILON_DECAY,
    buffer_size=BUFFER_SIZE,
    batch_size=BATCH_SIZE,
    use_target_network=True,
    target_update_freq=TARGET_UPDATE_FREQ
)

# Create Agent 2
agent2 = VanillaDQNAgent(
    state_size=state_size,
    action_size=action_size,
    hidden_size=HIDDEN_SIZE,
    learning_rate=LEARNING_RATE,
    gamma=GAMMA,
    epsilon=EPSILON_START,
    epsilon_min=EPSILON_MIN,
    epsilon_decay=EPSILON_DECAY,
    buffer_size=BUFFER_SIZE,
    batch_size=BATCH_SIZE,
    use_target_network=True,
    target_update_freq=TARGET_UPDATE_FREQ
)

print(f"Environment: {GRID_SIZE}x{GRID_SIZE} grid")
print(f"Target score: {TARGET_SCORE} food items")
print(f"Agent 1: DQN with {HIDDEN_SIZE} hidden neurons")
print(f"Agent 2: DQN with {HIDDEN_SIZE} hidden neurons")

In [None]:
# Cell 4: Training Loop
print("Starting Two-Snake Classic DQN Training...")
print(f"Episodes: {NUM_EPISODES}")
print()

start_time = time.time()

# Statistics tracking
stats = {
    'agent1_wins': 0,
    'agent2_wins': 0,
    'draws': 0,
    'agent1_total_reward': 0.0,
    'agent2_total_reward': 0.0,
}

# For plotting
all_agent1_rewards = []
all_agent2_rewards = []
all_agent1_scores = []
all_agent2_scores = []
all_win_rates = []
total_steps = 0

for episode in range(1, NUM_EPISODES + 1):
    observations, _ = env.reset()
    state1 = observations['agent1']
    state2 = observations['agent2']

    episode_reward1 = 0
    episode_reward2 = 0
    done = False
    step = 0

    while not done and step < MAX_STEPS_PER_EPISODE:
        # Get actions from both agents
        action1 = agent1.select_action(state1, training=True)
        action2 = agent2.select_action(state2, training=True)

        # Step environment (dict-based API)
        actions = {'agent1': action1, 'agent2': action2}
        next_observations, rewards, terminated, truncated, info = env.step(actions)

        next_state1 = next_observations['agent1']
        next_state2 = next_observations['agent2']
        reward1 = rewards['agent1']
        reward2 = rewards['agent2']
        done1 = terminated['agent1'] or truncated['agent1']
        done2 = terminated['agent2'] or truncated['agent2']
        done = done1 or done2

        # Train both agents
        agent1.train_step(state1, action1, reward1, next_state1, done1)
        agent2.train_step(state2, action2, reward2, next_state2, done2)

        episode_reward1 += reward1
        episode_reward2 += reward2

        state1 = next_state1
        state2 = next_state2
        step += 1
        total_steps += 1

    # Decay epsilon
    agent1.decay_epsilon()
    agent2.decay_epsilon()

    # Track winner
    if 'winner' in info:
        if info['winner'] == 1:
            stats['agent1_wins'] += 1
        elif info['winner'] == 2:
            stats['agent2_wins'] += 1
        else:
            stats['draws'] += 1
    else:
        stats['draws'] += 1

    # Store metrics
    all_agent1_rewards.append(episode_reward1)
    all_agent2_rewards.append(episode_reward2)
    all_agent1_scores.append(info.get('score1', 0))
    all_agent2_scores.append(info.get('score2', 0))

    win_rate1 = stats['agent1_wins'] / episode if episode > 0 else 0
    all_win_rates.append(win_rate1)

    # Logging
    if episode % EVAL_FREQ == 0:
        elapsed = time.time() - start_time
        eps_per_sec = episode / elapsed if elapsed > 0 else 0

        # Recent window stats
        window = min(100, episode)
        recent_a1_scores = all_agent1_scores[-window:]
        recent_a2_scores = all_agent2_scores[-window:]

        print(f"Episode {episode}/{NUM_EPISODES} | "
              f"A1 Score: {np.mean(recent_a1_scores):.2f} | "
              f"A2 Score: {np.mean(recent_a2_scores):.2f} | "
              f"Win Rate: {win_rate1:.2%} | "
              f"Eps: {agent1.epsilon:.3f} | "
              f"Speed: {eps_per_sec:.1f} ep/s")

training_time = time.time() - start_time
print(f"\nTraining complete!")
print(f"Total time: {training_time:.1f}s")
print(f"Total steps: {total_steps:,}")

In [None]:
# Cell 5: Save Weights
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# Save Agent 1
agent1_path = save_dir / f"agent1_dqn_{NUM_EPISODES}ep_{timestamp}.pt"
agent1.save(str(agent1_path))
print(f"Agent 1 saved to: {agent1_path}")

# Save Agent 2
agent2_path = save_dir / f"agent2_dqn_{NUM_EPISODES}ep_{timestamp}.pt"
agent2.save(str(agent2_path))
print(f"Agent 2 saved to: {agent2_path}")

In [None]:
# Cell 6: Visualization
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

def smooth(data, window=50):
    if len(data) < window:
        return data
    return np.convolve(data, np.ones(window)/window, mode='valid')

# Plot 1: Agent Scores
ax1 = axes[0, 0]
ax1.plot(all_agent1_scores, alpha=0.3, label='Agent 1 (raw)', color='blue')
ax1.plot(all_agent2_scores, alpha=0.3, label='Agent 2 (raw)', color='red')
if len(all_agent1_scores) > 50:
    ax1.plot(range(49, len(all_agent1_scores)), smooth(all_agent1_scores), label='Agent 1 (smooth)', color='blue')
    ax1.plot(range(49, len(all_agent2_scores)), smooth(all_agent2_scores), label='Agent 2 (smooth)', color='red')
ax1.set_xlabel('Episode')
ax1.set_ylabel('Score (Food Eaten)')
ax1.set_title('Agent Scores')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot 2: Agent Rewards
ax2 = axes[0, 1]
ax2.plot(all_agent1_rewards, alpha=0.3, label='Agent 1', color='blue')
ax2.plot(all_agent2_rewards, alpha=0.3, label='Agent 2', color='red')
if len(all_agent1_rewards) > 50:
    ax2.plot(range(49, len(all_agent1_rewards)), smooth(all_agent1_rewards), color='blue')
    ax2.plot(range(49, len(all_agent2_rewards)), smooth(all_agent2_rewards), color='red')
ax2.set_xlabel('Episode')
ax2.set_ylabel('Episode Reward')
ax2.set_title('Agent Rewards')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Plot 3: Win Rate
ax3 = axes[1, 0]
ax3.plot(all_win_rates, color='green')
ax3.axhline(y=0.5, color='gray', linestyle='--', label='50% (balanced)')
ax3.set_xlabel('Episode')
ax3.set_ylabel('Agent 1 Win Rate')
ax3.set_title('Cumulative Win Rate')
ax3.legend()
ax3.grid(True, alpha=0.3)

# Plot 4: Win Distribution
ax4 = axes[1, 1]
labels = ['Agent 1 Wins', 'Agent 2 Wins', 'Draws']
sizes = [stats['agent1_wins'], stats['agent2_wins'], stats['draws']]
colors = ['blue', 'red', 'gray']
ax4.pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%', startangle=90)
ax4.set_title('Match Outcome Distribution')

plt.suptitle(f'Two-Snake Classic DQN Training - {NUM_EPISODES} Episodes', fontsize=14)
plt.tight_layout()

fig_dir = project_root / 'results' / 'figures'
fig_dir.mkdir(parents=True, exist_ok=True)
fig_path = fig_dir / f'two_snake_classic_training_{timestamp}.png'
plt.savefig(fig_path, dpi=150, bbox_inches='tight')
print(f"Figure saved to: {fig_path}")

plt.show()

In [None]:
# Cell 7: Results Summary
print("=" * 50)
print("TRAINING SUMMARY")
print("=" * 50)
print(f"Algorithm: Classic DQN Two-Snake Competitive")
print(f"Episodes: {NUM_EPISODES}")
print(f"Total Steps: {total_steps:,}")
print(f"Training Time: {training_time:.1f}s")
print()
print("Match Results:")
print(f"  Agent 1 Wins: {stats['agent1_wins']} ({stats['agent1_wins']/NUM_EPISODES*100:.1f}%)")
print(f"  Agent 2 Wins: {stats['agent2_wins']} ({stats['agent2_wins']/NUM_EPISODES*100:.1f}%)")
print(f"  Draws: {stats['draws']} ({stats['draws']/NUM_EPISODES*100:.1f}%)")
print()
print("Final Performance (last 100 episodes):")
window = min(100, NUM_EPISODES)
print(f"  Agent 1 Avg Score: {np.mean(all_agent1_scores[-window:]):.2f}")
print(f"  Agent 2 Avg Score: {np.mean(all_agent2_scores[-window:]):.2f}")
print(f"  Agent 1 Avg Reward: {np.mean(all_agent1_rewards[-window:]):.2f}")
print(f"  Agent 2 Avg Reward: {np.mean(all_agent2_rewards[-window:]):.2f}")
print("=" * 50)