# Snake DQN Training (Colab)

Use this notebook on Google Colab to train the Snake DQN and download the checkpoint. Steps:
1. Set your git repository URL in the clone cell below.
2. Run the setup/install cell to clone the repo and install dependencies.
3. Adjust hyperparameters in the args cell if desired.
4. Run training.
5. Download the trained model.

## New Features (Latest Update):
- **CNN-based observations**: Uses image input (3 channels: body, head, food) instead of 11 boolean features - better at learning spatial patterns and avoiding self-traps
- **N-step returns**: Multi-step learning for longer-term planning (default: 5 steps)
- **Double DQN**: Already enabled for better stability
- **Huber loss + gradient clipping**: For stable training
- **MCTS Planning** (NEW!): After training, use `--planner mcts` in `watch_agent.py` to perform Monte Carlo Tree Search with 10-100+ step lookahead for smarter decisions

Default settings use CNN + 1-step for stability. Change `observation_type="features"` to use the old approach.

## Using Planning for Smarter Play:
After training, you can use the trained model with lookahead planning:
- **MCTS**: `python watch_agent.py --model models/dqn_snake_colab.pt --planner mcts --mcts-simulations 200 --mcts-depth 100`
- **Simple Lookahead**: `python watch_agent.py --model models/dqn_snake_colab.pt --planner lookahead --lookahead-depth 20`
- **Greedy (no planning)**: `python watch_agent.py --model models/dqn_snake_colab.pt --planner none`


In [None]:
# Mount Google Drive (for saving checkpoints)
from google.colab import drive
drive.mount('/content/drive')
print('✓ Google Drive mounted at /content/drive')

# Clone the repository (set your git URL)
import os

REPO_URL = os.environ.get("SNAKE_REPO_URL", "https://github.com/SyntheticVis-Umut/Snake.git")  # change if using a fork
GITHUB_TOKEN = os.environ.get("SNAKE_GITHUB_TOKEN") or os.environ.get("GITHUB_TOKEN")

# Clone and setup
import shutil
import sys
import subprocess

if not REPO_URL:
    raise ValueError("Set REPO_URL (or env SNAKE_REPO_URL) before running this cell.")

# Ensure /content directory exists (should exist in Colab, but check anyway)
os.makedirs('/content', exist_ok=True)

if os.path.exists('/content/Snake'):
    shutil.rmtree('/content/Snake')

clone_url = REPO_URL
masked_url = REPO_URL
if GITHUB_TOKEN and REPO_URL.startswith("https://github.com/"):
    # Inject token for private repos (note: token will appear in Colab logs)
    clone_url = REPO_URL.replace("https://", f"https://{GITHUB_TOKEN}@")
    masked_url = REPO_URL.replace("https://", "https://<TOKEN>@")
    print("Using token from env for clone.")

clone_cmd = ["git", "clone", "--depth", "1", clone_url, "/content/Snake"]
print("Clone command (token masked):", " ".join(clone_cmd).replace(clone_url, masked_url))
# Use cwd parameter to ensure git runs from a valid directory
result = subprocess.run(clone_cmd, capture_output=True, text=True, cwd='/content')
if result.returncode != 0:
    print("git clone stdout:\n", result.stdout)
    print("git clone stderr:\n", result.stderr)
    sys.exit("git clone failed; check REPO_URL, token (if private), and permissions")

os.chdir('/content/Snake')

# Install deps (CUDA wheels on Colab are handled automatically by torch)
%pip install -q pygame numpy tqdm

print('CWD:', os.getcwd())
print('Repository cloned and ready!')


Clone command (token masked): git clone --depth 1 https://github.com/SyntheticVis-Umut/Snake.git /content/Snake
CWD: /content/Snake
Repository cloned and ready!


In [2]:
# Quick import test
from train import train
from types import SimpleNamespace
print('Imports ok')


pygame 2.6.1 (SDL 2.28.4, Python 3.12.12)
Hello from the pygame community. https://www.pygame.org/contribute.html
Imports ok


In [None]:
# Configure hyperparameters
args = SimpleNamespace(
    episodes=30000,
    max_steps=1000,
    buffer_size=200000,
    batch_size=128,
    gamma=0.99,
    lr=2.5e-4,
    eps_start=1.0,
    eps_end=0.1,
    eps_decay=120000,
    target_update=500,
    warmup=20000,
    grid=(20, 20),
    seed=42,
    save_path="models/dqn_snake_colab.pt",
    resume=None,  # set to a checkpoint path to continue training
    device="auto",  # use GPU if available, else CPU
    grad_clip=1.0,  # gradient clipping for stability; set <=0 to disable
    double_dqn=True,  # Double DQN for stability
    dueling=True,  # Dueling heads for better value/advantage separation
    per_alpha=0.4,  # Prioritized replay alpha (lower reduces overfitting to noisy TD errors)
    per_beta=0.2,   # PER beta; increase later if desired
    eval_every=500,  # run greedy eval every N episodes (0 disables)
    eval_episodes=10,  # number of greedy episodes per eval
    observation_type="image",  # "features" (11 booleans) or "image" (CNN input - RECOMMENDED)
    n_step=1,  # standard 1-step; more stable early on
    save_every=200,  # periodic checkpointing for resume
)


In [None]:
# Train with automatic Drive backup
# Checkpoints will be saved to both local path and Google Drive every save_every episodes
import os
import torch

# Verify Google Drive is mounted
drive_mount = "/content/drive"
if not os.path.exists(drive_mount):
    raise RuntimeError(f"Google Drive not mounted! Please run the first cell to mount Drive.")

