# Snake DQN Training (Colab)

Use this notebook on Google Colab to train the Snake DQN and download the checkpoint. Steps:
1. Set your git repository URL in the clone cell below.
2. Run the setup/install cell to clone the repo and install dependencies.
3. Adjust hyperparameters in the args cell if desired.
4. Run training.
5. Download the trained model.



In [1]:
# Clone the repository (set your git URL)
import os

REPO_URL = os.environ.get("SNAKE_REPO_URL", "https://github.com/SyntheticVis-Umut/Snake.git")  # change if using a fork
GITHUB_TOKEN = os.environ.get("SNAKE_GITHUB_TOKEN") or os.environ.get("GITHUB_TOKEN")

# Clone and setup
import shutil
import sys
import subprocess

if not REPO_URL:
    raise ValueError("Set REPO_URL (or env SNAKE_REPO_URL) before running this cell.")

if os.path.exists('/content/Snake'):
    shutil.rmtree('/content/Snake')

clone_url = REPO_URL
masked_url = REPO_URL
if GITHUB_TOKEN and REPO_URL.startswith("https://github.com/"):
    # Inject token for private repos (note: token will appear in Colab logs)
    clone_url = REPO_URL.replace("https://", f"https://{GITHUB_TOKEN}@")
    masked_url = REPO_URL.replace("https://", "https://<TOKEN>@")
    print("Using token from env for clone.")

clone_cmd = ["git", "clone", "--depth", "1", clone_url, "/content/Snake"]
print("Clone command (token masked):", " ".join(clone_cmd).replace(clone_url, masked_url))
result = subprocess.run(clone_cmd, capture_output=True, text=True)
if result.returncode != 0:
    print("git clone stdout:\n", result.stdout)
    print("git clone stderr:\n", result.stderr)
    sys.exit("git clone failed; check REPO_URL, token (if private), and permissions")

os.chdir('/content/Snake')

# Install deps (CUDA wheels on Colab are handled automatically by torch)
%pip install -q pygame torch numpy tqdm

print('CWD:', os.getcwd())
print('Repository cloned and ready!')


Clone command (token masked): git clone --depth 1 https://github.com/SyntheticVis-Umut/Snake.git /content/Snake
CWD: /content/Snake
Repository cloned and ready!


In [2]:
# Quick import test
from train import train
from types import SimpleNamespace
print('Imports ok')


pygame 2.6.1 (SDL 2.28.4, Python 3.12.12)
Hello from the pygame community. https://www.pygame.org/contribute.html
Imports ok


In [3]:
# Configure hyperparameters
args = SimpleNamespace(
    episodes=500,
    max_steps=10000,
    buffer_size=120000,
    batch_size=512,
    gamma=0.995,
    lr=2e-4,
    eps_start=0.6,
    eps_end=0.01,
    eps_decay=9000,
    target_update=500,
    warmup=6000,
    grid=(20, 20),
    seed=42,
    save_path="models/dqn_snake_colab.pt",
    resume=None,  # set to a checkpoint path to continue training
    device="cuda",  # force GPU (A100 on Colab); use "auto" to fall back
    grad_clip=1.0,  # gradient clipping for stability; set <=0 to disable
    double_dqn=True,  # use Double DQN targets for better stability
    eval_every=100,  # run greedy eval every N episodes (0 disables)
    eval_episodes=10,  # number of greedy episodes per eval
)


In [4]:
# Train
train(args)
print('Training done, saved to', args.save_path)


Using CUDA device: NVIDIA A100-SXM4-80GB
Episode 1/500 Reward: -11.60 Epsilon: 0.599 Best: -11.60
Episode 10/500 Reward: -14.60 Epsilon: 0.580 Best: -11.30
Episode 20/500 Reward: -12.40 Epsilon: 0.559 Best: -4.20
Episode 30/500 Reward: -11.30 Epsilon: 0.537 Best: -4.20
Episode 40/500 Reward: -11.20 Epsilon: 0.524 Best: -4.20
Episode 50/500 Reward: -13.10 Epsilon: 0.508 Best: -4.20
Episode 60/500 Reward: -13.80 Epsilon: 0.486 Best: -4.20
Episode 70/500 Reward: -13.60 Epsilon: 0.473 Best: -4.20
Episode 80/500 Reward: -13.80 Epsilon: 0.454 Best: -4.20
Episode 90/500 Reward: -11.90 Epsilon: 0.440 Best: -4.20
[Eval @ episode 100] mean: -10.97 median: -10.90 max: -10.90 std: 0.15
Episode 100/500 Reward: -15.70 Epsilon: 0.425 Best: -4.20
Episode 110/500 Reward: -12.60 Epsilon: 0.415 Best: -4.20
Episode 120/500 Reward: -13.60 Epsilon: 0.400 Best: -4.20
Episode 130/500 Reward: -11.10 Epsilon: 0.390 Best: -1.30
Episode 140/500 Reward: -15.20 Epsilon: 0.378 Best: -1.30
Episode 150/500 Reward: -0.

