In [1]:
from google.colab import drive
drive.mount("/content/drive", force_remount=False)

import os
from pathlib import Path

BASE_DIR = Path("/content/drive/MyDrive/Colab_Notebooks")
REPO_DIR = BASE_DIR / "ewc_moe_atari"
RUN_NAME = "ale_run1"  # change to ale_run2, etc.
RUN_DIR  = REPO_DIR / RUN_NAME

BASE_DIR.mkdir(parents=True, exist_ok=True)
RUN_DIR.mkdir(parents=True, exist_ok=True)

print("BASE_DIR:", BASE_DIR)
print("REPO_DIR:", REPO_DIR)
print("RUN_DIR :", RUN_DIR)


Mounted at /content/drive


In [2]:
import subprocess, sys
from pathlib import Path

REPO_URL = "https://github.com/<YOUR_ORG_OR_USER>/ewc_moe_atari.git"  # <-- set this

def run(cmd, **kwargs):
    print("+", " ".join(cmd) if isinstance(cmd, list) else cmd)
    return subprocess.check_call(cmd, **kwargs)

# Clone or pull
if not REPO_DIR.exists():
    run(["git", "clone", REPO_URL, str(REPO_DIR)])
else:
    # Make sure it's a git repo (if you copied files manually)
    git_dir = REPO_DIR / ".git"
    if git_dir.exists():
        run(["git", "-C", str(REPO_DIR), "pull", "--ff-only"])
    else:
        raise RuntimeError(f"{REPO_DIR} exists but is not a git repo. Delete it or git clone into a new folder.")

# Install deps (prefer newest stable)
run([sys.executable, "-m", "pip", "install", "-U", "pip"])

# Gymnasium notes:
# - Newer Gymnasium + ale-py packaging often includes ROMs with gymnasium[atari]. :contentReference[oaicite:1]{index=1}
# - Gymnasium v1 requires importing ale_py before gym.make for Atari envs. :contentReference[oaicite:2]{index=2}
run([sys.executable, "-m", "pip", "install", "-U",
     "gymnasium[atari]",
     "ale-py",
     "tqdm",
     "opencv-python-headless"])

# Quick sanity check for Atari env creation
code = r"""
import gymnasium as gym
import ale_py
gym.register_envs(ale_py)  # helpful, and required in some setups
env = gym.make("ALE/Pong-v5")
obs, info = env.reset()
env.close()
print("Atari sanity check OK")
"""
run([sys.executable, "-c", code])


In [None]:
import os, shlex, subprocess
from pathlib import Path

os.chdir(str(REPO_DIR))

ckpt = RUN_DIR / "checkpoint_last.pt"
resume_flag = "--resume" if ckpt.exists() else ""
reset_disk_flag = "" if ckpt.exists() else "--reset_disk"

cmd = f"""
python train_ale.py \
  --run_dir {shlex.quote(str(RUN_DIR))} \
  {resume_flag} \
  {reset_disk_flag} \
  --checkpoint_every 20 \
  --writeback_policy periodic --writeback_every 20 \
  --n_experts 512 --n_envs 512 \
  --gpu_slots 16 --cpu_cache 64 \
  --games Pong,Breakout,SpaceInvaders,Seaquest \
  --env_backend ale_vec --vec_envs_per_game 8 \
  --batch_size 512 --steps 200 \
  --lr 1e-2 --optim adamw \
  --prefetch --prefetch_gpu --prefetch_gpu_max 0 \
  --io_workers 4 \
  --sort_by_expert
"""

log_path = RUN_DIR / "console.log"
bash_cmd = f"bash -lc {shlex.quote(cmd + f' |& tee -a {shlex.quote(str(log_path))}')}"

print("Running command:\n", cmd)
subprocess.check_call(bash_cmd, shell=True)
