# Topic 13: Mixture of Experts (MoE) - Scaling Models with Sparse Computation

## Learning Objectives

By the end of this notebook, you will:
- Understand why sparse MoE models scale better than dense models
- Learn different routing mechanisms (top-k, softmax, expert choice)
- Implement MoE from scratch with full technical detail
- Master load balancing techniques and auxiliary losses
- Know how to handle routing collapse and expert specialization
- Understand MoE in GPT-4, DeepSeek-V3, and Mixtral
- Learn expert parallelism strategies for distributed training

---

## 1. The Big Picture: Why Mixture of Experts?

### The Scaling Problem with Dense Models

**Traditional approach**: To improve model quality, increase parameters
- GPT-2: 1.5B parameters
- GPT-3: 175B parameters
- GPT-4: ~1.8T parameters (rumored)

**Problem**: Every token activates **all parameters**
- 175B model: Every forward pass uses all 175B params
- Compute scales linearly with parameters
- Memory bandwidth becomes bottleneck
- Inference cost is prohibitive

### The Sparse Solution: Mixture of Experts

**Key insight**: Not all parameters are needed for every input!

**MoE approach**:
1. Split model into **expert sub-networks**
2. Use **router** to select which experts to activate
3. Each token only uses **subset of total parameters**

**Example**: 8 experts, top-2 routing
- Total parameters: 8 × expert_size
- Active parameters per token: 2 × expert_size
- **Compute**: Same as 2-expert model
- **Capacity**: Same as 8-expert model

### The Magic: Decoupling Parameters from Compute

```
Dense Model:
  Parameters: 175B
  Compute per token: 175B FLOPs
  Scaling: Linear

MoE Model (8 experts, top-2):
  Parameters: 1.8T (10x more!)
  Compute per token: ~450B FLOPs (2.5x, not 10x!)
  Scaling: Sub-linear!
```

**Result**: 10x more parameters, only 2.5x more compute!

### Real-World Impact

**GPT-4** (rumored architecture):
- 16 experts, ~1.8T total parameters
- Top-2 routing
- ~220B active parameters per token
- Same compute as dense 220B model
- Quality of 1.8T model!

**DeepSeek-V3** (2025):
- 256 experts, 671B total parameters
- Top-8 routing with expert specialization
- ~37B active parameters per token
- Achieves GPT-4 level quality at fraction of cost

**Mixtral 8x7B**:
- 8 experts, 47B total parameters
- Top-2 routing
- ~13B active parameters per token
- Outperforms Llama 2 70B while being 5x smaller active size

### The Trade-offs

**Advantages**:
- 💚 **Scaling efficiency**: Sub-linear compute growth
- 💚 **Specialization**: Experts learn different skills
- 💚 **Quality**: Better than dense models at same compute

**Challenges**:
- ⚠️ **Load balancing**: Experts must be used evenly
- ⚠️ **Routing collapse**: Router may only use few experts
- ⚠️ **Memory**: Total parameters still take memory
- ⚠️ **Communication**: Expert parallelism needs fast interconnect

---

In [None]:
# Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from typing import Tuple, Optional, List
import seaborn as sns

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set style for better plots
plt.style.use('default')
sns.set_palette("husl")

---

## 2. MoE Architecture: Components

### Core Components

An MoE layer consists of three parts:

```
Input (batch, seq_len, d_model)
       ↓
1. ROUTER: Scores experts for each token
   → Router weights (batch, seq_len, num_experts)
       ↓
2. TOP-K SELECTION: Choose top-k experts per token
   → Selected experts + routing weights
       ↓
3. EXPERTS: Parallel feedforward networks
   → Expert outputs combined by routing weights
       ↓
Output (batch, seq_len, d_model)
```

### Mathematical Formulation

For input token $x \in \mathbb{R}^{d_{model}}$:

**1. Router logits**:
$$
h = x W_g \quad \text{where } W_g \in \mathbb{R}^{d_{model} \times n_{experts}}
$$

**2. Top-k selection**:
$$
\text{TopK}(h) = \{e_1, ..., e_k\} \quad \text{(indices of top k experts)}
$$

**3. Routing weights** (softmax over selected):
$$
p_i = \frac{\exp(h_{e_i})}{\sum_{j=1}^k \exp(h_{e_j})} \quad \text{for } i \in \{1, ..., k\}
$$

**4. Expert outputs**:
$$
E_i(x) = \text{Expert}_i(x) \quad \text{(feedforward network)}
$$

**5. Final output**:
$$
y = \sum_{i=1}^k p_i \cdot E_{e_i}(x)
$$

### Design Choices

**Number of experts**:
- More experts = more capacity, harder to load balance
- Common: 8 (Mixtral), 16 (GPT-4?), 64-256 (DeepSeek-V3)

**Top-k value**:
- k=1: Fastest, but routing collapse risk
- k=2: Most common, good balance (Mixtral, GPT-4)
- k=8: Used by DeepSeek-V3 for better quality

**Expert architecture**:
- Typically FFN layers (same as transformer FFN)
- Can use different activations (SwiGLU common)
- Size: Usually same as dense FFN per expert

---

## 3. Implementing Experts and Router

### Step 1: Expert Network

Each expert is a standard feedforward network.

