# Out-of-Core MOE for Multi-Task RL

This notebook implements a **Mixture of Experts** architecture with **tiered memory hierarchy** for sequential multi-task Atari learning.

## Core Hypothesis

Sequential training causes catastrophic forgetting. But with MOE + environment-aware routing:

1. Different games activate different experts - natural task separation
2. Game identity is temporally correlated - perfect for caching
3. Sequential presentation becomes a feature - cold experts stay protected

## 1. Setup

In [None]:
!pip install -q torch numpy
!pip install -q gymnasium opencv-python matplotlib

In [None]:
import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')

In [None]:
import os
os.makedirs('ooc_moe/core', exist_ok=True)
os.makedirs('ooc_moe/models', exist_ok=True)
os.makedirs('ooc_moe/envs', exist_ok=True)
for f in ['', '/core', '/models', '/envs']:
    open(f'ooc_moe{f}/__init__.py', 'w').close()
print('Project structure created')

## 2. Core Components

In [None]:
%%writefile ooc_moe/core/tiered_store.py
import torch
from torch import Tensor
from typing import Dict, List, Optional, Tuple, Set
from collections import OrderedDict
from dataclasses import dataclass, field
from threading import Lock, Thread
from queue import Queue, Empty
import time
from enum import Enum

class StorageTier(Enum):
    HBM = 'hbm'
    DRAM = 'dram'
    NVME = 'nvme'

@dataclass
class CacheStats:
    hits: int = 0
    misses: int = 0
    prefetch_hits: int = 0
    evictions: int = 0
    tier_accesses: Dict[StorageTier, int] = field(default_factory=dict)
    
    @property
    def hit_rate(self) -> float:
        total = self.hits + self.misses
        return self.hits / total if total > 0 else 0.0

class LRUCache:
    def __init__(self, capacity: int):
        self.capacity = capacity
        self.cache = OrderedDict()
        self.lock = Lock()
        self.prefetched = set()
    
    def get(self, key: int):
        with self.lock:
            if key not in self.cache:
                return None
            self.cache.move_to_end(key)
            was_pf = key in self.prefetched
            self.prefetched.discard(key)
            return self.cache[key], was_pf
    
    def put(self, key: int, value, is_prefetch: bool = False):
        with self.lock:
            evicted = None
            if key in self.cache:
                self.cache.move_to_end(key)
                self.cache[key] = value
            else:
                if len(self.cache) >= self.capacity:
                    evicted, _ = self.cache.popitem(last=False)
                    self.prefetched.discard(evicted)
                self.cache[key] = value
            if is_prefetch:
                self.prefetched.add(key)
            return evicted
    
    def __contains__(self, key):
        with self.lock:
            return key in self.cache
    
    def __len__(self):
        with self.lock:
            return len(self.cache)

