# OLMoE Routing Experiments - Complete End-to-End Notebook

**Complete workflow from installation to results analysis**

This notebook runs on:
- ‚úÖ Google Colab (GPU recommended)
- ‚úÖ Local Jupyter (GPU or CPU)
- ‚úÖ Kaggle, Paperspace, etc.

---

## Table of Contents

1. [Environment Setup](#1-environment-setup)
2. [GPU Configuration](#2-gpu-configuration)
3. [Installation](#3-installation)
4. [Custom Expert Selection & Model Patching](#4-custom-expert-selection--model-patching)
5. [Framework Setup](#5-framework-setup)
6. [Running Tests](#6-running-tests)
7. [Quick Experiment](#7-quick-experiment)
8. [Full Experiments](#8-full-experiments)
9. [Results Analysis](#9-results-analysis)
10. [Visualization](#10-visualization)

---

## 1. Environment Setup

Detect environment and configure accordingly

In [3]:
import sys
import os

# Detect environment
IN_COLAB = 'google.colab' in sys.modules
IN_KAGGLE = 'KAGGLE_KERNEL_RUN_TYPE' in os.environ

print(f"Running in Google Colab: {IN_COLAB}")
print(f"Running in Kaggle: {IN_KAGGLE}")
print(f"Python version: {sys.version}")
WORK_DIR = '/Users/aliab/Desktop/GitHub/MOE-with-feature-selection/olmoe_experiments'

# Set working directory
# if IN_COLAB:
#     # Mount Google Drive (optional - for saving results)
#     from google.colab import drive
#     drive.mount('/content/drive')
#     WORK_DIR = '/content/olmoe_experiments'
# else:
#     WORK_DIR = './olmoe_experiments'

os.makedirs(WORK_DIR, exist_ok=True)
os.chdir(WORK_DIR)
print(f"\nWorking directory: {os.getcwd()}")

Running in Google Colab: True
Running in Kaggle: False
Python version: 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]

Working directory: /Users/aliab/Desktop/GitHub/MOE-with-feature-selection/olmoe_experiments


## 2. GPU Configuration

Check GPU availability and configure for optimal performance

In [4]:
import torch

# Check CUDA availability
print("=" * 70)
print("GPU CONFIGURATION")
print("=" * 70)

cuda_available = torch.cuda.is_available()
print(f"\nCUDA Available: {cuda_available}")

if cuda_available:
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    
    for i in range(torch.cuda.device_count()):
        print(f"\nGPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"  Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")
    
    # Set device
    device = 'cuda'
    
    # Clear cache
    torch.cuda.empty_cache()
    print("\n‚úÖ GPU is ready!")
else:
    device = 'cpu'
    print("\n‚ö†Ô∏è  GPU not available. Using CPU (will be slower).")
    if IN_COLAB:
        print("\nüí° TIP: Enable GPU in Colab:")
        print("   Runtime ‚Üí Change runtime type ‚Üí Hardware accelerator ‚Üí GPU")

print(f"\nDevice for experiments: {device}")
print("=" * 70)

GPU CONFIGURATION

CUDA Available: True
CUDA Version: 12.6
Number of GPUs: 1

GPU 0: NVIDIA A100-SXM4-40GB
  Memory: 42.47 GB

‚úÖ GPU is ready!

Device for experiments: cuda


## 3. Installation

Install all required packages

In [5]:
%%bash
# Install dependencies
pip install -q torch transformers datasets pandas numpy matplotlib seaborn tqdm rich 
echo "‚úÖ All packages installed!"

‚úÖ All packages installed!


In [6]:
# Verify installations
import torch
import transformers
import datasets
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

print("Package Versions:")
print(f"  torch: {torch.__version__}")
print(f"  transformers: {transformers.__version__}")
print(f"  datasets: {datasets.__version__}")
print(f"  pandas: {pd.__version__}")
print(f"  numpy: {np.__version__}")
print("\n‚úÖ All imports successful!")

Package Versions:
  torch: 2.8.0+cu126
  transformers: 4.57.1
  datasets: 4.0.0
  pandas: 2.2.2
  numpy: 2.0.2

‚úÖ All imports successful!


## 4. Custom Expert Selection & Model Patching

**NEW: Support for custom forward pass with internal logging**

This section implements:
1. Custom expert selection (uniform weights)
2. Model patching to return router_logits
3. Internal logging of routing decisions

In [7]:
import torch
import torch.nn.functional as F
from typing import Tuple

def custom_select_experts(
    router_logits: torch.Tensor,
    top_k: int,
    num_experts: int
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Custom expert selection with uniform weights.
    
    This is equivalent to UniformRouting but integrated directly into the model.
    
    Args:
        router_logits: [tokens, num_experts] - Raw routing scores
        top_k: Number of experts to select
        num_experts: Total number of experts
    
    Returns:
        routing_weights: [tokens, top_k] - Uniform weights (1/top_k)
        selected_experts: [tokens, top_k] - Selected expert indices
    """
    # Convert logits to probabilities
    routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
    
    # Select top-k experts based on probabilities
    _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
    
    # KEY: Give each selected expert EQUAL probability (uniform routing)
    routing_weights = torch.ones_like(selected_experts, dtype=torch.float)
    routing_weights /= top_k
    
    return routing_weights.to(router_logits.dtype), selected_experts


def create_custom_forward(original_forward, top_k, num_experts):
    """
    Create a custom forward pass that:
    1. Uses custom_select_experts for routing
    2. Returns router_logits for analysis
    3. Enables internal logging of routing decisions
    """
    def new_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        
        # Get router logits: (batch * sequence_length, n_experts)
        router_logits = self.gate(hidden_states)

        # Use custom expert selection
        routing_weights, selected_experts = custom_select_experts(
            router_logits,
            top_k=top_k,
            num_experts=num_experts
        )

        # Initialize output
        final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim), 
            dtype=hidden_states.dtype, 
            device=hidden_states.device
        )

        # Create expert mask for efficient indexing
        # expert_mask: [num_experts, top_k, tokens]
        expert_mask = torch.nn.functional.one_hot(
            selected_experts, 
            num_classes=num_experts
        ).permute(2, 1, 0)

        # Process each expert
        for expert_idx in range(num_experts):
            expert_layer = self.experts[expert_idx]
            
            # Get tokens assigned to this expert
            idx, top_x = torch.where(expert_mask[expert_idx])

            if top_x.numel() == 0:
                continue  # No tokens for this expert

            # Compute expert output with routing weights
            current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
            current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

            # Accumulate results
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
        
        # Reshape output
        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
        
        # Return both output and router_logits for analysis
        return final_hidden_states, router_logits
    
    return new_forward


def patch_model_with_custom_routing(model, top_k=None):
    """
    Patch OLMoE model to use custom routing with internal logging.
    
    Args:
        model: OLMoE model instance
        top_k: Number of experts to use (None = use model default)
    """
    if top_k is None:
        top_k = getattr(model.config, 'num_experts_per_tok', 8)
    
    num_experts = getattr(model.config, 'num_local_experts', 64)
    
    print(f"Patching model with custom routing:")
    print(f"  top_k: {top_k}")
    print(f"  num_experts: {num_experts}")
    
    patched_layers = 0
    
    # Patch all MoE layers
    for layer_idx, layer in enumerate(model.model.layers):
        if hasattr(layer, 'mlp') and hasattr(layer.mlp, 'experts'):
            # Save original forward
            original_forward = layer.mlp.forward
            
            # Create and apply custom forward
            layer.mlp.forward = create_custom_forward(
                original_forward, 
                top_k, 
                num_experts
            ).__get__(layer.mlp, layer.mlp.__class__)
            
            patched_layers += 1
    
    print(f"‚úÖ Patched {patched_layers} MoE layers")
    return model


print("‚úÖ Custom expert selection functions defined!")
print("\nFunctions available:")
print("  - custom_select_experts()")
print("  - create_custom_forward()")
print("  - patch_model_with_custom_routing()")

‚úÖ Custom expert selection functions defined!

Functions available:
  - custom_select_experts()
  - create_custom_forward()
  - patch_model_with_custom_routing()


## 5. Framework Setup

Load the routing experiments framework with custom patching support

In [25]:
# Framework Setup - Simple direct path addition
import os
import sys

print("=" * 70)
print("FRAMEWORK SETUP")
print("=" * 70)

# Direct path to the framework directory
framework_dir = '/Users/aliab/Desktop/GitHub/MOE-with-feature-selection'

# Add to Python path
if framework_dir not in sys.path:
    sys.path.insert(0, framework_dir)
    print(f"‚úÖ Added to Python path: {framework_dir}")
else:
    print(f"‚úÖ Already in Python path: {framework_dir}")

print(f"\nPython path (first 3 entries):")
for i, p in enumerate(sys.path[:3], 1):
    print(f"  {i}. {p}")

print("\n" + "=" * 70)
!cd '/Users/aliab/Desktop/GitHub/MOE-with-feature-selection'

FRAMEWORK SETUP
‚úÖ Already in Python path: /Users/aliab/Desktop/GitHub/MOE-with-feature-selection

Python path (first 3 entries):
  1. /Users/aliab/Desktop/GitHub/MOE-with-feature-selection
  2. /content
  3. /env/python



In [26]:
# Import with cache clearing
import sys
import importlib

print("üì¶ Importing framework modules...")
print(f"Looking in: {sys.path[0]}\n")

# Clear any cached imports
if 'olmoe_routing_experiments' in sys.modules:
    print("üîÑ Clearing cached module...")
    del sys.modules['olmoe_routing_experiments']

# Try importing with detailed error handling
try:
    # Method 1: Direct import
    import olmoe_routing_experiments
    print("‚úÖ Step 1: Module imported")
    
    # Import specific classes
    from olmoe_routing_experiments import (
        RoutingConfig,
        ExperimentResults,
        RoutingStrategy,
        RegularRouting,
        NormalizedRouting,
        UniformRouting,
        AdaptiveRouting,
        RoutingExperimentRunner,
        ModelPatchingUtils
    )
    
    print("‚úÖ Step 2: All classes imported successfully!")
    print("\nüìö Available Components:")
    print("  ‚Ä¢ RoutingExperimentRunner")
    print("  ‚Ä¢ ModelPatchingUtils")
    print("  ‚Ä¢ RegularRouting, NormalizedRouting")
    print("  ‚Ä¢ UniformRouting, AdaptiveRouting")
    
except ModuleNotFoundError as e:
    print(f"‚ùå Module not found: {e}")
    print("\nüîß DEBUGGING:")
    
    # Check if file exists
    import os
    file_path = os.path.join(sys.path[0], 'olmoe_routing_experiments.py')
    print(f"\n1. File exists? {os.path.exists(file_path)}")
    print(f"   Path: {file_path}")
    
    # Check file permissions
    if os.path.exists(file_path):
        import stat
        st = os.stat(file_path)
        print(f"2. File readable? {bool(st.st_mode & stat.S_IRUSR)}")
        print(f"3. File size: {st.st_size} bytes")
    
    # Check Python version
    print(f"4. Python version: {sys.version}")
    
    print("\nüí° ALTERNATIVE FIX:")
    print("   Run this in a new cell:")
    print("   !cp /Users/aliab/Desktop/GitHub/MOE-with-feature-selection/olmoe_routing_experiments.py .")
    print("   import olmoe_routing_experiments")
    
    raise

except Exception as e:
    print(f"‚ùå Unexpected error: {e}")
    import traceback
    traceback.print_exc()
    raise

üì¶ Importing framework modules...
Looking in: /Users/aliab/Desktop/GitHub/MOE-with-feature-selection

‚ùå Module not found: No module named 'olmoe_routing_experiments'

üîß DEBUGGING:

1. File exists? False
   Path: /Users/aliab/Desktop/GitHub/MOE-with-feature-selection/olmoe_routing_experiments.py
4. Python version: 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]

üí° ALTERNATIVE FIX:
   Run this in a new cell:
   !cp /Users/aliab/Desktop/GitHub/MOE-with-feature-selection/olmoe_routing_experiments.py .
   import olmoe_routing_experiments


ModuleNotFoundError: No module named 'olmoe_routing_experiments'

In [27]:
!ls -lt
!cp /Users/aliab/Desktop/GitHub/MOE-with-feature-selection/olmoe_routing_experiments.py .
import olmoe_routing_experiments

total 0
cp: cannot stat '/Users/aliab/Desktop/GitHub/MOE-with-feature-selection/olmoe_routing_experiments.py': No such file or directory


ModuleNotFoundError: No module named 'olmoe_routing_experiments'

### Extended Framework with Custom Patching Support

In [None]:
class ExtendedRoutingExperimentRunner(RoutingExperimentRunner):
    """
    Extended runner with support for custom model patching.
    """
    
    def __init__(self, *args, use_custom_routing=False, **kwargs):
        super().__init__(*args, **kwargs)
        self.use_custom_routing = use_custom_routing
        
        if use_custom_routing:
            print("\nüîß CUSTOM ROUTING MODE ENABLED")
            print("   Model will be patched with custom_select_experts")
    
    def _set_expert_count(self, num_experts):
        """Override to support custom routing."""
        super()._set_expert_count(num_experts)
        
        if self.use_custom_routing:
            # Re-patch model with new expert count
            patch_model_with_custom_routing(self.model, top_k=num_experts)

print("‚úÖ Extended framework with custom patching support ready!")

## 6. Running Tests

Validate the framework is working correctly

In [None]:
print("=" * 70)
print("RUNNING VALIDATION TESTS")
print("=" * 70)

# Test 1: Routing Strategies
print("\n[1/5] Testing routing strategies...")
torch.manual_seed(42)
logits = torch.randn(1, 1, 64)

regular = RegularRouting(num_experts=8)
normalized = NormalizedRouting(num_experts=8)
uniform = UniformRouting(num_experts=8)

reg_indices, reg_weights = regular.route(logits)
norm_indices, norm_weights = normalized.route(logits)
uni_indices, uni_weights = uniform.route(logits)

# Verify uniform has equal weights
assert torch.allclose(uni_weights, torch.ones_like(uni_weights) / 8, atol=1e-6)
print("   ‚úÖ Routing strategies work correctly")

# Test 2: Custom Expert Selection
print("\n[2/5] Testing custom expert selection...")
router_logits = torch.randn(10, 64)  # 10 tokens, 64 experts
weights, indices = custom_select_experts(router_logits, top_k=8, num_experts=64)

assert weights.shape == (10, 8)
assert indices.shape == (10, 8)
assert torch.allclose(weights, torch.ones_like(weights) / 8, atol=1e-6)
print("   ‚úÖ Custom expert selection works correctly")

# Test 3: Different expert counts produce different routing
print("\n[3/5] Testing expert count variation...")
weights_4, indices_4 = custom_select_experts(router_logits, top_k=4, num_experts=64)
weights_16, indices_16 = custom_select_experts(router_logits, top_k=16, num_experts=64)

assert weights_4.shape[-1] == 4
assert weights_16.shape[-1] == 16
print("   ‚úÖ Different expert counts work correctly")

# Test 4: Verify uniform weights
print("\n[4/5] Testing uniform weight distribution...")
for k in [4, 8, 16]:
    w, _ = custom_select_experts(router_logits, top_k=k, num_experts=64)
    expected = torch.ones_like(w) / k
    assert torch.allclose(w, expected, atol=1e-6)
print("   ‚úÖ Uniform weights verified for all k values")

# Test 5: Statistics tracking
print("\n[5/5] Testing statistics tracking...")
strategy = UniformRouting(num_experts=8)
for _ in range(5):
    test_logits = torch.randn(2, 5, 64)
    strategy.route(test_logits)

stats = strategy.get_summary_stats()
assert 'avg_entropy' in stats
assert 'avg_concentration' in stats
assert 'unique_experts' in stats
print("   ‚úÖ Statistics tracking works correctly")

print("\n" + "=" * 70)
print("‚úÖ ALL TESTS PASSED!")
print("=" * 70)

## 7. Quick Experiment

Run a minimal experiment to verify everything works (~5 minutes)

**This experiment will:**
- Test 2 expert counts (8, 16)
- Test 2 strategies (regular, custom uniform)
- Evaluate on 50 samples
- Generate visualizations

In [None]:
print("=" * 70)
print("QUICK EXPERIMENT (Standard Routing)")
print("=" * 70)
print("\nConfiguration:")
print("  Expert counts: [8, 16]")
print("  Strategies: [regular, uniform]")
print("  Dataset: WikiText-2")
print("  Samples: 50")
print("  Estimated time: ~5 minutes")
print("\n" + "=" * 70)

# Create runner (standard mode)
runner_standard = RoutingExperimentRunner(
    model_name="allenai/OLMoE-1B-7B-0924",
    device=device,
    output_dir="./quick_experiment_standard"
)

# Run experiments
results_df_standard = runner_standard.run_all_experiments(
    expert_counts=[8, 16],
    strategies=['regular', 'uniform'],
    datasets=['wikitext'],
    max_samples=50
)

print("\n‚úÖ Standard routing experiment complete!")

In [None]:
# Display results
print("\nüìä QUICK EXPERIMENT RESULTS (Standard Routing)\n")
print(results_df_standard[[
    'config', 'perplexity', 'token_accuracy', 
    'tokens_per_second', 'avg_entropy'
]].to_string(index=False))

# Best configuration
best_idx = results_df_standard['perplexity'].idxmin()
best = results_df_standard.loc[best_idx]

print("\nüèÜ BEST CONFIGURATION:")
print(f"   Config: {best['config']}")
print(f"   Perplexity: {best['perplexity']:.2f}")
print(f"   Accuracy: {best['token_accuracy']:.4f}")
print(f"   Speed: {best['tokens_per_second']:.1f} tok/s")

### Quick Experiment with Custom Routing (Internal Logging)

**This uses the custom forward pass with internal router_logits logging**

In [None]:
print("=" * 70)
print("QUICK EXPERIMENT (Custom Routing with Internal Logging)")
print("=" * 70)
print("\nConfiguration:")
print("  Expert counts: [8, 16]")
print("  Routing: Custom uniform (patched forward pass)")
print("  Dataset: WikiText-2")
print("  Samples: 50")
print("  Features: Internal router_logits logging")
print("\n" + "=" * 70)

# Create runner with custom routing enabled
runner_custom = ExtendedRoutingExperimentRunner(
    model_name="allenai/OLMoE-1B-7B-0924",
    device=device,
    output_dir="./quick_experiment_custom",
    use_custom_routing=True
)

# Patch the model
patch_model_with_custom_routing(runner_custom.model, top_k=8)

# Run experiments (only uniform strategy makes sense with custom routing)
results_df_custom = runner_custom.run_all_experiments(
    expert_counts=[8, 16],
    strategies=['uniform'],  # Custom routing is uniform
    datasets=['wikitext'],
    max_samples=50
)

print("\n‚úÖ Custom routing experiment complete!")

In [None]:
# Display results
print("\nüìä QUICK EXPERIMENT RESULTS (Custom Routing)\n")
print(results_df_custom[[
    'config', 'perplexity', 'token_accuracy', 
    'tokens_per_second', 'avg_entropy'
]].to_string(index=False))

# Compare with standard
print("\nüìà COMPARISON: Custom vs Standard Uniform Routing\n")

comparison_data = []
for expert_count in [8, 16]:
    std_row = results_df_standard[
        (results_df_standard['num_experts'] == expert_count) & 
        (results_df_standard['strategy'] == 'uniform')
    ].iloc[0]
    
    custom_row = results_df_custom[
        results_df_custom['num_experts'] == expert_count
    ].iloc[0]
    
    comparison_data.append({
        'Expert Count': expert_count,
        'Standard PPL': f"{std_row['perplexity']:.2f}",
        'Custom PPL': f"{custom_row['perplexity']:.2f}",
        'Difference': f"{custom_row['perplexity'] - std_row['perplexity']:.2f}"
    })

comparison_df = pd.DataFrame(comparison_data)
print(comparison_df.to_string(index=False))
print("\nüí° Note: Small differences are expected due to implementation details.")

## 8. Full Experiments

Run comprehensive experiments with all configurations (~30-60 minutes)

**‚ö†Ô∏è WARNING: This will take significant time and GPU resources!**

In [None]:
# Uncomment to run full experiments
RUN_FULL_EXPERIMENTS = False  # Set to True to run

if RUN_FULL_EXPERIMENTS:
    print("=" * 70)
    print("FULL EXPERIMENTS")
    print("=" * 70)
    print("\nConfiguration:")
    print("  Expert counts: [4, 8, 16, 32, 64]")
    print("  Strategies: [regular, normalized, uniform]")
    print("  Datasets: [wikitext, lambada]")
    print("  Samples: 500 per dataset")
    print("  Total experiments: 30")
    print("  Estimated time: ~60 minutes")
    print("\n" + "=" * 70)
    
    # Create runner
    runner_full = RoutingExperimentRunner(
        model_name="allenai/OLMoE-1B-7B-0924",
        device=device,
        output_dir="./full_experiments"
    )
    
    # Run experiments
    results_df_full = runner_full.run_all_experiments(
        expert_counts=[4, 8, 16, 32, 64],
        strategies=['regular', 'normalized', 'uniform'],
        datasets=['wikitext', 'lambada'],
        max_samples=500
    )
    
    # Generate visualizations
    runner_full.visualize_results(results_df_full)
    
    # Generate report
    runner_full.generate_report(results_df_full)
    
    print("\n‚úÖ Full experiments complete!")
    print(f"   Results: {runner_full.output_dir}")
else:
    print("‚è≠Ô∏è  Full experiments skipped (set RUN_FULL_EXPERIMENTS = True to run)")

## 9. Results Analysis

Analyze experiment results with various methods

In [None]:
from analyze_results import ResultAnalyzer

# Use quick experiment results for analysis
analyzer = ResultAnalyzer("./quick_experiment_standard")

# Print summary
analyzer.print_summary()

In [None]:
# Compare strategies
comparison = analyzer.compare_strategies()
comparison

In [None]:
# Find optimal configuration
optimal = analyzer.find_optimal_config(
    quality_weight=0.7,  # 70% weight on quality
    speed_weight=0.3     # 30% weight on speed
)

In [None]:
# Analyze specific strategy
analyzer.analyze_strategy('uniform')

## 10. Visualization

Create custom visualizations from results

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (14, 10)

# Load results
df = results_df_standard

# Create custom visualizations
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle('OLMoE Routing Experiments Results', fontsize=16, fontweight='bold')

# 1. Perplexity by strategy
ax1 = axes[0, 0]
for strategy in df['strategy'].unique():
    strategy_df = df[df['strategy'] == strategy]
    ax1.plot(
        strategy_df['num_experts'], 
        strategy_df['perplexity'], 
        marker='o', 
        label=strategy,
        linewidth=2
    )
ax1.set_xlabel('Number of Experts')
ax1.set_ylabel('Perplexity (‚Üì better)')
ax1.set_title('Perplexity vs Expert Count')
ax1.legend()
ax1.grid(True, alpha=0.3)

# 2. Token accuracy by strategy
ax2 = axes[0, 1]
for strategy in df['strategy'].unique():
    strategy_df = df[df['strategy'] == strategy]
    ax2.plot(
        strategy_df['num_experts'], 
        strategy_df['token_accuracy'], 
        marker='s', 
        label=strategy,
        linewidth=2
    )
ax2.set_xlabel('Number of Experts')
ax2.set_ylabel('Token Accuracy (‚Üë better)')
ax2.set_title('Token Accuracy vs Expert Count')
ax2.legend()
ax2.grid(True, alpha=0.3)

# 3. Speed-quality trade-off
ax3 = axes[1, 0]
for strategy in df['strategy'].unique():
    strategy_df = df[df['strategy'] == strategy]
    ax3.scatter(
        strategy_df['perplexity'],
        strategy_df['tokens_per_second'],
        label=strategy,
        s=100,
        alpha=0.7
    )
ax3.set_xlabel('Perplexity (‚Üì better)')
ax3.set_ylabel('Tokens/Second (‚Üë better)')
ax3.set_title('Speed vs Quality Trade-off')
ax3.legend()
ax3.grid(True, alpha=0.3)

# 4. Routing entropy by strategy
ax4 = axes[1, 1]
for strategy in df['strategy'].unique():
    strategy_df = df[df['strategy'] == strategy]
    ax4.plot(
        strategy_df['num_experts'], 
        strategy_df['avg_entropy'], 
        marker='^', 
        label=strategy,
        linewidth=2
    )
ax4.set_xlabel('Number of Experts')
ax4.set_ylabel('Average Entropy')
ax4.set_title('Routing Entropy vs Expert Count')
ax4.legend()
ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('custom_analysis.png', dpi=300, bbox_inches='tight')
plt.show()

print("‚úÖ Visualizations created!")

### Compare Standard vs Custom Routing

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
fig.suptitle('Standard vs Custom Routing Comparison', fontsize=14, fontweight='bold')

# Get uniform routing data from both
std_uniform = results_df_standard[results_df_standard['strategy'] == 'uniform']
custom_uniform = results_df_custom

# 1. Perplexity comparison
ax1 = axes[0]
x = range(len(std_uniform))
width = 0.35
ax1.bar([i - width/2 for i in x], std_uniform['perplexity'], width, label='Standard', alpha=0.8)
ax1.bar([i + width/2 for i in x], custom_uniform['perplexity'], width, label='Custom (Patched)', alpha=0.8)
ax1.set_xlabel('Configuration')
ax1.set_ylabel('Perplexity')
ax1.set_title('Perplexity: Standard vs Custom')
ax1.set_xticks(x)
ax1.set_xticklabels([f"{row['num_experts']} exp" for _, row in std_uniform.iterrows()])
ax1.legend()
ax1.grid(True, alpha=0.3)

# 2. Entropy comparison
ax2 = axes[1]
ax2.bar([i - width/2 for i in x], std_uniform['avg_entropy'], width, label='Standard', alpha=0.8)
ax2.bar([i + width/2 for i in x], custom_uniform['avg_entropy'], width, label='Custom (Patched)', alpha=0.8)
ax2.set_xlabel('Configuration')
ax2.set_ylabel('Average Entropy')
ax2.set_title('Routing Entropy: Standard vs Custom')
ax2.set_xticks(x)
ax2.set_xticklabels([f"{row['num_experts']} exp" for _, row in std_uniform.iterrows()])
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('standard_vs_custom.png', dpi=300, bbox_inches='tight')
plt.show()

print("‚úÖ Comparison visualization created!")

## Summary & Next Steps

### What We Accomplished

‚úÖ **Environment Setup**: Configured GPU/CPU for optimal performance
‚úÖ **Installation**: Installed all required packages
‚úÖ **Custom Routing**: Implemented custom expert selection with internal logging
‚úÖ **Model Patching**: Added support for patching MoE forward pass
‚úÖ **Testing**: Validated all components work correctly
‚úÖ **Quick Experiments**: Ran both standard and custom routing experiments
‚úÖ **Analysis**: Analyzed results with multiple methods
‚úÖ **Visualization**: Created comprehensive visualizations

### Key Findings

View the results above to understand:
1. How different expert counts affect quality (perplexity)
2. How routing strategies compare
3. Speed vs quality trade-offs
4. Routing entropy patterns
5. Differences between standard and custom routing

### Next Steps

1. **Run Full Experiments**: Set `RUN_FULL_EXPERIMENTS = True` for comprehensive analysis
2. **Custom Strategies**: Implement your own routing strategies
3. **More Datasets**: Test on additional datasets (PIQA, etc.)
4. **Deep Analysis**: Use the ResultAnalyzer for detailed comparisons
5. **Save Results**: Download results to Google Drive for further analysis

### Files Generated

- `quick_experiment_standard/` - Standard routing results
- `quick_experiment_custom/` - Custom routing results
- `custom_analysis.png` - Custom visualizations
- `standard_vs_custom.png` - Comparison plot

### Documentation

For more details, see:
- `QUICKSTART.md` - Quick start guide
- `ROUTING_EXPERIMENTS_README.md` - Complete documentation
- `ARCHITECTURE.md` - System architecture
- `IMPLEMENTATION_SUMMARY.md` - Technical details

---

**Happy experimenting! üöÄ**