# Out-of-Core MOE for Multi-Task RL

Training notebook for the ooc_moe project. Clones from GitHub for easy iteration.

In [None]:
# Install dependencies
!pip install -q torch numpy gymnasium opencv-python matplotlib

In [None]:
# Clone repo from GitHub (edit REPO_URL to match your repository)
import os

REPO_URL = "https://github.com/RespectableGlioma/ooc_moe.git"  # <-- Change this!
BRANCH = "main"  # Change if using a different branch

if not os.path.exists('ooc_moe'):
    print(f'Cloning {REPO_URL}...')
    !git clone --branch {BRANCH} --depth 1 {REPO_URL}
    print('Done!')
else:
    print('Repo exists, pulling latest changes...')
    !cd ooc_moe && git pull
    print('Updated!')

Cloning https://github.com/YOUR_USERNAME/ooc_moe.git...
Cloning into 'ooc_moe'...
fatal: could not read Username for 'https://github.com': No such device or address
Done!


: 

In [None]:
# Setup imports
import sys
import os

# Add repo root to path (where ooc_moe package lives)
REPO_PATH = '/content/ooc_moe'
if os.path.exists(REPO_PATH):
    sys.path.insert(0, REPO_PATH)
else:
    sys.path.insert(0, '.')  # Fallback for local dev

import torch
import torch.nn.functional as F
import numpy as np
from collections import defaultdict, deque
import matplotlib.pyplot as plt
import time

from ooc_moe.models.moe_agent import MoERLAgent, MoERLAgentConfig
from ooc_moe.envs.atari_wrappers import create_dummy_envs
from ooc_moe.core.env_detector import EnvironmentDetectorTrainer

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'PyTorch {torch.__version__} on {device}')
if device == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name(0)}')

: 

In [None]:
# Configuration
CONFIG = {
    'num_experts': 64,
    'expert_dim': 128,
    'expert_hidden_dim': 256,
    'num_layers': 2,
    'num_heads': 4,
    'top_k': 2,
    'context_len': 8,
    'hbm_capacity': 16,
    'dram_capacity': 32,
    'num_games': 5,
    'steps_per_game': 1000,
    'lr': 1e-4,
    # Specialization parameters
    'aux_loss_weight': 0.001,  # Reduced from 0.01 to allow specialization
    'detector_lr': 1e-3,       # Learning rate for env detector
    'detector_train_freq': 10, # Train detector every N steps
}
print('Config:', CONFIG)

: 

In [None]:
# Create agent, environments, and detector trainer
config = MoERLAgentConfig(
    obs_shape=(1, 84, 84),
    num_actions=18,
    num_envs=CONFIG['num_games'],
    frame_stack=4,
    num_experts=CONFIG['num_experts'],
    expert_dim=CONFIG['expert_dim'],
    expert_hidden_dim=CONFIG['expert_hidden_dim'],
    num_layers=CONFIG['num_layers'],
    num_heads=CONFIG['num_heads'],
    top_k=CONFIG['top_k'],
    context_len=CONFIG['context_len'],
    hbm_capacity=CONFIG['hbm_capacity'],
    dram_capacity=CONFIG['dram_capacity'],
)

agent = config.create_agent(device)
envs, env_names = create_dummy_envs(CONFIG['num_games'])
optimizer = torch.optim.Adam(agent.parameters(), lr=CONFIG['lr'])

# Create detector trainer for learning game->expert mappings
detector_trainer = EnvironmentDetectorTrainer(
    detector=agent.env_detector,
    lr=CONFIG['detector_lr'],
    env_loss_weight=1.0,
    expert_loss_weight=1.0,
)

print(f'Agent created with {config.estimate_parameter_count()["total"]:,} params')
print(f'Environments: {env_names}')
print(f'Detector trainer ready (trains every {CONFIG["detector_train_freq"]} steps)')

: 

In [None]:
# Training loop with detector training for specialization
results = {
    'cache_history': [],
    'expert_usage': defaultdict(lambda: defaultdict(int)),
    'rewards': defaultdict(list),
    'losses': [],
    'detector_losses': [],
    'prefetch_accuracy': [],
}

start_time = time.time()
global_step = 0