class TieredExpertStore:
    def __init__(self, num_experts, expert_dim, hidden_dim, hbm_capacity=32,
                 dram_capacity=128, device='cuda', simulate_latency=True):
        self.num_experts = num_experts
        self.expert_dim = expert_dim
        self.hidden_dim = hidden_dim
        self.device = device
        self.simulate_latency = simulate_latency
        self.hbm_capacity = hbm_capacity
        self.dram_capacity = dram_capacity
        
        self.hbm_cache = LRUCache(hbm_capacity)
        self.dram_cache = LRUCache(dram_capacity)
        self.nvme_store = {}
        self._init_experts()
        
        self.prefetch_queue = Queue()
        self.stop_prefetch = False
        self.stats = CacheStats()
        self.stats.tier_accesses = {t: 0 for t in StorageTier}
        
        self.prefetch_worker = Thread(target=self._prefetch_loop, daemon=True)
        self.prefetch_worker.start()
    
    def _init_experts(self):
        for i in range(self.num_experts):
            self.nvme_store[i] = {
                'w1': torch.randn(self.hidden_dim, self.expert_dim) * 0.02,
                'w2': torch.randn(self.expert_dim, self.hidden_dim) * 0.02,
                'b1': torch.zeros(self.hidden_dim),
                'b2': torch.zeros(self.expert_dim),
            }
    
    def _prefetch_loop(self):
        while not self.stop_prefetch:
            try:
                eid = self.prefetch_queue.get(timeout=0.1)
                self._promote(eid, is_prefetch=True)
            except Empty:
                continue
    
    def _promote(self, expert_id, is_prefetch=False):
        result = self.hbm_cache.get(expert_id)
        if result:
            params, was_pf = result
            self.stats.hits += 1
            if was_pf:
                self.stats.prefetch_hits += 1
            self.stats.tier_accesses[StorageTier.HBM] += 1
            return params
        
        result = self.dram_cache.get(expert_id)
        if result:
            params_cpu, was_pf = result
            if self.simulate_latency:
                time.sleep(0.001)
            params_gpu = {k: v.to(self.device) for k, v in params_cpu.items()}
            self.hbm_cache.put(expert_id, params_gpu, is_prefetch)
            self.stats.hits += 1
            self.stats.tier_accesses[StorageTier.DRAM] += 1
            return params_gpu
        
        if self.simulate_latency:
            time.sleep(0.01)
        params_cold = self.nvme_store[expert_id]
        params_gpu = {k: v.to(self.device) for k, v in params_cold.items()}
        evicted = self.hbm_cache.put(expert_id, params_gpu, is_prefetch)
        if evicted is not None:
            self.stats.evictions += 1
        self.stats.misses += 1
        self.stats.tier_accesses[StorageTier.NVME] += 1
        return params_gpu
    
    def get_experts(self, expert_ids, context_hash=None):
        return {eid: self._promote(eid) for eid in expert_ids}
    
    def prefetch(self, expert_ids):
        for eid in expert_ids:
            if eid not in self.hbm_cache:
                self.prefetch_queue.put(eid)
    
    def get_stats_summary(self):
        return {
            'hit_rate': self.stats.hit_rate,
            'total_hits': self.stats.hits,
            'total_misses': self.stats.misses,
            'prefetch_hits': self.stats.prefetch_hits,
            'evictions': self.stats.evictions,
            'hbm_occupancy': len(self.hbm_cache) / self.hbm_capacity,
        }
    
    def shutdown(self):
        self.stop_prefetch = True
        self.prefetch_worker.join(timeout=2.0)

In [None]:
%%writefile ooc_moe/core/moe_layers.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import Dict, List, Optional, Tuple, NamedTuple
from .tiered_store import TieredExpertStore

class RouterOutput(NamedTuple):
    expert_indices: Tensor
    expert_weights: Tensor
    router_logits: Tensor
    load_balancing_loss: Tensor

class ExpertRouter(nn.Module):
    def __init__(self, input_dim, num_experts, top_k=2, noise_std=0.1):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.noise_std = noise_std
        self.router = nn.Linear(input_dim, num_experts, bias=False)
        nn.init.normal_(self.router.weight, std=0.01)
    
    def forward(self, x, training=True):
        batch_size = x.shape[0]
        logits = self.router(x)
        if training and self.noise_std > 0:
            logits = logits + torch.randn_like(logits) * self.noise_std
        top_k_logits, top_k_idx = torch.topk(logits, self.top_k, dim=-1)
        weights = F.softmax(top_k_logits, dim=-1)
        
        mask = F.one_hot(top_k_idx, self.num_experts).float().sum(dim=1)
        frac = mask.sum(dim=0) / batch_size
        probs = F.softmax(logits, dim=-1).mean(dim=0)
        lb_loss = self.num_experts * (frac * probs).sum()
        return RouterOutput(top_k_idx, weights, logits, lb_loss)