In [5]:
# Manual greedy evaluation of the latest checkpoint
import torch
from src.dqn import QNetwork
from src.env import SnakeEnv
from train import evaluate_policy

ckpt = torch.load(args.save_path, map_location=args.device)

env = SnakeEnv(grid_size=tuple(args.grid), render_mode=None, seed=args.seed + 999)
state_dim = env.reset().shape[0]
action_dim = len(env.ACTIONS)
policy_net = QNetwork(state_dim, action_dim).to(args.device)
policy_net.load_state_dict(ckpt["policy_state_dict"])

mean_r, median_r, max_r, std_r = evaluate_policy(
    policy_net, env, episodes=20, device=args.device, max_steps=args.max_steps
)
print({"mean": mean_r, "median": median_r, "max": max_r, "std": std_r})

env.close()



{'mean': 229.625, 'median': 237.3000030517578, 'max': 378.0, 'std': 76.93509674072266}


In [10]:
# Manual greedy evaluation of the latest checkpoint
import torch
from src.dqn import QNetwork
from src.env import SnakeEnv
from train import evaluate_policy

ckpt = torch.load(args.save_path, map_location=args.device)

env = SnakeEnv(grid_size=tuple(args.grid), render_mode=None, seed=args.seed + 999)
state_dim = env.reset().shape[0]
action_dim = len(env.ACTIONS)
policy_net = QNetwork(state_dim, action_dim).to(args.device)
policy_net.load_state_dict(ckpt["policy_state_dict"])

mean_r, median_r, max_r, std_r = evaluate_policy(
    policy_net, env, episodes=20, device=args.device, max_steps=args.max_steps
)
print({"mean": mean_r, "median": median_r, "max": max_r, "std": std_r})

env.close()



{'mean': 229.625, 'median': 237.3000030517578, 'max': 378.0, 'std': 76.93509674072266}


In [11]:
# Download the trained model (Colab)
# Option 1: Direct browser download (works in Colab web interface)
# Option 2: Save to Google Drive (recommended for VS Code + Colab)

import os
from google.colab import files

# Check if file exists
if not os.path.exists(args.save_path):
    print(f"ERROR: Model file not found at {args.save_path}")
    print("Make sure training completed and saved a checkpoint.")
else:
    file_size = os.path.getsize(args.save_path)
    print(f"Model file found: {args.save_path} ({file_size:,} bytes)")
    print("\n" + "="*60)
    print("CHOOSE DOWNLOAD METHOD:")
    print("="*60)
    
    # Option 1: Direct download
    print("\n[Option 1] Direct browser download:")
    print("  - Works in Colab web interface")
    print("  - May not work in VS Code + Colab extension")
    print("  - Check your browser's download folder")
    try:
        files.download(args.save_path)
        print("  ✓ Download initiated!")
    except Exception as e:
        print(f"  ✗ Direct download failed: {e}")
        print("  → Try Option 2 (Google Drive) instead")
    
    # Option 2: Save to Google Drive
    print("\n[Option 2] Save to Google Drive (RECOMMENDED):")
    print("  - Works everywhere (Colab web, VS Code, etc.)")
    print("  - Persistent storage")
    print("  - Then download from Drive to your computer")
    print("\n  Run this cell to mount Drive and save:")
    print("  " + "-"*56)
    print("  from google.colab import drive")
    print("  drive.mount('/content/drive')")
    print("  import shutil")
    print(f"  shutil.copy('{args.save_path}', '/content/drive/MyDrive/dqn_snake_colab.pt')")
    print("  print('Model saved to Google Drive! Download from Drive.')")
    print("  " + "-"*56)
    
    print("\n" + "="*60)



Model file found: models/dqn_snake_colab.pt (77,493 bytes)

CHOOSE DOWNLOAD METHOD:

[Option 1] Direct browser download:
  - Works in Colab web interface
  - May not work in VS Code + Colab extension
  - Check your browser's download folder


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

  ✓ Download initiated!

[Option 2] Save to Google Drive (RECOMMENDED):
  - Works everywhere (Colab web, VS Code, etc.)
  - Persistent storage
  - Then download from Drive to your computer

  Run this cell to mount Drive and save:
  --------------------------------------------------------
  from google.colab import drive
  drive.mount('/content/drive')
  import shutil
  shutil.copy('models/dqn_snake_colab.pt', '/content/drive/MyDrive/dqn_snake_colab.pt')
  print('Model saved to Google Drive! Download from Drive.')
  --------------------------------------------------------



In [12]:
# Save model to Google Drive (Alternative download method)
# Run this cell if direct download didn't work
from google.colab import drive
import shutil

# Mount Google Drive (you'll need to authorize once)
drive.mount('/content/drive')

# Copy model to Drive
drive_path = '/content/drive/MyDrive/dqn_snake_colab.pt'
shutil.copy(args.save_path, drive_path)
print(f'✓ Model saved to Google Drive: {drive_path}')
print('Now download it from Google Drive to your local machine!')



ValueError: mount failed