# Module 08: Mixture of Experts

Mixture of Experts (MoE) allows models to scale capacity without proportionally increasing compute.

## Learning Objectives
- Understand the MoE architecture
- Learn about routing mechanisms (top-k, soft)
- Implement load balancing
- Explore capacity factors and expert utilization

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(42)
np.random.seed(42)

print("[OK] Libraries loaded")

## 1. MoE Concept

Instead of one large network, use multiple specialized "experts" and a router to select which experts process each input.

In [None]:
# Visualize MoE concept
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Dense model
ax1.add_patch(plt.Rectangle((0.3, 0.1), 0.4, 0.8, fill=True, color='lightblue', edgecolor='black'))
ax1.text(0.5, 0.5, 'Dense\nNetwork\n(all params\nused)', ha='center', va='center', fontsize=12)
ax1.arrow(0.1, 0.5, 0.15, 0, head_width=0.05, head_length=0.02, fc='black')
ax1.arrow(0.75, 0.5, 0.15, 0, head_width=0.05, head_length=0.02, fc='black')
ax1.text(0.05, 0.5, 'Input', ha='center', va='center')
ax1.text(0.95, 0.5, 'Output', ha='center', va='center')
ax1.set_xlim(0, 1)
ax1.set_ylim(0, 1)
ax1.set_title('Dense Network', fontsize=14)
ax1.axis('off')

# MoE model
ax2.add_patch(plt.Rectangle((0.35, 0.4), 0.15, 0.2, fill=True, color='lightgreen', edgecolor='black'))
ax2.text(0.425, 0.5, 'Router', ha='center', va='center', fontsize=10)

colors = ['lightblue', 'lightcoral', 'lightyellow', 'lightgray']
for i, (y, c) in enumerate(zip([0.8, 0.55, 0.3, 0.05], colors)):
    alpha = 1.0 if i in [0, 2] else 0.3  # Highlight selected experts
    ax2.add_patch(plt.Rectangle((0.55, y), 0.15, 0.15, fill=True, color=c, 
                                 edgecolor='black', alpha=alpha))
    ax2.text(0.625, y + 0.075, f'E{i+1}', ha='center', va='center', fontsize=10)

ax2.arrow(0.1, 0.5, 0.2, 0, head_width=0.03, head_length=0.02, fc='black')
ax2.arrow(0.75, 0.5, 0.15, 0, head_width=0.03, head_length=0.02, fc='black')
ax2.text(0.05, 0.5, 'Input', ha='center', va='center')
ax2.text(0.95, 0.5, 'Output', ha='center', va='center')
ax2.set_xlim(0, 1)
ax2.set_ylim(0, 1)
ax2.set_title('Mixture of Experts (only 2 experts active)', fontsize=14)
ax2.axis('off')

plt.tight_layout()
plt.show()

print("[OK] MoE uses a subset of experts per input --> efficient scaling")

## 2. Basic Expert and Router

In [None]:
class Expert(nn.Module):
    """A single expert network."""
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)


class Router(nn.Module):
    """Routes inputs to experts."""
    def __init__(self, input_dim, num_experts):
        super().__init__()
        self.gate = nn.Linear(input_dim, num_experts)
    
    def forward(self, x):
        # Returns logits for each expert
        return self.gate(x)


# Test
input_dim, hidden_dim, output_dim = 64, 128, 64
num_experts = 4

experts = nn.ModuleList([Expert(input_dim, hidden_dim, output_dim) for _ in range(num_experts)])
router = Router(input_dim, num_experts)

x = torch.randn(8, input_dim)
routing_logits = router(x)
routing_probs = F.softmax(routing_logits, dim=-1)

print(f"Input shape: {x.shape}")
print(f"Routing probabilities shape: {routing_probs.shape}")
print(f"Sample routing probs: {routing_probs[0].detach().numpy().round(3)}")

## 3. Top-K Routing

In [None]:
def top_k_routing(routing_logits, k=2):
    """
    Select top-k experts for each input.
    
    Returns:
        indices: Which experts to use (batch_size, k)
        weights: How much to weight each expert (batch_size, k)
    """
    # Get top-k experts
    top_k_logits, top_k_indices = torch.topk(routing_logits, k, dim=-1)
    
    # Normalize weights among selected experts
    top_k_weights = F.softmax(top_k_logits, dim=-1)
    
    return top_k_indices, top_k_weights

# Demo
indices, weights = top_k_routing(routing_logits, k=2)

print("Top-2 Routing:")
for i in range(min(4, len(x))):
    print(f"  Sample {i}: Experts {indices[i].tolist()} with weights {weights[i].detach().numpy().round(3)}")

## 4. Complete MoE Layer

In [None]:
class MixtureOfExperts(nn.Module):
    """
    Mixture of Experts layer with top-k routing.
    """
    def __init__(self, input_dim, hidden_dim, output_dim, num_experts=4, top_k=2):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        
        # Experts
        self.experts = nn.ModuleList([
            Expert(input_dim, hidden_dim, output_dim) 
            for _ in range(num_experts)
        ])
        
        # Router
        self.router = Router(input_dim, num_experts)
        
        # For tracking expert utilization
        self.register_buffer('expert_counts', torch.zeros(num_experts))
    
    def forward(self, x):
        batch_size = x.size(0)
        
        # Get routing decisions
        routing_logits = self.router(x)
        top_k_indices, top_k_weights = top_k_routing(routing_logits, self.top_k)
        
        # Track expert usage
        for idx in top_k_indices.flatten():
            self.expert_counts[idx] += 1
        
        # Compute weighted sum of expert outputs
        output = torch.zeros(batch_size, self.experts[0].net[-1].out_features, device=x.device)
        
        for i in range(self.top_k):
            expert_indices = top_k_indices[:, i]
            expert_weights = top_k_weights[:, i].unsqueeze(-1)
            
            for expert_idx in range(self.num_experts):
                mask = (expert_indices == expert_idx)
                if mask.any():
                    expert_input = x[mask]
                    expert_output = self.experts[expert_idx](expert_input)
                    output[mask] += expert_weights[mask] * expert_output
        
        return output
    
    def get_expert_utilization(self):
        """Return normalized expert utilization."""
        total = self.expert_counts.sum()
        if total > 0:
            return self.expert_counts / total
        return self.expert_counts
    
    def reset_counts(self):
        self.expert_counts.zero_()