class TieredMoELayer(nn.Module):
    def __init__(self, input_dim, hidden_dim, expert_store, top_k=2, dropout=0.1):
        super().__init__()
        self.expert_store = expert_store
        self.top_k = top_k
        self.router = ExpertRouter(input_dim, expert_store.num_experts, top_k)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(input_dim)
        self.last_expert_indices = None
    
    def forward(self, x, context_hash=None):
        shape = x.shape
        if len(shape) == 3:
            b, s, d = shape
            x = x.reshape(b * s, d)
        else:
            s = None
        
        x_norm = self.layer_norm(x)
        rout = self.router(x_norm, self.training)
        self.last_expert_indices = rout.expert_indices
        
        unique = rout.expert_indices.unique().tolist()
        params = self.expert_store.get_experts(unique, context_hash)
        
        output = torch.zeros_like(x_norm)
        for k in range(self.top_k):
            idx = rout.expert_indices[:, k]
            w = rout.expert_weights[:, k]
            for eid in idx.unique().tolist():
                mask = (idx == eid)
                if not mask.any():
                    continue
                p = params[eid]
                h = F.relu(x_norm[mask] @ p['w1'].T + p['b1'])
                y = h @ p['w2'].T + p['b2']
                output[mask] += w[mask].unsqueeze(-1) * y
        
        output = x + self.dropout(output)
        if s is not None:
            output = output.reshape(b, s, d)
        return output, rout.load_balancing_loss
    
    def get_routing_stats(self):
        if self.last_expert_indices is None:
            return {}
        idx = self.last_expert_indices.flatten()
        u, c = torch.unique(idx, return_counts=True)
        return {'num_unique_experts': len(u), 'expert_usage': {int(e): int(n) for e, n in zip(u, c)}}

class MoETransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, expert_store, ffn_hidden_dim, top_k=2, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
        self.attn_norm = nn.LayerNorm(dim)
        self.attn_drop = nn.Dropout(dropout)
        self.moe = TieredMoELayer(dim, ffn_hidden_dim, expert_store, top_k, dropout)
    
    def forward(self, x, attention_mask=None, context_hash=None):
        h = self.attn_norm(x)
        a, _ = self.attn(h, h, h, attn_mask=attention_mask, need_weights=False)
        x = x + self.attn_drop(a)
        x, aux = self.moe(x, context_hash)
        return x, aux
    
    def get_last_routing_stats(self):
        return self.moe.get_routing_stats()

In [None]:
%%writefile ooc_moe/core/env_detector.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, NamedTuple

class PredictionOutput(NamedTuple):
    env_logits: torch.Tensor
    expert_probs: torch.Tensor
    prefetch_set: List[int]
    predicted_env: int
    confidence: float

class EnvironmentDetector(nn.Module):
    def __init__(self, obs_shape, num_envs, num_experts, hidden_dim=256, history_len=4):
        super().__init__()
        c, h, w = obs_shape
        self.num_experts = num_experts
        self.encoder = nn.Sequential(
            nn.Conv2d(c * history_len, 32, 8, stride=4), nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1), nn.ReLU(),
            nn.Flatten(),
        )
        with torch.no_grad():
            out_size = self.encoder(torch.zeros(1, c * history_len, h, w)).shape[1]
        self.proj = nn.Sequential(nn.Linear(out_size, hidden_dim), nn.ReLU())
        self.env_head = nn.Linear(hidden_dim, num_envs)
        self.expert_head = nn.Linear(hidden_dim, num_experts)
    
    def forward(self, obs):
        h = self.proj(self.encoder(obs))
        return self.env_head(h), torch.sigmoid(self.expert_head(h))
    
    @torch.no_grad()
    def predict(self, obs, top_k=32):
        self.eval()
        env_logits, expert_probs = self.forward(obs)
        env_probs = F.softmax(env_logits, dim=-1)
        pred_env = env_probs.argmax(dim=-1).item()
        conf = env_probs.max(dim=-1).values.item()
        probs = expert_probs.squeeze(0)
        prefetch = probs.topk(min(top_k, len(probs))).indices.tolist()
        return PredictionOutput(env_logits, expert_probs, prefetch, pred_env, conf)
    
    def get_prefetch_set(self, obs, top_k=32):
        return self.predict(obs, top_k).prefetch_set

In [None]:
%%writefile ooc_moe/models/moe_agent.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Tuple, NamedTuple
import math
from ..core.tiered_store import TieredExpertStore
from ..core.moe_layers import MoETransformerBlock
from ..core.env_detector import EnvironmentDetector

class AgentOutput(NamedTuple):
    action_logits: torch.Tensor
    value: torch.Tensor
    expert_ids: List[int]
    aux_loss: torch.Tensor
    env_prediction: int
    cache_stats: Dict

