# OLMoE Higher Criticism (HC) Routing Experiments

**Optimal Sparse Signal Detection for Adaptive Expert Selection**

This notebook implements and evaluates **Higher Criticism (HC)** statistical routing as a replacement for fixed Top-K routing in OLMoE.

---

## Research Hypothesis

HC routing will **optimally detect sparse signals** (relevant experts) using the Higher Criticism statistic, which is theoretically optimal for detecting sparse heterogeneous mixtures.

- **Fully adaptive** - Automatically finds optimal threshold (using adaptive threshold)
- **Sparse-first design** - Optimized for few relevant experts among many
- **Theoretical guarantee** - Donoho & Jin (2004) detection boundary

---

## What is Higher Criticism?

**Higher Criticism (Donoho & Jin, 2004)** is a statistical method for detecting **sparse heterogeneous mixtures**.

### Traditional Application
- Detecting weak but sparse signals in high-dimensional data
- Optimal detection rate for sparse mixture problems  
- More powerful than Higher Criticism for sparse signals

### Our Novel Application: Expert Routing

**Sparse Signal Framework:**
- **Signal**: A small number of highly relevant experts among 64 total
- **Noise**: Many irrelevant experts with random/low scores
- **Goal**: Detect the sparse set of truly relevant experts

**The HC Statistic:**

Given n p-values sorted ascending p‚Çç‚ÇÅ‚Çé ‚â§ p‚Çç‚ÇÇ‚Çé ‚â§ ... ‚â§ p‚Çç‚Çô‚Çé:
```
HC(i) = ‚àön √ó (i/n - p‚Çç·µ¢‚Çé) / ‚àö(p‚Çç·µ¢‚Çé(1 - p‚Çç·µ¢‚Çé))
```

Where:
- i/n = expected fraction under uniform null
- p‚Çç·µ¢‚Çé = observed i-th sorted p-value
- Denominator = standard error

### Expert Selection Procedure

1. Compute p-values using KDE: p_i = 1 - CDF(logit_i)
2. Sort p-values ascending
3. Compute HC score at each rank i for i ‚àà [1, Œ≤√ón]
4. Find i* = argmax(HC) - the optimal number of experts
5. Select top i* experts (smallest p-values)
6. Apply constraints: clamp to [min_k, max_k]
7. Renormalize weights to sum to 1

### HC vs HC Comparison

| Aspect | HC | HC |
|--------|----|----|
| Parameter | Œ± (FDR level) | Œ≤ (search fraction) |
| Selection | Threshold-based | Argmax-based |
| Tuning | Requires Œ± tuning | Œ≤=0.5 often works |
| Optimal for | Known error rate | Sparse signals |

---

## Experimental Design

### Configurations (24 total)

**BASELINE (4 configs):**
- `topk_8`, `topk_16`, `topk_32`, `topk_64`: OLMoE's native Top-K routing

**HC COMPARISON (4 configs):**
- HC with Œ≤=0.50, max_k ‚àà {8, 16, 32, 64}

**HC ROUTING (16 configs = 4 beta √ó 4 max_k):**

| max_k | Description | Research Question |
|-------|-------------|-------------------|
| 8 | Same ceiling as baseline | Fair comparison with OLMoE |
| 16 | 2x ceiling | Does HC benefit from more headroom? |
| 32 | 4x ceiling | Where is the saturation point? |
| 64 | Uncapped (all experts) | What does HC choose when fully free? |

**Beta values:**
- Œ≤ = 0.30: Strict (2-4 experts typical)
- Œ≤ = 0.40: Moderate-strict (4-6 experts typical)
- Œ≤ = 0.50: Moderate (5-7 experts typical) ‚Äî RECOMMENDED
- Œ≤ = 0.60: Loose (6-8 experts typical)

### Test Prompts (by complexity)

**Simple:**
- "The cat sat on the"
- "Hello, my name is"
- "The capital of France is"

**Medium:**
- "In machine learning, a neural network"
- "The process of photosynthesis involves"
- "Climate change refers to long-term shifts in"

**Complex:**
- "Explain the relationship between quantum entanglement and"
- "Compare and contrast the economic policies of"
- "The philosophical implications of consciousness suggest that"

**Technical:**
- "In Python, a decorator is a function that"
- "The time complexity of quicksort is"
- "Transformer attention mechanism computes"

### Metrics
- `avg_experts`: Mean experts per token
- `std_experts`: Standard deviation
- `min/max_experts`: Range
- `ceiling_hit_rate`: % hitting max_k limit
- `floor_hit_rate`: % at min_k
- `reduction_vs_baseline`: % fewer experts than Top-8
- `inference_time`: Speed comparison

---

## Implementation Method

**APPROACH 2: Direct Method Replacement**

This notebook uses **Direct Method Replacement** to patch OLMoE routing:
- Completely replaces `OlmoeTopKRouter.forward()` method
- Original TopK computation **NEVER executes** (efficient!)
- Custom forward uses HC routing directly
- Easily reversible via `unpatch()`

**Not using Approach 1 (Hooks)** because hooks still execute original TopK wastefully.

---

## Runs on
- ‚úÖ **Google Colab** (Recommended - GPU required)
- ‚úÖ Local Jupyter with GPU

---

## Quick Start (Google Colab)

1. Upload this notebook to Google Drive
2. Open with Google Colab
3. Enable GPU: `Runtime ‚Üí Change runtime type ‚Üí GPU ‚Üí T4/A100`
4. Run all cells

---

## 1. Environment Setup

In [None]:
import sys
import os

# Detect environment
IN_COLAB = 'google.colab' in sys.modules

print(f"Running in Google Colab: {IN_COLAB}")
print(f"Python version: {sys.version}")

# Set working directory
if IN_COLAB:
    from google.colab import drive
    print("\nüìÅ Mounting Google Drive...")
    drive.mount('/content/drive')

    WORK_DIR = '/content/drive/MyDrive/olmoe_hc_experiments'
    REPO_DIR = '/content/drive/MyDrive/MOE-with-feature-selection'
else:
    WORK_DIR = './olmoe_hc_experiments'
    REPO_DIR = None

os.makedirs(WORK_DIR, exist_ok=True)
os.chdir(WORK_DIR)
print(f"\n‚úÖ Working directory: {os.getcwd()}")

if IN_COLAB:
    print(f"‚úÖ Repository location: {REPO_DIR}")
from pathlib import Path

# Output directory for results
OUTPUT_DIR = Path(WORK_DIR) / 'results'
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
print(f"‚úÖ Output directory: {OUTPUT_DIR}")


## 2. GPU Configuration

In [None]:
import torch

print("=" * 70)
print("GPU CONFIGURATION")
print("=" * 70)

if torch.cuda.is_available():
    print(f"\n‚úÖ CUDA Available")
    print(f"   CUDA Version: {torch.version.cuda}")
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

    device = 'cuda'
    torch.cuda.empty_cache()
else:
    print("\n‚ùå GPU not available!")
    print("\n‚ö†Ô∏è  This notebook requires a GPU.")
    if IN_COLAB:
        print("   Enable GPU: Runtime ‚Üí Change runtime type ‚Üí T4/A100 GPU")
    raise Exception("GPU required for this experiment")

print(f"\n‚úÖ Device: {device}")
print("=" * 70)

## 3. Installation

In [None]:
%%bash
pip install -q torch transformers datasets pandas numpy matplotlib seaborn tqdm scipy
echo "‚úÖ All packages installed!"

In [None]:
import transformers
import datasets
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats as scipy_stats

print("Package Versions:")
print(f"  torch: {torch.__version__}")
print(f"  transformers: {transformers.__version__}")
print(f"  datasets: {datasets.__version__}")
print(f"  pandas: {pd.__version__}")
print(f"  numpy: {np.__version__}")
print("\n‚úÖ All imports successful!")

## 4. HC Routing Module Setup

In [None]:
print("=" * 70)
print("HC ROUTING MODULE SETUP")
print("=" * 70)

# ==========================================================================
# CONFIGURATION
# ==========================================================================
BRANCH = "main"  # Change this to switch branches (e.g., "dev", "experiment", "feature-x")

if IN_COLAB:
    # Check if repo exists in Drive
    if os.path.exists(REPO_DIR):
        print(f"\nüìÇ Repository exists in Google Drive")
        print(f"   Location: {REPO_DIR}")
        print(f"   Branch: {BRANCH}")
        print(f"\n   Fetching and checking out branch '{BRANCH}'...")
        !cd {REPO_DIR} && git fetch origin && git checkout {BRANCH} && git pull origin {BRANCH}
    else:
        print(f"\nüì• Cloning repository to Google Drive (branch: {BRANCH})...")
        !git clone --branch {BRANCH} https://github.com/AliABULIEL/MOE-with-feature-selection.git {REPO_DIR}

    framework_dir = REPO_DIR
else:
    framework_dir = os.path.abspath('..')

# Add to Python path
if framework_dir not in sys.path:
    sys.path.insert(0, framework_dir)
    print(f"\n‚úÖ Added to path: {framework_dir}")

# Verify hc_routing.py exists
hc_routing_file = os.path.join(framework_dir, 'hc_routing.py')
if os.path.exists(hc_routing_file):
    file_size = os.path.getsize(hc_routing_file)
    print(f"‚úÖ Found: hc_routing.py ({file_size:,} bytes)")
else:
    raise Exception("hc_routing.py not found!")

print("\n" + "=" * 70)
print("‚úÖ MODULE READY")
print("=" * 70)

In [None]:
# Import HC routing functions
if 'hc_routing' in sys.modules:
    del sys.modules['hc_routing']
# if 'hc_routing' in sys.modules:  # Not needed for HC-only
    # del sys.modules['hc_routing']  # Not needed for HC-only

# ‚úÖ CORRECTED IMPORTS - Using simplified beta-only API
from hc_routing import (
    higher_criticism_routing,
    compute_hc_statistic,
    compute_hc_routing_statistics,
    load_kde_models
)



print("‚úÖ HC routing module imported successfully!")
print("\n" + "=" * 70)
print("SIMPLIFIED BETA-ONLY API")
print("=" * 70)
print("""
Beta is the ONLY tuning parameter for HC routing:

| Beta Value | Behavior      | Expert Selection |
|------------|---------------|------------------|
| 'auto'     | Adaptive HC‚Å∫  | Self-tuning      |
| 1.0        | Full search   | Most experts     |
| 0.7        | Wide search   | More experts     |
| 0.5        | Medium search | Moderate         |
| 0.3        | Narrow search | Fewer experts    |

Usage:
  higher_criticism_routing(logits, beta=0.5, min_k=4, max_k=12)
""")


## 4.5 Import Comprehensive Framework Modules

Import the full evaluation framework with metrics, datasets, and visualizations.


In [None]:
print("=" * 70)
print("IMPORTING COMPREHENSIVE FRAMEWORK MODULES")
print("=" * 70)

import importlib

# ‚úÖ Import HC metrics (BOTH old and new classes)
try:
    if 'hc_routing_metrics' in sys.modules:
        importlib.reload(sys.modules['hc_routing_metrics'])
    from hc_routing_metrics import HCMetricsComputer, HCMetrics  # Old classes
    from hc_routing_metrics import UnifiedEvaluationMetrics, save_metrics  # NEW classes
    print("‚úÖ Imported HCMetricsComputer")
    print("‚úÖ Imported UnifiedEvaluationMetrics, save_metrics")
    metrics_computer = HCMetricsComputer()
except ImportError as e:
    print(f"‚ö†Ô∏è Could not import metrics: {e}")
    metrics_computer = None

# ‚úÖ Import HC logger
try:
    if 'hc_routing_logging' in sys.modules:
        importlib.reload(sys.modules['hc_routing_logging'])
    from hc_routing_logging import HCRoutingLogger
    print("‚úÖ Imported HCRoutingLogger")
except ImportError as e:
    print(f"‚ö†Ô∏è Could not import HCRoutingLogger: {e}")
    HCRoutingLogger = None

# ‚úÖ Import dataset evaluation functions
try:
    if 'hc_routing_evaluation' in sys.modules:
        importlib.reload(sys.modules['hc_routing_evaluation'])
    from hc_routing_evaluation import (
        load_wikitext, load_lambada, load_hellaswag,
        evaluate_perplexity, evaluate_lambada, evaluate_hellaswag
    )
    print("‚úÖ Imported dataset evaluation functions")
