# Mixture of Experts (MoE) from Scratch

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/deep-learning-abc/blob/main/mixture_of_experts.ipynb)

This notebook implements the Mixture of Experts (MoE) mechanism from scratch, as used in modern LLMs like Mixtral, Switch Transformer, and DeepSeek.

**Key idea:** Replace the dense FFN with multiple "expert" FFNs and a learned router that sends each token to only 1-2 experts. This enables massive models where only a fraction of parameters are active per token.

We cover:
1. Dense FFN baseline
2. Top-1 routing (Switch Transformer style)
3. Top-2 routing (Mixtral style)
4. Load balancing loss
5. Expert utilization analysis

In [None]:
!pip install torch matplotlib

In [None]:
import torch
import matplotlib.pyplot as plt
import math

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 0. Mathematical Foundations

### The Scaling Problem

Larger models perform better, but compute grows linearly with parameters. A 175B model needs ~175B FLOPs per token.

**MoE insight:** Not every token needs every parameter. Replace the dense FFN with $E$ experts and activate only $k$ experts per token:

$$\text{MoE}(x) = \sum_{i \in \text{TopK}} g_i(x) \cdot \text{Expert}_i(x)$$

where $g(x) = \text{softmax}(x \cdot W_{\text{router}})$ is the router, and TopK selects the $k$ experts with highest router scores.

### Example: Mixtral 8x7B
- 8 experts, activate 2 per token
- Total: 47B parameters → Active: 13B per token
- Matches or beats dense 70B models!

### Load Balancing

Without encouragement, the router collapses to always picking the same experts. The **auxiliary load balancing loss** penalizes uneven expert usage:

$$\mathcal{L}_{\text{balance}} = E \cdot \sum_{i=1}^{E} f_i \cdot p_i$$

where $f_i$ = fraction of tokens routed to expert $i$, and $p_i$ = mean router probability for expert $i$.

## 1. Dense FFN Baseline

In [None]:
def dense_ffn(x, W1, b1, W2, b2):
    """Standard dense FFN: all parameters used for every token."""
    return torch.relu(x @ W1 + b1) @ W2 + b2

# Setup
torch.manual_seed(42)
d_model = 32
d_ff = 128
batch, seq_len = 2, 8

x = torch.randn(batch, seq_len, d_model, device=device)
W1 = torch.randn(d_model, d_ff, device=device) * 0.1
b1 = torch.zeros(d_ff, device=device)
W2 = torch.randn(d_ff, d_model, device=device) * 0.1
b2 = torch.zeros(d_model, device=device)

out_dense = dense_ffn(x, W1, b1, W2, b2)
dense_params = d_model * d_ff + d_ff + d_ff * d_model + d_model

print(f'Dense FFN:')
print(f'  Output shape: {out_dense.shape}')
print(f'  Total parameters: {dense_params}')
print(f'  Active parameters per token: {dense_params} (100%)')

## 2. MoE with Top-1 Routing (Switch Transformer)

Each token goes to exactly one expert.

In [None]:
def init_experts(n_experts, d_model, d_ff, device):
    """Initialize multiple expert FFN networks."""
    experts = []
    for _ in range(n_experts):
        expert = {
            'W1': torch.randn(d_model, d_ff, device=device) * 0.1,
            'b1': torch.zeros(d_ff, device=device),
            'W2': torch.randn(d_ff, d_model, device=device) * 0.1,
            'b2': torch.zeros(d_model, device=device),
        }
        experts.append(expert)
    return experts

def expert_ffn(x, expert):
    """Apply a single expert FFN."""
    return torch.relu(x @ expert['W1'] + expert['b1']) @ expert['W2'] + expert['b2']

def moe_top1(x, router_weights, experts):
    """Mixture of Experts with top-1 routing.
    
    Args:
        x: (batch, seq_len, d_model)
        router_weights: (d_model, n_experts) — learned routing matrix
        experts: list of expert parameter dicts
    
    Returns:
        output: (batch, seq_len, d_model)
        router_probs: (batch*seq_len, n_experts) — routing probabilities
        expert_indices: (batch*seq_len,) — which expert was selected
    """
    batch, seq_len, d_model = x.shape
    n_experts = len(experts)
    
    # Flatten batch and sequence dimensions
    x_flat = x.view(-1, d_model)  # (batch*seq_len, d_model)
    n_tokens = x_flat.shape[0]
    
    # Router: compute probabilities for each expert
    router_logits = x_flat @ router_weights  # (n_tokens, n_experts)
    router_probs = torch.softmax(router_logits, dim=-1)  # (n_tokens, n_experts)
    
    # Select top-1 expert per token
    expert_weights, expert_indices = router_probs.max(dim=-1)  # (n_tokens,), (n_tokens,)
    
    # Route each token to its selected expert
    output = torch.zeros_like(x_flat)
    for i in range(n_experts):
        mask = (expert_indices == i)  # which tokens go to expert i
        if mask.any():
            expert_input = x_flat[mask]  # (n_selected, d_model)
            expert_output = expert_ffn(expert_input, experts[i])
            output[mask] = expert_weights[mask].unsqueeze(-1) * expert_output
    
    return output.view(batch, seq_len, d_model), router_probs, expert_indices

