# Lab 2.4.4: MoE Router Analysis - SOLUTIONS

Complete solutions for the MoE router exercises.

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

## Key Concepts Demonstrated

In [None]:
# Router weight extraction function
def extract_router_weights(model, layer_idx=0):
    '''
    Extract router weights from MoE model.
    
    For Mixtral: model.model.layers[layer_idx].block_sparse_moe.gate.weight
    For DeepSeek: model.model.layers[layer_idx].mlp.gate.weight
    '''
    for name, module in model.named_modules():
        if 'gate' in name.lower() and hasattr(module, 'weight'):
            return module.weight.data.clone()
    return None

# Load balancing loss implementation
def compute_load_balancing_loss(router_logits, num_experts=8):
    '''
    Compute auxiliary loss for load balancing.
    
    Loss = num_experts * sum(expert_usage^2)
    
    This encourages uniform distribution across experts.
    '''
    probs = F.softmax(router_logits, dim=-1)
    expert_usage = probs.mean(dim=0)
    loss = num_experts * (expert_usage ** 2).sum()
    return loss

# Demo
print('Load Balancing Loss Examples:')
print('=' * 50)

# Uniform (ideal)
uniform = torch.zeros(100, 8)
print(f'Uniform routing loss: {compute_load_balancing_loss(uniform):.4f}')

# Imbalanced (bad)
imbalanced = torch.zeros(100, 8)
imbalanced[:, 0] = 5.0
print(f'Imbalanced routing loss: {compute_load_balancing_loss(imbalanced):.4f}')

print('\n Lower loss = better load balance')

In [None]:
# Top-k routing analysis
def analyze_topk_routing(logits, top_k=2):
    '''
    Analyze top-k expert selection.
    '''
    top_values, top_indices = torch.topk(logits, k=top_k, dim=-1)
    weights = F.softmax(top_values, dim=-1)
    
    # Count expert usage
    expert_counts = {}
    for i in range(logits.shape[-1]):
        expert_counts[i] = (top_indices == i).sum().item()
    
    return {
        'expert_counts': expert_counts,
        'mean_weight_ratio': (weights[:, 0] / weights[:, 1]).mean().item(),
        'used_experts': sum(1 for c in expert_counts.values() if c > 0),
    }

# Demo with synthetic data
logits = torch.randn(500, 64)
logits[:, :5] += 1.0  # Bias towards first 5 experts

results = analyze_topk_routing(logits, top_k=2)

print('\nTop-2 Routing Analysis:')
print(f'Experts used: {results["used_experts"]}/64')
print(f'Weight ratio (1st/2nd): {results["mean_weight_ratio"]:.2f}')

# Plot
plt.figure(figsize=(12, 4))
plt.bar(range(64), [results['expert_counts'].get(i, 0) for i in range(64)])
plt.axhline(y=500*2/64, color='red', linestyle='--', label='Expected')
plt.xlabel('Expert Index')
plt.ylabel('Selection Count')
plt.title('Expert Selection Distribution')
plt.legend()
plt.tight_layout()
plt.show()