except ImportError as e:
    print(f"‚ö†Ô∏è Could not import evaluation functions: {e}")

# ‚úÖ Import visualization functions
try:
    if 'hc_routing_visualization' in sys.modules:
        importlib.reload(sys.modules['hc_routing_visualization'])
    from hc_routing_visualization import create_comprehensive_visualization
    print("‚úÖ Imported visualization functions")
except ImportError as e:
    print(f"‚ö†Ô∏è Could not import visualization functions: {e}")

print("\n" + "=" * 70)
print("‚úÖ FRAMEWORK MODULES READY")
print("=" * 70)

## 4.6 DEBUG_MODE Configuration

Configure fast testing vs full evaluation mode.


In [None]:
print("=" * 70)
print("DEBUG MODE CONFIGURATION")
print("=" * 70)# Toggle for fast testing vs full evaluation
DEBUG_MODE = True  # Set to True for quick testing
if DEBUG_MODE:
  # Fast testing configuration
  MAX_SAMPLES = 10
  # Very small sample for speed
  LOG_EVERY_N = 5   # Log every 5 tokens
  SAVE_PLOTS = True
  print("\n‚ö° DEBUG MODE: ENABLED")
  print("   ‚Ä¢ Max samples: 10 (fast testing)")
  print("   ‚Ä¢ Logging: Every 5 tokens")
  print("   ‚Ä¢ Plots: Generated for all experiments")
else:
  # Full evaluation configuration
  MAX_SAMPLES = 200  # Full benchmark evaluation
  LOG_EVERY_N = 100
  # Log every 100 tokens for efficiency
  SAVE_PLOTS = False  # Only save summaries, not per-token logs
  print("\nüéØ PRODUCTION MODE: ENABLED")
  print("   ‚Ä¢ Max samples: 200 (full evaluation)")
  print("   ‚Ä¢ Logging: Every 100 tokens")
  print("   ‚Ä¢ Plots: Summary only")
  print("\n" + "=" * 70)
  from pathlib import Path
  if IN_COLAB:
     OUTPUT_DIR = Path(WORK_DIR) / 'hc_comprehensive_results'
  else:
    OUTPUT_DIR = Path('./hc_comprehensive_results')
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    (OUTPUT_DIR / 'logs').mkdir(exist_ok=True)
    (OUTPUT_DIR / 'plots').mkdir(exist_ok=True)
    (OUTPUT_DIR / 'visualizations').mkdir(exist_ok=True)
print(f"\nüìÅ Output directory: {OUTPUT_DIR}")

# Routing method identifier
ROUTING_METHOD = 'hc'  # Used for logging and visualization

## 5. Load OLMoE Model

In [None]:
from transformers import OlmoeForCausalLM, AutoTokenizer
from tqdm import tqdm
import time

print("=" * 70)
print("LOADING OLMoE MODEL")
print("=" * 70)

MODEL_NAME = "allenai/OLMoE-1B-7B-0924"

print(f"\nModel: {MODEL_NAME}")
print("Loading...")

start_time = time.time()

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
print("‚úÖ Tokenizer loaded")

# Load model
model = OlmoeForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",
    torch_dtype=torch.bfloat16
)
model.eval()

load_time = time.time() - start_time

print(f"‚úÖ Model loaded in {load_time:.1f}s")
print(f"\nModel Configuration:")
print(f"  ‚Ä¢ Architecture: {model.config.model_type}")
print(f"  ‚Ä¢ Hidden size: {model.config.hidden_size}")
print(f"  ‚Ä¢ Num layers: {model.config.num_hidden_layers}")
print(f"  ‚Ä¢ Num experts: {model.config.num_experts}")
print(f"  ‚Ä¢ Experts per token (Top-K): {model.config.num_experts_per_tok}")
print(f"  ‚Ä¢ Vocab size: {model.config.vocab_size}")

NUM_LAYERS = model.config.num_hidden_layers
NUM_EXPERTS = model.config.num_experts
DEFAULT_K = model.config.num_experts_per_tok

print(f"\nüìä OLMoE Routing Info:")
print(f"  ‚Ä¢ Total experts per layer: {NUM_EXPERTS}")
print(f"  ‚Ä¢ Default Top-K: {DEFAULT_K}")
print(f"  ‚Ä¢ Routing happens in {NUM_LAYERS} layers")

print("\n" + "=" * 70)
print("‚úÖ MODEL READY")
print("=" * 70)

## 6. OLMoE Router Integration

In [None]:
# Cell 18: HCRouterPatcher with patch_mode Implementation
# ======================================================================

import torch.nn.functional as F
from typing import Dict, List, Any, Union, Optional, Tuple
from collections import defaultdict
import re

# =============================================================================
# HELPER FUNCTIONS FOR PATCH_MODE
# =============================================================================

def _extract_layer_idx(module_name: str) -> int:
    """
    Extract transformer layer index from MoE block module name.

    Args:
        module_name: Full module path (e.g., 'model.layers.15.mlp')

    Returns:
        Integer layer index

    Raises:
        ValueError: If layer index cannot be extracted
    """
    match = re.search(r'(?:^|\.)layers\.(\d+)(?:\.|$)', module_name)
    if match:
        return int(match.group(1))
    raise ValueError(
        f"Cannot extract layer index from module name: '{module_name}'. "
        f"Expected pattern containing 'layers.<idx>' e.g., 'model.layers.15.mlp'"
    )


def _select_last_n_moe_layers(
    unique_sorted_indices: List[int],
    patch_mode: Optional[int]
) -> Tuple[List[int], bool]:
    """
    Select which layer indices to patch based on patch_mode.

    Args:
        unique_sorted_indices: Sorted list of unique MoE layer indices
        patch_mode: Number of final layers to patch (None/<=0 = all)

    Returns:
        Tuple of (selected_indices, was_clamped)
    """
    if patch_mode is None or patch_mode <= 0:
        return (unique_sorted_indices[:], False)

    if patch_mode > len(unique_sorted_indices):
        return (unique_sorted_indices[:], True)  # clamped

    return (unique_sorted_indices[-patch_mode:], False)


# =============================================================================
# HC ROUTER PATCHER CLASS
# =============================================================================

class HCRouterPatcher:
    """
    Patches OLMoE MoE blocks for HC routing ONLY (Approach 2).

    For TopK baselines, use NATIVE OLMoE by reloading with:
        model, tokenizer = load_model_with_topk(k=16)

    This is cleaner because:
    - TopK uses original, tested OLMoE code
    - HC uses this custom patcher for adaptive selection
    - No renormalization bugs possible for TopK

    NEW: patch_mode parameter controls which layers get HC routing
    - patch_mode=1 ‚Üí only last layer uses HC
    - patch_mode=4 ‚Üí only last 4 layers use HC
    - patch_mode=None ‚Üí all layers use HC (default, backward compatible)

    Usage:
        # For TopK: use load_model_with_topk(k) - see Cell 24
        # For HC (all layers):
        patcher = HCRouterPatcher(model)
        patcher.patch_with_hc(beta=0.5, min_k=2, max_k=8)

        # For HC (only last 4 layers):
        patcher.patch_with_hc(beta=0.5, min_k=2, max_k=8, patch_mode=4)

        # ... run evaluation ...
        stats = patcher.get_stats()
        status = patcher.get_patch_status()
        patcher.unpatch()
    """

    def __init__(self, model):
        self.model = model
        self.moe_blocks = []
        self.original_forwards = {}
        self.stats = defaultdict(list)
        self.stats['hc_max_values'] = []
        self.stats['hc_max_ranks'] = []
        self.patched = False
        self.current_config = {}

        # NEW: Track which layers are actually patched
        self.patched_layer_indices: List[int] = []
        self.patched_block_names: List[str] = []

        # Logger state
        self._logger = None
        self._log_every_n = 100
        self._token_counter = 0
        self._sample_counter = 0

        self._find_moe_blocks()

    def _find_moe_blocks(self):
        """Locate all OlmoeSparseMoeBlock modules."""
        for name, module in self.model.named_modules():
            if module.__class__.__name__ == 'OlmoeSparseMoeBlock':
                self.moe_blocks.append((name, module))

        if len(self.moe_blocks) == 0:
            raise ValueError("No OlmoeSparseMoeBlock modules found!")

        print(f"‚úÖ Found {len(self.moe_blocks)} MoE blocks")

    def set_external_logger(self, logger):
        """Set or update the external logger."""
        self._logger = logger
        if logger:
            print(f"‚úÖ External logger attached: {logger.experiment_name}")
        else:
            print("‚ÑπÔ∏è  External logger disabled")

    def start_sample(self):
        """Call before processing each new sample."""
        self._sample_counter += 1
        self._token_counter = 0

    def patch_with_hc(
        self,
        beta: float = 0.5,
        temperature: float = 1.0,
        min_k: int = 2,
        max_k: int = 8,
        collect_stats: bool = True,
        logger=None,
        log_every_n_tokens: int = 100,
        patch_mode: Optional[int] = None  # NEW PARAMETER
    ):
        """
        Patch model to use HC routing.

        Args:
            beta: Search fraction (0.0-1.0) - main tuning parameter
                  Lower beta = stricter selection = fewer experts
                  Higher beta = looser selection = more experts
            temperature: Softmax temperature
            min_k: Minimum experts (safety floor)
            max_k: Maximum experts (ceiling)
            collect_stats: Whether to collect statistics
            logger: Optional logger for detailed stats
            log_every_n_tokens: Logging frequency
            patch_mode: NEW - Controls which layers to patch
                  - None/0/negative: patch ALL layers (default, backward compatible)
                  - 1: patch only final layer
                  - 2: patch only last 2 layers
                  - N: patch only last N layers
        """
        self.unpatch()
        self.stats.clear()

        # Store config
        self.current_config = {
            'beta': beta,
            'temperature': temperature,
            'min_k': min_k,
            'max_k': max_k,
            'patch_mode': patch_mode  # NEW
        }

        # Logger setup
        self._logger = logger
        self._log_every_n = log_every_n_tokens
        self._token_counter = 0
        self._sample_counter = 0

        # Load KDE models
        kde_models = None
        try:
            kde_models = load_kde_models()
            if kde_models:
                print(f"   üìä Loaded KDE models for {len(kde_models)} layers")
        except:
            print(f"   ‚ö†Ô∏è No KDE models - using empirical p-values")

        patcher_self = self

        def create_hc_forward(layer_name, moe_block_ref):
            # Extract layer index
            layer_idx = 0
            if 'layers.' in layer_name:
                try:
                    parts = layer_name.split('.')
                    for i, part in enumerate(parts):
                        if part == 'layers' and i + 1 < len(parts):
                            layer_idx = int(parts[i + 1])
                            break
                except:
                    layer_idx = 0

            def hc_forward(hidden_states):
                batch_size, seq_len, hidden_dim = hidden_states.shape
                hidden_states_flat = hidden_states.view(-1, hidden_dim)

                # Get router logits
                router_logits = moe_block_ref.gate(hidden_states_flat)

                # Apply HC routing
                routing_weights, selected_experts, expert_counts, hc_stats = higher_criticism_routing(
                    router_logits,
                    beta=beta,
                    temperature=temperature,
                    min_k=min_k,
                    max_k=max_k,
                    layer_idx=layer_idx,
                    kde_models=kde_models,
                    logger=patcher_self._logger,
                    log_every_n_tokens=patcher_self._log_every_n,
                    sample_idx=patcher_self._sample_counter,
                    token_idx=patcher_self._token_counter,
                    return_stats=True
                )
                if hc_stats and collect_stats:
                    if 'hc_max_values' in hc_stats:
                        patcher_self.stats['hc_max_values'].extend(hc_stats['hc_max_values'])
                    if 'hc_max_ranks' in hc_stats:
                        patcher_self.stats['hc_max_ranks'].extend(hc_stats['hc_max_ranks'])

                # Update token counter
                patcher_self._token_counter += seq_len

                # Collect stats
                if collect_stats:
                    patcher_self.stats['expert_counts'].extend(expert_counts.flatten().cpu().tolist())
                    patcher_self.stats['layer_names'].extend([layer_name] * expert_counts.numel())

                # Dispatch to experts
                final_hidden_states = torch.zeros_like(hidden_states_flat)
                num_experts = routing_weights.shape[1]

                for expert_idx in range(num_experts):
                    expert_mask = routing_weights[:, expert_idx] > 0
                    if expert_mask.any():
                        expert_input = hidden_states_flat[expert_mask]
                        expert_output = moe_block_ref.experts[expert_idx](expert_input)
                        w = routing_weights[expert_mask, expert_idx].unsqueeze(-1)
                        final_hidden_states[expert_mask] += w * expert_output

                output = final_hidden_states.view(batch_size, seq_len, hidden_dim)
                return output, router_logits

            return hc_forward

        # =================================================================
        # PATCH MODE: Select which layers to patch
        # =================================================================

        # 1. Build candidates with layer indices
        candidates = []
        for name, moe_block in self.moe_blocks:
            layer_idx = _extract_layer_idx(name)
            candidates.append((layer_idx, name, moe_block))

        # 2. Compute unique sorted MoE layer indices
        all_layer_indices = [c[0] for c in candidates]
        unique_moe_layers = sorted(set(all_layer_indices))

        # 3. Select which layers to patch based on patch_mode
        selected_indices, was_clamped = _select_last_n_moe_layers(unique_moe_layers, patch_mode)
        selected_set = set(selected_indices)

        # 4. Compute display string for reporting
        if patch_mode is None or patch_mode <= 0:
            mode_display = "ALL"
        else:
            mode_display = str(patch_mode)

        # 5. Reset tracking
        self.patched_layer_indices = selected_indices[:]
        self.patched_block_names = []

        # 6. Patch ONLY selected blocks
        for layer_idx, name, moe_block in candidates:
            if layer_idx in selected_set:
                self.original_forwards[name] = moe_block.forward
                moe_block.forward = create_hc_forward(name, moe_block)
                self.patched_block_names.append(name)

        self.patched = True

        # 7. Print patch report
        print(f"‚úÖ HC Routing Patch Report")
        print(f"   Total MoE blocks discovered: {len(self.moe_blocks)}")
        print(f"   Unique MoE layer indices: {unique_moe_layers}")
        print(f"   patch_mode requested: {mode_display}")
        if was_clamped:
            print(f"   ‚ö†Ô∏è CLAMPED: requested {patch_mode} but only {len(unique_moe_layers)} MoE layers available")
        print(f"   Layers patched: {len(selected_indices)}/{len(unique_moe_layers)}")
        print(f"   Patched layer indices: {self.patched_layer_indices}")
        print(f"   Patched blocks: {len(self.patched_block_names)}")
        print(f"   HC params: Œ≤={beta}, min_k={min_k}, max_k={max_k}")

    def unpatch(self):
        """Restore original forward methods for patched blocks only. Idempotent."""
        if not self.patched:
            return

        restored_count = 0
        for name, moe_block in self.moe_blocks:
            if name in self.original_forwards:
                moe_block.forward = self.original_forwards[name]
                restored_count += 1

        self.original_forwards.clear()
        self.patched = False
        self.current_config = {}
        self.patched_layer_indices = []
        self.patched_block_names = []

        print(f"‚úÖ Unpatched {restored_count} MoE blocks")

    def get_stats(self) -> Dict[str, Any]:
        """Get routing statistics."""
        if not self.stats['expert_counts']:
            return {}

        counts = np.array(self.stats['expert_counts'])
        
        stats = {
            'avg_experts': float(np.mean(counts)),
            'std_experts': float(np.std(counts)),
            'min_experts': int(np.min(counts)),
            'max_experts': int(np.max(counts)),
            'median_experts': float(np.median(counts)),
            'total_tokens': len(counts),
            'distribution': np.bincount(counts.astype(int)).tolist()
        }
        
        # ADD: HC statistics if available
        if 'hc_max_values' in self.stats and self.stats['hc_max_values']:
            hc_vals = np.array(self.stats['hc_max_values'])
            stats['hc_max_mean'] = float(np.mean(hc_vals))
            stats['hc_max_std'] = float(np.std(hc_vals))
            stats['hc_threshold_ranks'] = self.stats.get('hc_max_ranks', [])
        
        # ADD: Layer-wise statistics if available
        if 'layer_names' in self.stats and self.stats['layer_names']:
            layer_stats = {}
            layers = np.array(self.stats['layer_names'])
            for layer_name in np.unique(layers):
                mask = layers == layer_name
                layer_counts = counts[mask]
                layer_stats[layer_name] = {
                    'avg': float(np.mean(layer_counts)),
                    'std': float(np.std(layer_counts))
                }
            stats['layer_wise'] = layer_stats
        
        return stats

    def get_patch_status(self) -> Dict[str, Any]:
        """Get detailed patching status for introspection."""
        return {
            'patched': self.patched,
            'total_moe_blocks': len(self.moe_blocks),
            'patched_block_count': len(self.patched_block_names),
            'patched_layer_indices': self.patched_layer_indices,
            'patched_block_names': self.patched_block_names,
            'config': self.current_config
        }

    def get_status(self) -> str:
        """Get current patcher status as string."""
        if not self.patched:
            return "Not patched (using native OLMoE)"

        mode = self.current_config.get('patch_mode')
        mode_str = "ALL" if mode is None or mode <= 0 else str(mode)

        return (f"HC patched: Œ≤={self.current_config.get('beta')}, "
                f"max_k={self.current_config.get('max_k')}, "
                f"patch_mode={mode_str}, "
                f"layers={self.patched_layer_indices}")

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.unpatch()


