# Snake Game Deep RL Experiments

This notebook provides interactive exploration and visualization of the Snake RL project.

## Features:
- Train and evaluate DQN and PPO agents
- Visualize training progress
- Compare different algorithms and hyperparameters
- Watch trained agents play the game
- Analyze policy behavior

In [None]:
import sys
import os
sys.path.append('../src')

import numpy as np
import matplotlib.pyplot as plt
import torch
import yaml
from tqdm import tqdm

from environments import SnakeEnv
from agents import DQNAgent, PPODiscreteAgent
from utils.training import MetricsTracker, evaluate_agent
from utils.visualization import plot_training_curves, plot_algorithm_comparison

# Set random seed
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

print("Imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 1. Load Configuration

In [None]:
# Load config
with open('../configs/snake_config.yaml', 'r') as f:
    config = yaml.safe_load(f)

print("Configuration loaded:")
print(f"Grid size: {config['environment']['grid_size']}")
print(f"State representation: {config['environment']['state_representation']}")
print(f"Algorithm: {config['training']['algorithm']}")
print(f"Total episodes: {config['training']['total_episodes']}")

## 2. Create Environment and Agent

In [None]:
# Create environment
env = SnakeEnv(
    grid_size=config['environment']['grid_size'],
    state_representation=config['environment']['state_representation'],
    initial_length=config['environment']['initial_length'],
    reward_food=config['environment']['reward_food'],
    reward_death=config['environment']['reward_death'],
    reward_step=config['environment']['reward_step'],
    reward_distance=config['environment']['reward_distance'],
    render_mode="human"  # Enable rendering
)

# Get state and action dimensions
obs_space = env.observation_space
if hasattr(obs_space, 'shape'):
    state_shape = obs_space.shape
else:
    state_shape = (obs_space.n,)

print(f"State shape: {state_shape}")
print(f"Action space: {env.action_space}")
print(f"Number of actions: {env.action_space.n}")

# Test environment
state, info = env.reset()
print(f"\nInitial state shape: {state.shape if hasattr(state, 'shape') else len(state)}")
print(f"Initial info: {info}")

## 3. Quick Test - Random Agent

In [None]:
# Test with random actions
env.reset()
done = False
steps = 0
total_reward = 0

while not done and steps < 100:
    action = env.action_space.sample()
    state, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated
    total_reward += reward
    steps += 1
    
    # Render every 5 steps
    if steps % 5 == 0:
        env.render()

print(f"Random agent: Score={info.get('score', 0)}, Steps={steps}, Reward={total_reward:.2f}")
env.close()

## 4. Train DQN Agent

In [None]:
# Create DQN agent
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

dqn_agent = DQNAgent(
    state_shape=state_shape,
    num_actions=env.action_space.n,
    learning_rate=config['dqn']['learning_rate'],
    gamma=config['dqn']['gamma'],
    epsilon_start=config['dqn']['epsilon_start'],
    epsilon_end=config['dqn']['epsilon_end'],
    epsilon_decay=config['dqn']['epsilon_decay'],
    replay_buffer_size=config['dqn']['replay_buffer_size'],
    batch_size=config['dqn']['batch_size'],
    target_update_frequency=config['dqn']['target_update_frequency'],
    hidden_sizes=config['dqn']['network'],
    activation=config['dqn']['activation'],
    state_representation=config['environment']['state_representation'],
    device=device,
    seed=42
)

print("DQN agent created!")

In [None]:
# Training loop (reduced episodes for notebook demo)
metrics_tracker = MetricsTracker()
total_episodes = 1000  # Reduced for demo
update_frequency = 4

print(f"Training DQN for {total_episodes} episodes...")

for episode in tqdm(range(total_episodes), desc="Training"):
    state, info = env.reset()
    episode_reward = 0
    episode_length = 0
    done = False
    
    # Collect episode
    while not done:
        action = dqn_agent.act(state, deterministic=False)
        next_state, reward, terminated, truncated, step_info = env.step(action)
        done = terminated or truncated
        
        # Store transition
        dqn_agent.store_transition(state, action, reward, next_state, done)
        
        episode_reward += reward
        episode_length += 1
        state = next_state
    
    # Train agent
    if len(dqn_agent.replay_buffer) >= dqn_agent.batch_size:
        if episode % update_frequency == 0:
            metrics = dqn_agent.train_step()
            metrics_tracker.record_episode(
                reward=episode_reward,
                score=info.get("score", 0),
                length=episode_length,
                loss=metrics.get("loss", None),
                epsilon=metrics.get("epsilon", None)
            )
        else:
            metrics_tracker.record_episode(
                reward=episode_reward,
                score=info.get("score", 0),
                length=episode_length,
                epsilon=dqn_agent.epsilon
            )
    else:
        metrics_tracker.record_episode(
            reward=episode_reward,
            score=info.get("score", 0),
            length=episode_length,
            epsilon=dqn_agent.epsilon
        )
    
    # Print progress
    if (episode + 1) % 100 == 0:
        stats = metrics_tracker.get_statistics(window=100)
        print(f"\nEpisode {episode + 1}")
        print(f"  Avg Reward: {stats.get('mean_reward', 0):.2f}")
        print(f"  Avg Score: {stats.get('mean_score', 0):.2f}")
        print(f"  Avg Length: {stats.get('mean_length', 0):.2f}")
        print(f"  Epsilon: {stats.get('current_epsilon', 0):.3f}")

print("\nTraining complete!")

## 5. Visualize Training Progress

In [None]:
# Plot training curves
plot_training_curves(metrics_tracker, window=50)

## 6. Evaluate Trained Agent

In [None]:
# Evaluate trained agent
eval_results = evaluate_agent(env, dqn_agent, num_episodes=10, deterministic=True)

print("Evaluation Results:")
print("="*50)
print(f"Mean Reward: {eval_results['mean_reward']:.2f} ± {eval_results['std_reward']:.2f}")
print(f"Max Reward: {eval_results['max_reward']:.2f}")
print(f"Mean Score: {eval_results['mean_score']:.2f} ± {eval_results['std_score']:.2f}")
print(f"Max Score: {eval_results['max_score']:.2f}")
print(f"Mean Length: {eval_results['mean_length']:.2f} ± {eval_results['std_length']:.2f}")
print("="*50)

## 7. Watch Trained Agent Play

In [None]:
# Watch trained agent play
env.render_mode = "human"
state, info = env.reset()
done = False
steps = 0

print("Watching trained agent play...")
print("Close the plot window to continue")

while not done and steps < 500:
    action = dqn_agent.act(state, deterministic=True)
    state, reward, terminated, truncated, step_info = env.step(action)
    done = terminated or truncated
    steps += 1
    
    # Render every step
    env.render()
    
    if done:
        print(f"\nGame Over! Score: {info.get('score', 0)}, Steps: {steps}")
        break

env.close()