for game_id, (env, name) in enumerate(zip(envs, env_names)):
    print(f'\nTraining on {name} ({game_id+1}/{len(envs)})...')

    obs, _ = env.reset()
    obs_buffer = deque(
        [torch.from_numpy(obs.astype(np.float32) / 255.0) for _ in range(config.context_len)],
        maxlen=config.context_len
    )

    episode_reward = 0
    episode_count = 0

    for step in range(CONFIG['steps_per_game']):
        global_step += 1

        # Build context
        context = torch.stack(list(obs_buffer), dim=0).unsqueeze(0).to(device)

        # Forward pass
        agent.train()
        output = agent(context, env_id=game_id, prefetch=True)

        # Track metrics
        for eid in output.expert_ids:
            results['expert_usage'][game_id][eid] += 1
        results['cache_history'].append(output.cache_stats['hit_rate'])

        # Store sample for detector training
        last_obs = context[:, -1].detach()  # [1, C*frame_stack, H, W]
        detector_trainer.store_sample(
            obs=last_obs.squeeze(0),  # [C*frame_stack, H, W]
            env_id=game_id,
            accessed_experts=output.expert_ids,
        )

        # Train detector periodically
        if global_step % CONFIG['detector_train_freq'] == 0:
            det_losses = detector_trainer.train_step(batch_size=32)
            results['detector_losses'].append(det_losses['total_loss'])

            # Track prefetch accuracy occasionally
            if global_step % (CONFIG['detector_train_freq'] * 10) == 0:
                accuracy = detector_trainer.compute_prefetch_accuracy(recent_n=50)
                results['prefetch_accuracy'].append(accuracy['f1'])

        # Sample action
        probs = F.softmax(output.action_logits, dim=-1)
        action = torch.multinomial(probs, 1).item()

        # Environment step
        next_obs, reward, done, _, _ = env.step(action)
        episode_reward += reward

        # Policy gradient update with REDUCED aux loss for specialization
        log_prob = torch.log(probs[0, action] + 1e-8)
        loss = -log_prob * reward + CONFIG['aux_loss_weight'] * output.aux_loss

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(agent.parameters(), 0.5)
        optimizer.step()

        results['losses'].append(loss.item())

        # Update observation buffer
        obs_buffer.append(torch.from_numpy(next_obs.astype(np.float32) / 255.0))

        if done:
            results['rewards'][game_id].append(episode_reward)
            episode_reward = 0
            episode_count += 1
            obs, _ = env.reset()
            obs_buffer = deque(
                [torch.from_numpy(obs.astype(np.float32) / 255.0) for _ in range(config.context_len)],
                maxlen=config.context_len
            )

    # Log progress
    recent_cache = results['cache_history'][-CONFIG['steps_per_game']:]
    avg_reward = np.mean(results['rewards'][game_id]) if results['rewards'][game_id] else 0

    # Compute specialization metric for this game
    usage = results['expert_usage'][game_id]
    total_usage = sum(usage.values())
    top_expert_pct = max(usage.values()) / total_usage if total_usage > 0 else 0

    print(f'  Episodes: {episode_count}, Avg reward: {avg_reward:.2f}')
    print(f'  Cache hit rate: {np.mean(recent_cache):.2%}')
    print(f'  Unique experts: {len(usage)}, Top expert: {top_expert_pct:.1%}')
    if results['prefetch_accuracy']:
        print(f'  Detector prefetch F1: {results["prefetch_accuracy"][-1]:.2%}')

elapsed = time.time() - start_time
print(f'\nTraining complete in {elapsed:.1f}s')

# Cleanup
agent.expert_store.shutdown()

: 

In [None]:
# Results visualization
fig, axes = plt.subplots(2, 3, figsize=(18, 10))

# Cache hit rate
ax = axes[0, 0]
ax.plot(results['cache_history'], alpha=0.3, label='Raw')
window = 100
smoothed = np.convolve(results['cache_history'], np.ones(window)/window, mode='valid')
ax.plot(range(window-1, len(results['cache_history'])), smoothed, 'r-', lw=2, label=f'Smoothed (w={window})')
ax.axhline(np.mean(results['cache_history']), color='g', linestyle='--', label=f'Mean: {np.mean(results["cache_history"]):.2%}')
for i in range(1, CONFIG['num_games']):
    ax.axvline(i * CONFIG['steps_per_game'], color='gray', linestyle=':', alpha=0.5)
ax.set_xlabel('Step')
ax.set_ylabel('Cache Hit Rate')
ax.set_title('Cache Performance Over Training')
ax.legend()

# Expert usage heatmap (should show distinct bands with specialization)
ax = axes[0, 1]
usage_matrix = np.zeros((CONFIG['num_games'], CONFIG['num_experts']))
for gid, usage in results['expert_usage'].items():
    for eid, count in usage.items():
        if eid < CONFIG['num_experts']:
            usage_matrix[gid, eid] = count
usage_matrix = usage_matrix / (usage_matrix.sum(axis=1, keepdims=True) + 1e-8)
im = ax.imshow(usage_matrix, aspect='auto', cmap='hot')
ax.set_xlabel('Expert ID')
ax.set_ylabel('Game ID')
ax.set_title('Expert Usage by Game (Should Show Distinct Bands)')
plt.colorbar(im, ax=ax)

# Episode rewards
ax = axes[0, 2]
for gid, rewards in results['rewards'].items():
    if rewards:
        ax.plot(rewards, label=f'Game {gid}', alpha=0.7)
ax.set_xlabel('Episode')
ax.set_ylabel('Reward')
ax.set_title('Episode Rewards')
ax.legend()

# Expert overlap matrix
ax = axes[1, 0]
def get_top_k(usage_dict, k=10):
    return set(e for e, _ in sorted(usage_dict.items(), key=lambda x: x[1], reverse=True)[:k])