# Test
torch.manual_seed(42)
n_experts = 4
d_ff_expert = d_ff  # Each expert is same size as the dense FFN

experts = init_experts(n_experts, d_model, d_ff_expert, device)
router_W = torch.randn(d_model, n_experts, device=device) * 0.1

out_moe, probs, indices = moe_top1(x, router_W, experts)

total_moe_params = n_experts * dense_params + d_model * n_experts  # experts + router
active_moe_params = dense_params  # only 1 expert active per token

print(f'MoE Top-1 ({n_experts} experts):')
print(f'  Output shape: {out_moe.shape}')
print(f'  Total parameters: {total_moe_params}')
print(f'  Active parameters per token: {active_moe_params} ({100*active_moe_params/total_moe_params:.1f}%)')
print(f'\nRouting decisions (expert index per token):')
print(f'  {indices.view(batch, seq_len).cpu()}')
print(f'\nRouter probabilities (first 4 tokens):')
for t in range(4):
    p = probs[t].cpu()
    print(f'  Token {t}: {[f"{v:.3f}" for v in p.tolist()]} → Expert {indices[t].item()}')

In [None]:
# Visualize routing decisions
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Router probabilities heatmap
im = axes[0].imshow(probs.detach().cpu().numpy(), cmap='YlOrRd', aspect='auto')
axes[0].set_xlabel('Expert')
axes[0].set_ylabel('Token index (batch*seq_len)')
axes[0].set_title('Router Probabilities per Token')
plt.colorbar(im, ax=axes[0])

# Expert load distribution
expert_counts = [(indices == i).sum().item() for i in range(n_experts)]
colors = plt.cm.Set2(range(n_experts))
axes[1].bar(range(n_experts), expert_counts, color=colors)
axes[1].axhline(y=len(indices) / n_experts, color='red', linestyle='--', label=f'Ideal ({len(indices)//n_experts} each)')
axes[1].set_xlabel('Expert')
axes[1].set_ylabel('Number of tokens')
axes[1].set_title('Expert Load Distribution')
axes[1].legend()

plt.tight_layout()
plt.show()

## 3. MoE with Top-2 Routing (Mixtral Style)

Each token is processed by 2 experts, with outputs weighted by router probabilities.

In [None]:
def moe_top2(x, router_weights, experts):
    """Mixture of Experts with top-2 routing (Mixtral style).
    
    Each token goes to 2 experts; outputs are weighted by
    normalized router probabilities.
    """
    batch, seq_len, d_model = x.shape
    n_experts = len(experts)
    
    x_flat = x.view(-1, d_model)
    n_tokens = x_flat.shape[0]
    
    # Router
    router_logits = x_flat @ router_weights
    router_probs = torch.softmax(router_logits, dim=-1)
    
    # Select top-2 experts per token
    top2_weights, top2_indices = router_probs.topk(2, dim=-1)  # (n_tokens, 2)
    
    # Normalize weights for selected experts (they should sum to 1)
    top2_weights = top2_weights / top2_weights.sum(dim=-1, keepdim=True)
    
    # Route tokens to experts
    output = torch.zeros_like(x_flat)
    for k in range(2):  # for each of the 2 selected experts
        for i in range(n_experts):
            mask = (top2_indices[:, k] == i)
            if mask.any():
                expert_input = x_flat[mask]
                expert_output = expert_ffn(expert_input, experts[i])
                output[mask] += top2_weights[mask, k].unsqueeze(-1) * expert_output
    
    return output.view(batch, seq_len, d_model), router_probs, top2_indices, top2_weights

# Test
torch.manual_seed(42)
n_experts = 8  # Mixtral uses 8 experts
experts_8 = init_experts(n_experts, d_model, d_ff_expert, device)
router_W_8 = torch.randn(d_model, n_experts, device=device) * 0.1

out_top2, probs_8, top2_idx, top2_w = moe_top2(x, router_W_8, experts_8)

total_params_8 = n_experts * dense_params + d_model * n_experts
active_params_8 = 2 * dense_params  # 2 experts active per token