# Create patcher instance (will be used for HC experiments only)
patcher = HCRouterPatcher(model)

print("\n" + "=" * 70)
print("‚úÖ HC PATCHER READY (Approach 2 + patch_mode)")
print("=" * 70)
print("   For TopK: Use load_model_with_topk(k=...) - NATIVE OLMoE")
print("   For HC:   Use patcher.patch_with_hc(...) - Custom adaptive routing")
print("   NEW: Use patch_mode to control which layers get HC routing")


In [None]:
"""
CELL: PATCH_MODE SANITY TESTS
=============================
Verify patch_mode feature works correctly.
"""

print("=" * 70)
print("PATCH_MODE SANITY TESTS")
print("=" * 70)

# -----------------------------------------------------------------------------
# TEST 1: Selection helper function
# -----------------------------------------------------------------------------
print("\nüìã TEST 1: _select_last_n_moe_layers()")

# Full 16-layer model
all_16 = list(range(16))
assert _select_last_n_moe_layers(all_16, None) == (all_16, False), "None should select all"
assert _select_last_n_moe_layers(all_16, 0) == (all_16, False), "0 should select all"
assert _select_last_n_moe_layers(all_16, -5) == (all_16, False), "negative should select all"
assert _select_last_n_moe_layers(all_16, 1) == ([15], False), "1 should select [15]"
assert _select_last_n_moe_layers(all_16, 4) == ([12, 13, 14, 15], False), "4 should select last 4"
assert _select_last_n_moe_layers(all_16, 16) == (all_16, False), "16 should select all (exact)"
assert _select_last_n_moe_layers(all_16, 100) == (all_16, True), "100 should clamp"

# Sparse MoE layers
sparse = [0, 2, 7, 15]
assert _select_last_n_moe_layers(sparse, 1) == ([15], False), "sparse: 1 -> [15]"
assert _select_last_n_moe_layers(sparse, 2) == ([7, 15], False), "sparse: 2 -> [7, 15]"
assert _select_last_n_moe_layers(sparse, 3) == ([2, 7, 15], False), "sparse: 3 -> [2, 7, 15]"
assert _select_last_n_moe_layers(sparse, 4) == (sparse, False), "sparse: 4 -> all"
assert _select_last_n_moe_layers(sparse, 10) == (sparse, True), "sparse: 10 -> clamp"

print("   ‚úÖ All selection tests passed!")

# -----------------------------------------------------------------------------
# TEST 2: Layer index extraction
# -----------------------------------------------------------------------------
print("\nüìã TEST 2: _extract_layer_idx()")

assert _extract_layer_idx('model.layers.0.mlp') == 0
assert _extract_layer_idx('model.layers.15.mlp.experts') == 15
assert _extract_layer_idx('transformer.layers.7.moe_block') == 7
assert _extract_layer_idx('layers.3.block') == 3
assert _extract_layer_idx('model.layers.10.mlp') == 10

# Test failure case
try:
    _extract_layer_idx('invalid.path.no.layers')
    assert False, "Should have raised ValueError"
except ValueError as e:
    assert 'invalid.path.no.layers' in str(e)
    print("   ‚úÖ ValueError raised correctly for invalid path")

print("   ‚úÖ All extraction tests passed!")

# -----------------------------------------------------------------------------
# TEST 3: Live patch_mode tests with actual model
# -----------------------------------------------------------------------------
print("\nüìã TEST 3: Live patch_mode tests")

# Get total unique layers for reference
total_unique_layers = len(set(_extract_layer_idx(name) for name, _ in patcher.moe_blocks))
print(f"   Model has {total_unique_layers} unique MoE layers")

# Test patch_mode=1 (only last layer)
print("\n   Testing patch_mode=1...")
patcher.unpatch()
patcher.patch_with_hc(beta=0.5, max_k=8, patch_mode=1)
assert len(patcher.patched_layer_indices) == 1, f"Expected 1 layer, got {len(patcher.patched_layer_indices)}"
print(f"   ‚úÖ patch_mode=1: patched layers = {patcher.patched_layer_indices}")

# Test patch_mode=4 (last 4 layers)
print("\n   Testing patch_mode=4...")
patcher.unpatch()
patcher.patch_with_hc(beta=0.5, max_k=8, patch_mode=4)
assert len(patcher.patched_layer_indices) == 4, f"Expected 4 layers, got {len(patcher.patched_layer_indices)}"
print(f"   ‚úÖ patch_mode=4: patched layers = {patcher.patched_layer_indices}")

# Test patch_mode=None (all layers - backward compatible)
print("\n   Testing patch_mode=None (all)...")
patcher.unpatch()
patcher.patch_with_hc(beta=0.5, max_k=8, patch_mode=None)
assert len(patcher.patched_layer_indices) == total_unique_layers, "Should patch all layers"
print(f"   ‚úÖ patch_mode=None: patched all {len(patcher.patched_layer_indices)} layers")

# Test default (no patch_mode arg - should be same as None)
print("\n   Testing default (no patch_mode arg)...")
patcher.unpatch()
patcher.patch_with_hc(beta=0.5, max_k=8)  # No patch_mode argument
assert len(patcher.patched_layer_indices) == total_unique_layers, "Default should patch all"
print(f"   ‚úÖ Default: patched all {len(patcher.patched_layer_indices)} layers")

# Test clamping (patch_mode > available)
print("\n   Testing patch_mode=100 (should clamp)...")
patcher.unpatch()
patcher.patch_with_hc(beta=0.5, max_k=8, patch_mode=100)
assert len(patcher.patched_layer_indices) == total_unique_layers, "Should clamp to all"
print(f"   ‚úÖ patch_mode=100: clamped to {len(patcher.patched_layer_indices)} layers")

# Test unpatch idempotency
print("\n   Testing unpatch idempotency...")
patcher.unpatch()
patcher.unpatch()  # Should not error
patcher.unpatch()  # Should not error
print("   ‚úÖ unpatch() is idempotent")

# Test get_patch_status()
print("\n   Testing get_patch_status()...")
patcher.patch_with_hc(beta=0.5, max_k=8, patch_mode=2)
status = patcher.get_patch_status()
assert status['patched'] == True
assert status['patched_block_count'] == len(patcher.patched_block_names)
assert status['patched_layer_indices'] == patcher.patched_layer_indices
assert 'patch_mode' in status['config']
print(f"   ‚úÖ get_patch_status() works correctly")
print(f"      Status: {status}")

# Test get_status()
print("\n   Testing get_status()...")
status_str = patcher.get_status()
print(f"   Status string: {status_str}")
assert "patch_mode=2" in status_str
assert "layers=" in status_str
print("   ‚úÖ get_status() works correctly")

# Clean up
patcher.unpatch()

print("\n" + "=" * 70)
print("üéâ ALL PATCH_MODE TESTS PASSED!")
print("=" * 70)


## 6.5 VERIFICATION: Routing Actually Changed

**CRITICAL TEST:** Prove that HC routing is actually working (not just simulation)

In [None]:
print("=" * 70)
print("VERIFICATION TEST: HC ROUTING IS ACTUALLY WORKING (Approach 2)")
print("=" * 70)

verification_prompt = "The capital of France is"
inputs = tokenizer(verification_prompt, return_tensors='pt').to(device)