overlap_matrix = np.zeros((CONFIG['num_games'], CONFIG['num_games']))
for i in range(CONFIG['num_games']):
    for j in range(CONFIG['num_games']):
        top_i = get_top_k(results['expert_usage'][i])
        top_j = get_top_k(results['expert_usage'][j])
        overlap_matrix[i, j] = len(top_i & top_j) / 10

im = ax.imshow(overlap_matrix, cmap='Blues', vmin=0, vmax=1)
ax.set_xlabel('Game')
ax.set_ylabel('Game')
ax.set_title('Expert Overlap (Top 10) - Lower = Better Specialization')
for i in range(CONFIG['num_games']):
    for j in range(CONFIG['num_games']):
        ax.text(j, i, f'{overlap_matrix[i,j]:.0%}', ha='center', va='center')
plt.colorbar(im, ax=ax)

# Detector loss
ax = axes[1, 1]
if results['detector_losses']:
    ax.plot(results['detector_losses'], alpha=0.7)
    if len(results['detector_losses']) > 50:
        window = 50
        smoothed = np.convolve(results['detector_losses'], np.ones(window)/window, mode='valid')
        ax.plot(range(window-1, len(results['detector_losses'])), smoothed, 'r-', lw=2, label='Smoothed')
        ax.legend()
ax.set_xlabel('Training Step (detector)')
ax.set_ylabel('Loss')
ax.set_title('Detector Training Loss')

# Prefetch accuracy
ax = axes[1, 2]
if results['prefetch_accuracy']:
    ax.plot(results['prefetch_accuracy'], 'b-o', markersize=4)
    ax.axhline(np.mean(results['prefetch_accuracy']), color='r', linestyle='--', 
               label=f'Mean: {np.mean(results["prefetch_accuracy"]):.2%}')
    ax.legend()
ax.set_xlabel('Checkpoint')
ax.set_ylabel('Prefetch F1 Score')
ax.set_title('Detector Prefetch Accuracy')
ax.set_ylim(0, 1)

plt.tight_layout()
plt.savefig('training_results.png', dpi=150)
plt.show()

: 

In [None]:
# Summary statistics
print('=' * 60)
print('TRAINING SUMMARY')
print('=' * 60)

print(f'\nOverall cache hit rate: {np.mean(results["cache_history"]):.2%}')
print(f'Final cache hit rate: {np.mean(results["cache_history"][-500:]):.2%}')

# Detector performance
if results['prefetch_accuracy']:
    print(f'\nDetector prefetch F1 (final): {results["prefetch_accuracy"][-1]:.2%}')
    print(f'Detector prefetch F1 (mean): {np.mean(results["prefetch_accuracy"]):.2%}')

print('\nPer-game statistics:')
for gid in range(CONFIG['num_games']):
    usage = results['expert_usage'][gid]
    rewards = results['rewards'][gid]
    top5 = sorted(usage.items(), key=lambda x: x[1], reverse=True)[:5]
    total = sum(usage.values())
    print(f'\n  Game {gid}:')
    print(f'    Unique experts: {len(usage)}')
    print(f'    Top 5 experts: {[(e, f"{c/total:.1%}") for e, c in top5]}')
    print(f'    Avg episode reward: {np.mean(rewards) if rewards else 0:.2f}')

print('\nExpert specialization (low overlap = good specialization):')
overlaps = []
for i in range(CONFIG['num_games']):
    for j in range(i+1, CONFIG['num_games']):
        top_i = get_top_k(results['expert_usage'][i])
        top_j = get_top_k(results['expert_usage'][j])
        overlap = len(top_i & top_j)
        overlaps.append(overlap)
        print(f'  Game {i} vs {j}: {overlap}/10 experts overlap')
print(f'  Average overlap: {np.mean(overlaps):.1f}/10')

# Expected partitions (based on initialization)
print('\n' + '=' * 60)
print('EXPECTED EXPERT PARTITIONS (from env_expert_bias initialization)')
print('=' * 60)
experts_per_env = CONFIG['num_experts'] // CONFIG['num_games']
for gid in range(CONFIG['num_games']):
    start = gid * experts_per_env
    end = start + experts_per_env if gid < CONFIG['num_games'] - 1 else CONFIG['num_experts']
    print(f'  Game {gid}: Experts {start}-{end-1}')
print('\nIf specialization is working, each game should predominantly use its partition!')

: 

In [None]:
# Save results
import json
from google.colab import files

save_data = {
    'config': CONFIG,
    'cache_history': results['cache_history'],
    'expert_usage': {int(k): dict(v) for k, v in results['expert_usage'].items()},
    'rewards': {int(k): list(v) for k, v in results['rewards'].items()},
    'summary': {
        'mean_cache_hit': float(np.mean(results['cache_history'])),
        'final_cache_hit': float(np.mean(results['cache_history'][-500:])),
        'avg_overlap': float(np.mean(overlaps)),
    }
}

with open('results.json', 'w') as f:
    json.dump(save_data, f, indent=2)

print('Results saved to results.json')
files.download('results.json')
files.download('training_results.png')

: 

: 