# Snake DQN Training (Colab)

Use this notebook on Google Colab to train the Snake DQN and download the checkpoint. Steps:
1. Set your git repository URL in the clone cell below.
2. Run the setup/install cell to clone the repo and install dependencies.
3. Adjust hyperparameters in the args cell if desired.
4. Run training.
5. Download the trained model.



In [3]:
# Clone the repository (set your git URL)
import os

REPO_URL = os.environ.get("SNAKE_REPO_URL", "https://github.com/SyntheticVis-Umut/Snake.git")  # change if using a fork

# Clone and setup
import shutil
import sys

if not REPO_URL:
    raise ValueError("Set REPO_URL (or env SNAKE_REPO_URL) before running this cell.")

if os.path.exists('/content/Snake'):
    shutil.rmtree('/content/Snake')

exit_code = os.system(f'git clone {REPO_URL} /content/Snake')
if exit_code != 0:
    sys.exit("git clone failed; check REPO_URL and permissions")

os.chdir('/content/Snake')

# Install deps (CUDA wheels on Colab are handled automatically by torch)
%pip install -q pygame torch numpy tqdm

print('CWD:', os.getcwd())
print('Repository cloned and ready!')


CWD: /content/Snake
Repository cloned and ready!


In [None]:
# Quick import test
from train import train
from types import SimpleNamespace
print('Imports ok')


In [None]:
# Configure hyperparameters
args = SimpleNamespace(
    episodes=100000,
    max_steps=1000,
    buffer_size=50000,
    batch_size=512,
    gamma=0.99,
    lr=5e-4,
    eps_start=0.5,
    eps_end=0.02,
    eps_decay=5000,
    target_update=1000,
    warmup=2000,
    grid=(20, 20),
    seed=42,
    save_path="models/dqn_snake_colab.pt",
    resume=None,  # set to a checkpoint path to continue training
    device="cuda",  # force GPU (A100 on Colab); use "auto" to fall back
)


In [None]:
# Train
train(args)
print('Training done, saved to', args.save_path)


In [None]:
# Download the trained model (Colab)
from google.colab import files
files.download(args.save_path)

