# EWC + MoE Router (RNN) Multi-Atari: Day/Night Training (Gymnasium)

This notebook is a **Colab-friendly scaffold** to test a hypothesis:

- You can train in a correlated **online** stream ("day") but do most SGD updates during a shuffled **replay** phase ("sleep").
- Long-term retention is handled via **Elastic Weight Consolidation (EWC)** on shared parameters (router / trunk and optionally per-game experts).
- Capacity and out-of-core motivation: a **shared recurrent router** gates a set of **experts** (MoE), which you can later swap across HBM/DRAM/NVMe.

**Practical constraints handled here:**
- Gymnasium (modern API)
- Atari preprocessing via `AtariPreprocessing`
- Frame stacking via `FrameStackObservation` (FrameStack is deprecated in Gymnasium v1+)
- Unified action space using `full_action_space=True` (Discrete(18) across games)
- Reward clipping to `[-1, 1]`

> This is a *smoke test* configuration by default (few thousand steps). Atari at scale is expensive.


In [None]:
# Install deps (Colab)
!pip -q install -U 'gymnasium[atari]' ale-py opencv-python 'torch>=2.1.0' 'torchvision>=0.16.0'

import gymnasium
print('Gymnasium version:', gymnasium.__version__)


In [None]:
# Add this repository's src/ to PYTHONPATH
import os, sys

# If you uploaded the zip to Colab and unzipped it, set PROJECT_DIR accordingly.
# For example:
#   PROJECT_DIR = '/content/ewc_moe_atari_colab'
# Here we assume the notebook is in the project root.
PROJECT_DIR = os.getcwd()
SRC_DIR = os.path.join(PROJECT_DIR, 'src')

# Check if we are potentially in the wrong directory
if not os.path.exists(SRC_DIR):
    print(f"Warning: 'src' directory not found at {SRC_DIR}.")
    print("Please ensure you are running this notebook from the project root or set PROJECT_DIR correctly.")
else:
    if SRC_DIR not in sys.path:
        sys.path.append(SRC_DIR)
    print('PROJECT_DIR:', PROJECT_DIR)
    print('SRC_DIR:', SRC_DIR)
    print('Added src to sys.path')


In [None]:
# Quick env smoke test
import gymnasium as gym

from envs.atari import make_atari_env

env_id = 'ALE/Pong-v5'

env = make_atari_env(env_id, seed=0, frame_stack=4, clip_rewards=True, full_action_space=True)
obs, info = env.reset()
print('obs shape:', obs.shape, 'dtype:', obs.dtype)
print('action space:', env.action_space)

obs, r, terminated, truncated, info = env.step(env.action_space.sample())
print('step -> reward:', r, 'done:', terminated or truncated)

env.close()


In [None]:
# Run a small Day/Night experiment on multiple Atari games
import torch

from config import Config
from training.day_night import DayNightTrainer

GAMES = [
    'ALE/Pong-v5',
    'ALE/Breakout-v5',
    'ALE/SpaceInvaders-v5',
]

cfg = Config(
    num_experts=8,
    games_per_day=2,
    day_steps_per_game=1500,          # keep small for smoke test
    sleep_updates_per_game=100,       # keep small for smoke test
    batch_size=16,
    seq_len=8,
    ewc_lambda=0.2,
    top_experts_per_game=3,
    protect_experts=True,
    protect_encoder=False,
)

trainer = DayNightTrainer(GAMES, cfg, seed=0)

for day in range(2):
    result = trainer.run_one_day(day)
    print('
=== DAY', day, '===')
    print('games_today:', result['games_today'])
    print('cache_hit_rate:', result['cache_hit_rate'])
    print('flagged_experts:', result['flagged_experts'])
    print('episode_returns (per-game):')
    for g, rets in result['episode_returns'].items():
        print(' ', g, 'n_eps', len(rets), 'last_return', (rets[-1] if rets else None))
    print('sleep:', result['sleep'])
    print('ewc_tasks:', result['ewc_tasks'])

print('
Done.')
