# BlockFFN + Mamba Hybrid: Routing Signal Visualization

This notebook visualizes the routing signals extracted from BlockFFN to understand:
1. How routing sparsity varies across tokens
2. Whether sparsity correlates with semantic importance
3. How routing patterns differ across layers

Run `scripts/03_extract_routing.py` first to generate the routing data.

In [None]:
import json
import sys
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

from config import OUTPUT_DIR, ROUTING_SIGNALS_PATH, ALPHA_SWEEP_RESULTS_PATH, GATED_HYBRID_RESULTS_PATH

plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

## 1. Load Routing Data

In [None]:
# Load routing signals from JSON
routing_path = project_root / ROUTING_SIGNALS_PATH

if routing_path.exists():
    with open(routing_path) as f:
        routing_data = json.load(f)
    print(f"Loaded routing data from {routing_path}")
    print(f"Number of prompts: {len(routing_data.get('prompts', []))}")
else:
    print(f"Routing data not found at {routing_path}")
    print("Run scripts/03_extract_routing.py first!")
    routing_data = None

## 2. Overall Sparsity Distribution

In [None]:
if routing_data:
    prompts_data = routing_data.get('prompts', [])
    
    # Extract overall sparsity for each prompt
    sparsities = [p['overall_sparsity'] for p in prompts_data]
    prompts = [p['prompt'][:50] + '...' if len(p['prompt']) > 50 else p['prompt'] for p in prompts_data]
    
    fig, ax = plt.subplots(figsize=(12, 5))
    bars = ax.barh(range(len(sparsities)), sparsities, color='steelblue')
    ax.set_yticks(range(len(prompts)))
    ax.set_yticklabels(prompts, fontsize=9)
    ax.set_xlabel('Sparsity (fraction of zeros)', fontsize=12)
    ax.set_title('Overall Routing Sparsity by Prompt', fontsize=14)
    ax.axvline(x=0.7, color='red', linestyle='--', label='High sparsity threshold (0.7)')
    ax.legend()
    
    # Add value labels
    for i, (bar, val) in enumerate(zip(bars, sparsities)):
        ax.text(val + 0.01, i, f'{val:.2f}', va='center', fontsize=9)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nMean sparsity: {np.mean(sparsities):.3f}")
    print(f"Std sparsity: {np.std(sparsities):.3f}")

## 3. Per-Layer Sparsity Heatmap

