# Lab 2.4.4: MoE Router Analysis

**Module:** 2.4 - Efficient Architectures  
**Time:** 2 hours  
**Difficulty:** ‚≠ê‚≠ê‚≠ê‚≠ê (Advanced)

---

## üéØ Learning Objectives

By the end of this lab, you will:
- [ ] Understand how the router/gating network works
- [ ] Extract and analyze router weights
- [ ] Visualize expert selection distribution
- [ ] Understand load balancing and auxiliary losses

---

## üìö Prerequisites

- Completed: Lab 2.4.3 (MoE Exploration)
- Knowledge of: Softmax, top-k selection, loss functions

---

## üåç Real-World Context

**The Router Problem**

The router is the "brain" of MoE‚Äîit decides which experts process each token. A bad router:
- Uses only a few experts (wasting capacity)
- Creates training instability
- Fails to learn specialization

Understanding routers helps you:
- Debug underperforming MoE models
- Design better routing strategies
- Optimize inference for your use case

---

## üßí ELI5: The Router

> **Remember our hospital analogy?**
>
> The router is like the receptionist who decides which doctor you see.
>
> **How does the receptionist work?**
> 1. Look at your symptoms (input features)
> 2. Assign a "relevance score" for each doctor
> 3. Send you to the top 2 doctors (top-k routing)
> 4. Each doctor spends time proportional to their score (weighted combination)
>
> **The challenge:**
> - Don't send ALL patients to the same doctor (overload!)
> - Don't let any doctor sit idle (waste!)
> - This is the "load balancing" problem

### The Math

```
Router Input:  x ‚àà ‚Ñù^d           (hidden state for a token)
Router Weights: W ‚àà ‚Ñù^(n_experts √ó d)
Router Logits: logits = W @ x    (score for each expert)
Expert Weights: weights = softmax(topk(logits, k))
Output: Œ£ weights[i] √ó Expert[i](x)  (weighted expert outputs)
```

---

## Part 1: Setup and Load Model

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from typing import Dict, List, Tuple, Optional
import gc

from transformers import AutoModelForCausalLM, AutoTokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB")

In [None]:
# Load MoE model (same as Lab 2.4.3)
MODEL_NAME = "deepseek-ai/deepseek-moe-16b-base"

print(f"Loading {MODEL_NAME}...")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

print(f"‚úÖ Model loaded!")

---

## Part 2: Extracting Router Weights

In [None]:
def find_router_layers(model) -> List[Tuple[str, torch.nn.Module]]:
    """
    Find all router/gate modules in an MoE model.
    """
    routers = []
    
    for name, module in model.named_modules():
        # Check for common router names
        if any(keyword in name.lower() for keyword in ['gate', 'router']):
            if hasattr(module, 'weight'):  # Linear layer
                routers.append((name, module))
    
    return routers

# Find routers
routers = find_router_layers(model)
print(f"Found {len(routers)} router layers")

if routers:
    for i, (name, module) in enumerate(routers[:5]):
        print(f"  Layer {i}: {name}")
        if hasattr(module, 'weight'):
            print(f"           Weight shape: {module.weight.shape}")

In [None]:
def extract_router_weights(model, layer_idx: int = 0) -> Optional[torch.Tensor]:
    """
    Extract router weight matrix from a specific layer.
    
    Returns:
        Weight tensor of shape [num_experts, hidden_dim]
    """
    routers = find_router_layers(model)
    
    if layer_idx < len(routers):
        name, module = routers[layer_idx]
        return module.weight.data.clone().float()
    
    return None

# Extract router weights from first layer
router_weights = extract_router_weights(model, layer_idx=0)

if router_weights is not None:
    print(f"Router weight shape: {router_weights.shape}")
    print(f"  Interpretation: {router_weights.shape[0]} experts, {router_weights.shape[1]} hidden dim")
    
    num_experts = router_weights.shape[0]
    hidden_dim = router_weights.shape[1]
else:
    print("Could not extract router weights")
    num_experts = 64  # Default
    hidden_dim = 2048

---

## Part 3: Analyzing Router Weight Structure

