# Snake Game Deep RL - Google Colab Setup

This notebook sets up and runs the Snake RL project on Google Colab.

## Features:
- Free GPU access
- Easy setup with one-click install
- Save models to Google Drive
- Visualize results inline

## Step 1: Mount Google Drive (Optional but Recommended)

Mount your Google Drive to save models and results persistently.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Set your drive path
DRIVE_PATH = '/content/drive/MyDrive/drl_snake'  # Change this to your preferred path
import os
os.makedirs(DRIVE_PATH, exist_ok=True)
print(f"Drive mounted! Checkpoint directory: {DRIVE_PATH}")

## Step 2: Install Dependencies

In [None]:
# Install required packages
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install gymnasium matplotlib pyyaml tqdm tensorboard

# Verify GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")

## Step 3: Upload Project Files

**Option A: Upload from your local machine**
- Use the file browser on the left to upload the entire `drl` folder
- Or upload individual files as needed

**Option B: Clone from GitHub (if you've pushed to GitHub)**

In [None]:
# Option B: Clone from GitHub (uncomment and modify if using)
# !git clone https://github.com/yourusername/drl.git /content/drl
# %cd /content/drl

## Step 4: Set Up Project Structure in Colab

If you uploaded files, skip this. Otherwise, create the project structure:

In [None]:
import os
import sys

# Set project directory (adjust if you uploaded to a different location)
PROJECT_DIR = '/content/drl'  # Change if needed

# Create project structure
os.makedirs(f'{PROJECT_DIR}/src/environments', exist_ok=True)
os.makedirs(f'{PROJECT_DIR}/src/agents', exist_ok=True)
os.makedirs(f'{PROJECT_DIR}/src/networks', exist_ok=True)
os.makedirs(f'{PROJECT_DIR}/src/utils', exist_ok=True)
os.makedirs(f'{PROJECT_DIR}/src/experiments', exist_ok=True)
os.makedirs(f'{PROJECT_DIR}/configs', exist_ok=True)
os.makedirs(f'{PROJECT_DIR}/logs/snake', exist_ok=True)
os.makedirs(f'{PROJECT_DIR}/checkpoints/snake', exist_ok=True)

# Add to Python path
sys.path.insert(0, f'{PROJECT_DIR}/src')

print(f"Project directory: {PROJECT_DIR}")
print(f"Python path updated")

## Step 5: Upload Source Files

**Important**: Upload all Python files from your local `src/` directory:
- `src/environments/snake_env.py`
- `src/environments/snake_renderer.py`
- `src/environments/__init__.py`
- `src/agents/dqn_agent.py`
- `src/agents/ppo_discrete_agent.py`
- `src/agents/__init__.py`
- `src/networks/dqn_network.py`
- `src/networks/__init__.py`
- `src/utils/replay_buffer.py`
- `src/utils/training.py`
- `src/utils/visualization.py`
- `src/utils/__init__.py`
- `src/experiments/train_snake.py`
- `src/experiments/evaluate_snake.py`
- `src/experiments/__init__.py`

**Or use the file browser to upload the entire `src/` folder.**

## Step 6: Create Config File

Create the configuration file if you haven't uploaded it:

In [None]:
import yaml

config_content = {
    'environment': {
        'grid_size': 20,
        'state_representation': 'feature',  # 'grid', 'feature', or 'image'
        'initial_length': 3,
        'reward_food': 10.0,
        'reward_death': -10.0,
        'reward_step': -0.1,
        'reward_distance': 0.0
    },
    'dqn': {
        'learning_rate': 1e-4,
        'gamma': 0.99,
        'epsilon_start': 1.0,
        'epsilon_end': 0.01,
        'epsilon_decay': 0.995,
        'replay_buffer_size': 100000,
        'batch_size': 64,
        'target_update_frequency': 1000,
        'network': [128, 128, 64],
        'activation': 'relu'
    },
    'ppo': {
        'learning_rate': 3e-4,
        'gamma': 0.99,
        'gae_lambda': 0.95,
        'clip_epsilon': 0.2,
        'value_coef': 0.5,
        'entropy_coef': 0.01,
        'max_grad_norm': 0.5,
        'update_epochs': 10,
        'batch_size': 64,
        'network': [128, 128, 64],
        'activation': 'relu'
    },
    'training': {
        'algorithm': 'dqn',  # 'dqn' or 'ppo'
        'total_episodes': 2000,  # Reduced for Colab demo
        'eval_frequency': 100,
        'save_frequency': 500,
        'update_frequency': 4,
        'log_dir': f'{PROJECT_DIR}/logs/snake',
        'checkpoint_dir': f'{PROJECT_DIR}/checkpoints/snake',
        'experiment_name': 'snake_dqn_colab'
    },
    'evaluation': {
        'num_episodes': 10,
        'render': False,
        'save_videos': False
    }
}

config_path = f'{PROJECT_DIR}/configs/snake_config.yaml'
with open(config_path, 'w') as f:
    yaml.dump(config_content, f)

print(f"Config file created at: {config_path}")

## Step 7: Verify Imports

Check that all modules can be imported:

In [None]:
# Change to project directory
import os
os.chdir(PROJECT_DIR)

# Import modules
try:
    from environments import SnakeEnv
    from agents import DQNAgent, PPODiscreteAgent
    from utils.training import MetricsTracker, evaluate_agent
    from utils.visualization import plot_training_curves
    print("✓ All modules imported successfully!")
except ImportError as e:
    print(f"✗ Import error: {e}")
    print("Please upload all source files first.")

## Step 8: Quick Test - Create Environment

In [None]:
import numpy as np
import torch

# Set random seed
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# Load config
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

# Create environment
env = SnakeEnv(
    grid_size=config['environment']['grid_size'],
    state_representation=config['environment']['state_representation'],
    initial_length=config['environment']['initial_length'],
    reward_food=config['environment']['reward_food'],
    reward_death=config['environment']['reward_death'],
    reward_step=config['environment']['reward_step'],
    reward_distance=config['environment']['reward_distance']
)

# Test environment
state, info = env.reset()
print(f"✓ Environment created successfully!")
print(f"State shape: {state.shape if hasattr(state, 'shape') else len(state)}")
print(f"Action space: {env.action_space}")
print(f"Initial info: {info}")

## Step 9: Train Agent

Train your agent. This will take some time depending on the number of episodes.

In [None]:
from tqdm import tqdm

# Get state shape
obs_space = env.observation_space
if hasattr(obs_space, 'shape'):
    state_shape = obs_space.shape
else:
    state_shape = (obs_space.n,)

# Create agent
algorithm = config['training']['algorithm'].lower()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Algorithm: {algorithm.upper()}")

if algorithm == "dqn":
    agent = DQNAgent(
        state_shape=state_shape,
        num_actions=env.action_space.n,
        learning_rate=config['dqn']['learning_rate'],
        gamma=config['dqn']['gamma'],
        epsilon_start=config['dqn']['epsilon_start'],
        epsilon_end=config['dqn']['epsilon_end'],
        epsilon_decay=config['dqn']['epsilon_decay'],
        replay_buffer_size=config['dqn']['replay_buffer_size'],
        batch_size=config['dqn']['batch_size'],
        target_update_frequency=config['dqn']['target_update_frequency'],
        hidden_sizes=config['dqn']['network'],
        activation=config['dqn']['activation'],
        state_representation=config['environment']['state_representation'],
        device=device,
        seed=42
    )
elif algorithm == "ppo":
    agent = PPODiscreteAgent(
        state_shape=state_shape,
        num_actions=env.action_space.n,
        learning_rate=config['ppo']['learning_rate'],
        gamma=config['ppo']['gamma'],
        gae_lambda=config['ppo']['gae_lambda'],
        clip_epsilon=config['ppo']['clip_epsilon'],
        value_coef=config['ppo']['value_coef'],
        entropy_coef=config['ppo']['entropy_coef'],
        max_grad_norm=config['ppo']['max_grad_norm'],
        update_epochs=config['ppo']['update_epochs'],
        batch_size=config['ppo']['batch_size'],
        hidden_sizes=config['ppo']['network'],
        activation=config['ppo']['activation'],
        state_representation=config['environment']['state_representation'],
        device=device,
        seed=42
    )
else:
    raise ValueError(f"Unknown algorithm: {algorithm}")

print(f"✓ Agent created: {type(agent).__name__}")

In [None]:
# Training loop
metrics_tracker = MetricsTracker()
total_episodes = config['training']['total_episodes']
eval_frequency = config['training']['eval_frequency']
save_frequency = config['training']['save_frequency']
update_frequency = config['training']['update_frequency']

print(f"Starting training for {total_episodes} episodes...")
print(f"Evaluation every {eval_frequency} episodes")
print(f"Saving checkpoint every {save_frequency} episodes")

best_score = -np.inf

for episode in tqdm(range(total_episodes), desc="Training"):
    state, info = env.reset()
    episode_reward = 0
    episode_length = 0
    done = False
    
    # Collect episode
    while not done:
        if algorithm == "dqn":
            action = agent.act(state, deterministic=False)
            next_state, reward, terminated, truncated, step_info = env.step(action)
            done = terminated or truncated
            agent.store_transition(state, action, reward, next_state, done)
        else:  # ppo
            action, log_prob, value = agent.act(state, deterministic=False)
            next_state, reward, terminated, truncated, step_info = env.step(action)
            done = terminated or truncated
            agent.store_transition(state, action, reward, log_prob, value, done)
        
        episode_reward += reward
        episode_length += 1
        state = next_state
    
    # Train agent
    if algorithm == "dqn":
        if len(agent.replay_buffer) >= agent.batch_size:
            if episode % update_frequency == 0:
                metrics = agent.train_step()
                metrics_tracker.record_episode(
                    reward=episode_reward,
                    score=info.get("score", 0),
                    length=episode_length,
                    loss=metrics.get("loss", None),
                    epsilon=metrics.get("epsilon", None)
                )
            else:
                metrics_tracker.record_episode(
                    reward=episode_reward,
                    score=info.get("score", 0),
                    length=episode_length,
                    epsilon=agent.epsilon
                )
        else:
            metrics_tracker.record_episode(
                reward=episode_reward,
                score=info.get("score", 0),
                length=episode_length,
                epsilon=agent.epsilon
            )
    else:  # ppo
        if episode % update_frequency == 0 and len(agent.states) > 0:
            metrics = agent.train_step()
            metrics_tracker.record_episode(
                reward=episode_reward,
                score=info.get("score", 0),
                length=episode_length,
                loss=metrics.get("loss", None)
            )
        else:
            metrics_tracker.record_episode(
                reward=episode_reward,
                score=info.get("score", 0),
                length=episode_length
            )
    
    # Evaluation
    if (episode + 1) % eval_frequency == 0:
        eval_results = evaluate_agent(env, agent, num_episodes=5, deterministic=True)
        stats = metrics_tracker.get_statistics(window=100)
        
        print(f"\nEpisode {episode + 1}")
        print(f"  Recent Avg Reward: {stats.get('mean_reward', 0):.2f}")
        print(f"  Recent Avg Score: {stats.get('mean_score', 0):.2f}")
        print(f"  Recent Avg Length: {stats.get('mean_length', 0):.2f}")
        print(f"  Eval Avg Score: {eval_results['mean_score']:.2f}")
        print(f"  Eval Max Score: {eval_results['max_score']:.2f}")
        
        # Save best model
        if eval_results['mean_score'] > best_score:
            best_score = eval_results['mean_score']
            checkpoint_path = f'{PROJECT_DIR}/checkpoints/snake/best_model.pth'
            agent.save(checkpoint_path)
            print(f"  ✓ Saved best model (score: {best_score:.2f})")
    
    # Save checkpoint
    if (episode + 1) % save_frequency == 0:
        checkpoint_path = f'{PROJECT_DIR}/checkpoints/snake/checkpoint_ep{episode+1}.pth'
        agent.save(checkpoint_path)
        metrics_tracker.save(f'{PROJECT_DIR}/checkpoints/snake/metrics.json')
        print(f"  ✓ Saved checkpoint at episode {episode + 1}")

print("\n✓ Training complete!")

## Step 10: Visualize Training Progress

In [None]:
# Plot training curves
plot_training_curves(metrics_tracker, window=50)

## Step 11: Evaluate Trained Agent

In [None]:
# Evaluate trained agent
eval_results = evaluate_agent(env, agent, num_episodes=10, deterministic=True)

print("\n" + "="*50)
print("Final Evaluation Results")
print("="*50)
print(f"Mean Reward: {eval_results['mean_reward']:.2f} ± {eval_results['std_reward']:.2f}")
print(f"Max Reward: {eval_results['max_reward']:.2f}")
print(f"Mean Score: {eval_results['mean_score']:.2f} ± {eval_results['std_score']:.2f}")
print(f"Max Score: {eval_results['max_score']:.2f}")
print(f"Mean Length: {eval_results['mean_length']:.2f} ± {eval_results['std_length']:.2f}")
print("="*50)

## Step 12: Save Model to Google Drive

In [None]:
# Save model to Google Drive
if 'DRIVE_PATH' in globals():
    import shutil
    
    # Save final model
    final_model_path = f'{DRIVE_PATH}/snake_final_model.pth'
    agent.save(final_model_path)
    print(f"✓ Final model saved to: {final_model_path}")
    
    # Save metrics
    metrics_path = f'{DRIVE_PATH}/metrics.json'
    metrics_tracker.save(metrics_path)
    print(f"✓ Metrics saved to: {metrics_path}")
    
    # Copy best model if it exists
    best_model_local = f'{PROJECT_DIR}/checkpoints/snake/best_model.pth'
    if os.path.exists(best_model_local):
        best_model_drive = f'{DRIVE_PATH}/snake_best_model.pth'
        shutil.copy(best_model_local, best_model_drive)
        print(f"✓ Best model copied to: {best_model_drive}")
else:
    print("Google Drive not mounted. Models saved locally.")
    print(f"Note: Local files will be deleted when Colab session ends.")

## Step 13: Download Results (Optional)

If you didn't mount Google Drive, download the model files:

In [None]:
# Download model files
from google.colab import files

# Download best model
if os.path.exists(f'{PROJECT_DIR}/checkpoints/snake/best_model.pth'):
    files.download(f'{PROJECT_DIR}/checkpoints/snake/best_model.pth')

# Download metrics
if os.path.exists(f'{PROJECT_DIR}/checkpoints/snake/metrics.json'):
    files.download(f'{PROJECT_DIR}/checkpoints/snake/metrics.json')

## Tips for Colab Usage

1. **GPU Runtime**: Go to Runtime → Change runtime type → Select GPU (T4 or better)

2. **Session Limits**: 
   - Free Colab: ~12 hours max, may disconnect
   - Save checkpoints frequently
   - Use Google Drive for persistence

3. **Memory Management**:
   - Reduce `batch_size` if you get OOM errors
   - Reduce `replay_buffer_size` for DQN
   - Use smaller network sizes

4. **Faster Training**:
   - Use GPU runtime
   - Reduce `total_episodes` for quick tests
   - Use `feature` state representation (faster than `image`)

5. **Resume Training**:
   - Load checkpoint: `agent.load('path/to/checkpoint.pth')`
   - Continue training from saved episode

6. **TensorBoard in Colab**:
   ```python
   %load_ext tensorboard
   %tensorboard --logdir /content/drl/logs/snake
   ```