print(f"\nTest prompt: '{verification_prompt}'")
print(f"\nRunning 3 tests:\n")

# TEST 1: Native OLMoE (K=8) - No patching
print("TEST 1: Native OLMoE (TopK=8) - NO PATCHING")
print("-" * 70)
patcher.unpatch()  # Ensure no patching

with torch.no_grad():
    outputs_native = model.generate(**inputs, max_new_tokens=10, do_sample=False)

generated_native = tokenizer.decode(outputs_native[0], skip_special_tokens=True)
native_k = model.config.num_experts_per_tok

print(f"  Generated: '{generated_native}'")
print(f"  TopK from config: {native_k}")
print(f"  ‚úÖ Using NATIVE OLMoE routing (fixed K={native_k})")

# TEST 2: HC routing with strict beta - should use FEWER than 8
print("\n\nTEST 2: HC Routing (Œ≤=0.30, max_k=8) - STRICT")
print("-" * 70)
patcher.patch_with_hc(beta=0.30, min_k=2, max_k=8, collect_stats=True)

with torch.no_grad():
    outputs_strict = model.generate(**inputs, max_new_tokens=10, do_sample=False)

stats_strict = patcher.get_stats()
generated_strict = tokenizer.decode(outputs_strict[0], skip_special_tokens=True)

print(f"  Generated: '{generated_strict}'")
print(f"  Avg experts: {stats_strict['avg_experts']:.2f}")
print(f"  Range: [{stats_strict['min_experts']}, {stats_strict['max_experts']}]")
print(f"  Std: {stats_strict['std_experts']:.2f}")

# Check if routing changed
if stats_strict['avg_experts'] < 7.5:
    print(f"  ‚úÖ SUCCESS: Expert count is VARIABLE ({stats_strict['avg_experts']:.2f} < 8)")
    print("  ‚úÖ HC ROUTING IS WORKING!")
    test2_pass = True
else:
    print(f"  ‚ùå FAILURE: Expert count too close to 8 ({stats_strict['avg_experts']:.2f})")
    print("  ‚ùå Routing may not be working correctly!")
    test2_pass = False

patcher.unpatch()

# TEST 3: HC routing with loose beta - should use MORE than strict
print("\n\nTEST 3: HC Routing (Œ≤=0.60, max_k=8) - LOOSE")
print("-" * 70)
patcher.patch_with_hc(beta=0.60, min_k=2, max_k=8, collect_stats=True)

with torch.no_grad():
    outputs_loose = model.generate(**inputs, max_new_tokens=10, do_sample=False)

stats_loose = patcher.get_stats()
generated_loose = tokenizer.decode(outputs_loose[0], skip_special_tokens=True)

print(f"  Generated: '{generated_loose}'")
print(f"  Avg experts: {stats_loose['avg_experts']:.2f}")
print(f"  Range: [{stats_loose['min_experts']}, {stats_loose['max_experts']}]")
print(f"  Std: {stats_loose['std_experts']:.2f}")

# Check if beta affects selection
beta_effect = stats_loose['avg_experts'] >= stats_strict['avg_experts']
if beta_effect:
    print(f"  ‚úÖ SUCCESS: Œ≤=0.60 uses ‚â• experts than Œ≤=0.30")
    print(f"  ‚úÖ Beta parameter is working correctly!")
    test3_pass = True
else:
    print(f"  ‚ö†Ô∏è  WARNING: Expected Œ≤=0.60 ‚â• Œ≤=0.30")
    test3_pass = False

patcher.unpatch()

# FINAL VERDICT
print("\n\n" + "=" * 70)
print("VERIFICATION SUMMARY (Approach 2)")
print("=" * 70)

if test2_pass:
    print("\nüéâ ALL CRITICAL TESTS PASSED!")
    print("\n‚úÖ Native OLMoE: Fixed K=8 (original code path)")
    print(f"‚úÖ Strict HC (Œ≤=0.30): {stats_strict['avg_experts']:.2f} experts (VARIABLE)")
    print(f"‚úÖ Loose HC (Œ≤=0.60): {stats_loose['avg_experts']:.2f} experts (VARIABLE)")
    print("‚úÖ Output quality maintained (text is coherent)")
    print("\nüéØ APPROACH 2 WORKING!")
    print("   ‚Ä¢ TopK ‚Üí Native OLMoE (config.num_experts_per_tok)")
    print("   ‚Ä¢ HC ‚Üí HCRouterPatcher (adaptive selection)")
else:
    print("\n‚ùå VERIFICATION FAILED")
    print("   Expert counts are not varying as expected.")

print("\n" + "=" * 70)

## 7.5 Load Benchmark Datasets

Load WikiText-2, LAMBADA, and HellaSwag for comprehensive evaluation.


In [None]:
print("=" * 70)
print("LOADING BENCHMARK DATASETS")
print("=" * 70)

# Configure sample count
MAX_SAMPLES = 100  # Samples per dataset for comprehensive evaluation
if 'MAX_SAMPLES' not in globals():
    MAX_SAMPLES = 100  # Default if DEBUG_MODE section was not run

print(f"üìä Using MAX_SAMPLES = {MAX_SAMPLES}")
EVAL_DATASETS = {}

# Load WikiText-2
try:
    print("\nüìö Loading WikiText-2...")
    wikitext_data = load_wikitext(max_samples=MAX_SAMPLES)
    EVAL_DATASETS['wikitext'] = wikitext_data
    print(f"   ‚úÖ Loaded {len(wikitext_data)} samples")
except Exception as e:
    print(f"   ‚ö†Ô∏è Failed to load WikiText: {e}")

# Load LAMBADA
try:
    print("\nüìö Loading LAMBADA...")
    lambada_data = load_lambada(max_samples=MAX_SAMPLES)
    EVAL_DATASETS['lambada'] = lambada_data
    print(f"   ‚úÖ Loaded {len(lambada_data)} samples")
except Exception as e:
    print(f"   ‚ö†Ô∏è Failed to load LAMBADA: {e}")

# Load HellaSwag
try:
    print("\nüìö Loading HellaSwag...")
    hellaswag_data = load_hellaswag(max_samples=MAX_SAMPLES)
    EVAL_DATASETS['hellaswag'] = hellaswag_data
    print(f"   ‚úÖ Loaded {len(hellaswag_data)} samples")
except Exception as e:
    print(f"   ‚ö†Ô∏è Failed to load HellaSwag: {e}")

print("\n" + "=" * 70)
print("‚úÖ BENCHMARK DATASETS READY")
print("=" * 70)
print(f"\nDataset Summary:")
for name, data in EVAL_DATASETS.items():
    count = len(data) if hasattr(data, '__len__') else 0
    print(f"  ‚Ä¢ {name}: {count} samples")


## 8. Experiment Configurations

In [None]:
"""
SECTION: MODEL LOADING FUNCTIONS (Approach 2)
==============================================
- TopK baselines: load_model_with_topk(k) - Uses NATIVE OLMoE code
- HC routing: Use HCRouterPatcher (defined above)

This approach is cleaner because TopK uses original OLMoE implementation!
"""

print("=" * 70)
print("MODEL LOADING FUNCTIONS (Approach 2)")
print("=" * 70)

from transformers import AutoConfig, AutoTokenizer, OlmoeForCausalLM

# Global model cache to avoid reloading same config
_model_cache = {}

def load_model_with_topk(k: int = 8, use_cache: bool = True):
    """
    Load OLMoE model with specific top-k value via config.
    
    This uses NATIVE OLMoE routing - no patching needed!
    The model's original forward() runs with the specified K.
    
    Args:
        k: Number of experts per token (num_experts_per_tok)
        use_cache: Whether to cache and reuse loaded models
        
    Returns:
        model: OLMoE model with specified top-k
        tokenizer: Tokenizer
    """
    global _model_cache
    
    cache_key = f"topk_{k}"
    
    if use_cache and cache_key in _model_cache:
        print(f"‚úÖ Using cached model (TopK={k})")
        return _model_cache[cache_key]['model'], _model_cache[cache_key]['tokenizer']
    
    print(f"üì• Loading OLMoE with TopK={k}...")
    
    # Load config and modify top-k
    config = AutoConfig.from_pretrained(MODEL_NAME)
    original_k = config.num_experts_per_tok
    config.num_experts_per_tok = k
    
    print(f"   Config: num_experts_per_tok = {original_k} ‚Üí {k}")
    
    # Load model with modified config
    new_model = OlmoeForCausalLM.from_pretrained(
        MODEL_NAME,
        config=config,
        device_map="auto",
        torch_dtype=torch.bfloat16
    )
    new_model.eval()
    
    # Load tokenizer
    new_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    if new_tokenizer.pad_token is None:
        new_tokenizer.pad_token = new_tokenizer.eos_token
    
    if use_cache:
        _model_cache[cache_key] = {'model': new_model, 'tokenizer': new_tokenizer}
    
    print(f"‚úÖ Model loaded with NATIVE TopK={k}")
    
    return new_model, new_tokenizer


def clear_model_cache():
    """Clear all cached models to free memory."""
    global _model_cache
    
    for key in list(_model_cache.keys()):
        del _model_cache[key]
    
    _model_cache = {}
    
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    print("‚úÖ Model cache cleared")


def get_current_topk(model) -> int:
    """Get the current top-k value from model config."""
    return model.config.num_experts_per_tok


print("\n‚úÖ Functions defined:")
print("   ‚Ä¢ load_model_with_topk(k) - Load model with NATIVE TopK routing")
print("   ‚Ä¢ clear_model_cache() - Free memory")
print("   ‚Ä¢ get_current_topk(model) - Check current K value")

print("\n" + "=" * 70)

# =============================================================================
# EXPERIMENT CONFIGURATIONS (Approach 2)
# =============================================================================

from dataclasses import dataclass
from typing import Optional

@dataclass
class RoutingConfig:
    """Configuration for a routing experiment."""
    name: str
    routing_type: str  # 'topk' (native) or 'hc' (patcher)
    
    # TopK parameters (native OLMoE via config)
    k: int = 8
    
    # HC parameters (patcher)
    beta: Optional[float] = None
    min_k: int = 2
    max_k: int = 8
    temperature: float = 1.0
    patch_mode: Optional[int] = None


# =============================================================================
# DEFINE ALL CONFIGURATIONS
# =============================================================================

configs = []

# TOPK BASELINES (Native OLMoE) - Use load_model_with_topk()
topk_values = [4, 8, 16, 32, 64]

for k in topk_values:
    configs.append(RoutingConfig(
        name=f'{k}experts_topk_baseline',
        routing_type='topk',  # Changed from 'baseline' to 'topk'
        k=k
    ))

# HC ROUTING CONFIGURATIONS (Use HCRouterPatcher)
beta_values = [0.20, 0.60]
max_k_values = [16, 32, 64]
patch_mode = [1, 2, 4]

for max_k in max_k_values:
    for beta in beta_values:
        for patch in patch_mode:
            configs.append(RoutingConfig(
                name=f'{max_k}experts_hc_b{int(beta*100):03d}',
                routing_type='hc',
                beta=beta,
                max_k=max_k,
                min_k=2,
                temperature=1.0,
                patch_mode=patch  
            ))

print(f"\nüìã Total Configurations: {len(configs)}")
print(f"   ‚Ä¢ TopK (native): {len([c for c in configs if c.routing_type == 'topk'])}")
print(f"   ‚Ä¢ HC (patcher): {len([c for c in configs if c.routing_type == 'hc'])}")

print(f"\nüìä TopK values: {topk_values}")
print(f"üìä HC Beta values: {beta_values}")
print(f"üìä HC Max-K values: {max_k_values}")

print(f"\nüìù Configuration List:")
print("-" * 60)
for i, cfg in enumerate(configs[:12]):
    if cfg.routing_type == 'topk':
        print(f"  {i+1:2d}. {cfg.name:<30} [Native TopK, k={cfg.k}]")
    else:
        print(f"  {i+1:2d}. {cfg.name:<30} [HC Patcher, Œ≤={cfg.beta}, max_k={cfg.max_k}]")

if len(configs) > 12:
    print(f"  ... and {len(configs) - 12} more")

print(f"\n‚úÖ Configuration setup complete!")
print(f"   Ready to run {len(configs)} experiments")
print("\n" + "=" * 70)

## 9. Run Experiments

This section runs all 21 configurations on all test prompts.

In [None]:
import json
from datetime import datetime