class ObsEncoder(nn.Module):
    def __init__(self, obs_shape, out_dim=512, frame_stack=4):
        super().__init__()
        c, h, w = obs_shape
        self.conv = nn.Sequential(
            nn.Conv2d(c * frame_stack, 32, 8, stride=4), nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1), nn.ReLU(),
            nn.Flatten(),
        )
        with torch.no_grad():
            conv_out = self.conv(torch.zeros(1, c * frame_stack, h, w)).shape[1]
        self.proj = nn.Sequential(nn.Linear(conv_out, out_dim), nn.ReLU(), nn.Linear(out_dim, out_dim))
    
    def forward(self, obs):
        return self.proj(self.conv(obs))

class PosEncoding(nn.Module):
    def __init__(self, dim, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, dim)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class MoERLAgent(nn.Module):
    def __init__(self, obs_shape, num_actions, num_experts, expert_dim=512, expert_hidden=2048,
                 num_layers=6, num_heads=8, top_k=2, context_len=32, frame_stack=4, num_envs=57,
                 hbm_cap=32, dram_cap=128, dropout=0.1, device='cuda'):
        super().__init__()
        self.context_len = context_len
        self.device = device
        
        self.expert_store = TieredExpertStore(
            num_experts, expert_dim, expert_hidden, hbm_cap, dram_cap, device, True)
        self.obs_enc = ObsEncoder(obs_shape, expert_dim, frame_stack)
        self.pos_enc = PosEncoding(expert_dim, context_len)
        
        self.blocks = nn.ModuleList([
            MoETransformerBlock(expert_dim, num_heads, self.expert_store, expert_hidden, top_k, dropout)
            for _ in range(num_layers)])
        
        self.action_head = nn.Sequential(nn.Linear(expert_dim, expert_dim), nn.ReLU(), nn.Linear(expert_dim, num_actions))
        self.value_head = nn.Sequential(nn.Linear(expert_dim, expert_dim), nn.ReLU(), nn.Linear(expert_dim, 1))
        self.env_detector = EnvironmentDetector(obs_shape, num_envs, num_experts, 256, frame_stack)
        self.experts_used = []
    
    def forward(self, obs, env_id=None, prefetch=True):
        b, s = obs.shape[:2]
        device = obs.device
        
        env_pred = -1
        if prefetch:
            pred = self.env_detector.predict(obs[:, -1], top_k=32)
            env_pred = pred.predicted_env
            self.expert_store.prefetch(pred.prefetch_set)
        
        x = self.obs_enc(obs.reshape(b * s, *obs.shape[2:])).reshape(b, s, -1)
        x = self.pos_enc(x)
        
        mask = torch.triu(torch.ones(s, s, device=device) * float('-inf'), diagonal=1)
        
        total_aux = 0.0
        self.experts_used = []
        for block in self.blocks:
            x, aux = block(x, attention_mask=mask)
            total_aux = total_aux + aux
            stats = block.get_last_routing_stats()
            if 'expert_usage' in stats:
                self.experts_used.extend(stats['expert_usage'].keys())
        
        final = x[:, -1]
        return AgentOutput(
            self.action_head(final), self.value_head(final),
            list(set(self.experts_used)), total_aux, env_pred,
            self.expert_store.get_stats_summary())
    
    @torch.no_grad()
    def get_action(self, obs, deterministic=False):
        self.eval()
        out = self.forward(obs, prefetch=True)
        if deterministic:
            action = out.action_logits.argmax(dim=-1).item()
        else:
            probs = F.softmax(out.action_logits, dim=-1)
            action = torch.multinomial(probs, 1).item()
        return action, out.value.item()

