# Snake DQN Training (Colab)

Use this notebook on Google Colab to train the Snake DQN and download the checkpoint. Steps:
1. Set your git repository URL in the clone cell below.
2. Run the setup/install cell to clone the repo and install dependencies.
3. Adjust hyperparameters in the args cell if desired.
4. Run training.
5. Download the trained model.



In [1]:
# Clone the repository (set your git URL)
import os

REPO_URL = os.environ.get("SNAKE_REPO_URL", "https://github.com/SyntheticVis-Umut/Snake.git")  # change if using a fork
GITHUB_TOKEN = os.environ.get("SNAKE_GITHUB_TOKEN") or os.environ.get("GITHUB_TOKEN")

# Clone and setup
import shutil
import sys
import subprocess

if not REPO_URL:
    raise ValueError("Set REPO_URL (or env SNAKE_REPO_URL) before running this cell.")

if os.path.exists('/content/Snake'):
    shutil.rmtree('/content/Snake')

clone_url = REPO_URL
masked_url = REPO_URL
if GITHUB_TOKEN and REPO_URL.startswith("https://github.com/"):
    # Inject token for private repos (note: token will appear in Colab logs)
    clone_url = REPO_URL.replace("https://", f"https://{GITHUB_TOKEN}@")
    masked_url = REPO_URL.replace("https://", "https://<TOKEN>@")
    print("Using token from env for clone.")

clone_cmd = ["git", "clone", "--depth", "1", clone_url, "/content/Snake"]
print("Clone command (token masked):", " ".join(clone_cmd).replace(clone_url, masked_url))
result = subprocess.run(clone_cmd, capture_output=True, text=True)
if result.returncode != 0:
    print("git clone stdout:\n", result.stdout)
    print("git clone stderr:\n", result.stderr)
    sys.exit("git clone failed; check REPO_URL, token (if private), and permissions")

os.chdir('/content/Snake')

# Install deps (CUDA wheels on Colab are handled automatically by torch)
%pip install -q pygame torch numpy tqdm

print('CWD:', os.getcwd())
print('Repository cloned and ready!')


Clone command (token masked): git clone --depth 1 https://github.com/SyntheticVis-Umut/Snake.git /content/Snake
CWD: /content/Snake
Repository cloned and ready!


In [2]:
# Quick import test
from train import train
from types import SimpleNamespace
print('Imports ok')


pygame 2.6.1 (SDL 2.28.4, Python 3.12.12)
Hello from the pygame community. https://www.pygame.org/contribute.html
Imports ok


In [None]:
# Configure hyperparameters
args = SimpleNamespace(
    episodes=100000,
    max_steps=1000,
    buffer_size=50000,
    batch_size=512,
    gamma=0.99,
    lr=5e-4,
    eps_start=0.5,
    eps_end=0.02,
    eps_decay=5000,
    target_update=1000,
    warmup=2000,
    grid=(20, 20),
    seed=42,
    save_path="models/dqn_snake_colab.pt",
    resume=None,  # set to a checkpoint path to continue training
    device="cuda",  # force GPU (A100 on Colab); use "auto" to fall back
    grad_clip=1.0,  # gradient clipping for stability; set <=0 to disable
    double_dqn=True,  # use Double DQN targets for better stability
)


In [None]:
# Train
train(args)
print('Training done, saved to', args.save_path)


Using CUDA device: NVIDIA A100-SXM4-80GB
Episode 1/100000 Reward: -2.00 Epsilon: 0.498 Best: -2.00
Episode 10/100000 Reward: -3.80 Epsilon: 0.470 Best: -2.00
Episode 20/100000 Reward: -11.90 Epsilon: 0.444 Best: -2.00
Episode 30/100000 Reward: -12.40 Epsilon: 0.421 Best: -2.00
Episode 40/100000 Reward: -13.80 Epsilon: 0.393 Best: -2.00
Episode 50/100000 Reward: -12.50 Epsilon: 0.377 Best: -2.00
Episode 60/100000 Reward: -12.50 Epsilon: 0.366 Best: -2.00
Episode 70/100000 Reward: -11.10 Epsilon: 0.352 Best: -2.00
Episode 80/100000 Reward: -11.50 Epsilon: 0.334 Best: -2.00
Episode 90/100000 Reward: -12.20 Epsilon: 0.299 Best: -2.00
Episode 100/100000 Reward: -6.00 Epsilon: 0.270 Best: 0.30
Episode 110/100000 Reward: -3.50 Epsilon: 0.246 Best: 10.50
Episode 120/100000 Reward: -12.50 Epsilon: 0.211 Best: 14.40
Episode 130/100000 Reward: -11.30 Epsilon: 0.196 Best: 14.40
Episode 140/100000 Reward: 2.70 Epsilon: 0.172 Best: 46.30
Episode 150/100000 Reward: -2.80 Epsilon: 0.152 Best: 53.90
Ep

In [None]:
# Download the trained model (Colab)
from google.colab import files
files.download(args.save_path)