In [None]:
class Expert(nn.Module):
    """Single expert: a feedforward network"""
    
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff)
        self.w2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (*, d_model)
        Returns:
            (*, d_model)
        """
        # Standard FFN: Linear → ReLU → Dropout → Linear
        return self.w2(self.dropout(F.relu(self.w1(x))))


class SwiGLUExpert(nn.Module):
    """Expert using SwiGLU activation (used in LLaMA, Mixtral)"""
    
    def __init__(self, d_model: int, d_ff: int):
        super().__init__()
        # SwiGLU needs 3 weight matrices
        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        self.w2 = nn.Linear(d_ff, d_model, bias=False)
        self.w3 = nn.Linear(d_model, d_ff, bias=False)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        SwiGLU(x) = (Swish(W1·x) ⊙ W3·x) W2
        """
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


# Demo
d_model = 512
d_ff = 2048

expert_relu = Expert(d_model, d_ff).to(device)
expert_swiglu = SwiGLUExpert(d_model, d_ff).to(device)

x = torch.randn(4, 128, d_model, device=device)

out_relu = expert_relu(x)
out_swiglu = expert_swiglu(x)

print("Expert Networks")
print("="*60)
print(f"Input: {x.shape}")
print(f"ReLU Expert output: {out_relu.shape}")
print(f"SwiGLU Expert output: {out_swiglu.shape}")
print(f"\nReLU Expert params: {sum(p.numel() for p in expert_relu.parameters()):,}")
print(f"SwiGLU Expert params: {sum(p.numel() for p in expert_swiglu.parameters()):,}")
print("\n💡 SwiGLU has ~1.5x more parameters but better quality")

### Step 2: Router

The router decides which experts to use for each token.

