# Mixture of Experts (MoE) Implementation from Scratch

## Overview
This notebook implements a complete Mixture of Experts architecture from scratch using only PyTorch. We'll build:

1. **Expert Networks**: Individual feed-forward neural networks
2. **Router Mechanism**: Top-K routing with load balancing
3. **Sparse MoE Layer**: Combining expert outputs based on routing weights
4. **Complete Transformer**: Full transformer with MoE replacing the FFN layer
5. **Training Pipeline**: Pre-training on Shakespeare dataset
6. **Inference**: Text generation with the trained MoE model

## Learning Objectives
- Understand MoE architecture components
- Implement routing mechanisms and load balancing
- Build sparse computation patterns
- Train and evaluate MoE models
- Compare with traditional dense transformers

## Architecture Overview
```
Input → Multi-Head Attention → Layer Norm → MoE Layer → Output
                ↑                              ↓
           Residual Connection         Router + Experts
```

Let's start implementing each component step by step!

In [1]:
# Step 1: Import Required Libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import functional as F
import math
import random
import numpy as np
from dataclasses import dataclass
from typing import Optional
import requests
import os

# Set random seeds for reproducibility
torch.manual_seed(1337)
random.seed(1337)
np.random.seed(1337)

# Check device availability
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# GPU memory optimization
if device == 'cuda':
    torch.cuda.empty_cache()
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

Using device: cpu


In [2]:
# Step 2: Load and Prepare Dataset
# Download Shakespeare dataset if not already present
if not os.path.exists('input.txt'):
    url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
    with open('input.txt', 'w', encoding='utf-8') as f:
        f.write(requests.get(url).text)

# Read the dataset
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

print(f"Dataset size: {len(text):,} characters")
print("Sample text:")
print(text[:500])

# Character-level tokenization
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f"Vocabulary size: {vocab_size} characters")
print(f"Characters: {''.join(chars)}")

# Create character-to-integer and integer-to-character mappings
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}

def encode(s):
    """Encode string to list of integers"""
    return [stoi[c] for c in s]

def decode(l):
    """Decode list of integers to string"""
    return ''.join([itos[i] for i in l])

# Test encoding/decoding
test_text = "Hello World!"
encoded = encode(test_text)
decoded = decode(encoded)
print(f"Original: {test_text}")
print(f"Encoded: {encoded}")
print(f"Decoded: {decoded}")

Dataset size: 1,115,394 characters
Sample text:
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor
Vocabulary size: 65 characters
Characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
Original: Hello World!
Encoded: [20, 43, 50, 50, 53, 1, 35, 53, 56, 50, 42, 2]
Decoded: Hello World!


In [3]:
# Step 3: Define Individual Expert Networks
class Expert(nn.Module):
    """
    Individual expert network - a simple feed-forward neural network
    Each expert specializes in processing certain types of tokens
    """
    def __init__(self, n_embed, dropout=0.1):
        super().__init__()
        # Expansion layer: embed_dim -> 4*embed_dim (standard in transformers)
        self.w1 = nn.Linear(n_embed, 4 * n_embed, bias=False)
        # Contraction layer: 4*embed_dim -> embed_dim  
        self.w2 = nn.Linear(4 * n_embed, n_embed, bias=False)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # x shape: (batch_size, seq_len, n_embed)
        # Apply expansion, ReLU activation, contraction, and dropout
        return self.dropout(self.w2(F.relu(self.w1(x))))

# Test the Expert network
print("=== Testing Expert Network ===")
n_embed = 8  # Embedding dimension
expert = Expert(n_embed)

# Create test input: 1 batch, 4 tokens, 8-dim embeddings
test_input = torch.randn(1, 4, n_embed)
expert_output = expert(test_input)

print(f"Input shape: {test_input.shape}")
print(f"Expert output shape: {expert_output.shape}")
print(f"Expert parameters: {sum(p.numel() for p in expert.parameters()):,}")

# Verify the expert maintains input/output dimensions
assert test_input.shape == expert_output.shape, "Expert should maintain input shape"
print("✓ Expert network working correctly!")

=== Testing Expert Network ===
Input shape: torch.Size([1, 4, 8])
Expert output shape: torch.Size([1, 4, 8])
Expert parameters: 512
✓ Expert network working correctly!


In [4]:
# Step 4: Implement Router Mechanism
class TopKRouter(nn.Module):
    """
    Router that determines which experts should process each token
    Uses top-k routing to select the best experts for each token
    """
    def __init__(self, n_embed, num_experts, top_k):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        
        # Router network: maps input to expert scores
        self.router = nn.Linear(n_embed, num_experts, bias=False)
        
    def forward(self, x):
        # x shape: (batch_size, seq_len, n_embed)
        batch_size, seq_len, n_embed = x.shape
        
        # Step 1: Get routing scores (logits) for each expert
        # Shape: (batch_size, seq_len, num_experts)
        logits = self.router(x)
        
        # Step 2: Select top-k experts for each token
        # top_k_logits: (batch_size, seq_len, top_k)  
        # top_k_indices: (batch_size, seq_len, top_k)
        top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1)
        
        # Step 3: Create mask for selected experts
        # Initialize with negative infinity (will become 0 after softmax)
        routing_weights = torch.full_like(logits, float('-inf'))
        
        # Set selected expert logits
        routing_weights.scatter_(-1, top_k_indices, top_k_logits)
        
        # Step 4: Apply softmax to get routing probabilities
        # Only selected experts will have non-zero probabilities
        routing_weights = F.softmax(routing_weights, dim=-1)
        
        return routing_weights, top_k_indices