# Drive checkpoint path
drive_save_path = "/content/drive/MyDrive/" + os.path.basename(args.save_path)
args.drive_save_path = drive_save_path
print(f"Drive checkpoint path: {drive_save_path}")

# Check for existing checkpoint in Drive to resume from
if os.path.exists(drive_save_path):
    args.resume = drive_save_path
    print(f"✓ Found existing checkpoint in Drive: {drive_save_path}")
    try:
        ckpt = torch.load(drive_save_path, map_location="cpu")
        print("Last saved episode:", ckpt.get("episode"))
        print("Last frame idx:", ckpt.get("frame_idx"))
        print("Best score:", ckpt.get("best_score"))
    except Exception as e:
        print(f"Warning: Could not read checkpoint info: {e}")
else:
    print("No existing checkpoint found in Drive. Starting fresh training.")

print(f"Periodic saves every {args.save_every} episodes (if save_every > 0)")

train(args, drive_save_path=drive_save_path)

print("Training done, saved to", args.save_path)
print(f"Final checkpoint also saved to Drive: {drive_save_path}")


Using CUDA device: NVIDIA A100-SXM4-80GB
Episode 1/500 Reward: -11.60 Epsilon: 0.599 Best: -11.60
Episode 10/500 Reward: -14.60 Epsilon: 0.580 Best: -11.30
Episode 20/500 Reward: -12.40 Epsilon: 0.559 Best: -4.20
Episode 30/500 Reward: -11.30 Epsilon: 0.537 Best: -4.20
Episode 40/500 Reward: -11.20 Epsilon: 0.524 Best: -4.20
Episode 50/500 Reward: -13.10 Epsilon: 0.508 Best: -4.20
Episode 60/500 Reward: -13.80 Epsilon: 0.486 Best: -4.20
Episode 70/500 Reward: -13.60 Epsilon: 0.473 Best: -4.20
Episode 80/500 Reward: -13.80 Epsilon: 0.454 Best: -4.20
Episode 90/500 Reward: -11.90 Epsilon: 0.440 Best: -4.20
[Eval @ episode 100] mean: -10.97 median: -10.90 max: -10.90 std: 0.15
Episode 100/500 Reward: -15.70 Epsilon: 0.425 Best: -4.20
Episode 110/500 Reward: -12.60 Epsilon: 0.415 Best: -4.20
Episode 120/500 Reward: -13.60 Epsilon: 0.400 Best: -4.20
Episode 130/500 Reward: -11.10 Epsilon: 0.390 Best: -1.30
Episode 140/500 Reward: -15.20 Epsilon: 0.378 Best: -1.30
Episode 150/500 Reward: -0.

In [None]:
# Download the trained model file
# This will automatically download the .pt file to your local machine
from google.colab import files
import os

model_path = args.save_path
if os.path.exists(model_path):
    print(f'Downloading {model_path}...')
    files.download(model_path)
    print('✓ Model downloaded successfully!')
    print(f'File saved as: {os.path.basename(model_path)}')
else:
    print(f'⚠ Warning: Model file not found at {model_path}')
    print('Make sure training completed successfully.')

In [None]:
# Manual greedy evaluation of the latest checkpoint
import torch
from src.dqn import QNetwork, CNNQNetwork
from src.env import SnakeEnv
from train import evaluate_policy

ckpt = torch.load(args.save_path, map_location=args.device)
checkpoint_args = ckpt.get("args", {})
observation_type = checkpoint_args.get("observation_type", args.observation_type)
grid_size = checkpoint_args.get("grid", args.grid)
dueling = checkpoint_args.get("dueling", args.dueling)

env = SnakeEnv(
    grid_size=tuple(grid_size),
    render_mode=None,
    seed=args.seed + 999,
    observation_type=observation_type,
)
action_dim = len(env.ACTIONS)

policy_state = ckpt["policy_state_dict"]

# Determine model type from checkpoint
if observation_type == "image" or any(k.startswith("conv.") for k in policy_state.keys()):
    # CNN model
    policy_net = CNNQNetwork(grid_size=tuple(grid_size), output_dim=action_dim, dueling=dueling).to(args.device)
    print("Using CNN model (image observations)")
else:
    # MLP model
    sample_state = env.reset()
    state_dim = sample_state.shape[0]
    hidden_size = policy_state["feature.0.weight"].shape[0]
    policy_net = QNetwork(state_dim, action_dim, hidden=hidden_size, dueling=dueling).to(args.device)
    print(f"Using MLP model (feature observations, hidden={hidden_size}, dueling={dueling})")

policy_net.load_state_dict(policy_state)

mean_r, median_r, max_r, std_r = evaluate_policy(
    policy_net, env, episodes=20, device=args.device, max_steps=args.max_steps, observation_type=observation_type
)
print({"mean": mean_r, "median": median_r, "max": max_r, "std": std_r})

env.close()


{'mean': 229.625, 'median': 237.3000030517578, 'max': 378.0, 'std': 76.93509674072266}


In [12]:
# Save model to Google Drive (Alternative download method)
# Run this cell if direct download didn't work
from google.colab import drive
import shutil

# Mount Google Drive (you'll need to authorize once)
drive.mount('/content/drive')

# Copy model to Drive
drive_path = '/content/drive/MyDrive/dqn_snake_colab.pt'
shutil.copy(args.save_path, drive_path)
print(f'✓ Model saved to Google Drive: {drive_path}')
print('Now download it from Google Drive to your local machine!')



ValueError: mount failed