In [None]:
class TopKRouter(nn.Module):
    """Top-k router: selects k experts per token"""
    
    def __init__(self, d_model: int, num_experts: int, top_k: int = 2):
        super().__init__()
        self.d_model = d_model
        self.num_experts = num_experts
        self.top_k = top_k
        
        # Router is a simple linear layer
        self.gate = nn.Linear(d_model, num_experts, bias=False)
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
            x: (batch, seq_len, d_model)
        
        Returns:
            expert_indices: (batch, seq_len, top_k) - which experts to use
            expert_weights: (batch, seq_len, top_k) - how much to weight each
            router_logits: (batch, seq_len, num_experts) - raw logits (for aux loss)
        """
        # Compute router logits
        router_logits = self.gate(x)  # (batch, seq_len, num_experts)
        
        # Get top-k experts
        top_k_logits, expert_indices = torch.topk(router_logits, self.top_k, dim=-1)
        # expert_indices: (batch, seq_len, top_k)
        # top_k_logits: (batch, seq_len, top_k)
        
        # Softmax over selected experts to get weights
        expert_weights = F.softmax(top_k_logits, dim=-1)
        # expert_weights: (batch, seq_len, top_k)
        
        return expert_indices, expert_weights, router_logits


# Demo router
num_experts = 8
top_k = 2

router = TopKRouter(d_model, num_experts, top_k).to(device)
x = torch.randn(4, 128, d_model, device=device)

expert_indices, expert_weights, router_logits = router(x)

print("\nTop-K Router")
print("="*60)
print(f"Input: {x.shape}")
print(f"Router logits: {router_logits.shape}")
print(f"Selected experts: {expert_indices.shape}")
print(f"Expert weights: {expert_weights.shape}")

print(f"\nExample token routing:")
print(f"  Token 0, sample 0:")
print(f"    Selected experts: {expert_indices[0, 0].cpu().numpy()}")
print(f"    Expert weights: {expert_weights[0, 0].cpu().numpy()}")
print(f"    Weights sum to: {expert_weights[0, 0].sum().item():.4f}")

# Analyze routing distribution
expert_usage = torch.zeros(num_experts, device=device)
for i in range(num_experts):
    expert_usage[i] = (expert_indices == i).sum().item()

expert_usage = expert_usage / expert_usage.sum() * 100  # Convert to percentage

print(f"\nExpert usage distribution:")
for i in range(num_experts):
    print(f"  Expert {i}: {expert_usage[i].item():.1f}%")
print(f"\n⚠️ Ideally all experts should be used ~equally ({100/num_experts:.1f}% each)")

---

## 4. Complete MoE Layer Implementation

Now let's combine router and experts into a full MoE layer.

In [None]:
class MixtureOfExpertsLayer(nn.Module):
    """Complete Mixture of Experts layer"""
    
    def __init__(
        self,
        d_model: int,
        d_ff: int,
        num_experts: int,
        top_k: int = 2,
        expert_type: str = "relu",  # "relu" or "swiglu"
        dropout: float = 0.1
    ):
        super().__init__()
        self.d_model = d_model
        self.num_experts = num_experts
        self.top_k = top_k
        
        # Router
        self.router = TopKRouter(d_model, num_experts, top_k)
        
        # Experts
        if expert_type == "relu":
            self.experts = nn.ModuleList([
                Expert(d_model, d_ff, dropout) for _ in range(num_experts)
            ])
        elif expert_type == "swiglu":
            self.experts = nn.ModuleList([
                SwiGLUExpert(d_model, d_ff) for _ in range(num_experts)
            ])
        else:
            raise ValueError(f"Unknown expert_type: {expert_type}")
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, dict]:
        """
        Args:
            x: (batch, seq_len, d_model)
        
        Returns:
            output: (batch, seq_len, d_model)
            aux_info: Dict with routing statistics for load balancing
        """
        batch_size, seq_len, d_model = x.shape
        
        # 1. Route tokens to experts
        expert_indices, expert_weights, router_logits = self.router(x)
        # expert_indices: (batch, seq_len, top_k)
        # expert_weights: (batch, seq_len, top_k)
        
        # 2. Flatten for easier processing
        x_flat = x.view(-1, d_model)  # (batch * seq_len, d_model)
        expert_indices_flat = expert_indices.view(-1, self.top_k)  # (batch * seq_len, top_k)
        expert_weights_flat = expert_weights.view(-1, self.top_k)  # (batch * seq_len, top_k)
        
        # 3. Initialize output
        output_flat = torch.zeros_like(x_flat)
        
        # 4. Process each expert
        # This is the naive implementation - see later for optimized version
        for expert_idx in range(self.num_experts):
            # Find all tokens that selected this expert
            expert_mask = (expert_indices_flat == expert_idx)  # (batch * seq_len, top_k)
            
            # Get tokens assigned to this expert
            token_expert_indices = expert_mask.any(dim=-1)  # (batch * seq_len,)
            
            if token_expert_indices.any():
                # Get tokens for this expert
                expert_input = x_flat[token_expert_indices]  # (num_tokens, d_model)
                
                # Process through expert
                expert_output = self.experts[expert_idx](expert_input)
                
                # Get weights for this expert's contribution
                # For each token, find which position in top_k this expert was
                expert_weights_for_tokens = torch.zeros(
                    token_expert_indices.sum(), device=x.device
                )
                
                expert_weights_full = expert_weights_flat[token_expert_indices]  # (num_tokens, top_k)
                expert_mask_full = expert_mask[token_expert_indices]  # (num_tokens, top_k)
                
                # Get weight where this expert was selected
                for i in range(self.top_k):
                    mask = expert_mask_full[:, i]
                    expert_weights_for_tokens[mask] = expert_weights_full[mask, i]
                
                # Weighted contribution
                output_flat[token_expert_indices] += expert_output * expert_weights_for_tokens.unsqueeze(-1)
        
        # 5. Reshape back
        output = output_flat.view(batch_size, seq_len, d_model)
        
        # 6. Collect auxiliary information for load balancing
        aux_info = {
            'router_logits': router_logits,
            'expert_indices': expert_indices,
            'expert_weights': expert_weights
        }
        
        return output, aux_info


# Demo MoE layer
print("\nMixture of Experts Layer")
print("="*70)

d_model = 512
d_ff = 2048
num_experts = 8
top_k = 2

moe = MixtureOfExpertsLayer(
    d_model, d_ff, num_experts, top_k, expert_type="relu"
).to(device)

batch_size = 4
seq_len = 128
x = torch.randn(batch_size, seq_len, d_model, device=device)

output, aux_info = moe(x)

print(f"Input: {x.shape}")
print(f"Output: {output.shape}")
print(f"\nAuxiliary info keys: {list(aux_info.keys())}")

# Analyze expert usage
expert_indices = aux_info['expert_indices']
expert_usage = torch.zeros(num_experts, device=device)
for i in range(num_experts):
    expert_usage[i] = (expert_indices == i).sum().item()

total_assignments = expert_usage.sum()
expert_usage_pct = expert_usage / total_assignments * 100

print(f"\nExpert usage distribution:")
for i in range(num_experts):
    bar = '█' * int(expert_usage_pct[i].item() / 2)
    print(f"  Expert {i}: {expert_usage_pct[i].item():5.1f}% {bar}")

ideal_pct = 100 / num_experts
print(f"\n💡 Ideal: {ideal_pct:.1f}% per expert")
print(f"   Actual variance: {expert_usage_pct.std().item():.2f}")

# Calculate parameters
total_params = sum(p.numel() for p in moe.parameters())
dense_equivalent = Expert(d_model, d_ff).to(device)
dense_params = sum(p.numel() for p in dense_equivalent.parameters())

print(f"\nParameter comparison:")
print(f"  MoE total params: {total_params:,}")
print(f"  Dense FFN params: {dense_params:,}")
print(f"  MoE / Dense ratio: {total_params / dense_params:.1f}x")
print(f"\n  Active params per token: ~{2 * dense_params:,} ({top_k} experts)")
print(f"  Total capacity: {num_experts}x dense model")
print(f"  Compute cost: {top_k}x dense model")

---

## 5. The Load Balancing Problem

### What is Load Imbalance?

**Problem**: Router may send all tokens to a few popular experts
- Expert 0: 80% of tokens
- Expert 1: 15% of tokens
- Experts 2-7: 5% of tokens total

**Why this is bad**:
1. **Underutilization**: Most experts learn nothing
2. **Bottleneck**: Popular experts become bottleneck
3. **Capacity waste**: Paying for 8 experts, using 2
4. **Routing collapse**: Eventually all tokens go to one expert

### Solution: Auxiliary Load Balancing Loss

**Approach**: Add penalty when experts are used unevenly

#### Load Balancing Loss (Switch Transformer)

For $N$ tokens and $E$ experts:

**1. Fraction of tokens assigned to each expert**:
$$
f_i = \frac{1}{N} \sum_{j=1}^N \mathbb{1}[\arg\max(h_j) = i]
$$

**2. Router probability mass for each expert**:
$$
P_i = \frac{1}{N} \sum_{j=1}^N p_{j,i}
$$
where $p_{j,i}$ is the router probability for expert $i$ on token $j$

**3. Load balancing loss**:
$$
\mathcal{L}_{bal} = E \cdot \sum_{i=1}^E f_i \cdot P_i
$$

**Why this works**:
- If expert $i$ gets many tokens ($f_i$ high) AND router gives it high probability ($P_i$ high), loss is high
- Minimizing this loss encourages even distribution
- Multiplier $E$ scales loss to be invariant to number of experts

**Total training loss**:
$$
\mathcal{L}_{total} = \mathcal{L}_{task} + \alpha \cdot \mathcal{L}_{bal}
$$
where $\alpha$ is typically 0.01 to 0.1

In [None]:
def load_balancing_loss(router_logits: torch.Tensor, expert_indices: torch.Tensor) -> torch.Tensor:
    """
    Compute load balancing auxiliary loss (Switch Transformer)
    
    Args:
        router_logits: (batch, seq_len, num_experts) - raw router logits
        expert_indices: (batch, seq_len, top_k) - selected expert indices
    
    Returns:
        loss: Scalar load balancing loss
    """
    batch_size, seq_len, num_experts = router_logits.shape
    top_k = expert_indices.shape[-1]
    
    # 1. Router probabilities (softmax over all experts)
    router_probs = F.softmax(router_logits, dim=-1)  # (batch, seq_len, num_experts)
    
    # 2. Average probability per expert (P_i)
    P = router_probs.mean(dim=[0, 1])  # (num_experts,)
    
    # 3. Fraction of tokens assigned to each expert (f_i)
    # Create one-hot encoding of expert assignments
    num_tokens = batch_size * seq_len * top_k
    expert_mask = F.one_hot(expert_indices, num_experts).float()  # (batch, seq_len, top_k, num_experts)
    expert_mask = expert_mask.sum(dim=2)  # Sum over top_k: (batch, seq_len, num_experts)
    f = expert_mask.sum(dim=[0, 1]) / num_tokens  # (num_experts,)
    
    # 4. Load balancing loss: num_experts * sum(f_i * P_i)
    loss = num_experts * (f * P).sum()
    
    return loss


# Demo load balancing loss
print("\nLoad Balancing Loss")
print("="*70)

# Simulate balanced vs imbalanced routing
batch, seq_len, num_experts = 4, 128, 8
top_k = 2

# Case 1: Balanced routing (all experts used equally)
balanced_logits = torch.randn(batch, seq_len, num_experts)
balanced_indices = torch.randint(0, num_experts, (batch, seq_len, top_k))

loss_balanced = load_balancing_loss(balanced_logits, balanced_indices)

# Case 2: Imbalanced routing (favor first 2 experts)
imbalanced_logits = torch.randn(batch, seq_len, num_experts)
imbalanced_logits[..., :2] += 3.0  # Boost first 2 experts
imbalanced_indices = torch.randint(0, 2, (batch, seq_len, top_k))  # Only use first 2

loss_imbalanced = load_balancing_loss(imbalanced_logits, imbalanced_indices)

print(f"Balanced routing loss: {loss_balanced.item():.4f}")
print(f"Imbalanced routing loss: {loss_imbalanced.item():.4f}")
print(f"\n💡 Imbalanced loss is {loss_imbalanced / loss_balanced:.2f}x higher!")
print("   This penalty encourages the router to use all experts.")

# Visualize expert usage
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Balanced case
expert_usage_balanced = torch.zeros(num_experts)
for i in range(num_experts):
    expert_usage_balanced[i] = (balanced_indices == i).sum().item()
expert_usage_balanced = expert_usage_balanced / expert_usage_balanced.sum() * 100

axes[0].bar(range(num_experts), expert_usage_balanced.numpy(), color='green', alpha=0.7)
axes[0].axhline(y=100/num_experts, color='r', linestyle='--', label='Ideal')
axes[0].set_xlabel('Expert Index', fontsize=12)
axes[0].set_ylabel('Usage (%)', fontsize=12)
axes[0].set_title(f'Balanced Routing (Loss={loss_balanced.item():.3f})', fontsize=14)
axes[0].legend()
axes[0].grid(True, alpha=0.3, axis='y')

# Imbalanced case
expert_usage_imbalanced = torch.zeros(num_experts)
for i in range(num_experts):
    expert_usage_imbalanced[i] = (imbalanced_indices == i).sum().item()
expert_usage_imbalanced = expert_usage_imbalanced / expert_usage_imbalanced.sum() * 100

axes[1].bar(range(num_experts), expert_usage_imbalanced.numpy(), color='red', alpha=0.7)
axes[1].axhline(y=100/num_experts, color='r', linestyle='--', label='Ideal')
axes[1].set_xlabel('Expert Index', fontsize=12)
axes[1].set_ylabel('Usage (%)', fontsize=12)
axes[1].set_title(f'Imbalanced Routing (Loss={loss_imbalanced.item():.3f})', fontsize=14)
axes[1].legend()
axes[1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print("\n📊 Higher load balancing loss → More imbalanced expert usage")

---

## 6. Advanced Routing: Expert Choice

### The Capacity Problem

**Traditional top-k routing problem**:
- Token chooses experts
- Multiple tokens can choose same expert
- Expert may get overwhelmed (capacity overflow)
- Need to drop tokens or increase capacity

### Expert Choice Routing (DeepSeek-V3)

**Reverse the process**: Experts choose tokens!

```
Traditional: Each token picks top-k experts
  Token 1 → Expert 0, Expert 2
  Token 2 → Expert 0, Expert 3
  Token 3 → Expert 0, Expert 1
  → Expert 0 is overwhelmed!

Expert Choice: Each expert picks top-k tokens
  Expert 0 → Token 1, Token 5
  Expert 1 → Token 3, Token 7
  Expert 2 → Token 2, Token 4
  → Perfectly balanced!
```

### Algorithm

1. Compute router scores for all (token, expert) pairs
2. For each expert, select top-k tokens with highest scores
3. Process exactly k tokens per expert
4. If token not selected by any expert, use residual connection

In [None]:
class ExpertChoiceRouter(nn.Module):
    """Expert Choice routing: experts select tokens (DeepSeek-V3 style)"""
    
    def __init__(self, d_model: int, num_experts: int, capacity_factor: float = 1.25):
        super().__init__()
        self.d_model = d_model
        self.num_experts = num_experts
        self.capacity_factor = capacity_factor
        
        self.gate = nn.Linear(d_model, num_experts, bias=False)
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
            x: (batch, seq_len, d_model)
        
        Returns:
            expert_assignment: (num_experts, capacity) - which tokens each expert gets
            expert_weights: (num_experts, capacity) - weights for each assignment
            router_logits: (batch, seq_len, num_experts) - for aux loss
        """
        batch_size, seq_len, _ = x.shape
        num_tokens = batch_size * seq_len
        
        # Compute router logits
        router_logits = self.gate(x)  # (batch, seq_len, num_experts)
        router_logits_flat = router_logits.view(num_tokens, self.num_experts)
        
        # Capacity: how many tokens each expert can handle
        capacity = int(num_tokens / self.num_experts * self.capacity_factor)
        
        # For each expert, select top-k tokens
        expert_assignment = torch.zeros(
            self.num_experts, capacity, dtype=torch.long, device=x.device
        )
        expert_weights = torch.zeros(
            self.num_experts, capacity, device=x.device
        )
        
        for expert_idx in range(self.num_experts):
            # Get scores for this expert across all tokens
            expert_scores = router_logits_flat[:, expert_idx]  # (num_tokens,)
            
            # Select top-capacity tokens
            top_scores, top_indices = torch.topk(expert_scores, capacity)
            
            expert_assignment[expert_idx] = top_indices
            expert_weights[expert_idx] = F.softmax(top_scores, dim=-1)
        
        return expert_assignment, expert_weights, router_logits


