# Out-of-Core-ish MoE + EWC Multi-Atari: Day/Night Training (Gymnasium)

This notebook runs a **biologically inspired** loop:

- **Day:** play *K* Atari games sequentially (highly correlated stream). A **shared GRU router** infers context and gates a sparse set of experts.
- **Sleep:** shuffle replay from **only those games encountered today** (more IID-like gradients), do SGD updates.
- **Retention:** **EWC** protects (router + optionally encoder + per-game flagged experts).
- **Out-of-core mechanism:** a real `ExpertStore` pages experts across **GPU (HBM)** ↔ **CPU (DRAM)** ↔ **disk (NVMe)**.

**Important for persistence:** if you mount Google Drive, the run directory (metrics + expert shards) survives Colab restarts.


In [None]:
# Install deps (Colab)
# NOTE: We intentionally do NOT `pip install -U torch` in Colab because it can accidentally
# downgrade/replace the preinstalled CUDA build.
!pip -q install -U   'gymnasium[atari,accept-rom-license]'   opencv-python   pandas   matplotlib   tensorboard

import gymnasium
print('Gymnasium version:', gymnasium.__version__)


In [None]:
# (Optional but recommended) Mount Google Drive for persistence
import os, time

RUN_BASE = None

try:
    from google.colab import drive  # type: ignore
    drive.mount('/content/drive')
    if os.path.exists('/content/drive/MyDrive'):
        RUN_BASE = '/content/drive/MyDrive/ewc_moe_atari_runs'
except Exception as e:
    print('Drive mount skipped / not in Colab:', e)

if RUN_BASE is None:
    # Ephemeral (will be lost if runtime resets)
    RUN_BASE = '/content/ewc_moe_atari_runs'

run_id = time.strftime('%Y%m%d_%H%M%S')
RUN_DIR = os.path.join(RUN_BASE, run_id)
os.makedirs(RUN_DIR, exist_ok=True)

print('RUN_DIR:', RUN_DIR)


In [None]:
# Make repo imports work regardless of Colab's current working directory.
# We locate the project root by searching for `src/envs/atari.py`.
import os
import sys
from pathlib import Path


def _find_project_root() -> Path:
    cwd = Path.cwd().resolve()

    # 1) Check cwd and parents
    for p in [cwd, *cwd.parents]:
        if (p / 'src' / 'envs' / 'atari.py').is_file():
            return p

    # 2) Common unzip location
    common = Path('/content/ewc_moe_atari_colab').resolve()
    if (common / 'src' / 'envs' / 'atari.py').is_file():
        return common

    # 3) Shallow BFS under /content (depth-limited)
    base = Path('/content').resolve()
    if base.exists():
        queue = [(base, 0)]
        while queue:
            node, depth = queue.pop(0)
            if (node / 'src' / 'envs' / 'atari.py').is_file():
                return node
            if depth < 4:
                try:
                    for child in node.iterdir():
                        if child.is_dir() and child.name not in ('__pycache__', '.ipynb_checkpoints'):
                            queue.append((child, depth + 1))
                except Exception:
                    pass

    # Fallback
    return cwd


PROJECT_DIR = _find_project_root()
os.chdir(PROJECT_DIR)

SRC_DIR = str(PROJECT_DIR / 'src')
if SRC_DIR not in sys.path:
    sys.path.insert(0, SRC_DIR)

print('PROJECT_DIR:', PROJECT_DIR)
print('SRC_DIR:', SRC_DIR)


In [None]:
# Quick environment smoke test
from envs import make_atari_env

env = make_atari_env('ALE/Breakout-v5', seed=0, frame_stack=4, clip_rewards=True, full_action_space=True)
obs, info = env.reset()
print('obs shape:', obs.shape, 'dtype:', obs.dtype, 'action_space:', env.action_space)

for _ in range(5):
    obs, rew, terminated, truncated, info = env.step(env.action_space.sample())
    if terminated or truncated:
        obs, info = env.reset()

env.close()
print('ok')


In [None]:
# Run a small multi-day experiment
from config import Config
from training import DayNightTrainer
from logging_utils import RunLogger

# Pick a suite of games. (You can expand this list.)
GAMES = [
    'ALE/Breakout-v5',
    'ALE/SpaceInvaders-v5',
    'ALE/Pong-v5',
    'ALE/Seaquest-v5',
    'ALE/Qbert-v5',
    'ALE/BeamRider-v5',
]

cfg = Config()

# You asked for longer days (more games). Tune these up as your runtime allows.
cfg.games_per_day = 5
cfg.day_steps_per_game = 2500
cfg.sleep_updates_per_game = 250

# Make the expert suite bigger to force paging.
cfg.num_experts = 64
cfg.expert_top_k = 2

# Tight budgets so we see real HBM/DRAM/NVMe behavior.
cfg.hbm_expert_capacity = 4
cfg.dram_expert_capacity = 12
cfg.enable_nvme_tier = True

# For faster iteration in Colab:
cfg.batch_size = 16
cfg.seq_len = 8

print(cfg)

logger = RunLogger(RUN_DIR, config=cfg)
trainer = DayNightTrainer(GAMES, cfg, seed=0, run_dir=RUN_DIR)

NUM_DAYS = 3

for day in range(NUM_DAYS):
    print(f'=== DAY {day} ===')
    out = trainer.run_one_day(day)

    # Human-readable quick summary
    print('games_today:', out['games_today'])
    for g in out['games_today']:
        print(' ', g,
              'n_eps', out['n_episodes'][g],
              'last_return', out['episode_return_last'][g],
              'mean_return', out['episode_return_mean'][g],
              'hbm_hit_rate', out['day_cache'][g]['hbm_hit_rate'],
              'nvme_reads', out['day_cache'][g]['nvme_reads'])

    logger.log(day, out)

logger.close()
print('Done. Logs saved in:', RUN_DIR)


In [None]:
# Load metrics + plot
import os
import matplotlib.pyplot as plt

from viz.metrics import load_jsonl, metrics_to_frame

records = load_jsonl(os.path.join(RUN_DIR, 'metrics.jsonl'))
df = metrics_to_frame(records)
print('df shape:', df.shape)
df.head()


In [None]:
# Plot episode return (mean) per game across days
import numpy as np

# Columns are like: episode_return_mean/ALE/Breakout-v5
cols = [c for c in df.columns if c.startswith('episode_return_mean/')]

plt.figure()
for c in cols:
    plt.plot(df['_step'], df[c], marker='o', label=c.replace('episode_return_mean/', ''))
plt.xlabel('day')
plt.ylabel('mean episode return (day)')
plt.legend()
plt.show()


In [None]:
# Plot expert-store behavior: HBM hit rate + NVMe reads
cache_cols = [c for c in df.columns if c.startswith('day_cache/') and c.endswith('/hbm_hit_rate')]
read_cols = [c for c in df.columns if c.startswith('day_cache/') and c.endswith('/nvme_reads')]

plt.figure()
for c in cache_cols:
    plt.plot(df['_step'], df[c], marker='o', label=c.replace('day_cache/', '').replace('/hbm_hit_rate',''))
plt.xlabel('day')
plt.ylabel('HBM hit rate (day)')
plt.legend()
plt.show()

plt.figure()
for c in read_cols:
    plt.plot(df['_step'], df[c], marker='o', label=c.replace('day_cache/', '').replace('/nvme_reads',''))
plt.xlabel('day')
plt.ylabel('NVMe reads (count, day)')
plt.legend()
plt.show()


In [None]:
# TensorBoard (optional)
%load_ext tensorboard
%tensorboard --logdir {os.path.join(RUN_DIR, 'tb')}
