# Lesson 10: Mixture-of-Experts Transformer (Toy)

Mixture-of-Experts (MoE) layers let a model have **more parameters without proportional compute**.
A router decides which expert MLP(s) should process each token, so only a small subset of
experts run per token. This notebook builds a tiny decoder-only Transformer with a toy
MoE feedforward layer to make the idea concrete.

We keep the code small and readable rather than highly optimized.


## 1) Setup
We use PyTorch and keep defaults small so it runs quickly.


In [None]:
import math
import random
import re
from collections import Counter

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
random.seed(42)

# Small defaults for a toy run
D_MODEL = 256  # model width
N_HEADS = 4  # attention heads
N_LAYERS = 4  # transformer layers
N_EXPERTS = 4  # number of experts
TOP_K = 1  # top-k routing
SEQ_LEN = 128  # sequence length (tokens)
BATCH_SIZE = 32  # batch size
LR = 3e-4  # learning rate
MAX_STEPS = 200  # training steps
EVAL_EVERY = 50  # eval interval (steps)
AUX_WEIGHT = 0.01  # aux loss weight

# Bigger knobs (commented)
# D_MODEL = 512
# N_LAYERS = 8
# N_EXPERTS = 8
# TOP_K = 2
# MAX_STEPS = 1000

## 2) Dataset and tokenizer
We try `wikitext-2-raw-v1`. If it is not available locally, we fall back to a tiny
built-in text. The tokenizer is a simple regex word splitter to keep things transparent.


In [None]:
def load_wikitext_or_fallback():
    try:
        from datasets import load_dataset

        ds = load_dataset("wikitext", "wikitext-2-raw-v1")
        train_text = "\n\n".join(ds["train"]["text"])
        valid_text = "\n\n".join(ds["validation"]["text"])
        return train_text, valid_text
    except Exception as e:
        print("Falling back to a tiny built-in corpus. Reason:", e)
        fallback = (
            "In the beginning the universe was created. This has made a lot of people very angry. "
            "It is a truth universally acknowledged, that a single man in possession of a good fortune, "
            "must be in want of a wife. The quick brown fox jumps over the lazy dog. "
        )
        return fallback, fallback

def tokenize(text):
    # Split into words and punctuation. This is intentionally simple.
    return re.findall(r"\w+|[^\w\s]", text.lower())

def build_vocab(tokens, min_freq=2):
    counter = Counter(tokens)
    vocab = {"<pad>": 0, "<unk>": 1, "<bos>": 2, "<eos>": 3}
    for tok, freq in counter.items():
        if freq >= min_freq and tok not in vocab:
            vocab[tok] = len(vocab)
    id2tok = {i: t for t, i in vocab.items()}
    return vocab, id2tok

def encode(tokens, vocab):
    return [vocab.get(t, vocab["<unk>"]) for t in tokens]

def decode(ids, id2tok):
    return " ".join(id2tok.get(i, "<unk>") for i in ids)

train_text, valid_text = load_wikitext_or_fallback()
train_tokens = ["<bos>"] + tokenize(train_text) + ["<eos>"]
valid_tokens = ["<bos>"] + tokenize(valid_text) + ["<eos>"]

vocab, id2tok = build_vocab(train_tokens, min_freq=2)
train_ids = encode(train_tokens, vocab)
valid_ids = encode(valid_tokens, vocab)

# Truncate for a fast demo if the dataset is large
max_tokens = 200_000
if len(train_ids) > max_tokens:
    train_ids = train_ids[:max_tokens]
if len(valid_ids) > max_tokens:
    valid_ids = valid_ids[:max_tokens]

print(f"Vocab size: {len(vocab)}")
print(f"Train tokens: {len(train_ids)}")

class LMDataset(Dataset):
    def __init__(self, ids, seq_len):
        self.ids = ids
        self.seq_len = seq_len

    def __len__(self):
        return max(0, len(self.ids) - self.seq_len)

    def __getitem__(self, idx):
        x = torch.tensor(self.ids[idx : idx + self.seq_len], dtype=torch.long)
        y = torch.tensor(self.ids[idx + 1 : idx + self.seq_len + 1], dtype=torch.long)
        return x, y

train_loader = DataLoader(LMDataset(train_ids, SEQ_LEN), batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(LMDataset(valid_ids, SEQ_LEN), batch_size=BATCH_SIZE)

## 3) Baseline Transformer decoder (minimal)
We build a tiny decoder-only Transformer using PyTorch modules.
The feedforward block is a standard 2-layer MLP.


In [None]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff=None, dropout=0.1):
        super().__init__()
        d_ff = d_ff or 4 * d_model
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, ffn):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=0.1, batch_first=True)
        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = ffn

    def forward(self, x, attn_mask):
        # Self-attention (causal mask prevents attending to future tokens)
        h = self.ln1(x)
        attn_out, _ = self.attn(h, h, h, attn_mask=attn_mask, need_weights=False)
        x = x + attn_out

        # Feedforward; if the module returns (out, aux, stats), unpack it
        h = self.ln2(x)
        ffn_out = self.ffn(h)
        aux_loss = torch.tensor(0.0, device=x.device)
        stats = None
        if isinstance(ffn_out, tuple):
            ffn_out, aux_loss, stats = ffn_out
        x = x + ffn_out
        return x, aux_loss, stats

