# Mixture of Experts (MoE) Architectures

Mixture of Experts (MoE) models route tokens to specialized subnetworks, scaling parameter counts without linearly increasing compute. This notebook builds sparse gating, dispatch logic, and load-balancing losses in PyTorch.

## Learning Objectives

- Implement top-k routing that selects a subset of experts per token.
- Dispatch tokens to experts, accumulate outputs, and combine them with router probabilities.
- Apply load-balancing and entropy penalties to prevent expert collapse.
- Construct a reusable MoE feed-forward layer.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

torch.manual_seed(0)

class TopKRouter(nn.Module):
    def __init__(self, model_dim, num_experts, k=2):
        super().__init__()
        self.num_experts = num_experts
        self.k = k
        self.linear = nn.Linear(model_dim, num_experts)

    def forward(self, x):
        logits = self.linear(x)
        probs = F.softmax(logits, dim=-1)
        topk_probs, topk_idx = torch.topk(probs, self.k, dim=-1)
        return topk_probs, topk_idx, probs

router = TopKRouter(32, 4, k=2)
tokens = torch.randn(8, 32)
scores, indices, probs = router(tokens)
print(scores.shape, indices.shape)


## Experts and Dispatch

Experts are typically lightweight feed-forward modules. We dispatch each token to its selected experts, apply expert transformations, and combine outputs with router weights.

In [None]:
class Expert(nn.Module):
    def __init__(self, model_dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(model_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, model_dim),
        )

    def forward(self, x):
        return self.net(x)

experts = nn.ModuleList([Expert(32, 64) for _ in range(4)])

def moe_forward(x, router, experts):
    topk_probs, topk_idx, _ = router(x)
    outputs = torch.zeros_like(x)
    for i, (weights, expert_ids, token) in enumerate(zip(topk_probs, topk_idx, x)):
        combined = 0.0
        for w, expert_id in zip(weights, expert_ids):
            combined += w * experts[expert_id](token.unsqueeze(0)).squeeze(0)
        outputs[i] = combined
    return outputs

outputs = moe_forward(tokens, router, experts)
print(outputs.shape)


## Load Balancing Loss

Routers can collapse to a small subset of experts. Encourage uniform utilization with a KL divergence penalty and visualize average probabilities.

In [None]:
def load_balancing_loss(probs, epsilon=1e-9):
    mean_usage = probs.mean(dim=0)
    uniform = torch.full_like(mean_usage, 1.0 / mean_usage.numel())
    kl = (uniform * (uniform.add(epsilon).log() - mean_usage.add(epsilon).log())).sum()
    return kl

lb_loss = load_balancing_loss(probs)
print(f"Load balancing loss: {lb_loss.item():.6f}")

avg_probs = probs.mean(dim=0).detach()
plt.bar(range(len(avg_probs)), avg_probs)
plt.xlabel("Expert ID")
plt.ylabel("Average probability")
plt.title("Expert utilization snapshot")
plt.show()


## Mini Task – Capacity Limits

Real MoE implementations cap the number of tokens each expert processes (capacity). Implement a dispatcher that enforces a maximum capacity per expert and reports any dropped tokens.

In [None]:
def moe_forward_with_capacity(x, router, experts, capacity=2):
    # TODO: dispatch tokens with capacity limits; print dropped tokens
    raise NotImplementedError


In [None]:
def moe_forward_with_capacity(x, router, experts, capacity=2):
    topk_probs, topk_idx, _ = router(x)
    expert_buffers = {idx: [] for idx in range(len(experts))}
    outputs = torch.zeros_like(x)
    for token_idx, (weights, ids, token) in enumerate(zip(topk_probs, topk_idx, x)):
        dispatched = False
        for weight, expert_id in zip(weights, ids):
            if len(expert_buffers[expert_id]) < capacity:
                outputs[token_idx] += weight * experts[expert_id](token.unsqueeze(0)).squeeze(0)
                expert_buffers[expert_id].append(token_idx)
                dispatched = True
                break
        if not dispatched:
            print(f"Token {token_idx} dropped due to capacity limits")
    return outputs

moe_forward_with_capacity(tokens, router, experts, capacity=1)


## Further Enhancements

- **Switch Transformers** route each token to a single expert (k=1) for simplicity.
- **GShard** and **GLaM** scale MoE layers across multiple devices with specialized load-balancing losses.
- **Weighted load balancing** adds entropy regularization to encourage exploration of experts.

## Comprehensive Exercise – MoE Feed-Forward Layer

Build an `MoEFeedForward` module compatible with transformer feed-forward blocks. The module should:

- Accept a temperature parameter to anneal router logits.
- Return the combined output along with auxiliary metrics (load balancing loss, entropy, expert usage).
- Support configuring number of experts, top-k routing, and hidden dimensions.

In [None]:
class MoEFeedForward(nn.Module):
    def __init__(self, model_dim, hidden_dim, num_experts, k=2, router_temp=1.0):
        super().__init__()
        # TODO: initialize router, experts, and temperature parameter

    def forward(self, x, return_aux=False):
        # TODO: apply routing, combine expert outputs, compute aux metrics
        raise NotImplementedError


In [None]:
class MoEFeedForward(nn.Module):
    def __init__(self, model_dim, hidden_dim, num_experts, k=2, router_temp=1.0):
        super().__init__()
        self.router = TopKRouter(model_dim, num_experts, k)
        self.experts = nn.ModuleList([Expert(model_dim, hidden_dim) for _ in range(num_experts)])
        self.router_temp = nn.Parameter(torch.tensor(router_temp, dtype=torch.float32))

    def forward(self, x, return_aux=False):
        scaled_input = x / self.router_temp.clamp(min=0.1)
        topk_probs, topk_idx, probs = self.router(scaled_input)
        outputs = torch.zeros_like(x)
        expert_counts = torch.zeros(len(self.experts), device=x.device)
        for i, (weights, ids, token) in enumerate(zip(topk_probs, topk_idx, x)):
            for weight, expert_id in zip(weights, ids):
                outputs[i] += weight * self.experts[expert_id](token.unsqueeze(0)).squeeze(0)
                expert_counts[expert_id] += 1
        aux = {}
        if return_aux:
            lb = load_balancing_loss(probs)
            entropy = -(probs * (probs + 1e-9).log()).sum(dim=-1).mean()
            aux = {
                "load_balance": lb,
                "entropy": entropy,
                "expert_usage": expert_counts / expert_counts.sum().clamp(min=1.0),
            }
        return outputs, aux

moe_layer = MoEFeedForward(32, 64, num_experts=4)
out, aux = moe_layer(tokens, return_aux=True)
print(out.shape, aux)


## Further Reading

- Lepikhin et al. (2020) “GShard: Scaling Giant Models with Conditional Computation”
- Fedus et al. (2021) “Switch Transformers: Scaling to Trillion Parameter Models”
- Shazeer et al. (2017) “Outrageously Large Neural Networks”
- DeepSpeed MoE and Fairseq MoE tutorials for distributed training