#  MoE Routing Analysis Demo

**Speculative Expert Routing for Mixtral-8x7B**

This notebook demonstrates comprehensive analysis of Mixture-of-Experts (MoE) routing patterns collected from Mixtral-8x7B-Instruct-v0.1-FP8 during inference on HumanEval and GSM8K benchmarks.

---

**Contents:**
1. [Routing Analysis Dashboard](#section-1) - Token journey visualization & expert distributions
2. [Statistical Visualizations](#section-2) - Heatmaps, accuracy curves, path analysis
3. [Statistical Summary Tables](#section-3) - Conditional probabilities & model comparisons
4. [Domain Comparison](#section-4) - HumanEval vs GSM8K patterns

In [None]:
# ============================================
# Setup & Imports
# ============================================
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from collections import Counter, defaultdict
from typing import Dict, List, Tuple
import warnings
warnings.filterwarnings('ignore')

# Plotting settings
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 11

# Try importing interactive widgets (optional)
try:
    from ipywidgets import interact, Dropdown, IntSlider, Output
    import ipywidgets as widgets
    HAS_WIDGETS = True
except ImportError:
    HAS_WIDGETS = False
    print("ipywidgets not available - interactive features disabled")

print(" Setup complete!")

In [None]:
# ============================================
# Data Loading
# ============================================

DATA_DIR = Path("routing_data_collected")
HUMANEVAL_TOKENS_DIR = DATA_DIR / "humaneval_tokens"
GSM8K_TOKENS_DIR = DATA_DIR / "gsm8k_tokens"

def load_token_journey(filepath: Path) -> List[dict]:
    """Load a single token's 32-layer journey."""
    with open(filepath, 'r') as f:
        return [json.loads(line) for line in f if line.strip()]

def load_all_journeys(token_dir: Path, max_tokens: int = None) -> List[List[dict]]:
    """Load all token journeys from a directory."""
    files = sorted(token_dir.glob("*.jsonl"))
    if max_tokens:
        files = files[:max_tokens]
    return [load_token_journey(f) for f in files]

# Load sample for quick exploration (full load can take a minute)
print("Loading token journeys...")
humaneval_files = sorted(HUMANEVAL_TOKENS_DIR.glob("*.jsonl"))
gsm8k_files = sorted(GSM8K_TOKENS_DIR.glob("*.jsonl"))

print(f" Found {len(humaneval_files):,} HumanEval token journeys")
print(f" Found {len(gsm8k_files):,} GSM8K token journeys")
print(f" Total: {len(humaneval_files) + len(gsm8k_files):,} token journeys")

# Load a subset for interactive exploration
SAMPLE_SIZE = 5000  # Adjust based on memory
humaneval_journeys = load_all_journeys(HUMANEVAL_TOKENS_DIR, max_tokens=SAMPLE_SIZE)
gsm8k_journeys = load_all_journeys(GSM8K_TOKENS_DIR, max_tokens=SAMPLE_SIZE)
all_journeys = humaneval_journeys + gsm8k_journeys

print(f"\n Loaded {len(all_journeys):,} token journeys for analysis")

---
<a id='section-1'></a>
##  Section 1: Routing Analysis Dashboard

Visualize individual token journeys through the 32 MoE layers and aggregate expert usage patterns.

In [None]:
# ============================================
# 1.1 Single Token Journey Visualization
# ============================================

def plot_token_journey(journey: List[dict], title: str = "Token Journey"):
    """Visualize a single token's path through 32 layers."""
    layers = [r['layer'] for r in journey]
    top1_experts = [r['experts'][0] for r in journey]
    top2_experts = [r['experts'][1] for r in journey]
    top1_probs = [r['gating_probs'][0] for r in journey]
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Left: Expert path
    ax1 = axes[0]
    ax1.plot(layers, top1_experts, 'o-', markersize=8, linewidth=2, 
             label='Top-1 Expert', color='#1f77b4')
    ax1.plot(layers, top2_experts, 's--', markersize=6, linewidth=1.5, 
             alpha=0.6, label='Top-2 Expert', color='#ff7f0e')
    ax1.set_xlabel('Layer')
    ax1.set_ylabel('Expert ID')
    ax1.set_title(f'{title} - Expert Selection Path')
    ax1.set_yticks(range(8))
    ax1.set_xticks(range(0, 32, 4))
    ax1.legend(loc='upper right')
    ax1.grid(True, alpha=0.3)
    
    # Right: Gating probability (confidence)
    ax2 = axes[1]
    colors = plt.cm.RdYlGn(np.array(top1_probs))
    ax2.bar(layers, top1_probs, color=colors, edgecolor='black', linewidth=0.5)
    ax2.axhline(y=0.5, color='red', linestyle='--', alpha=0.5, label='50% threshold')
    ax2.set_xlabel('Layer')
    ax2.set_ylabel('Top-1 Gating Probability')
    ax2.set_title(f'{title} - Routing Confidence')
    ax2.set_ylim(0, 1)
    ax2.legend()
    
    plt.tight_layout()
    plt.show()
    
    # Summary stats
    unique_experts = len(set(top1_experts))
    avg_confidence = np.mean(top1_probs)
    print(f" Unique experts used: {unique_experts}/8")
    print(f" Average routing confidence: {avg_confidence:.2%}")

# Example: Visualize first token journey
sample_journey = humaneval_journeys[0]
plot_token_journey(sample_journey, title="HumanEval P0 Token 0")

In [None]:
# ============================================
# 1.2 Interactive Token Journey Selector
# ============================================

if HAS_WIDGETS:
    # Create interactive widget
    @interact(
        dataset=Dropdown(options=['humaneval', 'gsm8k'], value='humaneval'),
        token_idx=IntSlider(min=0, max=min(SAMPLE_SIZE-1, 100), value=0, description='Token #')
    )
    def interactive_journey(dataset, token_idx):
        journeys = humaneval_journeys if dataset == 'humaneval' else gsm8k_journeys
        if token_idx < len(journeys):
            plot_token_journey(journeys[token_idx], f"{dataset.upper()} Token {token_idx}")
        else:
            print(f"Token index {token_idx} out of range")
else:
    print("Interactive widgets unavailable. Showing static examples:")
    plot_token_journey(gsm8k_journeys[0], "GSM8K P0 Token 0")

In [None]:
# ============================================
# 1.3 Layer-wise Expert Distribution
# ============================================

def compute_layer_expert_distribution(journeys: List[List[dict]]) -> np.ndarray:
    """Compute expert usage counts per layer."""
    # Shape: (32 layers, 8 experts)
    counts = np.zeros((32, 8))
    for journey in journeys:
        for record in journey:
            layer = record['layer']
            top1_expert = record['experts'][0]
            counts[layer, top1_expert] += 1
    # Normalize to percentages
    return counts / counts.sum(axis=1, keepdims=True) * 100

# Compute distributions
dist_all = compute_layer_expert_distribution(all_journeys)

# Create heatmap
fig, ax = plt.subplots(figsize=(14, 8))
sns.heatmap(dist_all.T, annot=False, cmap='YlOrRd', 
            xticklabels=range(32), yticklabels=range(8),
            cbar_kws={'label': 'Usage %'}, ax=ax)
ax.set_xlabel('Layer', fontsize=12)
ax.set_ylabel('Expert ID', fontsize=12)
ax.set_title('Expert Usage Distribution Across Layers (Top-1 Selection)', fontsize=14)
plt.tight_layout()
plt.show()

# Identify layer-specific biases
print("\n Layer-Specific Expert Preferences (>20% usage):")
for layer in range(32):
    dominant = np.where(dist_all[layer] > 20)[0]
    if len(dominant) > 0:
        prefs = ", ".join([f"E{e}({dist_all[layer, e]:.1f}%)" for e in dominant])
        print(f"   Layer {layer:2d}: {prefs}")

In [None]:
# ============================================
# 1.4 Gating Probability Distribution
# ============================================

def collect_gating_probs(journeys: List[List[dict]]) -> Tuple[List[float], List[float]]:
    """Collect all gating probabilities."""
    top1_probs = []
    top2_probs = []
    for journey in journeys:
        for record in journey:
            top1_probs.append(record['gating_probs'][0])
            top2_probs.append(record['gating_probs'][1])
    return top1_probs, top2_probs

top1_probs, top2_probs = collect_gating_probs(all_journeys)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Top-1 distribution
axes[0].hist(top1_probs, bins=50, color='#1f77b4', edgecolor='black', alpha=0.7)
axes[0].axvline(np.mean(top1_probs), color='red', linestyle='--', 
                label=f'Mean: {np.mean(top1_probs):.3f}')
axes[0].set_xlabel('Top-1 Gating Probability')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Top-1 Expert Confidence Distribution')
axes[0].legend()

# Top-2 distribution
axes[1].hist(top2_probs, bins=50, color='#ff7f0e', edgecolor='black', alpha=0.7)
axes[1].axvline(np.mean(top2_probs), color='red', linestyle='--',
                label=f'Mean: {np.mean(top2_probs):.3f}')
axes[1].set_xlabel('Top-2 Gating Probability')
axes[1].set_ylabel('Frequency')
axes[1].set_title('Top-2 Expert Confidence Distribution')
axes[1].legend()

plt.tight_layout()
plt.show()

print(f"\n Gating Statistics:")
print(f"   Top-1 Mean: {np.mean(top1_probs):.3f}, Std: {np.std(top1_probs):.3f}")
print(f"   Top-2 Mean: {np.mean(top2_probs):.3f}, Std: {np.std(top2_probs):.3f}")
print(f"   High confidence (>0.8): {np.mean(np.array(top1_probs) > 0.8):.1%} of decisions")

---
<a id='section-2'></a>
##  Section 2: Statistical Visualizations

Deep analysis of inter-layer dependencies and predictability patterns.

In [None]:
# ============================================
# 2.1 Inter-layer Affinity Heatmaps (32x32)
# ============================================

def compute_transition_matrix(journeys: List[List[dict]], source_layer: int, target_layer: int) -> np.ndarray:
    """Compute P(expert @ target_layer | expert @ source_layer)."""
    # Shape: (8 source experts, 8 target experts)
    counts = np.zeros((8, 8))
    for journey in journeys:
        if source_layer < len(journey) and target_layer < len(journey):
            src_expert = journey[source_layer]['experts'][0]
            tgt_expert = journey[target_layer]['experts'][0]
            counts[src_expert, tgt_expert] += 1
    # Normalize rows (conditional probability)
    row_sums = counts.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0] = 1  # Avoid division by zero
    return counts / row_sums

def compute_affinity_matrix(journeys: List[List[dict]]) -> np.ndarray:
    """Compute 32x32 matrix of avg transition predictability."""
    affinity = np.zeros((32, 32))
    for src in range(32):
        for tgt in range(32):
            if src != tgt:
                trans = compute_transition_matrix(journeys, src, tgt)
                # Max probability in each row = predictability
                affinity[src, tgt] = trans.max(axis=1).mean()
    return affinity

print("Computing inter-layer affinity matrix (this may take a moment)...")
affinity_matrix = compute_affinity_matrix(all_journeys)

# Plot heatmap
fig, ax = plt.subplots(figsize=(12, 10))
sns.heatmap(affinity_matrix, cmap='viridis', 
            xticklabels=range(32), yticklabels=range(32),
            cbar_kws={'label': 'Avg Max Transition Prob'}, ax=ax)
ax.set_xlabel('Target Layer', fontsize=12)
ax.set_ylabel('Source Layer', fontsize=12)
ax.set_title('Inter-Layer Routing Affinity\n(Higher = More Predictable Transitions)', fontsize=14)
plt.tight_layout()
plt.show()

# Analyze patterns
adjacent_avg = np.mean([affinity_matrix[i, i+1] for i in range(31)])
non_adjacent_avg = np.mean([affinity_matrix[i, j] for i in range(32) for j in range(32) if abs(i-j) > 1])
print(f"\n Affinity Analysis:")
print(f"   Adjacent layers (L→L+1) avg predictability: {adjacent_avg:.3f}")
print(f"   Non-adjacent layers avg predictability: {non_adjacent_avg:.3f}")
print(f"   Adjacent advantage: +{(adjacent_avg - non_adjacent_avg) / non_adjacent_avg * 100:.1f}%")

In [None]:
# ============================================
# 2.2 Sample Transition Matrices (Layer L → L+1)
# ============================================

# Show detailed transition matrices for selected layer pairs
layer_pairs = [(0, 1), (15, 16), (30, 31)]  # Early, middle, late

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for idx, (src, tgt) in enumerate(layer_pairs):
    trans = compute_transition_matrix(all_journeys, src, tgt)
    sns.heatmap(trans, annot=True, fmt='.2f', cmap='Blues',
                xticklabels=range(8), yticklabels=range(8),
                cbar=False, ax=axes[idx])
    axes[idx].set_xlabel(f'Expert @ Layer {tgt}')
    axes[idx].set_ylabel(f'Expert @ Layer {src}')
    axes[idx].set_title(f'P(E@L{tgt} | E@L{src})')

plt.suptitle('Conditional Expert Selection Probabilities', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

In [None]:
# ============================================
# 2.3 Prediction Accuracy Curves (1-hop to 31-hop)
# ============================================

def compute_hop_accuracy(journeys: List[List[dict]], hop_distance: int) -> float:
    """Compute accuracy of predicting layer L+hop from layer L using most frequent transition."""
    correct = 0
    total = 0
    
    for src_layer in range(32 - hop_distance):
        tgt_layer = src_layer + hop_distance
        trans = compute_transition_matrix(journeys, src_layer, tgt_layer)
        
        for journey in journeys:
            src_expert = journey[src_layer]['experts'][0]
            tgt_expert = journey[tgt_layer]['experts'][0]
            predicted = trans[src_expert].argmax()  # Most likely next expert
            if predicted == tgt_expert:
                correct += 1
            total += 1
    
    return correct / total if total > 0 else 0

print("Computing hop accuracy (this takes a few minutes)...")
hop_accuracies = []
hop_distances = list(range(1, 32))

for hop in hop_distances:
    acc = compute_hop_accuracy(all_journeys, hop)
    hop_accuracies.append(acc * 100)  # Convert to percentage
    if hop % 5 == 0:
        print(f"   Hop {hop:2d}: {acc*100:.2f}%")

# Plot
fig, ax = plt.subplots(figsize=(12, 6))
ax.plot(hop_distances, hop_accuracies, 'o-', linewidth=2, markersize=6, color='#2ca02c')
ax.axhline(y=12.5, color='red', linestyle='--', alpha=0.7, label='Random Baseline (12.5%)')
ax.fill_between(hop_distances, 12.5, hop_accuracies, alpha=0.2, color='green')

ax.set_xlabel('Hop Distance (layers ahead)', fontsize=12)
ax.set_ylabel('Prediction Accuracy (%)', fontsize=12)
ax.set_title('Expert Prediction Accuracy vs. Layer Distance\n(Using Most Frequent Transition)', fontsize=14)
ax.set_xticks(range(1, 32, 2))
ax.set_ylim(0, max(hop_accuracies) + 5)
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\n Key Findings:")
print(f"   1-hop accuracy: {hop_accuracies[0]:.2f}%")
print(f"   5-hop accuracy: {hop_accuracies[4]:.2f}%")
print(f"   Accuracy at 31-hop: {hop_accuracies[-1]:.2f}%")
print(f"   Layers before hitting ~15%: {next((i+1 for i, a in enumerate(hop_accuracies) if a < 15), 31)}")

In [None]:
# ============================================
# 2.4 Path Frequency Analysis
# ============================================

def get_path_signature(journey: List[dict]) -> Tuple[int, ...]:
    """Get the 32-layer expert path as a tuple."""
    return tuple(record['experts'][0] for record in journey)

# Count path frequencies
path_counts = Counter()
for journey in all_journeys:
    path = get_path_signature(journey)
    path_counts[path] += 1

# Statistics
total_tokens = len(all_journeys)
unique_paths = len(path_counts)
top_10_paths = path_counts.most_common(10)

print(f" Path Analysis:")
print(f"   Total tokens analyzed: {total_tokens:,}")
print(f"   Unique 32-layer paths: {unique_paths:,}")
print(f"   Path uniqueness ratio: {unique_paths/total_tokens:.2%}")

# Plot top paths
fig, ax = plt.subplots(figsize=(12, 5))
path_labels = [f"Path {i+1}" for i in range(10)]
path_freqs = [count for _, count in top_10_paths]

bars = ax.bar(path_labels, path_freqs, color=plt.cm.tab10.colors, edgecolor='black')
ax.set_ylabel('Frequency')
ax.set_title('Top 10 Most Common Expert Routing Paths', fontsize=14)

# Add percentage labels
for bar, freq in zip(bars, path_freqs):
    pct = freq / total_tokens * 100
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 5, 
            f'{pct:.2f}%', ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.show()

# Show actual paths
print("\n Top 5 Most Common Paths:")
for i, (path, count) in enumerate(top_10_paths[:5]):
    path_str = '→'.join(str(e) for e in path[:8]) + '...→' + '→'.join(str(e) for e in path[-4:])
    print(f"   {i+1}. [{path_str}] (n={count}, {count/total_tokens:.2%})")

---
<a id='section-3'></a>
##  Section 3: Statistical Summary Tables

Quantitative comparison of routing prediction methods.

In [None]:
# ============================================
# 3.1 Model Performance Comparison
# ============================================

# From experiment results
results_df = pd.DataFrame([
    {'Model': 'Random Baseline', 'Top-1 Accuracy': 12.50, 'Top-2 Accuracy': 25.00, 'Improvement vs Random': '—'},
    {'Model': 'XGBoost (Tree)', 'Top-1 Accuracy': 36.16, 'Top-2 Accuracy': 53.26, 'Improvement vs Random': '2.9×'},
    {'Model': 'Lookup Rules', 'Top-1 Accuracy': 42.17, 'Top-2 Accuracy': 56.15, 'Improvement vs Random': '3.4×'},
    {'Model': 'MLP (Baseline)', 'Top-1 Accuracy': 44.64, 'Top-2 Accuracy': 62.63, 'Improvement vs Random': '3.6×'},
    {'Model': 'MLP (Optimized)', 'Top-1 Accuracy': 55.60, 'Top-2 Accuracy': 73.24, 'Improvement vs Random': '4.5×'},
])

print("═" * 70)
print("                    MODEL PERFORMANCE COMPARISON")
print("═" * 70)
print(results_df.to_string(index=False))
print("═" * 70)

# Visualization
fig, ax = plt.subplots(figsize=(10, 6))
x = np.arange(len(results_df))
width = 0.35

bars1 = ax.bar(x - width/2, results_df['Top-1 Accuracy'], width, label='Top-1', color='#1f77b4')
bars2 = ax.bar(x + width/2, results_df['Top-2 Accuracy'], width, label='Top-2', color='#2ca02c')

ax.axhline(y=12.5, color='red', linestyle='--', alpha=0.5, label='Random (12.5%)')
ax.set_ylabel('Accuracy (%)')
ax.set_title('Expert Prediction Accuracy by Model', fontsize=14)
ax.set_xticks(x)
ax.set_xticklabels(results_df['Model'], rotation=15, ha='right')
ax.legend()
ax.set_ylim(0, 80)

# Add value labels
for bar in bars1:
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, 
            f'{bar.get_height():.1f}', ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.show()

In [None]:
# ============================================
# 3.2 Detailed Conditional Probability Table
# ============================================

# Most predictable layer transition
best_pair = None
best_predictability = 0
for src in range(31):
    pred = affinity_matrix[src, src+1]
    if pred > best_predictability:
        best_predictability = pred
        best_pair = (src, src+1)

print(f"\n Most Predictable Transition: Layer {best_pair[0]} → Layer {best_pair[1]}")
print(f"   Predictability Score: {best_predictability:.3f}")

# Show detailed matrix
trans = compute_transition_matrix(all_journeys, best_pair[0], best_pair[1])
trans_df = pd.DataFrame(trans, 
                        index=[f'From E{i}' for i in range(8)],
                        columns=[f'To E{i}' for i in range(8)])

print(f"\nConditional Probability Matrix: P(E@L{best_pair[1]} | E@L{best_pair[0]})")
print("─" * 60)
print(trans_df.round(3).to_string())

# Highlight strong transitions
print("\n Strong Transitions (>30%):")
for src in range(8):
    for tgt in range(8):
        if trans[src, tgt] > 0.3:
            print(f"   E{src} → E{tgt}: {trans[src, tgt]:.1%}")

In [None]:
# ============================================
# 3.3 Cross-Layer Correlation Analysis
# ============================================

def compute_expert_correlation(journeys: List[List[dict]]) -> np.ndarray:
    """Compute Pearson correlation of expert IDs across layers."""
    # Build matrix: (num_tokens, 32 layers)
    n_tokens = len(journeys)
    expert_matrix = np.zeros((n_tokens, 32))
    
    for i, journey in enumerate(journeys):
        for record in journey:
            expert_matrix[i, record['layer']] = record['experts'][0]
    
    # Compute correlation
    return np.corrcoef(expert_matrix.T)

print("Computing cross-layer correlations...")
corr_matrix = compute_expert_correlation(all_journeys)

fig, ax = plt.subplots(figsize=(12, 10))
mask = np.eye(32, dtype=bool)  # Mask diagonal
sns.heatmap(corr_matrix, cmap='RdBu_r', center=0, 
            mask=mask, vmin=-0.3, vmax=0.3,
            xticklabels=range(32), yticklabels=range(32),
            cbar_kws={'label': 'Pearson Correlation'}, ax=ax)
ax.set_xlabel('Layer', fontsize=12)
ax.set_ylabel('Layer', fontsize=12)
ax.set_title('Cross-Layer Expert Selection Correlations', fontsize=14)
plt.tight_layout()
plt.show()

# Find strongest correlations
print("\n Strongest Non-Adjacent Correlations:")
corr_pairs = []
for i in range(32):
    for j in range(i+2, 32):  # Skip adjacent
        corr_pairs.append((i, j, corr_matrix[i, j]))
corr_pairs.sort(key=lambda x: abs(x[2]), reverse=True)

for i, j, corr in corr_pairs[:5]:
    print(f"   Layer {i} ↔ Layer {j}: r = {corr:+.3f}")

---
<a id='section-4'></a>
##  Section 4: Domain Comparison (HumanEval vs GSM8K)

Compare routing patterns between code generation (HumanEval) and math reasoning (GSM8K) tasks.

In [None]:
# ============================================
# 4.1 Expert Usage by Domain
# ============================================

dist_humaneval = compute_layer_expert_distribution(humaneval_journeys)
dist_gsm8k = compute_layer_expert_distribution(gsm8k_journeys)

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

sns.heatmap(dist_humaneval.T, cmap='Blues', annot=False,
            xticklabels=range(32), yticklabels=range(8),
            cbar_kws={'label': 'Usage %'}, ax=axes[0])
axes[0].set_xlabel('Layer')
axes[0].set_ylabel('Expert ID')
axes[0].set_title('HumanEval (Code Generation)', fontsize=13)

sns.heatmap(dist_gsm8k.T, cmap='Oranges', annot=False,
            xticklabels=range(32), yticklabels=range(8),
            cbar_kws={'label': 'Usage %'}, ax=axes[1])
axes[1].set_xlabel('Layer')
axes[1].set_ylabel('Expert ID')
axes[1].set_title('GSM8K (Math Reasoning)', fontsize=13)

plt.suptitle('Expert Usage Patterns by Domain', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

In [None]:
# ============================================
# 4.2 Domain Difference Analysis
# ============================================

# Compute difference
diff_matrix = dist_humaneval - dist_gsm8k

fig, ax = plt.subplots(figsize=(14, 6))
sns.heatmap(diff_matrix.T, cmap='RdBu_r', center=0, 
            vmin=-15, vmax=15,
            xticklabels=range(32), yticklabels=range(8),
            cbar_kws={'label': 'HumanEval − GSM8K (%)'}, ax=ax)
ax.set_xlabel('Layer', fontsize=12)
ax.set_ylabel('Expert ID', fontsize=12)
ax.set_title('Expert Usage Difference (Blue=HumanEval prefers, Red=GSM8K prefers)', fontsize=13)
plt.tight_layout()
plt.show()

# Find domain-specific experts
print("\n Domain-Specific Expert Preferences (>5% difference):")
print("\nHumanEval-preferred (Code):")
for layer in range(32):
    for expert in range(8):
        if diff_matrix[layer, expert] > 5:
            print(f"   Layer {layer:2d}, Expert {expert}: +{diff_matrix[layer, expert]:.1f}%")

print("\nGSM8K-preferred (Math):")
for layer in range(32):
    for expert in range(8):
        if diff_matrix[layer, expert] < -5:
            print(f"   Layer {layer:2d}, Expert {expert}: {diff_matrix[layer, expert]:.1f}%")

In [None]:
# ============================================
# 4.3 Aggregate Domain Statistics
# ============================================

def compute_domain_stats(journeys: List[List[dict]], name: str) -> dict:
    """Compute aggregate statistics for a domain."""
    all_top1 = []
    unique_experts_per_token = []
    
    for journey in journeys:
        experts = [r['experts'][0] for r in journey]
        probs = [r['gating_probs'][0] for r in journey]
        all_top1.extend(probs)
        unique_experts_per_token.append(len(set(experts)))
    
    return {
        'Domain': name,
        'Tokens': len(journeys),
        'Avg Confidence': np.mean(all_top1),
        'Avg Unique Experts': np.mean(unique_experts_per_token),
        'High Conf (>0.8)': np.mean(np.array(all_top1) > 0.8) * 100
    }

stats = [
    compute_domain_stats(humaneval_journeys, 'HumanEval'),
    compute_domain_stats(gsm8k_journeys, 'GSM8K')
]

stats_df = pd.DataFrame(stats)
stats_df['Avg Confidence'] = stats_df['Avg Confidence'].apply(lambda x: f"{x:.3f}")
stats_df['Avg Unique Experts'] = stats_df['Avg Unique Experts'].apply(lambda x: f"{x:.1f}/8")
stats_df['High Conf (>0.8)'] = stats_df['High Conf (>0.8)'].apply(lambda x: f"{x:.1f}%")

print("\n" + "═" * 70)
print("                     DOMAIN COMPARISON SUMMARY")
print("═" * 70)
print(stats_df.to_string(index=False))
print("═" * 70)

---
##  Summary

This notebook demonstrated:

1. **Routing Analysis Dashboard**: Visualized individual token journeys and aggregate expert usage patterns across 32 MoE layers

2. **Statistical Visualizations**: 
   - Inter-layer affinity heatmaps showing transition predictability
   - Prediction accuracy decay curves from 1-hop to 31-hop
   - Path frequency analysis revealing routing diversity

3. **Statistical Summary Tables**:
   - Model comparison: MLP achieves 55.60% Top-1 and 73.24% Top-2 accuracy — accuracy improvement of +345% (4.5×)
   - Conditional probability matrices for expert transitions
   - Cross-layer correlation coefficients

4. **Domain Comparison**: HumanEval vs GSM8K routing patterns show task-specific expert preferences

---
*Data: 44,283 token journeys from Mixtral-8x7B-Instruct-v0.1-FP8*