def run_inference(
    prompt: str,
    max_new_tokens: int = 20,
    collect_routing: bool = True
) -> Dict[str, Any]:
    """
    Run inference on a single prompt.

    Returns:
        Dictionary with:
            - generated_text: str
            - num_tokens: int
            - inference_time: float
            - routing_stats: dict (if collect_routing=True)
    """
    inputs = tokenizer(prompt, return_tensors='pt').to(device)

    start_time = time.time()

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )

    inference_time = time.time() - start_time

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    num_tokens = outputs.shape[1]

    result = {
        'generated_text': generated_text,
        'num_tokens': num_tokens,
        'inference_time': inference_time
    }

    if collect_routing:
        result['routing_stats'] = patcher.get_stats()

    return result


def run_configuration(
    config: RoutingConfig,
    prompts: List[str],
    prompt_complexities: List[str],
    max_new_tokens: int = 20
) -> Dict[str, Any]:
    """
    Run a single configuration on all prompts.

    Returns:
        Dictionary with aggregated results.
    """
    # Setup routing
    if config.routing_type == 'baseline':
        # Use OLMoE's native routing (no patching)
        patcher.unpatch()
        print(f"  Running with OLMoE native Top-{config.k} routing")
    elif config.routing_type == 'hc':
        patcher.patch_with_hc(
            beta=config.beta,
            temperature=config.temperature,
            min_k=config.min_k,
            max_k=config.max_k,
            collect_stats=True,
            patch_mode=config.patch_mode

        )
                # logger=logger,
        # log_every_n_tokens=LOG_EVERY_N
        print(f"  Running with HC routing (Œ≤={config.beta}, max_k={config.max_k})")

    # Run inference on all prompts
    all_results = []
    expert_counts_all = []

    for prompt, complexity in tqdm(
        zip(prompts, prompt_complexities),
        total=len(prompts),
        desc=f"  {config.name}",
        leave=False
    ):
        patcher.stats.clear()  # Clear stats for this prompt

        result = run_inference(prompt, max_new_tokens=max_new_tokens)
        result['prompt'] = prompt
        result['complexity'] = complexity

        all_results.append(result)

        if config.routing_type == 'hc' and 'routing_stats' in result:
            stats = result['routing_stats']
            if 'avg_experts' in stats:
                expert_counts_all.extend(
                    [stats['avg_experts']] * result['num_tokens']
                )

    # Aggregate results
    total_tokens = sum(r['num_tokens'] for r in all_results)
    total_time = sum(r['inference_time'] for r in all_results)

    aggregated = {
        'config_name': config.name,
        'routing_type': config.routing_type,
        'num_prompts': len(prompts),
        'total_tokens': total_tokens,
        'total_time': total_time,
        'avg_time_per_prompt': total_time / len(prompts),
        'tokens_per_second': total_tokens / total_time if total_time > 0 else 0,
        'detailed_results': all_results
    }

    # Add HC-specific metrics
    if config.routing_type == 'hc':
        # Aggregate routing stats across all prompts
        all_expert_counts = []
        for r in all_results:
            if 'routing_stats' in r and r['routing_stats']:
                if 'distribution' in r['routing_stats']:
                    dist = r['routing_stats']['distribution']
                    for k, count in enumerate(dist):
                        all_expert_counts.extend([k] * count)

        if all_expert_counts:
            all_expert_counts = np.array(all_expert_counts)
            aggregated['avg_experts'] = float(np.mean(all_expert_counts))
            aggregated['std_experts'] = float(np.std(all_expert_counts))
            aggregated['min_experts'] = int(np.min(all_expert_counts))
            aggregated['max_experts'] = int(np.max(all_expert_counts))
            aggregated['median_experts'] = float(np.median(all_expert_counts))

            # Ceiling/floor hit rates
            aggregated['ceiling_hit_rate'] = float(
                np.sum(all_expert_counts == config.max_k) / len(all_expert_counts) * 100
            )
            aggregated['floor_hit_rate'] = float(
                np.sum(all_expert_counts == config.min_k) / len(all_expert_counts) * 100
            )

            # Reduction vs baseline
            baseline_experts = 8  # OLMoE default
            aggregated['reduction_vs_baseline'] = float(
                (baseline_experts - aggregated['avg_experts']) / baseline_experts * 100
            )

        aggregated['beta'] = config.beta
        aggregated['max_k'] = config.max_k
        aggregated['min_k'] = config.min_k
    else:
        aggregated['k'] = config.k
        aggregated['avg_experts'] = config.k
        aggregated['std_experts'] = 0.0
        aggregated['min_experts'] = config.k
        aggregated['max_experts'] = config.k

    return aggregated

print("‚úÖ Inference functions defined")

## 9.5 Comprehensive Benchmark Evaluation

Run evaluation on WikiText (perplexity), LAMBADA (accuracy), and HellaSwag (accuracy).


In [None]:
"""
SECTION 9.5: COMPREHENSIVE BENCHMARK EVALUATION (Approach 2)
============================================================
- TopK: Uses NATIVE OLMoE via load_model_with_topk(k)
- HC: Uses HCRouterPatcher on base model (k=8)

This is cleaner because TopK uses the original OLMoE code path!
"""

print("=" * 70)
print("COMPREHENSIVE BENCHMARK EVALUATION (Approach 2)")
print("=" * 70)
print("   TopK ‚Üí Native OLMoE (config.num_experts_per_tok)")
print("   HC   ‚Üí HCRouterPatcher on base model")
print("=" * 70)

if 'EVAL_DATASETS' not in globals() or not EVAL_DATASETS:
    print("‚ö†Ô∏è No datasets loaded. Skipping benchmark evaluation.")
    print("   Run Section 7.5 to load datasets first.")
    comprehensive_results = []