# Demo expert choice routing
print("\nExpert Choice Routing")
print("="*70)

ec_router = ExpertChoiceRouter(d_model, num_experts=8, capacity_factor=1.25).to(device)
x = torch.randn(4, 128, d_model, device=device)

expert_assignment, expert_weights, router_logits = ec_router(x)

batch_size, seq_len, _ = x.shape
num_tokens = batch_size * seq_len
capacity = expert_assignment.shape[1]

print(f"Input: {x.shape}")
print(f"Number of tokens: {num_tokens}")
print(f"Capacity per expert: {capacity}")
print(f"Expert assignment shape: {expert_assignment.shape}")
print(f"Expert weights shape: {expert_weights.shape}")

# Check how many tokens each expert gets
print(f"\nTokens per expert:")
for i in range(8):
    print(f"  Expert {i}: {capacity} tokens (guaranteed)")

print(f"\n💡 Expert Choice guarantees perfect load balancing!")
print(f"   Each expert processes exactly {capacity} tokens.")
print(f"   No auxiliary loss needed!")

---

## 7. Real-World MoE: Mixtral Architecture

Let's implement a complete transformer block with MoE, matching Mixtral's design.

In [None]:
class RMSNorm(nn.Module):
    """RMS Layer Normalization"""
    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return x / rms * self.weight


