# Out-of-core MoE + ALE (Atari)

This Colab notebook:

- mounts Google Drive (for persistence)
- clones or updates the repo into Drive
- installs dependencies (Gymnasium + ALE)
- runs `train_ale.py` with persistent `--run_dir` and `--disk_root`
- auto-resumes if a checkpoint exists

> If the runtime disconnects, rerun **Mount Drive** + **Run training**. It will resume automatically.


In [None]:
# --- 1) Mount Google Drive (required for persistence) ---
from google.colab import drive
drive.mount("/content/drive", force_remount=False)

from pathlib import Path
import os, time

BASE_DIR = Path("/content/drive/MyDrive/Colab_Notebooks")
REPO_NAME = "ewc_moe_atari"
REPO_URL  = "https://github.com/RespectableGlioma/ewc_moe_atari.git"  # change if you fork
REPO_DIR  = BASE_DIR / REPO_NAME

# Choose a run name (folder) for outputs/checkpoints/expert-store
RUN_NAME  = "ale_run1"  # change to ale_run2, etc.
RUN_DIR   = REPO_DIR / RUN_NAME

# Persist the cold tier (expert store) on Drive
DISK_ROOT = RUN_DIR / "ooc_disk"

BASE_DIR.mkdir(parents=True, exist_ok=True)
RUN_DIR.mkdir(parents=True, exist_ok=True)
DISK_ROOT.mkdir(parents=True, exist_ok=True)

print("REPO_DIR :", REPO_DIR)
print("RUN_DIR  :", RUN_DIR)
print("DISK_ROOT:", DISK_ROOT)


In [None]:
# --- 2) Clone or pull the repo into Drive ---
import subprocess, sys

def sh(cmd):
    print("+", cmd)
    subprocess.check_call(cmd, shell=True)

if not REPO_DIR.exists():
    sh(f"git clone {REPO_URL} {REPO_DIR}")
else:
    # If it's already a git repo, update it
    if (REPO_DIR / '.git').exists():
        sh(f"git -C {REPO_DIR} pull --ff-only")
    else:
        print(f"Using existing folder (not a git repo): {REPO_DIR}")

print("Repo ready:", REPO_DIR)

# Quick sanity check: ensure the repo has GPU-prefetch support (train_ale expects it)
import re
ts = (REPO_DIR / "ooc_moe" / "tiered_store.py")
if ts.exists():
    txt = ts.read_text()
    if re.search(r"def\s+prefetch_to_gpu\s*\(", txt) is None:
        raise RuntimeError(
            "This repo copy is missing TieredExpertStore.prefetch_to_gpu(). "
            "Pull the latest changes (or use the colab-ready snapshot) before running."
        )
else:
    raise RuntimeError("Missing ooc_moe/tiered_store.py in repo.")


## Install dependencies

Gymnasium + ALE packaging has changed over time. The most recent Gymnasium release notes indicate that with `ale-py>=0.9` the ROMs are packaged, so installing Atari should just be `pip install "gymnasium[atari]"`, and you must `import ale_py` before creating Atari envs. citeturn14search5turn14search16

The cell below installs a modern stack and runs a small sanity check.


In [None]:
# --- 3) Install dependencies (idempotent) ---
import sys, subprocess, textwrap

def pip(cmd):
    print("+ pip", cmd)
    subprocess.check_call([sys.executable, "-m", "pip"] + cmd.split())

pip("install -U pip")

# Install repo requirements
subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", str(REPO_DIR / "requirements.txt")])

# Install Atari dependencies (modern Gymnasium/ALE)
pip('install -U "gymnasium[atari]" ale-py tqdm opencv-python-headless')

# Sanity check: Gymnasium v1+ requires importing ale_py before env creation
code = r'''
import gymnasium as gym
import ale_py
gym.register_envs(ale_py)
env = gym.make("ALE/Pong-v5")
obs, info = env.reset()
env.close()
print("Atari sanity check OK")
'''
subprocess.check_call([sys.executable, "-c", code])


## Run training

- Writes logs/metrics/checkpoints into `RUN_DIR`
- Writes the disk tier (expert weights + optimizer state) into `DISK_ROOT`
- Auto-resumes if `checkpoint_last.pt` exists


In [None]:
# --- 4) Run train_ale.py (auto-resume) ---
import os, shlex, subprocess
from pathlib import Path

os.chdir(str(REPO_DIR))

ckpt = RUN_DIR / "checkpoint_last.pt"
resume_flag = "--resume" if ckpt.exists() else ""
reset_disk_flag = "" if ckpt.exists() else "--reset_disk"

LOG_PATH = RUN_DIR / "console.log"

cmd = f'''
python train_ale.py   --run_dir {shlex.quote(str(RUN_DIR))}   --disk_root {shlex.quote(str(DISK_ROOT))}   {resume_flag}   {reset_disk_flag}   --checkpoint_every 20   --writeback_policy periodic --writeback_every 20   --n_experts 512 --n_envs 512   --gpu_slots 16 --cpu_cache 64   --games Pong,Breakout,SpaceInvaders,Seaquest   --env_backend ale_vec --vec_envs_per_game 8   --batch_size 512 --steps 200   --lr 1e-2 --optim adamw   --prefetch --prefetch_gpu --prefetch_gpu_max 0   --io_workers 4   --sort_by_expert
'''

print("Running:\n", cmd)
bash_cmd = f"bash -lc {shlex.quote(cmd + f' |& tee -a {shlex.quote(str(LOG_PATH))}')}"
subprocess.check_call(bash_cmd, shell=True)


## Optional: quick metrics peek

This reads the last ~10 lines of `metrics.jsonl` if present.


In [None]:
from pathlib import Path
p = RUN_DIR / "metrics.jsonl"
if p.exists():
    lines = p.read_text().strip().splitlines()
    for l in lines[-10:]:
        print(l)
else:
    print("No metrics.jsonl yet.")