print(f'MoE Top-2 ({n_experts} experts, like Mixtral):')
print(f'  Output shape: {out_top2.shape}')
print(f'  Total parameters: {total_params_8}')
print(f'  Active per token: {active_params_8} ({100*active_params_8/total_params_8:.1f}%)')
print(f'\nRouting decisions (top-2 experts per token):')
for t in range(min(8, batch * seq_len)):
    e1, e2 = top2_idx[t].cpu().tolist()
    w1, w2 = top2_w[t].cpu().tolist()
    print(f'  Token {t}: Expert {e1} (w={w1:.3f}) + Expert {e2} (w={w2:.3f})')

## 4. Load Balancing Loss

Without load balancing, the router tends to collapse — always sending tokens to the same 1-2 experts.

In [None]:
def load_balancing_loss(router_probs, expert_indices, n_experts):
    """Compute auxiliary load balancing loss.
    
    Encourages uniform expert utilization.
    Loss = E * sum(f_i * p_i) where:
      f_i = fraction of tokens routed to expert i
      p_i = mean router probability for expert i
    
    Minimum when all experts get equal traffic.
    """
    n_tokens = router_probs.shape[0]
    
    # f_i: fraction of tokens assigned to each expert
    # For top-2, a token counts for both selected experts
    if expert_indices.dim() == 2:  # top-k
        f = torch.zeros(n_experts, device=router_probs.device)
        for k in range(expert_indices.shape[1]):
            for i in range(n_experts):
                f[i] += (expert_indices[:, k] == i).float().sum()
        f = f / (n_tokens * expert_indices.shape[1])
    else:  # top-1
        f = torch.zeros(n_experts, device=router_probs.device)
        for i in range(n_experts):
            f[i] = (expert_indices == i).float().mean()
    
    # p_i: mean router probability for each expert
    p = router_probs.mean(dim=0)  # (n_experts,)
    
    # Loss
    loss = n_experts * (f * p).sum()
    
    return loss, f, p

# Compute for our top-2 routing
loss, f, p = load_balancing_loss(probs_8, top2_idx, n_experts)
print(f'Load Balancing Loss: {loss.item():.4f}')
print(f'  (Ideal = 1.0 when perfectly balanced)')
print(f'\nExpert utilization (f_i = fraction of tokens):')
for i in range(n_experts):
    bar = '█' * int(f[i].item() * 80)
    print(f'  Expert {i}: {f[i].item():.3f} {bar}')
print(f'\nMean router probability (p_i):')
for i in range(n_experts):
    bar = '█' * int(p[i].item() * 80)
    print(f'  Expert {i}: {p[i].item():.3f} {bar}')

In [None]:
# Demonstrate: what happens with and without load balancing
# Simulate routing collapse vs balanced routing
n_tokens_sim = 1000
n_experts_sim = 8