class TransformerLM(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_layers, ffn_factory):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(SEQ_LEN, d_model)
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, ffn_factory())
            for _ in range(n_layers)
        ])
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, idx):
        bsz, seq_len = idx.shape
        pos = torch.arange(seq_len, device=idx.device).unsqueeze(0).expand(bsz, seq_len)
        x = self.tok_emb(idx) + self.pos_emb(pos)

        # Causal mask: True values are blocked
        attn_mask = torch.triu(torch.ones(seq_len, seq_len, device=idx.device), diagonal=1).bool()

        total_aux = torch.tensor(0.0, device=idx.device)
        last_stats = None
        for block in self.blocks:
            x, aux_loss, stats = block(x, attn_mask)
            total_aux = total_aux + aux_loss
            if stats is not None:
                last_stats = stats

        x = self.ln_f(x)
        logits = self.head(x)
        return logits, total_aux, last_stats

def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

baseline = TransformerLM(
    vocab_size=len(vocab),
    d_model=D_MODEL,
    n_heads=N_HEADS,
    n_layers=N_LAYERS,
    ffn_factory=lambda: FeedForward(D_MODEL),
)
print(f"Baseline params: {count_params(baseline):,}")

## 4) MoE Feedforward
The router is a linear layer that produces a score for each expert per token.
We keep **top-k** experts per token, normalize their scores into gates, and
combine the experts' outputs with those weights.

We also compute a tiny load-balancing loss that encourages experts to be used.


In [None]:
class MoEFeedForward(nn.Module):
    def __init__(self, d_model, n_experts, top_k=1, d_ff=None, dropout=0.1):
        super().__init__()
        self.n_experts = n_experts
        self.top_k = top_k
        self.router = nn.Linear(d_model, n_experts)
        self.experts = nn.ModuleList([
            FeedForward(d_model, d_ff=d_ff, dropout=dropout)
            for _ in range(n_experts)
        ])

    def forward(self, x):
        # x: (batch, seq, d_model)
        logits = self.router(x)  # (batch, seq, n_experts)
        topk_vals, topk_idx = logits.topk(self.top_k, dim=-1)
        topk_gates = F.softmax(topk_vals, dim=-1)

        # Scatter normalized gates back into full expert dimension
        gates = torch.zeros_like(logits)
        gates.scatter_(-1, topk_idx, topk_gates)

        # Toy implementation: evaluate all experts, then mix with gates
        expert_outs = torch.stack([expert(x) for expert in self.experts], dim=2)
        out = torch.sum(expert_outs * gates.unsqueeze(-1), dim=2)

        # Routing statistics for the top-1 expert per token
        top1 = topk_idx[..., 0]
        counts = torch.bincount(top1.reshape(-1), minlength=self.n_experts)

        # Simple load-balancing loss: encourage importance and load to align
        importance = gates.sum(dim=(0, 1))
        load = counts.float()
        importance = importance / (importance.sum() + 1e-9)
        load = load / (load.sum() + 1e-9)
        aux_loss = (self.n_experts * (importance * load).sum())

        stats = {"counts": counts.detach().cpu()}
        return out, aux_loss, stats

moe_model = TransformerLM(
    vocab_size=len(vocab),
    d_model=D_MODEL,
    n_heads=N_HEADS,
    n_layers=N_LAYERS,
    ffn_factory=lambda: MoEFeedForward(D_MODEL, N_EXPERTS, TOP_K),
)
print(f"MoE params: {count_params(moe_model):,}")

## 5) Training loop
We train briefly and print both the language modeling loss and routing counts.
The auxiliary loss is tiny and only nudges the router to balance expert usage.


In [None]:
model = moe_model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
loss_fn = nn.CrossEntropyLoss()

model.train()
step = 0
for x, y in train_loader:
    x = x.to(device)
    y = y.to(device)

    logits, aux_loss, stats = model(x)
    lm_loss = loss_fn(logits.view(-1, len(vocab)), y.view(-1))
    total_loss = lm_loss + AUX_WEIGHT * aux_loss

    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    if step % EVAL_EVERY == 0:
        counts = stats["counts"] if stats is not None else None
        print(f"step {step:4d} | lm_loss {lm_loss.item():.4f} | aux {aux_loss.item():.4f}")
        if counts is not None:
            print("routing counts:", counts.tolist())

    step += 1
    if step >= MAX_STEPS:
        break

## 6) Inference: generate text
We generate a short sample with greedy decoding.


In [None]:
def generate(model, prompt, max_new_tokens=50):
    model.eval()
    ids = encode(tokenize(prompt), vocab)
    ids = ids if ids else [vocab["<bos>"]]
    for _ in range(max_new_tokens):
        x = torch.tensor(ids[-SEQ_LEN:], dtype=torch.long, device=device).unsqueeze(0)
        with torch.no_grad():
            logits, _, _ = model(x)
        next_id = int(torch.argmax(logits[0, -1], dim=-1))
        ids.append(next_id)
        if next_id == vocab["<eos>"]:
            break
    return decode(ids, id2tok)

print(generate(model, "in the", max_new_tokens=40))

## 7) Scaling notes
Real MoE systems add several engineering ideas:
- Distributed experts: experts live on different devices and only selected experts run.
- Capacity factors: limit how many tokens each expert processes to avoid overload.
- Auxiliary losses: encourage balanced routing and stable training.
- Token dropping or overflow buffers when experts are full.

Our toy version computes all experts (slow) and ignores capacity limits. It is
meant for clarity, not performance.


## 8) Exercises
1) Change TOP_K to 2 and compare routing counts and loss.
2) Increase N_EXPERTS and see how the parameter count grows.
3) Add a temperature to the router logits before top-k selection.
4) Replace greedy decoding with top-k sampling in `generate`.
5) Try turning off the auxiliary loss and compare routing balance.