class MixtralMoEBlock(nn.Module):
    """
    Complete Mixtral-style transformer block with MoE
    
    Architecture:
    - RMSNorm → Attention → Residual
    - RMSNorm → MoE FFN → Residual
    """
    
    def __init__(
        self,
        d_model: int,
        num_heads: int,
        d_ff: int,
        num_experts: int = 8,
        top_k: int = 2,
        dropout: float = 0.1
    ):
        super().__init__()
        
        # Attention
        self.attn_norm = RMSNorm(d_model)
        self.attention = nn.MultiheadAttention(
            d_model, num_heads, dropout=dropout, batch_first=True
        )
        
        # MoE FFN
        self.ffn_norm = RMSNorm(d_model)
        self.moe = MixtureOfExpertsLayer(
            d_model, d_ff, num_experts, top_k,
            expert_type="swiglu", dropout=dropout
        )
    
    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, dict]:
        """
        Args:
            x: (batch, seq_len, d_model)
            mask: Optional attention mask
        
        Returns:
            output: (batch, seq_len, d_model)
            aux_info: MoE routing statistics
        """
        # Attention block
        attn_out, _ = self.attention(
            self.attn_norm(x),
            self.attn_norm(x),
            self.attn_norm(x),
            attn_mask=mask,
            need_weights=False
        )
        x = x + attn_out
        
        # MoE FFN block
        moe_out, aux_info = self.moe(self.ffn_norm(x))
        x = x + moe_out
        
        return x, aux_info


# Demo Mixtral-style block
print("\nMixtral-Style MoE Transformer Block")
print("="*70)

d_model = 4096
num_heads = 32
d_ff = 14336  # Mixtral uses larger FFN
num_experts = 8
top_k = 2

mixtral_block = MixtralMoEBlock(
    d_model, num_heads, d_ff, num_experts, top_k
).to(device)

batch_size = 2
seq_len = 256
x = torch.randn(batch_size, seq_len, d_model, device=device)

output, aux_info = mixtral_block(x)

print(f"Configuration:")
print(f"  d_model: {d_model}")
print(f"  Heads: {num_heads}")
print(f"  d_ff: {d_ff}")
print(f"  Experts: {num_experts}")
print(f"  Top-k: {top_k}")

