# Snake DQN Training (Colab)

Use this notebook on Google Colab to train the Snake DQN and download the checkpoint. Steps:
1. Set your git repository URL in the clone cell below.
2. Run the setup/install cell to clone the repo and install dependencies.
3. Adjust hyperparameters in the args cell if desired.
4. Run training.
5. Download the trained model.



In [1]:
# Clone the repository (set your git URL)
import os

REPO_URL = os.environ.get("SNAKE_REPO_URL", "https://github.com/SyntheticVis-Umut/Snake.git")  # change if using a fork
GITHUB_TOKEN = os.environ.get("SNAKE_GITHUB_TOKEN") or os.environ.get("GITHUB_TOKEN")

# Clone and setup
import shutil
import sys
import subprocess

if not REPO_URL:
    raise ValueError("Set REPO_URL (or env SNAKE_REPO_URL) before running this cell.")

if os.path.exists('/content/Snake'):
    shutil.rmtree('/content/Snake')

clone_url = REPO_URL
masked_url = REPO_URL
if GITHUB_TOKEN and REPO_URL.startswith("https://github.com/"):
    # Inject token for private repos (note: token will appear in Colab logs)
    clone_url = REPO_URL.replace("https://", f"https://{GITHUB_TOKEN}@")
    masked_url = REPO_URL.replace("https://", "https://<TOKEN>@")
    print("Using token from env for clone.")

clone_cmd = ["git", "clone", "--depth", "1", clone_url, "/content/Snake"]
print("Clone command (token masked):", " ".join(clone_cmd).replace(clone_url, masked_url))
result = subprocess.run(clone_cmd, capture_output=True, text=True)
if result.returncode != 0:
    print("git clone stdout:\n", result.stdout)
    print("git clone stderr:\n", result.stderr)
    sys.exit("git clone failed; check REPO_URL, token (if private), and permissions")

os.chdir('/content/Snake')

# Install deps (CUDA wheels on Colab are handled automatically by torch)
%pip install -q pygame torch numpy tqdm

print('CWD:', os.getcwd())
print('Repository cloned and ready!')


Clone command (token masked): git clone --depth 1 https://github.com/SyntheticVis-Umut/Snake.git /content/Snake
CWD: /content/Snake
Repository cloned and ready!


In [2]:
# Quick import test
from train import train
from types import SimpleNamespace
print('Imports ok')


pygame 2.6.1 (SDL 2.28.4, Python 3.12.12)
Hello from the pygame community. https://www.pygame.org/contribute.html
Imports ok


In [None]:
# Configure hyperparameters
args = SimpleNamespace(
    episodes=20000,
    max_steps=1000,
    buffer_size=100000,
    batch_size=256,
    gamma=0.995,
    lr=3e-4,
    eps_start=0.6,
    eps_end=0.01,
    eps_decay=8000,
    target_update=500,
    warmup=4000,
    grid=(20, 20),
    seed=42,
    save_path="models/dqn_snake_colab.pt",
    resume=None,  # set to a checkpoint path to continue training
    device="cuda",  # force GPU (A100 on Colab); use "auto" to fall back
    grad_clip=1.0,  # gradient clipping for stability; set <=0 to disable
    double_dqn=True,  # use Double DQN targets for better stability
    eval_every=100,  # run greedy eval every N episodes (0 disables)
    eval_episodes=10,  # number of greedy episodes per eval
)


In [4]:
# Train
train(args)
print('Training done, saved to', args.save_path)


Using CUDA device: NVIDIA A100-SXM4-80GB
Episode 1/20000 Reward: -11.60 Epsilon: 0.599 Best: -11.60
Episode 10/20000 Reward: -14.60 Epsilon: 0.577 Best: -11.30
Episode 20/20000 Reward: -12.10 Epsilon: 0.555 Best: -11.20
Episode 30/20000 Reward: -15.90 Epsilon: 0.530 Best: -8.30
Episode 40/20000 Reward: -14.30 Epsilon: 0.515 Best: -8.30
Episode 50/20000 Reward: -12.10 Epsilon: 0.496 Best: -8.30
Episode 60/20000 Reward: -2.90 Epsilon: 0.477 Best: -2.90
Episode 70/20000 Reward: -14.40 Epsilon: 0.456 Best: -2.90
Episode 80/20000 Reward: -11.20 Epsilon: 0.441 Best: -2.90
Episode 90/20000 Reward: -13.20 Epsilon: 0.427 Best: -2.90
Episode 100/20000 Reward: -13.10 Epsilon: 0.411 Best: -2.90
Episode 110/20000 Reward: -12.10 Epsilon: 0.395 Best: -2.90
Episode 120/20000 Reward: -15.80 Epsilon: 0.380 Best: -2.90
Episode 130/20000 Reward: -20.70 Epsilon: 0.361 Best: -2.90
Episode 140/20000 Reward: -15.30 Epsilon: 0.338 Best: 5.20
Episode 150/20000 Reward: -1.70 Epsilon: 0.322 Best: 5.20
Episode 160

KeyboardInterrupt: 

In [None]:
# Manual greedy evaluation of the latest checkpoint
import torch
from src.dqn import QNetwork
from src.env import SnakeEnv
from train import evaluate_policy

ckpt = torch.load(args.save_path, map_location=args.device)

env = SnakeEnv(grid_size=tuple(args.grid), render_mode=None, seed=args.seed + 999)
state_dim = env.reset().shape[0]
action_dim = len(env.ACTIONS)
policy_net = QNetwork(state_dim, action_dim).to(args.device)
policy_net.load_state_dict(ckpt["policy_state_dict"])

mean_r, median_r, max_r, std_r = evaluate_policy(
    policy_net, env, episodes=20, device=args.device, max_steps=args.max_steps
)
print({"mean": mean_r, "median": median_r, "max": max_r, "std": std_r})

env.close()



In [None]:
# Download the trained model (Colab)
from google.colab import files
files.download(args.save_path)