else:
    print(f"\nExperiment Scope:")
    print(f"  ‚Ä¢ Configurations: {len(configs)}")
    print(f"  ‚Ä¢ Datasets: {list(EVAL_DATASETS.keys())}")
    print(f"  ‚Ä¢ Samples per dataset: {MAX_SAMPLES}")
    print(f"  ‚Ä¢ Total experiments: {len(configs) * len(EVAL_DATASETS)}")

    # Configure logging
    if 'LOG_EVERY_N' not in globals():
        LOG_EVERY_N = 100

    comprehensive_results = []
    benchmark_start = time.time()
    
    # Track current model state for Approach 2
    current_model = model  # Start with the base model (k=8)
    current_tokenizer = tokenizer
    current_model_k = 8
    current_patcher = None

    for dataset_name, dataset_data in EVAL_DATASETS.items():
        print(f"\n{'='*70}")
        print(f"EVALUATING ON: {dataset_name.upper()}")
        print(f"{'='*70}")

        for i, config in enumerate(configs):
            print(f"\n[{i+1}/{len(configs)}] {config.name} on {dataset_name}")
            print("-" * 50)

            config_start = time.time()

            # =================================================================
            # SETUP ROUTING (Approach 2)
            # =================================================================

            if config.routing_type == 'topk':
                # =====================================================
                # NATIVE TOPK: Load model with specific K via config
                # =====================================================
                if current_model_k != config.k:
                    print(f"  üì• Loading NATIVE TopK={config.k} model...")
                    current_model, current_tokenizer = load_model_with_topk(k=config.k, use_cache=True)
                    current_model_k = config.k
                    current_patcher = None  # No patcher for native TopK
                
                eval_model = current_model
                eval_tokenizer = current_tokenizer
                print(f"  üîß Using NATIVE TopK={config.k} (original OLMoE code)")
                
            else:  # config.routing_type == 'hc'
                # =====================================================
                # HC: Use base model (K=8) with HCRouterPatcher
                # =====================================================
                if current_model_k != 8:
                    print(f"  üì• Loading base model (K=8) for HC patching...")
                    current_model, current_tokenizer = load_model_with_topk(k=8, use_cache=True)
                    current_model_k = 8
                
                eval_model = current_model
                eval_tokenizer = current_tokenizer
                
                # Create patcher and apply HC
                current_patcher = HCRouterPatcher(eval_model)
                current_patcher.patch_with_hc(
                    beta=config.beta,
                    min_k=config.min_k,
                    max_k=config.max_k,
                    collect_stats=True,
                    log_every_n_tokens=LOG_EVERY_N, 
                    patch_mode=config.patch_mode
                )
                print(f"  üîß Using HC patcher (Œ≤={config.beta}, max_k={config.max_k}, patch_mode={config.patch_mode})")

            # =================================================================
            # INITIALIZE RESULT
            # =================================================================

            result = {
                'config_name': config.name,
                'routing_type': config.routing_type,
                'dataset': dataset_name,
                'k_or_max_k': config.k if config.routing_type == 'topk' else config.max_k,
                'beta': config.beta if config.routing_type == 'hc' else None,
                'min_k': config.min_k if config.routing_type == 'hc' else None,
                'patch_mode': config.patch_mode if config.routing_type == 'hc' else None,
                'patched_layers': (
                    current_patcher.patched_layer_indices 
                    if current_patcher and hasattr(current_patcher, 'patched_layer_indices') 
                    else []
                ),
                'num_patched_layers': (
                    len(current_patcher.patched_layer_indices) 
                    if current_patcher and hasattr(current_patcher, 'patched_layer_indices') 
                    else (0 if config.routing_type == 'topk' else 16)
                ),
            }

            # Start sample counter for HC
            if current_patcher is not None:
                current_patcher.start_sample()

            # =================================================================
            # RUN EVALUATION
            # =================================================================

            eval_result = {}  # Initialize to avoid reference errors
            
            try:
                if dataset_name == 'wikitext':
                    eval_result = evaluate_perplexity(
                        model=eval_model,
                        tokenizer=eval_tokenizer,
                        dataset=dataset_data,
                        patcher=current_patcher,  # None for TopK, patcher for HC
                        device=device,
                        max_length=512,
                        log_routing=(config.routing_type == 'hc'),
                        output_dir=str(OUTPUT_DIR),
                        experiment_name=f"{config.name}_{dataset_name}",
                        log_every_n=LOG_EVERY_N
                    )
                    result['perplexity'] = eval_result['perplexity']
                    if 'perplexity_token_weighted' in eval_result:
                        result['perplexity_token_weighted'] = eval_result['perplexity_token_weighted']
                    if 'perplexity_sample_weighted' in eval_result:
                        result['perplexity_sample_weighted'] = eval_result['perplexity_sample_weighted']
                    print(f"  ‚úÖ Perplexity: {eval_result['perplexity']:.2f}")

                elif dataset_name == 'lambada':
                    eval_result = evaluate_lambada(
                        model=eval_model,
                        tokenizer=eval_tokenizer,
                        dataset=dataset_data,
                        patcher=current_patcher,
                        device=device,
                        log_routing=(config.routing_type == 'hc'),
                        output_dir=str(OUTPUT_DIR),
                        experiment_name=f"{config.name}_{dataset_name}",
                        log_every_n=LOG_EVERY_N
                    )
                    result['lambada_accuracy'] = eval_result['accuracy']
                    print(f"  ‚úÖ LAMBADA Accuracy: {eval_result['accuracy']:.4f}")

                elif dataset_name == 'hellaswag':
                    eval_result = evaluate_hellaswag(
                        model=eval_model,
                        tokenizer=eval_tokenizer,
                        dataset=dataset_data,
                        patcher=current_patcher,
                        device=device,
                        log_routing=(config.routing_type == 'hc'),
                        output_dir=str(OUTPUT_DIR),
                        experiment_name=f"{config.name}_{dataset_name}",
                        log_every_n=LOG_EVERY_N
                    )
                    result['hellaswag_accuracy'] = eval_result['accuracy']
                    print(f"  ‚úÖ HellaSwag Accuracy: {eval_result['accuracy']:.4f}")

            except Exception as e:
                print(f"  ‚ùå Evaluation failed: {e}")
                import traceback
                print(traceback.format_exc())
                result['error'] = str(e)

            # =================================================================
            # GET ROUTING STATISTICS
            # =================================================================

            config_time = time.time() - config_start

            if config.routing_type == 'topk':
                # TopK: fixed expert count (native OLMoE)
                result['avg_experts'] = float(config.k)
                result['std_experts'] = 0.0
                result['min_experts'] = config.k
                result['max_experts'] = config.k
                result['adaptive_range'] = 0
                result['ceiling_hit_rate'] = 100.0
                result['floor_hit_rate'] = 0.0
                result['mid_range_rate'] = 0.0
                result['reduction_vs_baseline'] = (8 - config.k) / 8 * 100
                result['tokens_per_second'] = eval_result.get('total_tokens', 0) / config_time if config_time > 0 else 0
                
                # No HC stats for TopK
                result['hc_max_mean'] = None
                result['hc_max_std'] = None
                result['hc_threshold_ranks'] = None
                result['layer_wise'] = None
                
            else:  # HC
                stats = current_patcher.get_stats() if current_patcher else {}
                if stats:
                    # Basic expert statistics
                    result['avg_experts'] = stats['avg_experts']
                    result['std_experts'] = stats['std_experts']
                    result['min_experts'] = stats['min_experts']
                    result['max_experts'] = stats['max_experts']
                    
                    # =========================================================
                    # NEW: Transfer HC statistics for visualization
                    # =========================================================
                    if 'hc_max_mean' in stats:
                        result['hc_max_mean'] = stats['hc_max_mean']
                        result['hc_max_std'] = stats.get('hc_max_std', 0.0)
                    else:
                        result['hc_max_mean'] = None
                        result['hc_max_std'] = None
                        
                    if 'hc_threshold_ranks' in stats:
                        result['hc_threshold_ranks'] = stats['hc_threshold_ranks']
                    else:
                        result['hc_threshold_ranks'] = None
                        
                    if 'layer_wise' in stats:
                        result['layer_wise'] = stats['layer_wise']
                        # Also flatten for DataFrame compatibility
                        for layer_name, layer_data in stats['layer_wise'].items():
                            # Extract layer index from name like 'model.layers.15.mlp'
                            try:
                                if 'layers.' in layer_name:
                                    layer_idx = layer_name.split('layers.')[1].split('.')[0]
                                else:
                                    layer_idx = layer_name
                                result[f'layer_{layer_idx}_avg'] = layer_data['avg']
                                result[f'layer_{layer_idx}_std'] = layer_data['std']
                            except:
                                pass
                    else:
                        result['layer_wise'] = None
                    
                    # Compute additional metrics from raw expert counts
                    expert_counts = np.array(current_patcher.stats.get('expert_counts', []))
                    if len(expert_counts) > 0:
                        result['adaptive_range'] = int(np.max(expert_counts) - np.min(expert_counts))
                        result['ceiling_hit_rate'] = float((expert_counts >= config.max_k).sum() / len(expert_counts) * 100)
                        result['floor_hit_rate'] = float((expert_counts <= config.min_k).sum() / len(expert_counts) * 100)
                        result['mid_range_rate'] = 100.0 - result['ceiling_hit_rate'] - result['floor_hit_rate']
                        
                        # Selection entropy
                        try:
                            from scipy.stats import entropy as scipy_entropy
                            counts_dist = np.bincount(expert_counts.astype(int), minlength=65)
                            counts_dist = counts_dist / (counts_dist.sum() + 1e-10)
                            result['selection_entropy'] = float(scipy_entropy(counts_dist + 1e-10))
                        except:
                            result['selection_entropy'] = 0.0
                        
                        # Tokens per second
                        total_tokens = stats.get('total_tokens', len(expert_counts))
                        result['tokens_per_second'] = total_tokens / config_time if config_time > 0 else 0
                    else:
                        result['tokens_per_second'] = eval_result.get('total_tokens', 0) / config_time if config_time > 0 else 0
                    
                    result['reduction_vs_baseline'] = (8 - stats['avg_experts']) / 8 * 100
                    
                    print(f"  üìä Avg Experts: {stats['avg_experts']:.2f} ¬± {stats['std_experts']:.2f}")
                    print(f"  üìä Range: [{stats['min_experts']}, {stats['max_experts']}]")
                    print(f"  üìä Reduction: {result['reduction_vs_baseline']:.1f}%")
                    if 'hc_max_mean' in stats:
                        print(f"  üìä HC Max: {stats['hc_max_mean']:.3f} ¬± {stats.get('hc_max_std', 0):.3f}")
                
                # Unpatch for next iteration
                if current_patcher:
                    current_patcher.unpatch()

            result['elapsed_time'] = config_time
            print(f"  ‚è±Ô∏è Completed in {config_time:.1f}s")

            comprehensive_results.append(result)

            # Clear GPU cache periodically
            if torch.cuda.is_available() and i % 5 == 0:
                torch.cuda.empty_cache()

    # =========================================================================
    # CLEANUP
    # =========================================================================
    
    benchmark_time = time.time() - benchmark_start

    # =========================================================================
    # RESULTS SUMMARY
    # =========================================================================

    print("\n" + "=" * 70)
    print("COMPREHENSIVE BENCHMARK EVALUATION COMPLETE!")
    print("=" * 70)
    print(f"\n‚è±Ô∏è Total time: {benchmark_time / 60:.1f} minutes")
    print(f"üìä Experiments completed: {len(comprehensive_results)}")

    # Create summary DataFrame
    results_df = pd.DataFrame(comprehensive_results)

    # =========================================================================
    # DISPLAY RESULTS BY DATASET
    # =========================================================================

    print("\n" + "=" * 70)
    print("RESULTS BY DATASET")
    print("=" * 70)

    for dataset_name in EVAL_DATASETS.keys():
        dataset_results = results_df[results_df['dataset'] == dataset_name]

        print(f"\nüìä {dataset_name.upper()}")
        print("-" * 50)

        if dataset_name == 'wikitext' and 'perplexity' in dataset_results.columns:
            cols = ['config_name', 'routing_type', 'avg_experts', 'perplexity', 'reduction_vs_baseline']
            display_df = dataset_results[cols].copy()
            display_df['perplexity'] = display_df['perplexity'].round(2)
            display_df['avg_experts'] = display_df['avg_experts'].round(2)
            display_df['reduction_vs_baseline'] = display_df['reduction_vs_baseline'].round(1)
            print(display_df.sort_values('perplexity').to_string(index=False))

        elif dataset_name == 'lambada' and 'lambada_accuracy' in dataset_results.columns:
            cols = ['config_name', 'routing_type', 'avg_experts', 'lambada_accuracy', 'reduction_vs_baseline']
            display_df = dataset_results[cols].copy()
            display_df['lambada_accuracy'] = display_df['lambada_accuracy'].round(4)
            display_df['avg_experts'] = display_df['avg_experts'].round(2)
            display_df['reduction_vs_baseline'] = display_df['reduction_vs_baseline'].round(1)
            print(display_df.sort_values('lambada_accuracy', ascending=False).to_string(index=False))

        elif dataset_name == 'hellaswag' and 'hellaswag_accuracy' in dataset_results.columns:
            cols = ['config_name', 'routing_type', 'avg_experts', 'hellaswag_accuracy', 'reduction_vs_baseline']
            display_df = dataset_results[cols].copy()
            display_df['hellaswag_accuracy'] = display_df['hellaswag_accuracy'].round(4)
            display_df['avg_experts'] = display_df['avg_experts'].round(2)
            display_df['reduction_vs_baseline'] = display_df['reduction_vs_baseline'].round(1)
            print(display_df.sort_values('hellaswag_accuracy', ascending=False).to_string(index=False))

    # =========================================================================
    # SAVE RESULTS
    # =========================================================================

    results_path = OUTPUT_DIR / 'comprehensive_results_approach2.csv'
    results_df.to_csv(results_path, index=False)
    print(f"\n‚úÖ Results saved to: {results_path}")

    # =========================================================================
    # KEY FINDINGS
    # =========================================================================

    print("\n" + "=" * 70)
    print("KEY FINDINGS (Approach 2)")
    print("=" * 70)

    topk_baseline = results_df[(results_df['routing_type'] == 'topk') & (results_df['k_or_max_k'] == 8)]
    hc_results = results_df[results_df['routing_type'] == 'hc']

    if len(topk_baseline) > 0 and len(hc_results) > 0:
        print("\nüìà Baseline (Native TopK=8) vs HC Routing:")
        print("-" * 50)

        for dataset_name in EVAL_DATASETS.keys():
            baseline_row = topk_baseline[topk_baseline['dataset'] == dataset_name]
            hc_ds = hc_results[hc_results['dataset'] == dataset_name]

            if len(baseline_row) > 0 and len(hc_ds) > 0:
                print(f"\n  {dataset_name.upper()}:")
                baseline_row = baseline_row.iloc[0]
                
                if dataset_name == 'wikitext' and 'perplexity' in baseline_row:
                    baseline_ppl = baseline_row['perplexity']
                    print(f"    Baseline (Native TopK=8): PPL={baseline_ppl:.2f}")
                    best_hc = hc_ds.loc[hc_ds['perplexity'].idxmin()]
                    ppl_diff = ((best_hc['perplexity'] - baseline_ppl) / baseline_ppl) * 100
                    print(f"    Best HC ({best_hc['config_name']}): PPL={best_hc['perplexity']:.2f} ({ppl_diff:+.1f}%), "
                          f"Experts={best_hc['avg_experts']:.2f}")
                    
                elif dataset_name == 'lambada' and 'lambada_accuracy' in baseline_row:
                    baseline_acc = baseline_row['lambada_accuracy']
                    print(f"    Baseline (Native TopK=8): Acc={baseline_acc:.4f}")
                    best_hc = hc_ds.loc[hc_ds['lambada_accuracy'].idxmax()]
                    acc_diff = ((best_hc['lambada_accuracy'] - baseline_acc) / baseline_acc) * 100 if baseline_acc > 0 else 0
                    print(f"    Best HC ({best_hc['config_name']}): Acc={best_hc['lambada_accuracy']:.4f} ({acc_diff:+.1f}%), "
                          f"Experts={best_hc['avg_experts']:.2f}")

    print("\n" + "=" * 70)
    print("‚úÖ APPROACH 2 EVALUATION COMPLETE")
    print("=" * 70)
    print("   ‚Ä¢ TopK baselines used NATIVE OLMoE (via config)")
    print("   ‚Ä¢ HC routing used HCRouterPatcher")
    print("   ‚Ä¢ Results saved to:", results_path)

## 10. Save Comprehensive Results


In [None]:
"""
SECTION: SAVING COMPREHENSIVE RESULTS
=====================================
Copy this entire cell to replace the broken one in your notebook.
"""

print("=" * 70)
print("SAVING COMPREHENSIVE RESULTS")
print("=" * 70)

# Create output directories
from pathlib import Path
import json

if 'OUTPUT_DIR' not in globals():
    if IN_COLAB:
        OUTPUT_DIR = Path(WORK_DIR) / 'hc_comprehensive_results'
    else:
        OUTPUT_DIR = Path('./hc_comprehensive_results')

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
(OUTPUT_DIR / 'logs').mkdir(exist_ok=True)
(OUTPUT_DIR / 'visualizations').mkdir(exist_ok=True)

print(f"üìÅ Output directory: {OUTPUT_DIR}")

# Prefer comprehensive benchmark results, fallback to test prompt results
if 'comprehensive_results' in globals() and comprehensive_results:
    results_to_save = comprehensive_results
    print("‚úÖ Using comprehensive benchmark results")
elif 'all_experiment_results' in globals() and all_experiment_results:
    # Add missing columns to test prompt results for compatibility
    for result in all_experiment_results:
        if 'routing_type' in result and result['routing_type'] == 'baseline':
            result['routing_type'] = 'topk'  # Match viz expectations
        if 'k' in result and 'k_or_max_k' not in result:
            result['k_or_max_k'] = result['k']
        elif 'max_k' in result and 'k_or_max_k' not in result:
            result['k_or_max_k'] = result['max_k']
        if 'dataset' not in result:
            result['dataset'] = 'test_prompts'
        if 'mid_range_rate' not in result and 'ceiling_hit_rate' in result:
            result['mid_range_rate'] = 100 - result.get('ceiling_hit_rate', 0) - result.get('floor_hit_rate', 0)
    results_to_save = all_experiment_results
    print("‚ö†Ô∏è Using test prompt results (no benchmark data)")
else:
    print("‚ö†Ô∏è No results to save!")
    results_to_save = []