In [None]:
if router_weights is not None:
    # Analyze weight statistics
    weight_norms = torch.norm(router_weights, dim=1).cpu().numpy()
    weight_means = router_weights.mean(dim=1).cpu().numpy()
    weight_stds = router_weights.std(dim=1).cpu().numpy()
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Weight norms per expert
    axes[0, 0].bar(range(num_experts), weight_norms, color='#3498DB')
    axes[0, 0].set_xlabel('Expert Index')
    axes[0, 0].set_ylabel('L2 Norm')
    axes[0, 0].set_title('Router Weight Norms by Expert', fontweight='bold')
    axes[0, 0].axhline(y=weight_norms.mean(), color='red', linestyle='--', 
                       label=f'Mean: {weight_norms.mean():.2f}')
    axes[0, 0].legend()
    
    # Weight distribution heatmap
    sample_weights = router_weights[:, :100].cpu().numpy()  # First 100 dims
    im = axes[0, 1].imshow(sample_weights, aspect='auto', cmap='RdBu_r')
    axes[0, 1].set_xlabel('Hidden Dimension (first 100)')
    axes[0, 1].set_ylabel('Expert Index')
    axes[0, 1].set_title('Router Weight Heatmap', fontweight='bold')
    plt.colorbar(im, ax=axes[0, 1])
    
    # Expert similarity matrix
    weights_norm = F.normalize(router_weights, dim=1)
    similarity = (weights_norm @ weights_norm.T).cpu().numpy()
    
    im = axes[1, 0].imshow(similarity, cmap='viridis')
    axes[1, 0].set_xlabel('Expert Index')
    axes[1, 0].set_ylabel('Expert Index')
    axes[1, 0].set_title('Expert Similarity (Cosine)', fontweight='bold')
    plt.colorbar(im, ax=axes[1, 0])
    
    # Weight variance per expert
    axes[1, 1].bar(range(num_experts), weight_stds, color='#E74C3C')
    axes[1, 1].set_xlabel('Expert Index')
    axes[1, 1].set_ylabel('Standard Deviation')
    axes[1, 1].set_title('Router Weight Variance by Expert', fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
    # Find similar experts
    print("\nüîç Expert Similarity Analysis:")
    np.fill_diagonal(similarity, 0)  # Ignore self-similarity
    
    for i in range(min(5, num_experts)):
        most_similar = np.argmax(similarity[i])
        sim_score = similarity[i, most_similar]
        print(f"  Expert {i} most similar to Expert {most_similar} (cosine: {sim_score:.3f})")

### üîç What Router Weights Tell Us

1. **Weight Norms**: Experts with higher norms may be "more confident" in their routing
2. **Heatmap**: Shows which input features each expert is sensitive to
3. **Similarity**: Low similarity means experts are diverse (good!); high means redundancy
4. **Variance**: Higher variance = more selective; lower = more general

---

## Part 4: Load Balancing Analysis

In [None]:
def analyze_load_distribution(model, tokenizer, texts: List[str], 
                             top_k: int = 2) -> Dict:
    """
    Analyze how tokens are distributed across experts.
    
    Returns:
        Dict with load statistics
    """
    # Hook to capture router outputs
    router_outputs = []
    
    def hook_fn(module, input, output):
        if isinstance(output, tuple):
            logits = output[0]
        else:
            logits = output
        router_outputs.append(logits.detach().cpu())
    
    # Attach hook to first router
    routers = find_router_layers(model)
    if not routers:
        return {}
    
    _, first_router = routers[0]
    hook = first_router.register_forward_hook(hook_fn)
    
    # Process texts
    for text in texts:
        inputs = tokenizer(text, return_tensors="pt").to(device)
        with torch.no_grad():
            _ = model(**inputs)
    
    hook.remove()
    
    # Analyze distribution
    expert_counts = defaultdict(int)
    total_tokens = 0
    
    for logits in router_outputs:
        if len(logits.shape) == 3:  # [batch, seq, experts]
            logits = logits.view(-1, logits.shape[-1])  # [tokens, experts]
        
        for token_logits in logits:
            top_experts = torch.topk(token_logits, k=min(top_k, len(token_logits))).indices
            for exp in top_experts.tolist():
                expert_counts[exp] += 1
            total_tokens += 1
    
    # Calculate statistics
    num_experts = max(expert_counts.keys()) + 1 if expert_counts else 64
    
    # Ensure all experts represented
    for i in range(num_experts):
        if i not in expert_counts:
            expert_counts[i] = 0
    
    counts = [expert_counts[i] for i in range(num_experts)]
    expected_per_expert = total_tokens * top_k / num_experts
    
    return {
        'expert_counts': dict(expert_counts),
        'total_tokens': total_tokens,
        'num_experts': num_experts,
        'top_k': top_k,
        'expected_per_expert': expected_per_expert,
        'actual_mean': np.mean(counts),
        'actual_std': np.std(counts),
        'max_load': max(counts),
        'min_load': min(counts),
        'load_balance_score': 1 - (np.std(counts) / (np.mean(counts) + 1e-8)),
    }

# Analyze load distribution
test_texts = [
    "The quick brown fox jumps over the lazy dog.",
    "def fibonacci(n): return n if n <= 1 else fibonacci(n-1) + fibonacci(n-2)",
    "In the realm of mathematics, calculus provides tools for understanding change.",
    "SELECT * FROM users WHERE created_at > '2024-01-01' ORDER BY name;",
    "Once upon a time in a land far away, there lived a brave princess.",
] * 10  # Repeat for more data

print("Analyzing load distribution...")
load_stats = analyze_load_distribution(model, tokenizer, test_texts)

if load_stats:
    print(f"\nüìä Load Balancing Statistics:")
    print(f"   Total tokens processed: {load_stats['total_tokens']}")
    print(f"   Number of experts: {load_stats['num_experts']}")
    print(f"   Top-k routing: {load_stats['top_k']}")
    print(f"\n   Expected load per expert: {load_stats['expected_per_expert']:.1f}")
    print(f"   Actual mean load: {load_stats['actual_mean']:.1f}")
    print(f"   Actual std load: {load_stats['actual_std']:.1f}")
    print(f"   Max load: {load_stats['max_load']}")
    print(f"   Min load: {load_stats['min_load']}")
    print(f"\n   Load balance score: {load_stats['load_balance_score']:.3f} (1.0 = perfect)")

In [None]:
# Visualize load distribution
if load_stats and load_stats['expert_counts']:
    counts = load_stats['expert_counts']
    experts = sorted(counts.keys())
    loads = [counts[e] for e in experts]
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Bar chart of expert loads
    colors = ['#E74C3C' if l > load_stats['expected_per_expert'] * 1.5 
              else '#27AE60' if l < load_stats['expected_per_expert'] * 0.5
              else '#3498DB' for l in loads]
    
    axes[0].bar(experts, loads, color=colors)
    axes[0].axhline(y=load_stats['expected_per_expert'], color='red', 
                   linestyle='--', label='Expected', linewidth=2)
    axes[0].set_xlabel('Expert Index')
    axes[0].set_ylabel('Token Count')
    axes[0].set_title('Expert Load Distribution\n(Red = Overloaded, Green = Underutilized)', 
                     fontweight='bold')
    axes[0].legend()
    
    # Histogram of loads
    axes[1].hist(loads, bins=20, color='#3498DB', edgecolor='white')
    axes[1].axvline(x=load_stats['expected_per_expert'], color='red', 
                   linestyle='--', label='Expected', linewidth=2)
    axes[1].set_xlabel('Token Count per Expert')
    axes[1].set_ylabel('Number of Experts')
    axes[1].set_title('Distribution of Expert Loads', fontweight='bold')
    axes[1].legend()
    
    plt.tight_layout()
    plt.show()
    
    # Identify problematic experts
    print("\n‚ö†Ô∏è Load Balance Issues:")
    overloaded = [(e, c) for e, c in counts.items() 
                  if c > load_stats['expected_per_expert'] * 1.5]
    underused = [(e, c) for e, c in counts.items() 
                 if c < load_stats['expected_per_expert'] * 0.5]
    
    if overloaded:
        print(f"   Overloaded experts: {sorted(overloaded, key=lambda x: -x[1])[:5]}")
    if underused:
        print(f"   Underutilized experts: {sorted(underused, key=lambda x: x[1])[:5]}")
    
    if not overloaded and not underused:
        print("   ‚úÖ Load is well balanced!")

---

## Part 5: Understanding Auxiliary Loss

MoE models use an **auxiliary loss** to encourage load balancing during training.

In [None]:
def compute_load_balancing_loss(router_logits: torch.Tensor, 
                                top_k: int = 2) -> torch.Tensor:
    """
    Compute the auxiliary load balancing loss.
    
    This encourages the router to use all experts equally.
    
    Args:
        router_logits: [num_tokens, num_experts]
        top_k: Number of experts selected per token
    
    Returns:
        Scalar loss value
    """
    num_tokens, num_experts = router_logits.shape
    
    # Compute routing probabilities (softmax)
    routing_probs = F.softmax(router_logits, dim=-1)  # [tokens, experts]
    
    # Fraction of tokens routed to each expert (average)
    expert_usage = routing_probs.mean(dim=0)  # [experts]
    
    # Ideal: each expert gets 1/num_experts of tokens
    ideal_usage = 1.0 / num_experts
    
    # Load balancing loss: variance from ideal
    # We want to minimize how much expert usage differs from uniform
    load_balance_loss = num_experts * (expert_usage ** 2).sum()
    
    return load_balance_loss

# Demonstrate with synthetic data
print("üìä Auxiliary Loss Demonstration:")
print("=" * 50)

num_tokens = 100
num_experts_demo = 8

# Case 1: Uniform routing (ideal)
uniform_logits = torch.zeros(num_tokens, num_experts_demo)
uniform_loss = compute_load_balancing_loss(uniform_logits)
print(f"\nUniform routing logits (ideal):")
print(f"  Loss: {uniform_loss:.4f}")

# Case 2: Imbalanced routing (bad)
imbalanced_logits = torch.zeros(num_tokens, num_experts_demo)
imbalanced_logits[:, 0] = 5.0  # First expert heavily preferred
imbalanced_loss = compute_load_balancing_loss(imbalanced_logits)
print(f"\nImbalanced routing (expert 0 preferred):")
print(f"  Loss: {imbalanced_loss:.4f}")

# Case 3: Slight imbalance
slight_imbalance = torch.randn(num_tokens, num_experts_demo) * 0.5
slight_loss = compute_load_balancing_loss(slight_imbalance)
print(f"\nSlight imbalance (random):")
print(f"  Loss: {slight_loss:.4f}")

print(f"\nüí° Lower loss = better load balance")
print(f"   The auxiliary loss is added to the main training loss")
print(f"   to prevent expert collapse during training.")

In [None]:
# Visualize how auxiliary loss affects routing

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

scenarios = [
    ("Uniform (Low Loss)", torch.zeros(100, 8)),
    ("Random (Medium Loss)", torch.randn(100, 8) * 0.5),
    ("Collapsed (High Loss)", torch.zeros(100, 8)),
]
scenarios[2][1][:, 0] = 5.0  # Collapse to expert 0

for idx, (title, logits) in enumerate(scenarios):
    probs = F.softmax(logits, dim=-1)
    expert_usage = probs.mean(dim=0).numpy()
    loss = compute_load_balancing_loss(logits)
    
    axes[idx].bar(range(8), expert_usage, color='#3498DB')
    axes[idx].axhline(y=1/8, color='red', linestyle='--', label='Ideal')
    axes[idx].set_xlabel('Expert Index')
    axes[idx].set_ylabel('Usage Fraction')
    axes[idx].set_title(f'{title}\nLoss: {loss:.2f}', fontweight='bold')
    axes[idx].set_ylim(0, 1)
    axes[idx].legend()

plt.tight_layout()
plt.show()

---

## Part 6: Top-K Routing Deep Dive

In [None]:
def simulate_topk_routing(logits: torch.Tensor, top_k: int = 2) -> Dict:
    """
    Simulate top-k routing and analyze the results.
    
    Args:
        logits: [num_tokens, num_experts]
        top_k: Number of experts per token
    
    Returns:
        Dict with routing statistics
    """
    num_tokens, num_experts = logits.shape
    
    # Get top-k experts and their weights
    top_values, top_indices = torch.topk(logits, k=top_k, dim=-1)
    
    # Normalize weights (softmax over selected experts only)
    routing_weights = F.softmax(top_values, dim=-1)
    
    # Analyze
    expert_counts = defaultdict(int)
    expert_weight_sums = defaultdict(float)
    
    for token_idx in range(num_tokens):
        for k_idx in range(top_k):
            expert = top_indices[token_idx, k_idx].item()
            weight = routing_weights[token_idx, k_idx].item()
            expert_counts[expert] += 1
            expert_weight_sums[expert] += weight
    
    # First expert weight statistics
    first_expert_weights = routing_weights[:, 0]
    
    return {
        'expert_counts': dict(expert_counts),
        'expert_weight_sums': dict(expert_weight_sums),
        'mean_first_expert_weight': first_expert_weights.mean().item(),
        'std_first_expert_weight': first_expert_weights.std().item(),
        'routing_weights': routing_weights,
        'top_indices': top_indices,
    }

# Simulate routing with different scenarios
print("üìä Top-K Routing Analysis")
print("=" * 50)

# Generate realistic-ish router logits
num_tokens = 500
num_experts_sim = 64

# Simulate: some experts naturally preferred
base_logits = torch.randn(num_tokens, num_experts_sim)
# Add bias to a few "popular" experts
base_logits[:, :5] += 1.0

for top_k in [1, 2, 4]:
    results = simulate_topk_routing(base_logits, top_k=top_k)
    
    print(f"\nTop-{top_k} Routing:")
    print(f"  Mean first expert weight: {results['mean_first_expert_weight']:.3f}")
    print(f"  Std first expert weight: {results['std_first_expert_weight']:.3f}")
    
    # How many experts actually used?
    used_experts = len([c for c in results['expert_counts'].values() if c > 0])
    print(f"  Experts actually used: {used_experts}/{num_experts_sim}")

In [None]:
# Visualize top-k routing behavior

results = simulate_topk_routing(base_logits, top_k=2)

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Expert selection frequency
counts = results['expert_counts']
experts = sorted(counts.keys())
freq = [counts.get(e, 0) for e in experts]

axes[0].bar(experts, freq, color='#3498DB')
axes[0].axhline(y=num_tokens * 2 / num_experts_sim, color='red', 
               linestyle='--', label='Expected')
axes[0].set_xlabel('Expert Index')
axes[0].set_ylabel('Selection Count')
axes[0].set_title('Expert Selection Frequency', fontweight='bold')
axes[0].legend()

# First vs second expert weight distribution
weights = results['routing_weights']
axes[1].hist(weights[:, 0].numpy(), bins=30, alpha=0.7, label='1st Expert', color='#27AE60')
axes[1].hist(weights[:, 1].numpy(), bins=30, alpha=0.7, label='2nd Expert', color='#E74C3C')
axes[1].set_xlabel('Routing Weight')
axes[1].set_ylabel('Frequency')
axes[1].set_title('Distribution of Expert Weights', fontweight='bold')
axes[1].legend()

# Weight ratio (1st / 2nd)
weight_ratio = weights[:, 0] / (weights[:, 1] + 1e-8)
axes[2].hist(weight_ratio.numpy(), bins=30, color='#9B59B6')
axes[2].axvline(x=1.0, color='red', linestyle='--', label='Equal weights')
axes[2].set_xlabel('Weight Ratio (1st / 2nd)')
axes[2].set_ylabel('Frequency')
axes[2].set_title('1st vs 2nd Expert Weight Ratio', fontweight='bold')
axes[2].legend()

plt.tight_layout()
plt.show()

print(f"\nüí° Observations:")
print(f"   - Mean weight ratio: {weight_ratio.mean():.2f}")
print(f"   - The 1st expert typically gets {weights[:, 0].mean():.1%} of the weight")
print(f"   - This means the 2nd expert provides {weights[:, 1].mean():.1%} refinement")

---

## ‚ö†Ô∏è Common Mistakes

### Mistake 1: Ignoring Load Balance During Fine-Tuning
```python
# ‚ùå Forgetting auxiliary loss
loss = cross_entropy_loss

# ‚úÖ Include load balancing
loss = cross_entropy_loss + alpha * load_balance_loss
# alpha typically 0.01-0.1
```

### Mistake 2: Wrong Top-K for Your Use Case
```python
# ‚ùå Always using top-2
# ‚úÖ Consider:
#    - top-1: Fastest, but less expert diversity
#    - top-2: Good balance (most common)
#    - top-4+: More capacity, slower inference
```

### Mistake 3: Not Monitoring Expert Usage
```python
# ‚ùå Training without monitoring
# ‚úÖ Log expert usage during training
wandb.log({"expert_usage": expert_counts})
```

---

## üéâ Checkpoint

You've learned:
- ‚úÖ How the router selects experts for each token
- ‚úÖ Extracting and analyzing router weights
- ‚úÖ Load balancing and its importance
- ‚úÖ The auxiliary loss mechanism
- ‚úÖ Top-k routing behavior and tradeoffs

---

## üßπ Cleanup

In [None]:
if 'model' in dir():
    del model
if 'tokenizer' in dir():
    del tokenizer

torch.cuda.empty_cache()
gc.collect()

print("‚úÖ Cleanup complete!")