In [None]:
if routing_data:
    prompts_data = routing_data.get('prompts', [])
    
    # Build matrix of per-layer sparsity
    all_layers = set()
    for p in prompts_data:
        all_layers.update(p.get('per_layer_sparsity', {}).keys())
    
    layers = sorted([int(l) for l in all_layers])
    
    if layers:
        sparsity_matrix = np.zeros((len(prompts_data), len(layers)))
        
        for i, p in enumerate(prompts_data):
            layer_sparsity = p.get('per_layer_sparsity', {})
            for j, layer in enumerate(layers):
                sparsity_matrix[i, j] = layer_sparsity.get(str(layer), 0)
        
        fig, ax = plt.subplots(figsize=(14, 6))
        im = ax.imshow(sparsity_matrix, aspect='auto', cmap='RdYlGn', vmin=0, vmax=1)
        
        # Show every 5th layer on x-axis for readability
        xtick_indices = range(0, len(layers), max(1, len(layers)//10))
        ax.set_xticks(list(xtick_indices))
        ax.set_xticklabels([layers[i] for i in xtick_indices])
        
        ax.set_yticks(range(len(prompts_data)))
        ax.set_yticklabels([p['prompt'][:30] + '...' for p in prompts_data], fontsize=9)
        
        ax.set_xlabel('Layer', fontsize=12)
        ax.set_ylabel('Prompt', fontsize=12)
        ax.set_title('Routing Sparsity Across Layers and Prompts', fontsize=14)
        
        cbar = plt.colorbar(im, ax=ax)
        cbar.set_label('Sparsity', fontsize=11)
        
        plt.tight_layout()
        plt.show()
        
        # Average sparsity per layer
        avg_per_layer = sparsity_matrix.mean(axis=0)
        
        fig, ax = plt.subplots(figsize=(12, 4))
        ax.plot(layers, avg_per_layer, 'o-', markersize=4)
        ax.fill_between(layers, avg_per_layer, alpha=0.3)
        ax.set_xlabel('Layer', fontsize=12)
        ax.set_ylabel('Average Sparsity', fontsize=12)
        ax.set_title('Average Routing Sparsity by Layer', fontsize=14)
        ax.axhline(y=0.7, color='red', linestyle='--', alpha=0.7, label='Threshold (0.7)')
        ax.legend()
        plt.tight_layout()
        plt.show()

## 4. Token-Level Analysis

Examine how sparsity varies across tokens within a prompt. Do certain token types (punctuation, function words) consistently show higher sparsity?

In [None]:
if routing_data and prompts_data:
    # Use first prompt for detailed analysis
    example = prompts_data[0]
    tokens = example.get('tokens', [])
    
    print(f"Analyzing prompt: {example['prompt']!r}")
    print(f"Number of tokens: {len(tokens)}")
    print(f"\nTokens: {tokens}")
    
    # Note: We don't have per-token sparsity saved in JSON (too large)
    # This would require running the extractor again or saving more data
    print("\nNote: Per-token sparsity visualization requires running the model.")
    print("To see per-token sparsity, modify 03_extract_routing.py to save per-token data.")

## 5. Alpha Sweep Results

Visualize how perplexity changes with different alpha values (Mamba contribution).

In [None]:
alpha_path = project_root / ALPHA_SWEEP_RESULTS_PATH

if alpha_path.exists():
    alpha_df = pd.read_csv(alpha_path)
    print(f"Loaded alpha sweep results from {alpha_path}")
    display(alpha_df)
    
    # Filter to numeric alpha values
    numeric_df = alpha_df[alpha_df['alpha'].apply(lambda x: isinstance(x, (int, float)) or (isinstance(x, str) and x.replace('.','').isdigit()))].copy()
    numeric_df['alpha'] = pd.to_numeric(numeric_df['alpha'])
    numeric_df = numeric_df.sort_values('alpha')
    
    if not numeric_df.empty and 'perplexity' in numeric_df.columns:
        fig, axes = plt.subplots(1, 2, figsize=(14, 5))
        
        # Perplexity vs Alpha
        ax1 = axes[0]
        ax1.plot(numeric_df['alpha'], numeric_df['perplexity'], 'o-', markersize=8, linewidth=2)
        ax1.set_xlabel('Alpha (0=Attention, 1=Mamba)', fontsize=12)
        ax1.set_ylabel('Perplexity', fontsize=12)
        ax1.set_title('Perplexity vs Alpha', fontsize=14)
        ax1.grid(True, alpha=0.3)
        
        # Add baseline reference
        baseline_row = alpha_df[alpha_df['alpha'] == 'baseline']
        if not baseline_row.empty:
            baseline_ppl = baseline_row['perplexity'].values[0]
            ax1.axhline(y=baseline_ppl, color='red', linestyle='--', label=f'Baseline: {baseline_ppl:.2f}')
            ax1.legend()
        
        # Throughput vs Alpha (if available)
        ax2 = axes[1]
        if 'tokens_per_second' in numeric_df.columns:
            ax2.plot(numeric_df['alpha'], numeric_df['tokens_per_second'], 'o-', markersize=8, linewidth=2, color='green')
            ax2.set_xlabel('Alpha', fontsize=12)
            ax2.set_ylabel('Tokens/Second', fontsize=12)
            ax2.set_title('Throughput vs Alpha', fontsize=14)
            ax2.grid(True, alpha=0.3)
        else:
            ax2.text(0.5, 0.5, 'Throughput data not available', ha='center', va='center', transform=ax2.transAxes)
        
        plt.tight_layout()
        plt.show()
else:
    print(f"Alpha sweep results not found at {alpha_path}")
    print("Run scripts/05_sweep_alpha.py first!")

## 6. Gated Hybrid Results

Compare routing-based gating vs fixed alpha.

In [None]:
gated_path = project_root / GATED_HYBRID_RESULTS_PATH

if gated_path.exists():
    gated_df = pd.read_csv(gated_path)
    print(f"Loaded gated hybrid results from {gated_path}")
    display(gated_df)
    
    if 'perplexity' in gated_df.columns:
        fig, ax = plt.subplots(figsize=(12, 6))
        
        colors = []
        for mode in gated_df['mode']:
            if 'routing' in mode:
                colors.append('steelblue')
            elif 'fixed' in mode:
                colors.append('orange')
            else:
                colors.append('gray')
        
        bars = ax.barh(range(len(gated_df)), gated_df['perplexity'], color=colors)
        ax.set_yticks(range(len(gated_df)))
        ax.set_yticklabels(gated_df['mode'])
        ax.set_xlabel('Perplexity', fontsize=12)
        ax.set_title('Perplexity by Gating Mode', fontsize=14)
        
        # Add value labels
        for i, (bar, val) in enumerate(zip(bars, gated_df['perplexity'])):
            ax.text(val + 0.1, i, f'{val:.2f}', va='center', fontsize=9)
        
        # Legend
        from matplotlib.patches import Patch
        legend_elements = [
            Patch(facecolor='steelblue', label='Routing-based'),
            Patch(facecolor='orange', label='Fixed alpha'),
            Patch(facecolor='gray', label='Baseline')
        ]
        ax.legend(handles=legend_elements, loc='lower right')
        
        plt.tight_layout()
        plt.show()
        
        # Find best in each category
        routing_results = gated_df[gated_df['mode'].str.contains('routing', case=False)]
        fixed_results = gated_df[gated_df['mode'].str.contains('fixed', case=False)]
        
        if not routing_results.empty:
            best_routing = routing_results.loc[routing_results['perplexity'].idxmin()]
            print(f"\nBest routing-based: {best_routing['mode']} (perplexity: {best_routing['perplexity']:.2f})")
        
        if not fixed_results.empty:
            best_fixed = fixed_results.loc[fixed_results['perplexity'].idxmin()]
            print(f"Best fixed alpha: {best_fixed['mode']} (perplexity: {best_fixed['perplexity']:.2f})")
else:
    print(f"Gated hybrid results not found at {gated_path}")
    print("Run scripts/07_gated_hybrid.py first!")

## 7. Summary and Conclusions

In [None]:
print("=" * 60)
print("EXPERIMENT SUMMARY")
print("=" * 60)

if routing_data:
    mean_sparsity = np.mean([p['overall_sparsity'] for p in routing_data.get('prompts', [])])
    print(f"\n1. Routing Sparsity: {mean_sparsity:.1%}")
    if mean_sparsity > 0.7:
        print("   -> HIGH: Most tokens activate few experts (good for Mamba)")
    else:
        print("   -> MODERATE: Mixed importance across tokens")

if alpha_path.exists():
    alpha_df = pd.read_csv(alpha_path)
    print(f"\n2. Alpha Sweep:")
    baseline = alpha_df[alpha_df['alpha'] == 'baseline']
    if not baseline.empty:
        print(f"   Baseline perplexity: {baseline['perplexity'].values[0]:.2f}")

if gated_path.exists():
    gated_df = pd.read_csv(gated_path)
    print(f"\n3. Gated Hybrid:")
    
    routing_results = gated_df[gated_df['mode'].str.contains('routing', case=False)]
    fixed_results = gated_df[gated_df['mode'].str.contains('fixed', case=False)]
    
    if not routing_results.empty and not fixed_results.empty:
        best_routing_ppl = routing_results['perplexity'].min()
        best_fixed_ppl = fixed_results['perplexity'].min()
        
        print(f"   Best routing-based perplexity: {best_routing_ppl:.2f}")
        print(f"   Best fixed-alpha perplexity: {best_fixed_ppl:.2f}")
        
        if best_routing_ppl < best_fixed_ppl:
            print("\n   CONCLUSION: Routing-based gating outperforms fixed alpha!")
            print("   The hypothesis is SUPPORTED.")
        else:
            print("\n   CONCLUSION: Fixed alpha performs better.")
            print("   Further investigation needed.")

print("\n" + "=" * 60)