print(f"\nShapes:")
print(f"  Input: {x.shape}")
print(f"  Output: {output.shape}")

total_params = sum(p.numel() for p in mixtral_block.parameters())
print(f"\nTotal parameters: {total_params:,}")

# Calculate active parameters
attn_params = sum(p.numel() for p in mixtral_block.attention.parameters())
expert_params = sum(p.numel() for p in mixtral_block.moe.experts[0].parameters())
active_params = attn_params + (top_k * expert_params)

print(f"Active parameters per token: {active_params:,}")
print(f"Parameter efficiency: {total_params / active_params:.1f}x")

print(f"\n✅ This matches Mixtral 8x7B architecture!")
print(f"   47B total params, ~13B active per token")

---

## 8. Training MoE: Complete Example

Let's train a small MoE model and monitor expert usage.

In [None]:
def train_moe_demo():
    """Train MoE model and monitor expert specialization"""
    
    # Small model for demonstration
    d_model = 256
    d_ff = 512
    num_experts = 8
    top_k = 2
    
    model = MixtureOfExpertsLayer(
        d_model, d_ff, num_experts, top_k, expert_type="relu"
    ).to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    # Training loop
    num_steps = 200
    batch_size = 16
    seq_len = 64
    
    expert_usage_history = []
    loss_history = []
    lb_loss_history = []
    
    print("Training MoE Model...")
    
    for step in range(num_steps):
        # Dummy data
        x = torch.randn(batch_size, seq_len, d_model, device=device)
        target = torch.randn(batch_size, seq_len, d_model, device=device)
        
        # Forward pass
        output, aux_info = model(x)
        
        # Task loss (MSE for demonstration)
        task_loss = F.mse_loss(output, target)
        
        # Load balancing loss
        lb_loss = load_balancing_loss(
            aux_info['router_logits'],
            aux_info['expert_indices']
        )
        
        # Total loss
        alpha = 0.01  # Load balancing weight
        total_loss = task_loss + alpha * lb_loss
        
        # Backward pass
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        # Track expert usage
        expert_indices = aux_info['expert_indices']
        expert_usage = torch.zeros(num_experts, device=device)
        for i in range(num_experts):
            expert_usage[i] = (expert_indices == i).sum().item()
        expert_usage = expert_usage / expert_usage.sum()
        
        expert_usage_history.append(expert_usage.cpu().numpy())
        loss_history.append(task_loss.item())
        lb_loss_history.append(lb_loss.item())
        
        if (step + 1) % 50 == 0:
            print(f"Step {step+1}/{num_steps}: Loss={task_loss.item():.4f}, "
                  f"LB Loss={lb_loss.item():.4f}")
    
    print("\n✅ Training complete!")
    
    return expert_usage_history, loss_history, lb_loss_history


# Run training
expert_usage_history, loss_history, lb_loss_history = train_moe_demo()

# Visualize results
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Plot 1: Training loss
axes[0, 0].plot(loss_history, linewidth=2)
axes[0, 0].set_xlabel('Step', fontsize=12)
axes[0, 0].set_ylabel('Task Loss', fontsize=12)
axes[0, 0].set_title('Training Loss', fontsize=14)
axes[0, 0].grid(True, alpha=0.3)

# Plot 2: Load balancing loss
axes[0, 1].plot(lb_loss_history, color='orange', linewidth=2)
axes[0, 1].set_xlabel('Step', fontsize=12)
axes[0, 1].set_ylabel('Load Balancing Loss', fontsize=12)
axes[0, 1].set_title('Load Balancing Loss', fontsize=14)
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Expert usage over time
expert_usage_array = np.array(expert_usage_history)
for i in range(8):
    axes[1, 0].plot(expert_usage_array[:, i], label=f'Expert {i}', linewidth=2)
axes[1, 0].axhline(y=1/8, color='r', linestyle='--', label='Ideal (12.5%)')
axes[1, 0].set_xlabel('Step', fontsize=12)
axes[1, 0].set_ylabel('Usage Fraction', fontsize=12)
axes[1, 0].set_title('Expert Usage Over Time', fontsize=14)
axes[1, 0].legend(ncol=3, fontsize=8)
axes[1, 0].grid(True, alpha=0.3)

# Plot 4: Final expert usage distribution
final_usage = expert_usage_array[-1] * 100
axes[1, 1].bar(range(8), final_usage, color='green', alpha=0.7)
axes[1, 1].axhline(y=12.5, color='r', linestyle='--', label='Ideal (12.5%)')
axes[1, 1].set_xlabel('Expert Index', fontsize=12)
axes[1, 1].set_ylabel('Usage (%)', fontsize=12)
axes[1, 1].set_title('Final Expert Usage Distribution', fontsize=14)
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print("\n📊 Key Observations:")
print("  - Load balancing loss decreases → experts become more balanced")
print("  - Expert usage converges to ~12.5% each (1/8)")
print("  - Some variance is expected and healthy (expert specialization)")

---

## Mini Exercises

### Exercise 1: Calculate MoE Efficiency

For a model with:
- 64 experts
- Top-8 routing
- Each expert has 1B parameters

Calculate:
1. Total parameters
2. Active parameters per token
3. Parameter efficiency ratio
4. How does this compare to a dense model with same active parameters?

In [None]:
# Your code here


In [None]:
# Solution
num_experts = 64
top_k = 8
params_per_expert = 1_000_000_000  # 1B

total_params = num_experts * params_per_expert
active_params = top_k * params_per_expert
efficiency = total_params / active_params