if results_to_save:
    # Create comprehensive DataFrame
    results_df = pd.DataFrame(results_to_save)

    # Save CSV
    csv_path = OUTPUT_DIR / 'hc_comprehensive_results.csv'
    results_df.to_csv(csv_path, index=False)
    print(f"‚úÖ Saved CSV: {csv_path}")

    # Save JSON
    json_path = OUTPUT_DIR / 'hc_comprehensive_results.json'
    results_df.to_json(json_path, orient='records', indent=2)
    print(f"‚úÖ Saved JSON: {json_path}")

    # Save per-config summary JSONs (dual file logging)
    logs_dir = OUTPUT_DIR / 'logs'
    for result in results_to_save:
        config_name = result.get('config_name', 'unknown')
        dataset = result.get('dataset', 'mixed')

        summary_file = logs_dir / f"{config_name}_{dataset}.json"
        with open(summary_file, 'w') as f:
            json.dump(result, f, indent=2, default=str)

    print(f"‚úÖ Saved {len(results_to_save)} individual log files to {logs_dir}")

    # Display summary tables
    print("\n" + "=" * 70)
    print("RESULTS SUMMARY")
    print("=" * 70)

    # Show top results
    if 'avg_experts' in results_df.columns:
        print("\nTop 10 configurations by average experts:")
        display_cols = ['config_name', 'avg_experts']
        for col in ['reduction_vs_baseline', 'dataset']:
            if col in results_df.columns:
                display_cols.append(col)
        top_df = results_df.nsmallest(10, 'avg_experts')[display_cols]
        print(top_df.to_string(index=False))

    # Quality metrics by dataset
    if 'dataset' in results_df.columns:
        for dataset in results_df['dataset'].unique():
            subset = results_df[results_df['dataset'] == dataset]
            print(f"\nüìä {dataset.upper()} Results:")

            display_cols = ['config_name']
            if 'perplexity' in subset.columns:
                display_cols.append('perplexity')
            if 'lambada_accuracy' in subset.columns:
                display_cols.append('lambada_accuracy')
            # if 'hellaswag_accuracy' in subset.columns:
            #     display_cols.append('hellaswag_accuracy')
            if 'avg_experts' in subset.columns:
                display_cols.append('avg_experts')

            if len(display_cols) > 1:
                print(subset[display_cols].head(10).to_string(index=False))

else:
    print("‚ö†Ô∏è No results to save!")
    results_df = pd.DataFrame()  # Create empty DataFrame

print("\n" + "=" * 70)

# Routing method identifier
ROUTING_METHOD = 'hc'  # Used for logging and visualization

## 11. Comprehensive Visualizations


In [None]:
print("=" * 70)
print("GENERATING COMPREHENSIVE VISUALIZATIONS")
print("=" * 70)

# =============================================================================
# PATCH_MODE ANALYSIS VISUALIZATIONS
# =============================================================================

def create_patch_mode_visualizations(results_df, output_dir):
    """
    Create visualizations specifically for patch_mode analysis.
    
    Shows:
    1. Performance vs number of patched layers
    2. Comparison of patch_mode settings
    3. Efficiency analysis (quality vs compute)
    """
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    # Filter HC results only
    hc_df = results_df[results_df['routing_type'] == 'hc'].copy()
    
    if len(hc_df) == 0:
        print("‚ö†Ô∏è No HC results for patch_mode visualization")
        return None
    
    # Add patch_mode_label column
    hc_df['patch_mode_label'] = hc_df['patch_mode'].apply(
        lambda x: f'Last {int(x)}' if pd.notna(x) and x > 0 else 'All 16'
    )
    
    # Add num_patched for ordering
    hc_df['num_patched'] = hc_df['patch_mode'].apply(
        lambda x: int(x) if pd.notna(x) and x > 0 else 16
    )
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle('HC Routing: patch_mode Analysis', fontsize=16, fontweight='bold')
    
    # Color palette
    palette = sns.color_palette('viridis', n_colors=len(hc_df['patch_mode_label'].unique()))
    
    # =========================================================================
    # Plot 1: Perplexity vs Patch Mode (WikiText)
    # =========================================================================
    ax1 = axes[0, 0]
    wiki_df = hc_df[hc_df['dataset'] == 'wikitext']
    if len(wiki_df) > 0 and 'perplexity' in wiki_df.columns:
        wiki_grouped = wiki_df.groupby(['patch_mode_label', 'num_patched']).agg({
            'perplexity': 'mean'
        }).reset_index().sort_values('num_patched')
        
        bars = ax1.bar(wiki_grouped['patch_mode_label'], wiki_grouped['perplexity'], 
                       color=palette[:len(wiki_grouped)])
        ax1.set_xlabel('Layers Patched with HC')
        ax1.set_ylabel('Perplexity')
        ax1.set_title('WikiText Perplexity by patch_mode')
        ax1.tick_params(axis='x', rotation=45)
        
        # Add value labels
        for bar, val in zip(bars, wiki_grouped['perplexity']):
            ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5, 
                    f'{val:.1f}', ha='center', va='bottom', fontsize=9)
    else:
        ax1.text(0.5, 0.5, 'No WikiText data', ha='center', va='center')
        ax1.set_title('WikiText Perplexity by patch_mode')
    
    # =========================================================================
    # Plot 2: Average Experts vs Patch Mode
    # =========================================================================
    ax2 = axes[0, 1]
    if 'avg_experts' in hc_df.columns:
        experts_grouped = hc_df.groupby(['patch_mode_label', 'num_patched']).agg({
            'avg_experts': 'mean'
        }).reset_index().sort_values('num_patched')
        
        bars = ax2.bar(experts_grouped['patch_mode_label'], experts_grouped['avg_experts'],
                       color=palette[:len(experts_grouped)])
        ax2.axhline(y=8, color='red', linestyle='--', label='TopK=8 baseline')
        ax2.set_xlabel('Layers Patched with HC')
        ax2.set_ylabel('Average Experts Selected')
        ax2.set_title('Expert Selection by patch_mode')
        ax2.tick_params(axis='x', rotation=45)
        ax2.legend()
        
        for bar, val in zip(bars, experts_grouped['avg_experts']):
            ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
                    f'{val:.2f}', ha='center', va='bottom', fontsize=9)
    else:
        ax2.text(0.5, 0.5, 'No expert data', ha='center', va='center')
        ax2.set_title('Expert Selection by patch_mode')
    
    # =========================================================================
    # Plot 3: Throughput vs Patch Mode
    # =========================================================================
    ax3 = axes[0, 2]
    if 'tokens_per_second' in hc_df.columns:
        tps_grouped = hc_df.groupby(['patch_mode_label', 'num_patched']).agg({
            'tokens_per_second': 'mean'
        }).reset_index().sort_values('num_patched')
        
        bars = ax3.bar(tps_grouped['patch_mode_label'], tps_grouped['tokens_per_second'],
                       color=palette[:len(tps_grouped)])
        ax3.set_xlabel('Layers Patched with HC')
        ax3.set_ylabel('Tokens/Second')
        ax3.set_title('Throughput by patch_mode')
        ax3.tick_params(axis='x', rotation=45)
    else:
        ax3.text(0.5, 0.5, 'No throughput data', ha='center', va='center')
        ax3.set_title('Throughput by patch_mode')
    
    # =========================================================================
    # Plot 4: Heatmap - Beta vs Patch Mode (Perplexity)
    # =========================================================================
    ax4 = axes[1, 0]
    wiki_df = hc_df[hc_df['dataset'] == 'wikitext']
    if len(wiki_df) > 0 and 'perplexity' in wiki_df.columns and 'beta' in wiki_df.columns:
        pivot = wiki_df.pivot_table(
            values='perplexity', 
            index='beta', 
            columns='patch_mode_label',
            aggfunc='mean'
        )
        # Reorder columns by num_patched
        col_order = hc_df.groupby('patch_mode_label')['num_patched'].first().sort_values().index
        pivot = pivot[[c for c in col_order if c in pivot.columns]]
        
        sns.heatmap(pivot, annot=True, fmt='.1f', cmap='RdYlGn_r', ax=ax4)
        ax4.set_title('Perplexity: Beta √ó patch_mode')
        ax4.set_xlabel('Layers Patched')
        ax4.set_ylabel('Beta')
    else:
        ax4.text(0.5, 0.5, 'Insufficient data for heatmap', ha='center', va='center')
        ax4.set_title('Perplexity: Beta √ó patch_mode')
    
    # =========================================================================
    # Plot 5: Efficiency Frontier (Perplexity vs Experts)
    # =========================================================================
    ax5 = axes[1, 1]
    wiki_df = hc_df[hc_df['dataset'] == 'wikitext']
    if len(wiki_df) > 0 and 'perplexity' in wiki_df.columns and 'avg_experts' in wiki_df.columns:
        for pm_label in wiki_df['patch_mode_label'].unique():
            subset = wiki_df[wiki_df['patch_mode_label'] == pm_label]
            ax5.scatter(subset['avg_experts'], subset['perplexity'], 
                       label=pm_label, s=100, alpha=0.7)
        
        ax5.set_xlabel('Average Experts')
        ax5.set_ylabel('Perplexity')
        ax5.set_title('Efficiency Frontier: Quality vs Compute')
        ax5.legend(title='Layers Patched')
        ax5.axvline(x=8, color='red', linestyle='--', alpha=0.5, label='TopK=8')
    else:
        ax5.text(0.5, 0.5, 'Insufficient data', ha='center', va='center')
        ax5.set_title('Efficiency Frontier')
    
    # =========================================================================
    # Plot 6: Summary Statistics Table
    # =========================================================================
    ax6 = axes[1, 2]
    ax6.axis('off')
    
    # Create summary
    summary_data = []
    for pm_label in sorted(hc_df['patch_mode_label'].unique(), 
                           key=lambda x: hc_df[hc_df['patch_mode_label']==x]['num_patched'].iloc[0]):
        subset = hc_df[hc_df['patch_mode_label'] == pm_label]
        wiki_subset = subset[subset['dataset'] == 'wikitext']
        
        row = {
            'Layers': pm_label,
            'Configs': len(subset),
        }
        if 'perplexity' in wiki_subset.columns and len(wiki_subset) > 0:
            row['PPL'] = f"{wiki_subset['perplexity'].mean():.1f}"
        if 'avg_experts' in subset.columns:
            row['Experts'] = f"{subset['avg_experts'].mean():.2f}"
        if 'reduction_vs_baseline' in subset.columns:
            row['Reduction'] = f"{subset['reduction_vs_baseline'].mean():.1f}%"
        
        summary_data.append(row)
    
    if summary_data:
        summary_df = pd.DataFrame(summary_data)
        table = ax6.table(
            cellText=summary_df.values,
            colLabels=summary_df.columns,
            loc='center',
            cellLoc='center'
        )
        table.auto_set_font_size(False)
        table.set_fontsize(10)
        table.scale(1.2, 1.5)
        ax6.set_title('Summary by patch_mode', fontweight='bold', pad=20)
    
    plt.tight_layout()
    
    # Save
    viz_path = output_dir / 'visualizations' / 'patch_mode_analysis.png'
    plt.savefig(viz_path, dpi=300, bbox_inches='tight')
    print(f"‚úÖ Saved patch_mode visualization: {viz_path}")
    plt.show()
    
    return viz_path


# Run patch_mode visualization if we have results
if 'results_df' in globals() and results_df is not None and len(results_df) > 0:
    if 'patch_mode' in results_df.columns:
        print("\n" + "=" * 70)
        print("PATCH_MODE ANALYSIS")
        print("=" * 70)
        create_patch_mode_visualizations(results_df, OUTPUT_DIR)
    else:
        print("‚ö†Ô∏è patch_mode column not in results - skipping patch_mode visualization")