# Test the Router
print("=== Testing Router Mechanism ===")
num_experts = 3
top_k = 2
n_embed = 8

router = TopKRouter(n_embed, num_experts, top_k)

# Test input: 1 batch, 4 tokens, 8-dim embeddings
test_input = torch.randn(1, 4, n_embed)
routing_weights, expert_indices = router(test_input)

print(f"Input shape: {test_input.shape}")
print(f"Routing weights shape: {routing_weights.shape}")
print(f"Expert indices shape: {expert_indices.shape}")

print(f"\nRouting weights (each row sums to 1.0):")
print(routing_weights.squeeze(0))

print(f"\nSelected expert indices:")
print(expert_indices.squeeze(0))

# Verify that each row sums to 1.0 (probability distribution)
row_sums = routing_weights.sum(dim=-1)
print(f"\nRow sums (should be 1.0): {row_sums.squeeze(0)}")

print("✓ Router working correctly!")

=== Testing Router Mechanism ===
Input shape: torch.Size([1, 4, 8])
Routing weights shape: torch.Size([1, 4, 3])
Expert indices shape: torch.Size([1, 4, 2])

Routing weights (each row sums to 1.0):
tensor([[0.4686, 0.5314, 0.0000],
        [0.2901, 0.7099, 0.0000],
        [0.4520, 0.5480, 0.0000],
        [0.6347, 0.0000, 0.3653]], grad_fn=<SqueezeBackward1>)

Selected expert indices:
tensor([[1, 0],
        [1, 0],
        [1, 0],
        [0, 2]])

Row sums (should be 1.0): tensor([1., 1., 1., 1.], grad_fn=<SqueezeBackward1>)
✓ Router working correctly!


In [5]:
# Step 5: Implement Noisy Top-K Router (Advanced Load Balancing)
class NoisyTopKRouter(nn.Module):
    """
    Enhanced router with Gaussian noise for better load balancing
    Noise helps prevent certain experts from being consistently favored
    """
    def __init__(self, n_embed, num_experts, top_k, noise_std=0.1):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.noise_std = noise_std
        
        # Router network
        self.router = nn.Linear(n_embed, num_experts, bias=False)
        
    def forward(self, x):
        # x shape: (batch_size, seq_len, n_embed)
        batch_size, seq_len, n_embed = x.shape
        
        # Step 1: Get base routing scores
        logits = self.router(x)
        
        # Step 2: Add Gaussian noise during training for load balancing
        if self.training:
            noise = torch.randn_like(logits) * self.noise_std
            logits = logits + noise
        
        # Step 3: Select top-k experts
        top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1)
        
        # Step 4: Create routing weights with softmax
        routing_weights = torch.full_like(logits, float('-inf'))
        routing_weights.scatter_(-1, top_k_indices, top_k_logits)
        routing_weights = F.softmax(routing_weights, dim=-1)
        
        return routing_weights, top_k_indices

# Test Noisy Router
print("=== Testing Noisy Top-K Router ===")
noisy_router = NoisyTopKRouter(n_embed, num_experts, top_k, noise_std=0.1)

# Test in training mode (with noise)
noisy_router.train()
routing_weights_train, _ = noisy_router(test_input)

# Test in evaluation mode (without noise)  
noisy_router.eval()
routing_weights_eval, _ = noisy_router(test_input)

print("Training mode routing weights (with noise):")
print(routing_weights_train.squeeze(0))

print("\nEvaluation mode routing weights (without noise):")
print(routing_weights_eval.squeeze(0))

print("\nDifference (shows effect of noise):")
print((routing_weights_train - routing_weights_eval).abs().max().item())

print("✓ Noisy router working correctly!")

=== Testing Noisy Top-K Router ===
Training mode routing weights (with noise):
tensor([[0.0000, 0.5040, 0.4960],
        [0.0000, 0.5704, 0.4296],
        [0.4679, 0.0000, 0.5321],
        [0.7194, 0.2806, 0.0000]], grad_fn=<SqueezeBackward1>)

Evaluation mode routing weights (without noise):
tensor([[0.0000, 0.5254, 0.4746],
        [0.0000, 0.6359, 0.3641],
        [0.4976, 0.0000, 0.5024],
        [0.7319, 0.2681, 0.0000]], grad_fn=<SqueezeBackward1>)

Difference (shows effect of noise):
0.06540459394454956
✓ Noisy router working correctly!


