# Out-of-Core-ish MoE + EWC Multi-Atari: Day/Night Training (Gymnasium)

This notebook runs a **biologically inspired** loop:

- **Day:** play *K* Atari games sequentially (highly correlated stream). A **shared GRU router** infers context and gates a sparse set of experts.
- **Sleep:** shuffle replay from **only those games encountered today** (more IID-like gradients), do SGD updates.
- **Retention:** **EWC** protects (router + optionally encoder + per-game flagged experts).
- **Out-of-core mechanism:** a real `ExpertStore` pages experts across **GPU (HBM)** ↔ **CPU (DRAM)** ↔ **disk (NVMe)**.

**Important for persistence:** if you mount Google Drive, the run directory (metrics + expert shards) survives Colab restarts.


In [1]:
# Install deps (Colab)
# NOTE: We intentionally do NOT `pip install -U torch` in Colab because it can accidentally
# downgrade/replace the preinstalled CUDA build.
!pip -q install -U   'gymnasium[atari,accept-rom-license]'   opencv-python   pandas   matplotlib   tensorboard

import gymnasium
print('Gymnasium version:', gymnasium.__version__)


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.2/91.2 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.8/52.8 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m72.9/72.9 MB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.4/12.4 MB[0m [31m133.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.7/8.7 MB[0m [31m121.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.5/5.5 MB[0m [31m135.2 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-colab 1.0.0 requires pandas==2.2.2, but you have pandas 2.3.3 which is incompatible.
tensorf

In [2]:
# (Optional but recommended) Mount Google Drive for persistence
import os, time

RUN_BASE = None

try:
    from google.colab import drive  # type: ignore
    drive.mount('/content/drive')
    if os.path.exists('/content/drive/MyDrive'):
        RUN_BASE = '/content/drive/MyDrive/ewc_moe_atari_runs'
except Exception as e:
    print('Drive mount skipped / not in Colab:', e)

if RUN_BASE is None:
    # Ephemeral (will be lost if runtime resets)
    RUN_BASE = '/content/ewc_moe_atari_runs'

run_id = time.strftime('%Y%m%d_%H%M%S')
RUN_DIR = os.path.join(RUN_BASE, run_id)
os.makedirs(RUN_DIR, exist_ok=True)

print('RUN_DIR:', RUN_DIR)


Mounted at /content/drive
RUN_DIR: /content/drive/MyDrive/ewc_moe_atari_runs/20260120_130029


In [3]:
# Make repo imports work regardless of Colab's current working directory.
# We locate the project root by searching for `src/envs/atari.py`.
import os
import sys
from pathlib import Path


def _find_project_root() -> Path:
    cwd = Path.cwd().resolve()

    # 1) Check cwd and parents
    for p in [cwd, *cwd.parents]:
        if (p / 'src' / 'envs' / 'atari.py').is_file():
            return p

    # 2) Common unzip location
    common = Path('/content/ewc_moe_atari_colab').resolve()
    if (common / 'src' / 'envs' / 'atari.py').is_file():
        return common

    # 3) Shallow BFS under /content (depth-limited)
    base = Path('/content').resolve()
    if base.exists():
        queue = [(base, 0)]
        while queue:
            node, depth = queue.pop(0)
            if (node / 'src' / 'envs' / 'atari.py').is_file():
                return node
            if depth < 4:
                try:
                    for child in node.iterdir():
                        if child.is_dir() and child.name not in ('__pycache__', '.ipynb_checkpoints'):
                            queue.append((child, depth + 1))
                except Exception:
                    pass

    # Fallback
    return cwd


PROJECT_DIR = _find_project_root()
os.chdir(PROJECT_DIR)

SRC_DIR = str(PROJECT_DIR / 'src')
if SRC_DIR not in sys.path:
    sys.path.insert(0, SRC_DIR)

print('PROJECT_DIR:', PROJECT_DIR)
print('SRC_DIR:', SRC_DIR)


PROJECT_DIR: /content/ewc_moe_atari
SRC_DIR: /content/ewc_moe_atari/src


In [4]:
# Quick environment smoke test
from envs import make_atari_env

env = make_atari_env('ALE/Breakout-v5', seed=0, frame_stack=4, clip_rewards=True, full_action_space=True)
obs, info = env.reset()
print('obs shape:', obs.shape, 'dtype:', obs.dtype, 'action_space:', env.action_space)

for _ in range(5):
    obs, rew, terminated, truncated, info = env.step(env.action_space.sample())
    if terminated or truncated:
        obs, info = env.reset()

env.close()
print('ok')


obs shape: (4, 84, 84) dtype: uint8 action_space: Discrete(18)
ok


In [5]:
# Run a small multi-day experiment
from config import Config
from training import DayNightTrainer
from logging_utils import RunLogger

# Pick a suite of games. (You can expand this list.)
GAMES = [
    'ALE/Breakout-v5',
    'ALE/SpaceInvaders-v5',
    'ALE/Pong-v5',
    'ALE/Seaquest-v5',
    'ALE/Qbert-v5',
    'ALE/BeamRider-v5',
]

cfg = Config()

# You asked for longer days (more games). Tune these up as your runtime allows.
cfg.games_per_day = 8           # longer days
cfg.day_steps_per_game = 4000
cfg.sleep_updates_per_game = 400

# Make the expert suite bigger to force paging.
cfg.num_experts = 256           # make it “big enough” to matter
cfg.expert_top_k = 2

# Tight budgets so we see real HBM/DRAM/NVMe behavior.
cfg.hbm_expert_capacity = 4     # force HBM pressure
cfg.dram_expert_capacity = 16   # force DRAM pressure
cfg.enable_nvme_tier = True     # enable disk tier


# For faster iteration in Colab:
cfg.batch_size = 16
cfg.seq_len = 8

print(cfg)

logger = RunLogger(RUN_DIR, config=cfg)
trainer = DayNightTrainer(GAMES, cfg, seed=0, run_dir=RUN_DIR)

NUM_DAYS = 3

for day in range(NUM_DAYS):
    print(f'=== DAY {day} ===')
    out = trainer.run_one_day(day)

    # Human-readable quick summary
    print('games_today:', out['games_today'])
    for g in out['games_today']:
        print(' ', g,
              'n_eps', out['n_episodes'][g],
              'last_return', out['episode_return_last'][g],
              'mean_return', out['episode_return_mean'][g],
              'hbm_hit_rate', out['day_cache'][g]['hbm_hit_rate'],
              'nvme_reads', out['day_cache'][g]['nvme_reads'])

    logger.log(day, out)

logger.close()
print('Done. Logs saved in:', RUN_DIR)


Config(frame_stack=4, reward_clip=True, num_experts=256, router_hidden_dim=128, expert_hidden_dim=256, feature_dim=512, expert_top_k=2, hbm_expert_capacity=4, dram_expert_capacity=16, enable_nvme_tier=True, pin_cpu_memory=True, gamma=0.99, learning_rate=0.0001, batch_size=16, seq_len=8, epsilon_start=1.0, epsilon_end=0.1, epsilon_decay_steps=50000, target_update_interval=1000, games_per_day=8, day_steps_per_game=4000, sleep_updates_per_game=400, salience_alpha=0.6, td_error_weight=1.0, policy_surprisal_weight=0.2, softmax_temp_for_surprisal=1.0, ewc_lambda=0.4, fisher_batches=25, top_experts_per_game=4, protect_encoder=False, protect_experts=True, log_every_sleep_steps=50)
=== DAY 0 ===
games_today: ['ALE/Seaquest-v5', 'ALE/BeamRider-v5', 'ALE/Breakout-v5', 'ALE/SpaceInvaders-v5', 'ALE/Pong-v5', 'ALE/Qbert-v5']
  ALE/Seaquest-v5 n_eps 6 last_return 2.0 mean_return 4.5 hbm_hit_rate 0.9998125 nvme_reads 3
  ALE/BeamRider-v5 n_eps 2 last_return 8.0 mean_return 11.5 hbm_hit_rate 1.0 nvme_r

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Load metrics + plot
import os
import matplotlib.pyplot as plt

from viz.metrics import load_jsonl, metrics_to_frame

records = load_jsonl(os.path.join(RUN_DIR, 'metrics.jsonl'))
df = metrics_to_frame(records)
print('df shape:', df.shape)
df.head()


In [None]:
# Plot episode return (mean) per game across days
import numpy as np

# Columns are like: episode_return_mean/ALE/Breakout-v5
cols = [c for c in df.columns if c.startswith('episode_return_mean/')]

plt.figure()
for c in cols:
    plt.plot(df['_step'], df[c], marker='o', label=c.replace('episode_return_mean/', ''))
plt.xlabel('day')
plt.ylabel('mean episode return (day)')
plt.legend()
plt.show()


In [None]:
# Plot expert-store behavior: HBM hit rate + NVMe reads
cache_cols = [c for c in df.columns if c.startswith('day_cache/') and c.endswith('/hbm_hit_rate')]
read_cols = [c for c in df.columns if c.startswith('day_cache/') and c.endswith('/nvme_reads')]

plt.figure()
for c in cache_cols:
    plt.plot(df['_step'], df[c], marker='o', label=c.replace('day_cache/', '').replace('/hbm_hit_rate',''))
plt.xlabel('day')
plt.ylabel('HBM hit rate (day)')
plt.legend()
plt.show()

plt.figure()
for c in read_cols:
    plt.plot(df['_step'], df[c], marker='o', label=c.replace('day_cache/', '').replace('/nvme_reads',''))
plt.xlabel('day')
plt.ylabel('NVMe reads (count, day)')
plt.legend()
plt.show()


In [None]:
# TensorBoard (optional)
%load_ext tensorboard
%tensorboard --logdir {os.path.join(RUN_DIR, 'tb')}


In [None]:
# after creating trainer
print("device:", trainer.device)
print("model expert[0] device:", next(trainer.model.experts[0].parameters()).device)
trainer.expert_store.reset_stats()

# force a forward
g = "ALE/Breakout-v5"
env = make_atari_env(g, seed=0, frame_stack=cfg.frame_stack, clip_rewards=True, full_action_space=True)
obs, _ = env.reset()
h = trainer.model.init_hidden(1, trainer.device)
obs_t = torch.from_numpy(obs).unsqueeze(0).to(trainer.device)
_ = trainer.model(obs_t, h, expert_store=trainer.expert_store, top_k=cfg.expert_top_k)

s = trainer.expert_store.reset_stats()
print("hbm_hits:", s.hbm_hits, "hbm_misses:", s.hbm_misses, "hit_rate:", s.hit_rate())
env.close()