print("MoE Model Analysis")
print("="*60)
print(f"Configuration:")
print(f"  Experts: {num_experts}")
print(f"  Top-k: {top_k}")
print(f"  Parameters per expert: {params_per_expert:,}")
print(f"\nResults:")
print(f"  Total parameters: {total_params/1e9:.1f}B")
print(f"  Active parameters: {active_params/1e9:.1f}B")
print(f"  Efficiency ratio: {efficiency:.1f}x")
print(f"\nComparison to dense model:")
print(f"  Dense {active_params/1e9:.1f}B: Uses all {active_params/1e9:.1f}B params")
print(f"  MoE {total_params/1e9:.1f}B: Uses only {active_params/1e9:.1f}B params")
print(f"  MoE has {efficiency:.1f}x more capacity at same compute!")

### Exercise 2: Implement Z-Loss

The Z-loss is another auxiliary loss used in some MoE models to prevent router logits from growing too large.

$$
\mathcal{L}_z = \frac{1}{N} \sum_{i=1}^N \left( \log \sum_{j=1}^E \exp(h_{i,j}) \right)^2
$$

Implement this loss function.

In [None]:
# Your code here


In [None]:
# Solution
def z_loss(router_logits: torch.Tensor) -> torch.Tensor:
    """
    Z-loss: Penalizes large router logits
    
    Args:
        router_logits: (batch, seq_len, num_experts)
    
    Returns:
        loss: Scalar
    """
    # log(sum(exp(logits))) for each token
    log_z = torch.logsumexp(router_logits, dim=-1)  # (batch, seq_len)
    
    # Square and average
    loss = (log_z ** 2).mean()
    
    return loss

# Test
router_logits = torch.randn(4, 128, 8)
loss = z_loss(router_logits)
print(f"Z-loss: {loss.item():.4f}")

# Test with large logits (should have higher loss)
large_logits = torch.randn(4, 128, 8) * 10
large_loss = z_loss(large_logits)
print(f"Z-loss (large logits): {large_loss.item():.4f}")
print(f"\n💡 Z-loss penalizes large router logits to improve stability")

### Exercise 3: Visualize Expert Specialization

Create a visualization showing which types of tokens each expert handles.
Use different token categories (e.g., random noise with different means) and see if experts specialize.

In [None]:
# Your code here


In [None]:
# Solution
def visualize_expert_specialization():
    """Show how experts specialize for different token types"""
    
    d_model = 256
    num_experts = 4
    router = TopKRouter(d_model, num_experts, top_k=1).to(device)
    
    # Create 4 token types with different characteristics
    num_tokens_per_type = 200
    token_types = []
    token_labels = []
    
    # Type 0: Small values
    tokens_0 = torch.randn(num_tokens_per_type, d_model, device=device) * 0.5
    token_types.append(tokens_0)
    token_labels.extend([0] * num_tokens_per_type)
    
    # Type 1: Large values
    tokens_1 = torch.randn(num_tokens_per_type, d_model, device=device) * 2.0
    token_types.append(tokens_1)
    token_labels.extend([1] * num_tokens_per_type)
    
    # Type 2: Positive bias
    tokens_2 = torch.randn(num_tokens_per_type, d_model, device=device) + 1.0
    token_types.append(tokens_2)
    token_labels.extend([2] * num_tokens_per_type)
    
    # Type 3: Negative bias
    tokens_3 = torch.randn(num_tokens_per_type, d_model, device=device) - 1.0
    token_types.append(tokens_3)
    token_labels.extend([3] * num_tokens_per_type)
    
    # Combine all tokens
    all_tokens = torch.cat(token_types, dim=0).unsqueeze(0)  # (1, 800, d_model)
    
    # Route tokens
    with torch.no_grad():
        expert_indices, _, _ = router(all_tokens)
    
    expert_indices = expert_indices.squeeze(0).squeeze(-1).cpu().numpy()  # (800,)
    token_labels = np.array(token_labels)
    
    # Create confusion matrix: token_type × expert
    confusion = np.zeros((4, num_experts))
    for token_type in range(4):
        mask = token_labels == token_type
        for expert_idx in range(num_experts):
            confusion[token_type, expert_idx] = (expert_indices[mask] == expert_idx).sum()
    
    # Normalize to percentages
    confusion = confusion / confusion.sum(axis=1, keepdims=True) * 100
    
    # Visualize
    plt.figure(figsize=(10, 8))
    sns.heatmap(confusion, annot=True, fmt='.1f', cmap='YlOrRd',
                xticklabels=[f'Expert {i}' for i in range(num_experts)],
                yticklabels=['Type 0\n(small)', 'Type 1\n(large)', 
                            'Type 2\n(+bias)', 'Type 3\n(-bias)'],
                cbar_kws={'label': 'Percentage (%)'})
    plt.xlabel('Expert', fontsize=12)
    plt.ylabel('Token Type', fontsize=12)
    plt.title('Expert Specialization by Token Type', fontsize=14)
    plt.tight_layout()
    plt.show()
    
    print("📊 Heatmap shows which experts handle which token types")
    print("   Darker colors = more tokens of that type routed to that expert")
    print("\n💡 Experts naturally specialize for different input patterns!")

visualize_expert_specialization()

---

## Comprehensive Exercise: Build a Complete MoE Model

Create a complete MoE language model with:
1. Multiple MoE layers
2. Load balancing loss
3. Expert choice routing (optional)
4. Training loop with expert usage monitoring
5. Comparison to dense baseline

Requirements:
- 4 layers
- 8 experts per layer
- Top-2 routing
- Track expert usage per layer
- Compare parameters and compute vs dense model