# Test
moe = MixtureOfExperts(64, 128, 64, num_experts=4, top_k=2)
x = torch.randn(32, 64)
output = moe(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Expert utilization: {moe.get_expert_utilization().numpy().round(3)}")

## 5. Load Balancing Loss

Without load balancing, the router may always select the same experts.

In [None]:
def load_balancing_loss(routing_logits, num_experts):
    """
    Compute load balancing loss to encourage even expert utilization.
    
    Based on Switch Transformer paper.
    """
    # Router probabilities
    probs = F.softmax(routing_logits, dim=-1)  # (batch, num_experts)
    
    # Fraction of tokens routed to each expert
    tokens_per_expert = probs.mean(dim=0)  # Average over batch
    
    # We want uniform distribution = 1/num_experts per expert
    # Auxiliary loss encourages this
    aux_loss = num_experts * (tokens_per_expert * tokens_per_expert).sum()
    
    return aux_loss

# Demo: Compare balanced vs unbalanced routing
# Unbalanced: router always prefers expert 0
unbalanced_logits = torch.zeros(32, 4)
unbalanced_logits[:, 0] = 10  # Strong preference for expert 0

# Balanced: uniform distribution
balanced_logits = torch.zeros(32, 4)

unbalanced_loss = load_balancing_loss(unbalanced_logits, 4)
balanced_loss = load_balancing_loss(balanced_logits, 4)

print(f"Unbalanced routing loss: {unbalanced_loss.item():.4f}")
print(f"Balanced routing loss: {balanced_loss.item():.4f}")
print(f"[OK] Load balancing loss is lower when experts are evenly used")

## 6. Training with MoE

In [None]:
class MoEClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes, num_experts=4, top_k=2):
        super().__init__()
        self.moe = MixtureOfExperts(input_dim, hidden_dim, hidden_dim, num_experts, top_k)
        self.classifier = nn.Linear(hidden_dim, num_classes)
        self.aux_loss_weight = 0.1
    
    def forward(self, x):
        features = self.moe(x)
        return self.classifier(features)

# Create data
X = torch.randn(500, 20)
y = (X[:, 0] + X[:, 1] > 0).long()

# Train
model = MoEClassifier(20, 64, 2, num_experts=4, top_k=2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

losses = []
utilizations = []

for epoch in range(100):
    model.moe.reset_counts()
    
    optimizer.zero_grad()
    output = model(X)
    
    # Main loss
    main_loss = criterion(output, y)
    
    # Load balancing loss
    routing_logits = model.moe.router(X)
    aux_loss = load_balancing_loss(routing_logits, 4)
    
    # Total loss
    loss = main_loss + 0.1 * aux_loss
    loss.backward()
    optimizer.step()
    
    losses.append(main_loss.item())
    utilizations.append(model.moe.get_expert_utilization().clone())
    
    if (epoch + 1) % 20 == 0:
        acc = (output.argmax(1) == y).float().mean()
        print(f"Epoch {epoch+1}: Loss={main_loss.item():.4f}, Acc={acc:.1%}")

# Plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(losses)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss')

util_history = torch.stack(utilizations).numpy()
for i in range(4):
    ax2.plot(util_history[:, i], label=f'Expert {i+1}')
ax2.axhline(0.25, color='k', linestyle='--', label='Ideal (uniform)')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Utilization')
ax2.set_title('Expert Utilization')
ax2.legend()

plt.tight_layout()
plt.show()

## 7. Capacity Factor

The capacity factor limits how many tokens each expert can process.

In [None]:
def compute_capacity(batch_size, num_experts, top_k, capacity_factor=1.25):
    """
    Compute expert capacity.
    
    capacity = (batch_size * top_k / num_experts) * capacity_factor
    """
    tokens_per_expert = batch_size * top_k / num_experts
    capacity = int(tokens_per_expert * capacity_factor)
    return capacity

# Example
batch_size = 64
num_experts = 8
top_k = 2

for cf in [1.0, 1.25, 1.5, 2.0]:
    cap = compute_capacity(batch_size, num_experts, top_k, cf)
    print(f"Capacity factor {cf}: {cap} tokens per expert")

print(f"\n[OK] Higher capacity = more tokens can use each expert")
print(f"[OK] Lower capacity = more dropped tokens but faster compute")

## Summary

Key concepts covered:

1. **MoE Architecture**: Multiple experts + router for conditional computation
2. **Top-K Routing**: Select k experts per input
3. **Load Balancing**: Auxiliary loss to ensure even expert utilization
4. **Expert Utilization**: Track which experts are being used
5. **Capacity Factor**: Control maximum tokens per expert

## MoE Benefits

- Scale model capacity without proportional compute increase
- Different experts can specialize on different data patterns
- Sparse activation = efficient inference

## Next Steps

- [->] Module 09: Multi-Modal Architectures
- [->] Module 10: Capstone Project