In [6]:
# Step 6: Implement Complete Sparse MoE Layer
class SparseMoE(nn.Module):
    """
    Complete Sparse Mixture of Experts layer
    Combines routing mechanism with expert networks
    """
    def __init__(self, n_embed, num_experts, top_k, dropout=0.1, use_noise=True):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.n_embed = n_embed
        
        # Create expert networks
        self.experts = nn.ModuleList([
            Expert(n_embed, dropout) for _ in range(num_experts)
        ])
        
        # Router selection
        if use_noise:
            self.router = NoisyTopKRouter(n_embed, num_experts, top_k)
        else:
            self.router = TopKRouter(n_embed, num_experts, top_k)
            
    def forward(self, x):
        # x shape: (batch_size, seq_len, n_embed)
        batch_size, seq_len, n_embed = x.shape
        
        # Step 1: Get routing weights and expert indices
        routing_weights, expert_indices = self.router(x)
        
        # Step 2: Process input through all experts
        # We process all tokens through all experts, then select outputs
        expert_outputs = []
        for expert in self.experts:
            expert_output = expert(x)  # Shape: (batch_size, seq_len, n_embed)
            expert_outputs.append(expert_output)
        
        # Stack expert outputs: (num_experts, batch_size, seq_len, n_embed)
        expert_outputs = torch.stack(expert_outputs, dim=0)
        
        # Step 3: Combine expert outputs using routing weights
        # Reshape for efficient computation
        x_flat = x.view(-1, n_embed)  # (batch_size * seq_len, n_embed)
        routing_weights_flat = routing_weights.view(-1, self.num_experts)  # (batch_size * seq_len, num_experts)
        
        # Initialize output
        final_output = torch.zeros_like(x_flat)
        
        # For each expert, add its weighted contribution
        for expert_idx in range(self.num_experts):
            # Get expert output for all tokens
            expert_output_flat = expert_outputs[expert_idx].view(-1, n_embed)
            
            # Get routing weights for this expert
            expert_weights = routing_weights_flat[:, expert_idx:expert_idx+1]  # (batch_size * seq_len, 1)
            
            # Add weighted expert output
            final_output += expert_weights * expert_output_flat
        
        # Reshape back to original dimensions
        final_output = final_output.view(batch_size, seq_len, n_embed)
        
        return final_output

# Test the complete MoE layer
print("=== Testing Complete Sparse MoE Layer ===")
moe_layer = SparseMoE(
    n_embed=8,
    num_experts=3, 
    top_k=2,
    dropout=0.1,
    use_noise=True
)

# Test forward pass
test_input = torch.randn(1, 4, 8)
moe_output = moe_layer(test_input)

print(f"Input shape: {test_input.shape}")
print(f"MoE output shape: {moe_output.shape}")
print(f"MoE parameters: {sum(p.numel() for p in moe_layer.parameters()):,}")

# Verify output shape matches input shape
assert test_input.shape == moe_output.shape, "MoE should maintain input shape"

print(f"\nInput sample:")
print(test_input.squeeze(0))

print(f"\nMoE output sample:")
print(moe_output.squeeze(0))

print("✓ Sparse MoE layer working correctly!")

=== Testing Complete Sparse MoE Layer ===
Input shape: torch.Size([1, 4, 8])
MoE output shape: torch.Size([1, 4, 8])
MoE parameters: 1,560

Input sample:
tensor([[ 2.2792, -0.3402, -0.7501,  0.2942,  0.7626,  2.6536,  0.4730,  0.0147],
        [-0.1513,  1.5333,  1.0515, -0.4613,  2.0802,  0.8309, -0.8416, -0.1644],
        [-1.0877,  0.0698, -1.1470, -0.5624, -0.1978,  0.8101, -1.0031,  0.6105],
        [ 1.2420,  0.5707, -0.0135, -1.0993,  1.3919,  0.9944,  0.5453,  0.1693]])

MoE output sample:
tensor([[-0.2407,  0.0637, -0.0806,  0.1530,  0.1621, -0.2777,  0.0332,  0.2249],
        [ 0.2400, -0.3806,  0.0677, -0.2340, -0.1102, -0.2391, -0.0081,  0.1539],
        [-0.1661, -0.0328,  0.1538, -0.0882,  0.0367,  0.1104,  0.0214, -0.0748],
        [-0.1191, -0.1103, -0.2469,  0.0279, -0.0493, -0.3956, -0.0544,  0.3399]],
       grad_fn=<SqueezeBackward1>)
✓ Sparse MoE layer working correctly!


In [7]:
# Step 7: Implement Attention Mechanisms
class Head(nn.Module):
    """Single attention head for multi-head attention"""
    def __init__(self, n_embed, head_size, block_size, dropout=0.1):
        super().__init__()
        self.head_size = head_size
        self.block_size = block_size
        
        # Key, Query, Value projections
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False) 
        self.value = nn.Linear(n_embed, head_size, bias=False)
        
        # Causal mask (lower triangular)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        B, T, C = x.shape  # batch, time, channels
        
        # Get key, query, value
        k = self.key(x)   # (B, T, head_size)
        q = self.query(x) # (B, T, head_size)
        v = self.value(x) # (B, T, head_size)
        
        # Scaled dot-product attention
        # Compute attention scores
        wei = q @ k.transpose(-2, -1) * (self.head_size ** -0.5)  # (B, T, T)
        
        # Apply causal mask (prevent looking at future tokens)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        
        # Apply softmax
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        
        # Apply attention to values
        out = wei @ v  # (B, T, head_size)
        return out

class MultiHeadAttention(nn.Module):
    """Multi-head attention mechanism"""
    def __init__(self, n_embed, n_head, block_size, dropout=0.1):
        super().__init__()
        assert n_embed % n_head == 0, "n_embed must be divisible by n_head"
        
        head_size = n_embed // n_head
        self.heads = nn.ModuleList([
            Head(n_embed, head_size, block_size, dropout) 
            for _ in range(n_head)
        ])
        
        # Output projection
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # Concatenate outputs from all heads
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        
        # Apply output projection and dropout
        out = self.dropout(self.proj(out))
        return out

# Test attention mechanisms
print("=== Testing Attention Mechanisms ===")
n_embed = 64
n_head = 8
block_size = 32
batch_size = 2
seq_len = 16