In [None]:
# Your code here


In [None]:
# Solution
class CompleteMoEModel(nn.Module):
    """Complete MoE model with multiple layers"""
    
    def __init__(
        self,
        d_model: int,
        d_ff: int,
        num_layers: int,
        num_experts: int,
        top_k: int
    ):
        super().__init__()
        self.num_layers = num_layers
        
        self.layers = nn.ModuleList([
            MixtureOfExpertsLayer(d_model, d_ff, num_experts, top_k, "relu")
            for _ in range(num_layers)
        ])
        
        self.norm = nn.LayerNorm(d_model)
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[dict]]:
        aux_infos = []
        
        for layer in self.layers:
            x, aux_info = layer(x)
            aux_infos.append(aux_info)
        
        x = self.norm(x)
        return x, aux_infos


# Build models
d_model = 512
d_ff = 2048
num_layers = 4
num_experts = 8
top_k = 2

moe_model = CompleteMoEModel(d_model, d_ff, num_layers, num_experts, top_k).to(device)

# Dense baseline for comparison
class DenseModel(nn.Module):
    def __init__(self, d_model: int, d_ff: int, num_layers: int):
        super().__init__()
        self.layers = nn.ModuleList([
            Expert(d_model, d_ff) for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            x = x + layer(x)
        return self.norm(x)

dense_model = DenseModel(d_model, d_ff, num_layers).to(device)

# Compare models
print("Model Comparison")
print("="*70)

moe_params = sum(p.numel() for p in moe_model.parameters())
dense_params = sum(p.numel() for p in dense_model.parameters())

print(f"\nMoE Model:")
print(f"  Layers: {num_layers}")
print(f"  Experts per layer: {num_experts}")
print(f"  Top-k: {top_k}")
print(f"  Total parameters: {moe_params:,}")
print(f"  Active params/token: ~{moe_params / num_experts * top_k:,.0f}")

print(f"\nDense Model:")
print(f"  Layers: {num_layers}")
print(f"  Total parameters: {dense_params:,}")
print(f"  Active params/token: {dense_params:,}")

print(f"\nRatios:")
print(f"  MoE has {moe_params/dense_params:.1f}x more total parameters")
print(f"  MoE uses ~{(moe_params/num_experts*top_k)/dense_params:.1f}x active parameters")

# Test forward pass
batch_size = 4
seq_len = 128
x = torch.randn(batch_size, seq_len, d_model, device=device)

moe_out, aux_infos = moe_model(x)
dense_out = dense_model(x)

print(f"\nForward pass:")
print(f"  Input: {x.shape}")
print(f"  MoE output: {moe_out.shape}")
print(f"  Dense output: {dense_out.shape}")

# Analyze expert usage per layer
print(f"\nExpert usage by layer:")
for layer_idx, aux_info in enumerate(aux_infos):
    expert_indices = aux_info['expert_indices']
    expert_usage = torch.zeros(num_experts, device=device)
    for i in range(num_experts):
        expert_usage[i] = (expert_indices == i).sum().item()
    expert_usage = expert_usage / expert_usage.sum() * 100
    
    print(f"\n  Layer {layer_idx}:")
    for i in range(num_experts):
        bar = '█' * int(expert_usage[i].item() / 3)
        print(f"    Expert {i}: {expert_usage[i].item():5.1f}% {bar}")

print("\n✅ Complete MoE model with multi-layer expert routing!")

---

## Key Takeaways

1. **MoE decouples parameters from compute**: 10x params, 2-3x compute
2. **Top-k routing**: Each token uses subset of experts (typically k=2)
3. **Load balancing is critical**: Use auxiliary loss to prevent routing collapse
4. **Expert specialization**: Experts naturally learn different skills
5. **Trade-offs**: Memory cost for total params, communication overhead
6. **Expert choice routing**: Alternative that guarantees perfect balance
7. **Production usage**: GPT-4, DeepSeek-V3, Mixtral all use MoE

## Modern MoE Architectures (2025)

**GPT-4** (rumored):
- 16 experts per layer
- Top-2 routing
- ~1.8T total parameters
- ~220B active per token

**DeepSeek-V3**:
- 256 experts per layer
- Top-8 routing with load balancing
- 671B total parameters
- ~37B active per token
- Achieves GPT-4 level quality

**Mixtral 8x7B**:
- 8 experts per layer
- Top-2 routing
- 47B total parameters
- ~13B active per token
- Outperforms LLaMA 2 70B

**Implementation Tips**:
- Always use load balancing loss (α=0.01-0.1)
- Monitor expert usage during training
- Consider expert choice routing for guaranteed balance
- Use SwiGLU experts for better quality
- Implement expert parallelism for large models

---

## Next Steps

Continue to: [Topic 14: torch.compile & Performance Optimization](14_torch_compile.ipynb)

---

## Further Reading

- [Switch Transformers: Scaling to Trillion Parameter Models](https://arxiv.org/abs/2101.03961) (2021)
- [GLaM: Efficient Scaling of Language Models with Mixture-of-Experts](https://arxiv.org/abs/2112.06905) (2021)
- [Mixtral of Experts](https://arxiv.org/abs/2401.04088) (2024)
- [DeepSeek-V3 Technical Report](https://arxiv.org/abs/2412.19437) (2025)
- [Expert Choice Routing](https://arxiv.org/abs/2202.09368) (2022)
- [ST-MoE: Designing Stable and Transferable Sparse Expert Models](https://arxiv.org/abs/2202.08906) (2022)