class MoERLAgentConfig:
    def __init__(self, obs_shape=(1,84,84), num_actions=18, num_envs=57, frame_stack=4,
                 num_experts=256, expert_dim=512, expert_hidden=2048, num_layers=6,
                 num_heads=8, top_k=2, context_len=32, dropout=0.1, hbm_cap=32, dram_cap=128):
        self.obs_shape = obs_shape
        self.num_actions = num_actions
        self.num_envs = num_envs
        self.frame_stack = frame_stack
        self.num_experts = num_experts
        self.expert_dim = expert_dim
        self.expert_hidden = expert_hidden
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.top_k = top_k
        self.context_len = context_len
        self.dropout = dropout
        self.hbm_cap = hbm_cap
        self.dram_cap = dram_cap
    
    def create_agent(self, device='cuda'):
        return MoERLAgent(
            self.obs_shape, self.num_actions, self.num_experts, self.expert_dim,
            self.expert_hidden, self.num_layers, self.num_heads, self.top_k,
            self.context_len, self.frame_stack, self.num_envs, self.hbm_cap,
            self.dram_cap, self.dropout, device)
    
    def estimate_params(self):
        exp = self.num_experts * (self.expert_dim * self.expert_hidden * 2 + self.expert_hidden + self.expert_dim)
        attn = self.num_layers * (4 * self.expert_dim * self.expert_dim)
        return {'expert_params': exp, 'attention_params': attn, 'total': exp + attn}

In [None]:
%%writefile ooc_moe/envs/atari_wrappers.py
import numpy as np
from typing import Tuple, Optional, Dict
import gymnasium as gym
from gymnasium import spaces

class DummyAtariEnv(gym.Env):
    def __init__(self, game_id=0, obs_shape=(4, 84, 84), num_actions=18):
        super().__init__()
        self.game_id = game_id
        self.obs_shape = obs_shape
        self.num_actions = num_actions
        self.observation_space = spaces.Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)
        self.action_space = spaces.Discrete(num_actions)
        self._rng = np.random.RandomState(game_id)
        self._step = 0
        self._pattern = self._make_pattern()
    
    def _make_pattern(self):
        p = np.zeros((84, 84), dtype=np.float32)
        for i in range(5):
            x, y = (self.game_id * 17 + i * 13) % 84, (self.game_id * 23 + i * 7) % 84
            s = 5 + (self.game_id % 10)
            p[max(0,y-s):min(84,y+s), max(0,x-s):min(84,x+s)] = 128
        return p
    
    def reset(self, seed=None, **kw):
        if seed:
            self._rng = np.random.RandomState(seed)
        self._step = 0
        return self._obs(), {'game_id': self.game_id}
    
    def step(self, action):
        self._step += 1
        r = 1.0 if action == self.game_id % self.num_actions else self._rng.random() * 0.1
        done = self._step >= 1000 or self._rng.random() < 0.001
        return self._obs(), r, done, False, {'game_id': self.game_id}
    
    def _obs(self):
        base = np.clip(self._pattern + self._rng.randn(84, 84) * 20, 0, 255).astype(np.uint8)
        return np.stack([np.clip(base + (self._step + i) % 50, 0, 255).astype(np.uint8) for i in range(self.obs_shape[0])])

def create_dummy_envs(n=10):
    return [DummyAtariEnv(i) for i in range(n)], [f'Game_{i}' for i in range(n)]

## 3. Test Components

In [None]:
import sys
sys.path.insert(0, '.')

from ooc_moe.core.tiered_store import TieredExpertStore
from ooc_moe.core.moe_layers import TieredMoELayer
from ooc_moe.core.env_detector import EnvironmentDetector
from ooc_moe.models.moe_agent import MoERLAgent, MoERLAgentConfig
from ooc_moe.envs.atari_wrappers import DummyAtariEnv, create_dummy_envs

print('All modules imported!')

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

# Test tiered store
store = TieredExpertStore(64, 128, 256, 8, 16, device, simulate_latency=False)
experts = store.get_experts([0, 1, 2])
print(f'Got {len(experts)} experts')
print(f'Stats: {store.get_stats_summary()}')
store.shutdown()
print('TieredExpertStore OK')

In [None]:
# Test full agent
config = MoERLAgentConfig(
    obs_shape=(1, 84, 84), num_actions=18, num_envs=5, frame_stack=4,
    num_experts=32, expert_dim=128, expert_hidden=256, num_layers=2,
    num_heads=4, top_k=2, context_len=4, hbm_cap=8, dram_cap=16)

print('Params:', config.estimate_params())

agent = config.create_agent(device)
obs = torch.randn(1, config.context_len, config.frame_stack, 84, 84, device=device)
out = agent(obs, env_id=0, prefetch=True)