else:
    # Validate required columns
    results_df['hc_variant'] = results_df['beta'].apply(
        lambda x: f"Œ≤={x}" if pd.notna(x) else 'TopK')
    required_cols = ['routing_type', 'k_or_max_k', 'dataset', 'avg_experts']
    missing_cols = [c for c in required_cols if c not in results_df.columns]

    if missing_cols:
        print(f"\n‚ö†Ô∏è Missing columns for full visualization: {missing_cols}")
        print("Falling back to basic visualization...")

        # Fallback to basic visualization
        import matplotlib.pyplot as plt
        import seaborn as sns

        sns.set_style('whitegrid')
        fig, axes = plt.subplots(2, 2, figsize=(14, 10))
        fig.suptitle('OLMoE HC Routing Analysis', fontsize=16, fontweight='bold')

        # Filter HC results
        hc_df = results_df[results_df['routing_type'] == 'hc'] if 'routing_type' in results_df.columns else results_df

        # Plot 1: Average experts
        if 'avg_experts' in results_df.columns:
            ax1 = axes[0, 0]
            top_10 = results_df.nsmallest(10, 'avg_experts')
            ax1.barh(range(len(top_10)), top_10['avg_experts'])
            ax1.set_yticks(range(len(top_10)))
            ax1.set_yticklabels(top_10['config_name'], fontsize=8)
            ax1.set_title('Top 10: Fewest Experts')
            ax1.set_xlabel('Average Experts')

        # Plot 2: Throughput
        if 'tokens_per_second' in results_df.columns:
            ax2 = axes[0, 1]
            top_10 = results_df.nlargest(10, 'tokens_per_second')
            ax2.barh(range(len(top_10)), top_10['tokens_per_second'])
            ax2.set_yticks(range(len(top_10)))
            ax2.set_yticklabels(top_10['config_name'], fontsize=8)
            ax2.set_title('Top 10: Throughput')
            ax2.set_xlabel('Tokens/Second')

        # Plot 3: Reduction vs baseline
        if 'reduction_vs_baseline' in hc_df.columns and len(hc_df) > 0:
            ax3 = axes[1, 0]
            top_10 = hc_df.nlargest(10, 'reduction_vs_baseline')
            ax3.barh(range(len(top_10)), top_10['reduction_vs_baseline'])
            ax3.set_yticks(range(len(top_10)))
            ax3.set_yticklabels(top_10['config_name'], fontsize=8)
            ax3.set_title('Top 10: Expert Reduction')
            ax3.set_xlabel('Reduction (%)')

        # Plot 4: Summary stats
        ax4 = axes[1, 1]
        ax4.axis('off')
        summary_text = f"""
        Total Configs: {len(results_df)}
        HC Configs: {len(hc_df)}
        """
        if 'avg_experts' in results_df.columns:
            summary_text += f"\nMin Avg Experts: {results_df['avg_experts'].min():.2f}"
            summary_text += f"\nMax Avg Experts: {results_df['avg_experts'].max():.2f}"
        ax4.text(0.1, 0.5, summary_text, fontsize=12, va='center')

        plt.tight_layout()
        viz_path = OUTPUT_DIR / 'visualizations' / 'basic_analysis.png'
        plt.savefig(viz_path, dpi=300, bbox_inches='tight')
        print(f"‚úÖ Saved basic visualization: {viz_path}")
        plt.show()
    else:
        # Use comprehensive visualization
        try:
            viz_path = create_comprehensive_visualization(
                results_df=results_df,
                output_path=str(OUTPUT_DIR / 'visualizations' / 'hc_comprehensive_comparison.png')
            )

            if viz_path:
                print(f"‚úÖ Saved comprehensive visualization: {viz_path}")

                # Display the visualization
                from IPython.display import Image, display
                if Path(viz_path).exists():
                    display(Image(filename=str(viz_path)))
            else:
                print("‚ö†Ô∏è Visualization not created")
        except Exception as e:
            print(f"‚ö†Ô∏è Comprehensive visualization failed: {e}")
            import traceback
            print(traceback.format_exc())
            print("\nTry running the basic visualization fallback above.")

print("\n" + "=" * 70)
print("‚úÖ VISUALIZATIONS COMPLETE")
print("=" * 70)


## 12. Statistical Analysis

In [None]:
print("=" * 70)
print("COMPREHENSIVE STATISTICAL ANALYSIS")
print("=" * 70)

# Define hc_df and baseline_df from results
if 'results_df' not in globals() or results_df is None or len(results_df) == 0:
    print("\n‚ö†Ô∏è No results DataFrame available for analysis")
    print("   Run Sections 9.5 and 10 first to generate results.")
else:
    # Filter HC and baseline results
    hc_df = results_df[results_df['routing_type'] == 'hc'].copy() if 'routing_type' in results_df.columns else pd.DataFrame()
    baseline_df = results_df[results_df['routing_type'] == 'topk'].copy() if 'routing_type' in results_df.columns else pd.DataFrame()

    print(f"\nResults breakdown:")
    print(f"  ‚Ä¢ Total configurations: {len(results_df)}")
    print(f"  ‚Ä¢ Baseline (TopK): {len(baseline_df)}")
    print(f"  ‚Ä¢ HC routing: {len(hc_df)}")

    if len(hc_df) == 0:
        print("\n‚ö†Ô∏è No HC results found. Skipping HC-specific analysis.")
    else:
        # =====================================================================
        # 1. BASELINE COMPARISON
        # =====================================================================
        print("\n1. BASELINE COMPARISON")
        print("-" * 50)

        if len(baseline_df) > 0 and 'avg_experts' in baseline_df.columns:
            print("\nBaseline Configurations:")
            baseline_cols = ['config_name', 'avg_experts']
            if 'dataset' in baseline_df.columns:
                baseline_cols.append('dataset')
            if 'perplexity' in baseline_df.columns:
                baseline_cols.append('perplexity')
            display_cols = [c for c in baseline_cols if c in baseline_df.columns]
            print(baseline_df[display_cols].to_string(index=False))

        # =====================================================================
        # 2. HC ROUTING ANALYSIS
        # =====================================================================
        print("\n\n2. HC ROUTING ANALYSIS")
        print("-" * 50)

        # Best by reduction (most efficient)
        if 'reduction_vs_baseline' in hc_df.columns:
            print("\nTop 5 by Expert Reduction:")
            best_reduction = hc_df.nlargest(5, 'reduction_vs_baseline')
            display_cols = ['config_name', 'avg_experts', 'reduction_vs_baseline']
            if 'alpha' in best_reduction.columns:
                display_cols.append('alpha')
            if 'max_k' in best_reduction.columns:
                display_cols.append('max_k')
            display_cols = [c for c in display_cols if c in best_reduction.columns]
            print(best_reduction[display_cols].to_string(index=False))
        else:
            print("\n‚ö†Ô∏è Column 'reduction_vs_baseline' not found")

        # Best by low ceiling hit rate (not constrained)
        if 'ceiling_hit_rate' in hc_df.columns:
            print("\n\nTop 5 by Low Ceiling Hit Rate (unconstrained):")
            best_unconstrained = hc_df.nsmallest(5, 'ceiling_hit_rate')
            display_cols = ['config_name', 'avg_experts', 'ceiling_hit_rate']
            if 'alpha' in best_unconstrained.columns:
                display_cols.append('alpha')
            if 'max_k' in best_unconstrained.columns:
                display_cols.append('max_k')
            display_cols = [c for c in display_cols if c in best_unconstrained.columns]
            print(best_unconstrained[display_cols].to_string(index=False))
        else:
            print("\n‚ö†Ô∏è Column 'ceiling_hit_rate' not found")

        # Best by adaptive range (most dynamic)
        if 'adaptive_range' in hc_df.columns:
            print("\n\nTop 5 by Adaptive Range (most dynamic):")
            best_adaptive = hc_df.nlargest(5, 'adaptive_range')
            display_cols = ['config_name', 'adaptive_range', 'avg_experts']
            if 'dataset' in best_adaptive.columns:
                display_cols.append('dataset')
            display_cols = [c for c in display_cols if c in best_adaptive.columns]
            print(best_adaptive[display_cols].to_string(index=False))

        # =====================================================================
        # 3. SATURATION ANALYSIS
        # =====================================================================
        print("\n\n3. SATURATION ANALYSIS")
        print("-" * 50)

        if 'alpha' in hc_df.columns and 'max_k' in hc_df.columns:
            for beta_val in sorted(hc_df['beta'].dropna().unique()):
                subset = hc_df[hc_df['beta'] == beta_val].sort_values('max_k')
                if 'avg_experts' in subset.columns and len(subset) > 1:
                    avg_experts = subset['avg_experts'].values
                    max_ks = subset['max_k'].values

                    # Find where increase is < 5%
                    saturation_point = None
                    for i in range(1, len(avg_experts)):
                        if avg_experts[i-1] > 0:
                            pct_increase = (avg_experts[i] - avg_experts[i-1]) / avg_experts[i-1] * 100
                            if pct_increase < 5:
                                saturation_point = max_ks[i]
                                break

                    if saturation_point:
                        print(f"Œ≤={beta}: Saturates at max_k={saturation_point}")
                    else:
                        print(f"Œ≤={beta}: No saturation detected (benefits from higher max_k)")
        else:
            print("‚ö†Ô∏è Columns 'alpha' or 'max_k' not found")

        # =====================================================================
        # 4. RECOMMENDED CONFIGURATIONS
        # =====================================================================
        print("\n\n4. RECOMMENDED CONFIGURATIONS")
        print("-" * 50)

        print("\nüéØ Based on analysis:")
        print("\n  ‚Ä¢ For MAXIMUM EFFICIENCY:")
        print("    Use Œ≤=0.30, max_k=8 (lowest expert count)")
        print("\n  ‚Ä¢ For BALANCED PERFORMANCE:")
        print("    Use Œ≤=0.50, max_k=16 (good trade-off)")
        print("\n  ‚Ä¢ For QUALITY-CRITICAL TASKS:")
        print("    Use Œ≤=0.60, max_k=32 (closest to baseline quality)")

print("\n" + "=" * 70)


## 13. Generate Report

In [None]:
if 'results_df' not in globals() or results_df is None or len(results_df) == 0:
    print("‚ö†Ô∏è No results to generate report from")
    print("   Run Sections 9.5 and 10 first.")
else:
    report_path = OUTPUT_DIR / 'hc_routing_comprehensive_report.md'

    with open(report_path, 'w') as f:
        from datetime import datetime
        f.write("# OLMoE HC Routing Comprehensive Evaluation Report\n\n")
        f.write(f"**Date:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
        f.write(f"**Model:** {MODEL_NAME}\n\n")

        f.write("---\n\n")

        f.write("## Executive Summary\n\n")
        f.write(f"- **Configurations tested:** {len(results_df)}\n")
        if 'routing_type' in results_df.columns:
            baseline_count = len(results_df[results_df['routing_type'] == 'topk'])
            hc_count = len(results_df[results_df['routing_type'] == 'hc'])
            f.write(f"  - Baselines: {baseline_count}\n")
            f.write(f"  - HC variants: {hc_count}\n")
        if 'dataset' in results_df.columns:
            datasets = results_df['dataset'].unique().tolist()
            f.write(f"- **Datasets evaluated:** {datasets}\n")
        f.write("\n")

        f.write("---\n\n")

        f.write("## Key Findings\n\n")

        # Best configurations
        hc_df = results_df[results_df['routing_type'] == 'hc'] if 'routing_type' in results_df.columns else results_df

        if len(hc_df) > 0 and 'avg_experts' in hc_df.columns:
            best = hc_df.nsmallest(1, 'avg_experts').iloc[0]
            f.write("### Best Efficiency (Fewest Experts)\n\n")
            f.write(f"- **Configuration:** {best['config_name']}\n")
            f.write(f"- **Avg Experts:** {best['avg_experts']:.2f}\n")
            if 'beta' in best:
                f.write(f"- **Beta:** {best['beta']}\n")
            if 'k_or_max_k' in best:
                f.write(f"- **max_k:** {best['k_or_max_k']}\n\n")

        if len(hc_df) > 0 and 'reduction_vs_baseline' in hc_df.columns:
            best_red = hc_df.nlargest(1, 'reduction_vs_baseline').iloc[0]
            f.write("### Best Expert Reduction\n\n")
            f.write(f"- **Configuration:** {best_red['config_name']}\n")
            f.write(f"- **Reduction:** {best_red['reduction_vs_baseline']:.1f}%\n")
            f.write(f"- **Avg Experts:** {best_red['avg_experts']:.2f}\n\n")

        f.write("---\n\n")

        f.write("## Recommendations\n\n")
        f.write("Based on comprehensive evaluation:\n\n")
        f.write("1. **Maximum Efficiency:** Œ≤=0.30, max_k=8\n")
        f.write("2. **Balanced Performance:** Œ≤=0.50, max_k=16\n")
        f.write("3. **Quality-Critical:** Œ≤=0.60, max_k=32\n\n")

        f.write("---\n\n")

        f.write("## Full Results\n\n")
        f.write(results_df.to_markdown(index=False))
        f.write("\n\n")

        f.write("---\n\n")
        f.write(f"Generated by HC Routing Framework\n")
        f.write(f"Output directory: {OUTPUT_DIR}\n")

    print(f"‚úÖ Generated comprehensive report: {report_path}")