# Scenario 1: Collapsed routing (most tokens go to expert 0)
collapsed_probs = torch.zeros(n_tokens_sim, n_experts_sim, device=device)
collapsed_probs[:, 0] = 0.8
collapsed_probs[:, 1:] = 0.2 / (n_experts_sim - 1)
collapsed_idx = torch.zeros(n_tokens_sim, dtype=torch.long, device=device)
collapsed_idx[n_tokens_sim//2:] = 1

# Scenario 2: Balanced routing
balanced_probs = torch.ones(n_tokens_sim, n_experts_sim, device=device) / n_experts_sim
balanced_idx = torch.arange(n_tokens_sim, device=device) % n_experts_sim

loss_collapsed, f_c, _ = load_balancing_loss(collapsed_probs, collapsed_idx, n_experts_sim)
loss_balanced, f_b, _ = load_balancing_loss(balanced_probs, balanced_idx, n_experts_sim)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

colors = plt.cm.Set2(range(n_experts_sim))
axes[0].bar(range(n_experts_sim), f_c.cpu().numpy(), color=colors)
axes[0].set_title(f'Collapsed Routing\nLoss = {loss_collapsed.item():.2f}')
axes[0].set_xlabel('Expert')
axes[0].set_ylabel('Fraction of tokens')
axes[0].axhline(y=1/n_experts_sim, color='red', linestyle='--')

axes[1].bar(range(n_experts_sim), f_b.cpu().numpy(), color=colors)
axes[1].set_title(f'Balanced Routing\nLoss = {loss_balanced.item():.2f}')
axes[1].set_xlabel('Expert')
axes[1].set_ylabel('Fraction of tokens')
axes[1].axhline(y=1/n_experts_sim, color='red', linestyle='--', label='Ideal')
axes[1].legend()

plt.suptitle('Load Balancing: Collapsed vs Balanced Expert Usage', fontsize=13)
plt.tight_layout()
plt.show()

print(f'Loss ratio: collapsed/balanced = {loss_collapsed.item()/loss_balanced.item():.1f}x')
print('→ Higher loss penalizes uneven routing, pushing toward balance')

## 5. Scaling Analysis

How MoE enables efficient scaling.

In [None]:
# Compare dense vs MoE parameter efficiency
d_model_real = 4096
d_ff_real = 14336  # Typical for 7B model

dense_ffn_params = 2 * d_model_real * d_ff_real  # W1 + W2 (ignoring biases)

print('=' * 70)
print('SCALING: Dense vs MoE Parameter Efficiency')
print(f'd_model={d_model_real}, d_ff={d_ff_real}')
print('=' * 70)

configs = [
    ('Dense', 1, 1),
    ('MoE 4x (top-1)', 4, 1),
    ('MoE 8x (top-2)', 8, 2),
    ('MoE 16x (top-2)', 16, 2),
]

print(f'{"Config":<20} {"Total Params":<18} {"Active Params":<18} {"Active %":<10} {"Speedup":<10}')
print('-' * 70)

for name, n_exp, top_k in configs:
    total = n_exp * dense_ffn_params + d_model_real * n_exp  # experts + router
    active = top_k * dense_ffn_params
    pct = 100 * active / total
    speedup = total / active
    print(f'{name:<20} {total/1e6:>12.1f}M  {active/1e6:>12.1f}M  {pct:>7.1f}%  {speedup:>7.1f}x')

print('-' * 70)
print('\nMixtral 8x7B equivalent:')
print(f'  47B total params, 13B active → 3.6x more "knowledge" per FLOP')

In [None]:
# Visualize the parameter vs compute tradeoff
expert_counts = [1, 2, 4, 8, 16, 32, 64]
base_params = 7  # billion params for one expert set

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Total params vs active params
total_p = [n * base_params for n in expert_counts]
active_top1 = [base_params for _ in expert_counts]
active_top2 = [2 * base_params for _ in expert_counts]

axes[0].plot(expert_counts, total_p, 'o-', linewidth=2, label='Total parameters', color='C0')
axes[0].plot(expert_counts, active_top1, 's--', linewidth=2, label='Active (top-1)', color='C1')
axes[0].plot(expert_counts, active_top2, '^--', linewidth=2, label='Active (top-2)', color='C2')
axes[0].set_xlabel('Number of Experts')
axes[0].set_ylabel('Parameters (B)')
axes[0].set_title('MoE: Total vs Active Parameters')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Efficiency ratio
efficiency_top1 = [t / a for t, a in zip(total_p, active_top1)]
efficiency_top2 = [t / a for t, a in zip(total_p, active_top2)]

axes[1].plot(expert_counts, efficiency_top1, 'o-', linewidth=2, label='Top-1 routing', color='C1')
axes[1].plot(expert_counts, efficiency_top2, 's-', linewidth=2, label='Top-2 routing', color='C2')
axes[1].set_xlabel('Number of Experts')
axes[1].set_ylabel('Knowledge/Compute Ratio')
axes[1].set_title('MoE Efficiency: More Experts = More Knowledge per FLOP')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Summary
print('=' * 75)
print('COMPARISON: MoE Routing Strategies')
print('=' * 75)
print(f'{"Property":<25} {"Top-1 (Switch)":<20} {"Top-2 (Mixtral)":<20}')
print('-' * 75)
print(f'{"Experts active/token":<25} {"1":<20} {"2":<20}')
print(f'{"Compute per token":<25} {"1x FFN":<20} {"2x FFN":<20}')
print(f'{"Quality":<25} {"Good":<20} {"Better":<20}')
print(f'{"Load balancing":<25} {"Critical":<20} {"Important":<20}')
print(f'{"Key model":<25} {"Switch Transformer":<20} {"Mixtral 8x7B":<20}')
print(f'{"Year":<25} {"2021":<20} {"2023":<20}')
print('=' * 75)

## Summary

In this notebook we implemented from scratch:

1. **Dense FFN baseline** — all parameters active for every token

2. **Top-1 MoE (Switch Transformer)** — each token routed to 1 expert. Maximum efficiency, slight quality tradeoff.

3. **Top-2 MoE (Mixtral style)** — each token routed to 2 experts with weighted outputs. Better quality than top-1.

4. **Load balancing loss** — auxiliary loss that penalizes uneven expert utilization, preventing routing collapse.

**Key insight:** MoE decouples model capacity (total parameters) from compute cost (active parameters). Mixtral 8x7B has 47B total parameters but only uses 13B per token — matching dense models 3-4x its active size.