# Test single head
head = Head(n_embed, n_embed // n_head, block_size)
test_input = torch.randn(batch_size, seq_len, n_embed)
head_output = head(test_input)

print(f"Single head input shape: {test_input.shape}")
print(f"Single head output shape: {head_output.shape}")

# Test multi-head attention
mha = MultiHeadAttention(n_embed, n_head, block_size)
mha_output = mha(test_input)

print(f"Multi-head attention output shape: {mha_output.shape}")
print(f"MHA parameters: {sum(p.numel() for p in mha.parameters()):,}")

# Verify shape preservation
assert test_input.shape == mha_output.shape, "MHA should preserve input shape"
print("✓ Attention mechanisms working correctly!")

=== Testing Attention Mechanisms ===
Single head input shape: torch.Size([2, 16, 64])
Single head output shape: torch.Size([2, 16, 8])
Multi-head attention output shape: torch.Size([2, 16, 64])
MHA parameters: 16,448
✓ Attention mechanisms working correctly!


In [8]:
# Step 8: Implement Transformer Block with MoE
class TransformerBlock(nn.Module):
    """
    Transformer block that replaces FFN with MoE layer
    Architecture: LayerNorm -> MultiHeadAttention -> LayerNorm -> MoE -> Residual connections
    """
    def __init__(self, n_embed, n_head, block_size, num_experts, top_k, dropout=0.1):
        super().__init__()
        
        # Multi-head attention
        self.sa = MultiHeadAttention(n_embed, n_head, block_size, dropout)
        
        # MoE layer (replaces traditional FFN)
        self.moe = SparseMoE(n_embed, num_experts, top_k, dropout, use_noise=True)
        
        # Layer normalizations
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)
        
    def forward(self, x):
        # First sub-layer: Multi-head attention with residual connection
        # Pre-norm: LayerNorm -> Attention -> Residual
        x = x + self.sa(self.ln1(x))
        
        # Second sub-layer: MoE with residual connection  
        # Pre-norm: LayerNorm -> MoE -> Residual
        x = x + self.moe(self.ln2(x))
        
        return x

# Test transformer block
print("=== Testing Transformer Block with MoE ===")
transformer_block = TransformerBlock(
    n_embed=64,
    n_head=8, 
    block_size=32,
    num_experts=8,
    top_k=2,
    dropout=0.1
)

# Test forward pass
test_input = torch.randn(2, 16, 64)  # (batch, seq_len, embed_dim)
block_output = transformer_block(test_input)

print(f"Input shape: {test_input.shape}")
print(f"Block output shape: {block_output.shape}")
print(f"Block parameters: {sum(p.numel() for p in transformer_block.parameters()):,}")

# Verify shape preservation
assert test_input.shape == block_output.shape, "Transformer block should preserve shape"

# Check for gradient flow
loss = block_output.sum()
loss.backward()

gradients_exist = any(p.grad is not None for p in transformer_block.parameters())
print(f"Gradients computed: {gradients_exist}")

print("✓ Transformer block with MoE working correctly!")

=== Testing Transformer Block with MoE ===
Input shape: torch.Size([2, 16, 64])
Block output shape: torch.Size([2, 16, 64])
Block parameters: 279,360
Gradients computed: True
✓ Transformer block with MoE working correctly!


In [9]:
# Step 9: Complete MoE Language Model
class MoELanguageModel(nn.Module):
    """
    Complete language model with Mixture of Experts
    """
    def __init__(self, vocab_size, n_embed, n_head, n_layer, block_size, 
                 num_experts, top_k, dropout=0.1):
        super().__init__()
        self.block_size = block_size
        self.vocab_size = vocab_size
        
        # Input embeddings
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(block_size, n_embed)
        
        # Transformer blocks with MoE
        self.blocks = nn.Sequential(*[
            TransformerBlock(n_embed, n_head, block_size, num_experts, top_k, dropout)
            for _ in range(n_layer)
        ])
        
        # Output layers
        self.ln_f = nn.LayerNorm(n_embed)  # Final layer norm
        self.lm_head = nn.Linear(n_embed, vocab_size)  # Language modeling head
        
        # Initialize weights
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        """Initialize weights using Xavier/Glorot initialization"""
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def forward(self, idx, targets=None):
        B, T = idx.shape
        
        # Input embeddings
        tok_emb = self.token_embedding_table(idx)  # (B, T, n_embed)
        pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))  # (T, n_embed)
        
        # Combine token and position embeddings
        x = tok_emb + pos_emb  # (B, T, n_embed)
        
        # Process through transformer blocks
        x = self.blocks(x)  # (B, T, n_embed)
        
        # Final layer norm and output projection
        x = self.ln_f(x)  # (B, T, n_embed)
        logits = self.lm_head(x)  # (B, T, vocab_size)
        
        # Calculate loss if targets are provided
        if targets is not None:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        else:
            loss = None
            
        return logits, loss
    
    def generate(self, idx, max_new_tokens, temperature=1.0):
        """Generate new tokens using the trained model"""
        for _ in range(max_new_tokens):
            # Crop context to block_size
            idx_cond = idx[:, -self.block_size:]
            
            # Get predictions
            logits, _ = self(idx_cond)
            
            # Focus on the last time step and apply temperature
            logits = logits[:, -1, :] / temperature  # (B, C)
            
            # Sample from distribution
            probs = F.softmax(logits, dim=-1)  # (B, C)
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            
            # Append to sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
            
        return idx

# Test the complete model
print("=== Testing Complete MoE Language Model ===")