print(f'Action logits: {out.action_logits.shape}')
print(f'Value: {out.value.shape}')
print(f'Experts used: {len(out.expert_ids)}')
print(f'Cache hit rate: {out.cache_stats["hit_rate"]:.2%}')

action, value = agent.get_action(obs)
print(f'Action: {action}, Value: {value:.4f}')

agent.expert_store.shutdown()
print('Agent OK!')

## 4. Training Demo

In [None]:
from collections import defaultdict, deque
import matplotlib.pyplot as plt
import numpy as np

config = MoERLAgentConfig(
    obs_shape=(1, 84, 84), num_actions=18, num_envs=3, frame_stack=4,
    num_experts=64, expert_dim=128, expert_hidden=256, num_layers=2,
    num_heads=4, top_k=2, context_len=8, hbm_cap=16, dram_cap=32)

agent = config.create_agent(device)
envs, names = create_dummy_envs(3)
optimizer = torch.optim.Adam(agent.parameters(), lr=1e-4)

expert_usage = defaultdict(lambda: defaultdict(int))
cache_history = []
steps_per_game = 300

for gid, (env, name) in enumerate(zip(envs, names)):
    print(f'Training on {name}...')
    obs, _ = env.reset()
    buf = deque([torch.from_numpy(obs.astype(np.float32) / 255.0) for _ in range(config.context_len)], maxlen=config.context_len)
    
    for step in range(steps_per_game):
        ctx = torch.stack(list(buf), dim=0).unsqueeze(0).to(device)
        agent.train()
        out = agent(ctx, env_id=gid, prefetch=True)
        
        for eid in out.expert_ids:
            expert_usage[gid][eid] += 1
        cache_history.append(out.cache_stats['hit_rate'])
        
        probs = F.softmax(out.action_logits, dim=-1)
        action = torch.multinomial(probs, 1).item()
        
        next_obs, reward, done, _, _ = env.step(action)
        
        loss = -torch.log(probs[0, action]) * reward + 0.01 * out.aux_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        buf.append(torch.from_numpy(next_obs.astype(np.float32) / 255.0))
        if done:
            obs, _ = env.reset()
            buf = deque([torch.from_numpy(obs.astype(np.float32) / 255.0) for _ in range(config.context_len)], maxlen=config.context_len)
    
    print(f'  Cache hit rate: {np.mean(cache_history[-steps_per_game:]):.2%}')

agent.expert_store.shutdown()
print('Done!')

In [None]:
# Visualize results
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(cache_history)
axes[0].axhline(np.mean(cache_history), color='r', linestyle='--', label=f'Mean: {np.mean(cache_history):.2%}')
axes[0].set_xlabel('Step')
axes[0].set_ylabel('Cache Hit Rate')
axes[0].set_title('Cache Performance')
axes[0].legend()

for gid in range(3):
    usage = expert_usage[gid]
    experts = list(usage.keys())[:15]
    counts = [usage[e] for e in experts]
    axes[1].bar([e + gid * 0.25 for e in range(len(experts))], counts, width=0.25, label=f'Game {gid}', alpha=0.7)

axes[1].set_xlabel('Expert (top 15)')
axes[1].set_ylabel('Usage Count')
axes[1].set_title('Expert Specialization')
axes[1].legend()

plt.tight_layout()
plt.show()

In [None]:
# Expert overlap analysis
print('Expert Overlap Analysis')
print('=' * 40)

for gid in range(3):
    usage = expert_usage[gid]
    top5 = sorted(usage.items(), key=lambda x: x[1], reverse=True)[:5]
    total = sum(usage.values())
    print(f'Game {gid}: {[(e, f"{c/total:.1%}") for e, c in top5]}')

print()
for i in range(3):
    for j in range(i+1, 3):
        top_i = set(e for e, _ in sorted(expert_usage[i].items(), key=lambda x: x[1], reverse=True)[:10])
        top_j = set(e for e, _ in sorted(expert_usage[j].items(), key=lambda x: x[1], reverse=True)[:10])
        print(f'Game {i} vs {j}: {len(top_i & top_j)}/10 experts overlap')