# Model hyperparameters
model_config = {
    'vocab_size': vocab_size,  # From dataset preparation
    'n_embed': 64,
    'n_head': 8,
    'n_layer': 4,
    'block_size': 32,
    'num_experts': 8,
    'top_k': 2,
    'dropout': 0.1
}

# Create model
model = MoELanguageModel(**model_config)

# Move to device
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total model parameters: {total_params:,}")

# Test forward pass
batch_size = 4
seq_len = 16
test_idx = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
test_targets = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)

with torch.no_grad():
    logits, loss = model(test_idx, test_targets)

print(f"Input shape: {test_idx.shape}")
print(f"Logits shape: {logits.shape}")
print(f"Loss: {loss.item():.4f}")

# Test generation
print("\n=== Testing Text Generation ===")
with torch.no_grad():
    # Start with a random token
    start_idx = torch.randint(0, vocab_size, (1, 1), device=device)
    generated = model.generate(start_idx, max_new_tokens=20)
    generated_text = decode(generated[0].cpu().tolist())
    print(f"Generated text: '{generated_text}'")

print("✓ Complete MoE Language Model working correctly!")

=== Testing Complete MoE Language Model ===
Total model parameters: 1,128,001
Input shape: torch.Size([4, 16])
Logits shape: torch.Size([64, 65])
Loss: 4.1748

=== Testing Text Generation ===
Generated text: 'b:RdQhkU!mmAOLa,&gbhW'
✓ Complete MoE Language Model working correctly!
Generated text: 'b:RdQhkU!mmAOLa,&gbhW'
✓ Complete MoE Language Model working correctly!


In [10]:
# Step 10: Data Preparation and Batch Creation
def create_data_split(text, train_ratio=0.9):
    """Split data into training and validation sets"""
    data = torch.tensor(encode(text), dtype=torch.long)
    n = int(train_ratio * len(data))
    train_data = data[:n]
    val_data = data[n:]
    return train_data, val_data

def get_batch(data, batch_size, block_size, device):
    """Generate a batch of input-target pairs"""
    # Randomly select starting positions
    ix = torch.randint(len(data) - block_size, (batch_size,))
    
    # Create input sequences
    x = torch.stack([data[i:i+block_size] for i in ix])
    
    # Create target sequences (shifted by 1)
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    
    return x.to(device), y.to(device)

@torch.no_grad()
def estimate_loss(model, train_data, val_data, batch_size, block_size, eval_iters=100):
    """Estimate training and validation loss"""
    out = {}
    model.eval()
    
    for split, data in [('train', train_data), ('val', val_data)]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(data, batch_size, block_size, device)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    
    model.train()
    return out

# Prepare data
print("=== Preparing Training Data ===")
train_data, val_data = create_data_split(text, train_ratio=0.9)

print(f"Training data size: {len(train_data):,} tokens")
print(f"Validation data size: {len(val_data):,} tokens")

# Test batch creation
batch_size = 4
block_size = 32

X_batch, Y_batch = get_batch(train_data, batch_size, block_size, device)
print(f"Batch input shape: {X_batch.shape}")
print(f"Batch target shape: {Y_batch.shape}")

# Show example input-target pair
print(f"\nExample input sequence:")
print(f"'{decode(X_batch[0].cpu().tolist())}'")
print(f"\nCorresponding target sequence:")
print(f"'{decode(Y_batch[0].cpu().tolist())}'")

# Verify that target is input shifted by 1
print("\nVerifying input-target relationship:")
for i in range(min(10, block_size)):
    input_char = decode([X_batch[0][i].item()])
    target_char = decode([Y_batch[0][i].item()])
    next_input_char = decode([X_batch[0][i+1].item()]) if i+1 < block_size else "END"
    print(f"Position {i}: input='{input_char}' -> target='{target_char}' (matches next input='{next_input_char}')")

print("✓ Data preparation working correctly!")

=== Preparing Training Data ===
Training data size: 1,003,854 tokens
Validation data size: 111,540 tokens
Batch input shape: torch.Size([4, 32])
Batch target shape: torch.Size([4, 32])

Example input sequence:
'
What is't, knave?

Servant:
An '

Corresponding target sequence:
'What is't, knave?

Servant:
An h'

Verifying input-target relationship:
Position 0: input='
' -> target='W' (matches next input='W')
Position 1: input='W' -> target='h' (matches next input='h')
Position 2: input='h' -> target='a' (matches next input='a')
Position 3: input='a' -> target='t' (matches next input='t')
Position 4: input='t' -> target=' ' (matches next input=' ')
Position 5: input=' ' -> target='i' (matches next input='i')
Position 6: input='i' -> target='s' (matches next input='s')
Position 7: input='s' -> target=''' (matches next input=''')
Position 8: input=''' -> target='t' (matches next input='t')
Position 9: input='t' -> target=',' (matches next input=',')
✓ Data preparation working correctly!


In [11]:
# Step 11: Training Configuration and Hyperparameters
@dataclass
class TrainingConfig:
    # Model hyperparameters
    vocab_size: int = vocab_size
    n_embed: int = 64
    n_head: int = 8
    n_layer: int = 4
    block_size: int = 32
    num_experts: int = 8
    top_k: int = 2
    dropout: float = 0.1
    
    # Training hyperparameters
    batch_size: int = 16
    learning_rate: float = 1e-3
    max_iters: int = 1000  # Increase to 50000+ for better results
    eval_interval: int = 100
    eval_iters: int = 50
    
    # Hardware
    device: str = device

# Create training configuration
config = TrainingConfig()

print("=== Training Configuration ===")
print(f"Model Architecture:")
print(f"  - Vocabulary size: {config.vocab_size}")
print(f"  - Embedding dimension: {config.n_embed}")
print(f"  - Number of heads: {config.n_head}")
print(f"  - Number of layers: {config.n_layer}")
print(f"  - Block size: {config.block_size}")
print(f"  - Number of experts: {config.num_experts}")
print(f"  - Top-K experts: {config.top_k}")
print(f"  - Dropout: {config.dropout}")

print(f"\nTraining Parameters:")
print(f"  - Batch size: {config.batch_size}")
print(f"  - Learning rate: {config.learning_rate}")
print(f"  - Max iterations: {config.max_iters}")
print(f"  - Evaluation interval: {config.eval_interval}")
print(f"  - Device: {config.device}")

# Create and initialize model
model = MoELanguageModel(
    vocab_size=config.vocab_size,
    n_embed=config.n_embed,
    n_head=config.n_head,
    n_layer=config.n_layer,
    block_size=config.block_size,
    num_experts=config.num_experts,
    top_k=config.top_k,
    dropout=config.dropout
).to(config.device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"\nModel Statistics:")
print(f"  - Total parameters: {total_params:,}")
print(f"  - Model size: {total_params * 4 / 1e6:.1f} MB (float32)")

# Create optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)

print("✓ Training configuration ready!")

=== Training Configuration ===
Model Architecture:
  - Vocabulary size: 65
  - Embedding dimension: 64
  - Number of heads: 8
  - Number of layers: 4
  - Block size: 32
  - Number of experts: 8
  - Top-K experts: 2
  - Dropout: 0.1

Training Parameters:
  - Batch size: 16
  - Learning rate: 0.001
  - Max iterations: 1000
  - Evaluation interval: 100
  - Device: cpu

Model Statistics:
  - Total parameters: 1,128,001
  - Model size: 4.5 MB (float32)
✓ Training configuration ready!
✓ Training configuration ready!


In [12]:
# Step 12: Training Loop
import time

def train_model(model, train_data, val_data, config, optimizer):
    """Complete training loop for MoE model"""
    
    print("=== Starting Training ===")
    start_time = time.time()
    
    # Training metrics
    train_losses = []
    val_losses = []
    
    for iter_num in range(config.max_iters):
        # Evaluate model periodically
        if iter_num % config.eval_interval == 0 or iter_num == config.max_iters - 1:
            losses = estimate_loss(
                model, train_data, val_data, 
                config.batch_size, config.block_size, config.eval_iters
            )
            
            elapsed_time = time.time() - start_time
            print(f"Step {iter_num:4d} | "
                  f"Train Loss: {losses['train']:.4f} | "
                  f"Val Loss: {losses['val']:.4f} | "
                  f"Time: {elapsed_time:.1f}s")
            
            train_losses.append(losses['train'])
            val_losses.append(losses['val'])
        
        # Get training batch
        xb, yb = get_batch(train_data, config.batch_size, config.block_size, config.device)
        
        # Forward pass
        logits, loss = model(xb, yb)
        
        # Backward pass
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
    
    total_time = time.time() - start_time
    print(f"\nTraining completed in {total_time:.1f}s")
    print(f"Final train loss: {train_losses[-1]:.4f}")
    print(f"Final val loss: {val_losses[-1]:.4f}")
    
    return train_losses, val_losses

# Start training
print("🚀 Starting MoE training...")
print(f"Training for {config.max_iters} iterations...")

train_losses, val_losses = train_model(model, train_data, val_data, config, optimizer)

print("✅ Training completed successfully!")

🚀 Starting MoE training...
Training for 1000 iterations...
=== Starting Training ===
Step    0 | Train Loss: 4.1660 | Val Loss: 4.1661 | Time: 3.3s
Step    0 | Train Loss: 4.1660 | Val Loss: 4.1661 | Time: 3.3s
Step  100 | Train Loss: 2.6180 | Val Loss: 2.6337 | Time: 21.0s
Step  100 | Train Loss: 2.6180 | Val Loss: 2.6337 | Time: 21.0s
Step  200 | Train Loss: 2.4624 | Val Loss: 2.4869 | Time: 38.5s
Step  200 | Train Loss: 2.4624 | Val Loss: 2.4869 | Time: 38.5s
Step  300 | Train Loss: 2.4151 | Val Loss: 2.4216 | Time: 76.6s
Step  300 | Train Loss: 2.4151 | Val Loss: 2.4216 | Time: 76.6s
Step  400 | Train Loss: 2.3164 | Val Loss: 2.3282 | Time: 123.0s
Step  400 | Train Loss: 2.3164 | Val Loss: 2.3282 | Time: 123.0s
Step  500 | Train Loss: 2.2516 | Val Loss: 2.2752 | Time: 196.5s
Step  500 | Train Loss: 2.2516 | Val Loss: 2.2752 | Time: 196.5s
Step  600 | Train Loss: 2.1812 | Val Loss: 2.2254 | Time: 277.0s
Step  600 | Train Loss: 2.1812 | Val Loss: 2.2254 | Time: 277.0s
Step  700 | Tra

In [13]:
# Step 13: Inference and Text Generation
def generate_text(model, prompt="", max_length=100, temperature=1.0, top_k=None):
    """Generate text using the trained MoE model"""
    model.eval()
    
    # Encode prompt or start with random token
    if prompt:
        context = torch.tensor(encode(prompt), dtype=torch.long, device=device).unsqueeze(0)
    else:
        # Start with a random character
        context = torch.randint(0, vocab_size, (1, 1), device=device)
    
    # Generate text
    with torch.no_grad():
        generated = model.generate(context, max_length, temperature)
    
    # Decode and return
    generated_text = decode(generated[0].cpu().tolist())
    model.train()
    return generated_text

# Test text generation with different settings
print("=== Testing Text Generation ===")

# Test 1: Random start, low temperature (more focused)
print("1. Random start, low temperature (focused):")
text1 = generate_text(model, prompt="", max_length=150, temperature=0.8)
print(f"Generated: '{text1}'")

print("\n" + "="*50)

# Test 2: Random start, high temperature (more creative)
print("2. Random start, high temperature (creative):")
text2 = generate_text(model, prompt="", max_length=150, temperature=1.2)
print(f"Generated: '{text2}'")

print("\n" + "="*50)

# Test 3: With prompt
print("3. With prompt:")
prompt = "ROMEO:"
text3 = generate_text(model, prompt=prompt, max_length=100, temperature=1.0)
print(f"Prompt: '{prompt}'")
print(f"Generated: '{text3}'")

print("\n" + "="*50)

# Test 4: Another character
print("4. Different character:")
prompt = "JULIET:"
text4 = generate_text(model, prompt=prompt, max_length=100, temperature=1.0)
print(f"Prompt: '{prompt}'")
print(f"Generated: '{text4}'")

print("\n✅ Text generation completed!")

=== Testing Text Generation ===
1. Random start, low temperature (focused):
Generated: 'for the duve will my not lods not! Yord these of to the seee!


FOpet bee the the beor of bear and the theve mugned andle in and that's I mim not that '

2. Random start, high temperature (creative):
Generated: 'for the duve will my not lods not! Yord these of to the seee!


FOpet bee the the beor of bear and the theve mugned andle in and that's I mim not that '

2. Random start, high temperature (creative):
Generated: 'do-radelfuf thuu, siousy ow ace SI me alle ow
nlibevh y thybh shimentle there thard foay with 'll Sexpus'sse hall'd frive'd go earterh
The hy livins, m'

3. With prompt:
Generated: 'do-radelfuf thuu, siousy ow ace SI me alle ow
nlibevh y thybh shimentle there thard foay with 'll Sexpus'sse hall'd frive'd go earterh
The hy livins, m'

3. With prompt:
Prompt: 'ROMEO:'
Generated: 'ROMEO:
Comade hencegod a then:
We cuip the be be thard a hour of inising and the wom mese,ry weeet
My to d'

In [None]:
# Step 14: Model Analysis and Expert Utilization
def analyze_expert_usage(model, data, num_batches=10):
    """Analyze how experts are being utilized"""
    model.eval()
    expert_usage = {}
    
    with torch.no_grad():
        for batch_idx in range(num_batches):
            # Get batch
            x, _ = get_batch(data, config.batch_size, config.block_size, config.device)
            
            # Ensure x is on the correct device and dtype
            x = x.to(config.device)
            
            # Track expert usage in each MoE layer
            for layer_idx, block in enumerate(model.blocks):
                if hasattr(block, 'moe'):
                    # Apply layer norm first, ensuring proper device/dtype
                    normalized_x = block.ln2(x)
                    
                    # Get routing decisions
                    routing_weights, expert_indices = block.moe.router(normalized_x)
                    
                    # Count expert usage
                    for expert_id in range(config.num_experts):
                        expert_mask = (expert_indices == expert_id)
                        usage_count = expert_mask.sum().item()
                        
                        key = f"layer_{layer_idx}_expert_{expert_id}"
                        if key not in expert_usage:
                            expert_usage[key] = 0
                        expert_usage[key] += usage_count
    
    model.train()
    return expert_usage

def print_expert_analysis(expert_usage, config):
    """Print expert utilization analysis"""
    print("=== Expert Utilization Analysis ===")
    
    for layer_idx in range(config.n_layer):
        print(f"\nLayer {layer_idx}:")
        layer_usage = []
        
        for expert_id in range(config.num_experts):
            key = f"layer_{layer_idx}_expert_{expert_id}"
            usage = expert_usage.get(key, 0)
            layer_usage.append(usage)
            print(f"  Expert {expert_id}: {usage:4d} tokens")
        
        # Calculate balance metrics
        total_usage = sum(layer_usage)
        if total_usage > 0:
            usage_percentages = [u/total_usage*100 for u in layer_usage]
            std_dev = np.std(usage_percentages)
            print(f"  Balance (std dev): {std_dev:.2f}% (lower is better)")
            print(f"  Most used expert: {np.argmax(layer_usage)} ({max(usage_percentages):.1f}%)")
            print(f"  Least used expert: {np.argmin(layer_usage)} ({min(usage_percentages):.1f}%)")

# Analyze expert usage
print("Analyzing expert utilization...")
expert_usage = analyze_expert_usage(model, val_data, num_batches=20)
print_expert_analysis(expert_usage, config)

# Model statistics
print(f"\n=== Model Statistics ===")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

# Parameters by component
attention_params = sum(p.numel() for block in model.blocks for p in block.sa.parameters())
moe_params = sum(p.numel() for block in model.blocks for p in block.moe.parameters())
embedding_params = sum(p.numel() for p in [model.token_embedding_table, model.position_embedding_table])
output_params = sum(p.numel() for p in [model.ln_f, model.lm_head])

print(f"Attention parameters: {attention_params:,} ({attention_params/sum(p.numel() for p in model.parameters())*100:.1f}%)")
print(f"MoE parameters: {moe_params:,} ({moe_params/sum(p.numel() for p in model.parameters())*100:.1f}%)")
print(f"Embedding parameters: {embedding_params:,} ({embedding_params/sum(p.numel() for p in model.parameters())*100:.1f}%)")
print(f"Output parameters: {output_params:,} ({output_params/sum(p.numel() for p in model.parameters())*100:.1f}%)")

# Active parameters (considering sparsity)
active_experts_per_token = config.top_k
expert_params_per_layer = moe_params // (config.n_layer * config.num_experts)
active_moe_params = expert_params_per_layer * active_experts_per_token * config.n_layer
active_total = attention_params + active_moe_params + embedding_params + output_params

print(f"\nActive parameters per forward pass:")
print(f"Active MoE parameters: {active_moe_params:,}")
print(f"Total active parameters: {active_total:,}")
print(f"Sparsity ratio: {active_total/sum(p.numel() for p in model.parameters())*100:.1f}%")

print("✅ Model analysis completed!")

Analyzing expert utilization...


RuntimeError: mixed dtype (CPU): all inputs must share same datatype.

# Step 15: Exercises and Extensions

## 🚀 Congratulations!
You've successfully implemented a complete Mixture of Experts model from scratch! This implementation includes:

✅ **Expert Networks**: Individual feed-forward neural networks  
✅ **Router Mechanism**: Top-K routing with load balancing  
✅ **Sparse MoE Layer**: Efficient expert combination  
✅ **Complete Transformer**: MoE-enabled transformer architecture  
✅ **Training Pipeline**: Full pre-training on Shakespeare dataset  
✅ **Inference System**: Text generation with trained model  
✅ **Analysis Tools**: Expert utilization metrics  

## 🔬 Exercises to Try

### 1. **Capacity Factor Implementation**
Add expert capacity constraints to prevent overloading:
```python
# Implement capacity factor to limit tokens per expert
def add_capacity_factor(self, capacity_factor=1.25):
    # Calculate expert capacity based on total tokens and capacity factor
    pass
```

### 2. **DeepSeek Auxiliary Loss-Free Load Balancing**
Implement DeepSeek's bias-based load balancing:
```python
# Add dynamic bias terms to router
class DeepSeekRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k, update_rate=0.1):
        # Initialize bias terms
        self.bias_terms = nn.Parameter(torch.zeros(num_experts))
        # Implement bias update logic
        pass
```

### 3. **Shared Experts Architecture**
Add DeepSeek's shared experts alongside routed experts:
```python
# Implement shared + routed experts
class SharedMoE(nn.Module):
    def __init__(self, n_embed, num_shared, num_routed, top_k):
        self.shared_experts = nn.ModuleList([...])  # Always active
        self.routed_experts = nn.ModuleList([...])  # Selectively active
        pass
```

### 4. **Fine-Grained Expert Segmentation**
Implement more experts with smaller dimensions:
```python
# Increase expert count while maintaining parameter count
def create_segmented_experts(base_experts, segmentation_factor):
    # Split experts into smaller specialized experts
    pass
```

### 5. **Load Balancing Loss**
Add traditional auxiliary loss for comparison:
```python
# Implement Fi * Pi load balancing loss
def calculate_load_balance_loss(routing_weights, expert_assignments):
    # Calculate expert importance and token fractions
    # Minimize Fi * Pi for better balance
    pass
```

## 🎯 Advanced Challenges

### Performance Optimization
- **Memory Efficiency**: Implement expert parallelization
- **Speed Optimization**: Add CUDA kernels for routing
- **Dynamic Routing**: Implement learned routing strategies

### Architecture Innovations
- **Hierarchical Experts**: Multi-level expert organization
- **Mixture of Depths**: Variable computation per token
- **Adaptive Sparsity**: Dynamic top-k selection

### Training Improvements
- **Curriculum Learning**: Progressive expert activation
- **Expert Specialization**: Guided expert training
- **Regularization**: Novel techniques for expert diversity

## 📊 Experiment Ideas

1. **Compare Architectures**: Dense vs Sparse MoE performance
2. **Scaling Study**: Effect of expert count on quality
3. **Load Balancing**: Compare different balancing strategies
4. **Domain Adaptation**: Train experts on different text types
5. **Efficiency Analysis**: Measure computational savings

## 🔧 Production Considerations

When scaling this implementation:
- **Distributed Training**: Multi-GPU expert placement
- **Inference Optimization**: Expert caching strategies
- **Model Serving**: Efficient expert loading
- **Monitoring**: Expert utilization tracking

## 📚 Further Reading

- **Original MoE Paper**: "Outrageously Large Neural Networks" (Shazeer et al.)
- **Switch Transformer**: Improved MoE scaling (Fedus et al.)
- **DeepSeek Papers**: V2 and V3 innovations
- **GLaM**: Efficient MoE training (Du et al.)

## 🎉 Next Steps

You now have a solid foundation in MoE architectures! Use this implementation to:
- Research new routing mechanisms
- Experiment with expert specialization
- Scale to larger models and datasets
- Contribute to the open-source community

**Happy experimenting with Mixture of Experts!** 🚀