# Multi-Token Prediction (MTP) Implementation from Scratch

## Overview
This notebook implements DeepSeek's Multi-Token Prediction mechanism from scratch. MTP is one of the three key innovations in DeepSeek V3 architecture, alongside Multi-Head Latent Attention (MLA) and Mixture of Experts (MoE).

## What We'll Build
1. **RMS Normalization Class**: For normalizing hidden states and embeddings
2. **Multi-Token Prediction Class**: Complete MTP implementation with causal chains
3. **Forward Pass**: Generate multiple future tokens for each input token
4. **Loss Calculation**: Compute loss between predicted and target tokens

## Key Concepts
- **Prediction Depth (k)**: Number of future tokens to predict (e.g., k=3 means predict 3 tokens ahead)
- **Causal Chain**: Each prediction head uses hidden state from previous head
- **Input Requirements**: Hidden state + input embedding for each prediction head
- **Output**: Multi-dimensional tensor containing predictions for all input positions and depths

Let's implement this step by step!

In [None]:
# Step 0: Import Required Packages
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Set random seed for reproducibility
torch.manual_seed(42)

# Check device availability
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

print("✅ Packages imported successfully!")

In [None]:
# Step 1: Define RMS Normalization Class
class RMSNorm(nn.Module):
    """
    Root Mean Square Normalization
    
    RMS normalization formula: x / sqrt(mean(x²) + ε)
    This differs from LayerNorm which uses: (x - mean(x)) / sqrt(var(x) + ε)
    """
    def __init__(self, d_model, eps=1e-8):
        super().__init__()
        self.eps = eps
        self.d_model = d_model
        
    def forward(self, x):
        # Calculate RMS: sqrt(mean(x²))
        # x shape: (..., d_model)
        
        # Step 1: Square all elements
        x_squared = x ** 2
        
        # Step 2: Take mean along the last dimension (d_model)
        mean_squared = x_squared.mean(dim=-1, keepdim=True)
        
        # Step 3: Take square root and add epsilon to prevent division by zero
        rms = torch.sqrt(mean_squared + self.eps)
        
        # Step 4: Normalize by dividing each element by RMS
        return x / rms

# Test RMS Normalization
print("=== Testing RMS Normalization ===")
d_model = 8
rms_norm = RMSNorm(d_model)

# Create test tensor
test_tensor = torch.randn(2, 4, d_model)  # (batch_size, seq_len, d_model)
print(f"Input shape: {test_tensor.shape}")
print(f"Input sample:\n{test_tensor[0, 0, :]}")

# Apply RMS normalization
normalized = rms_norm(test_tensor)
print(f"Output shape: {normalized.shape}")
print(f"Output sample:\n{normalized[0, 0, :]}")

# Verify normalization (RMS should be approximately 1.0)
rms_value = torch.sqrt((normalized[0, 0, :] ** 2).mean())
print(f"RMS of normalized tensor: {rms_value:.6f} (should be ≈ 1.0)")

print("✅ RMS Normalization working correctly!")

In [None]:
# Step 2: Multi-Token Prediction Class (Main Implementation)
class SimpleMTP(nn.Module):
    """
    Simple Multi-Token Prediction implementation based on DeepSeek's approach
    
    Key Features:
    - Predicts multiple tokens (depth k) for each input position
    - Maintains causal chain between prediction heads
    - Uses RMS normalization before merging hidden states and embeddings
    """
    def __init__(self, d_model, vocab_size, num_heads, nhead=8, dim_feedforward=2048):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.num_heads = num_heads  # This is prediction depth, not attention heads
        
        # RMS normalization layers
        self.rms_norm = RMSNorm(d_model)
        
        # Projection layers: (2*d_model) -> d_model
        # We concatenate hidden_state (d_model) + input_embedding (d_model) = 2*d_model
        self.projections = nn.Linear(2 * d_model, d_model)
        
        # Transformer encoder layer for processing merged embeddings
        self.transformer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            batch_first=True,
            dropout=0.1
        )
        
        # Shared unembedding matrix: d_model -> vocab_size
        self.unembedding = nn.Linear(d_model, vocab_size)
        
        # Token embeddings
        self.embeddings = nn.Embedding(vocab_size, d_model)
        
    def forward(self, input_tokens):
        """
        Forward pass for Multi-Token Prediction
        
        Args:
            input_tokens: (batch_size, seq_len) - Input token indices
            
        Returns:
            logits: (batch_size, max_i, num_heads, vocab_size) - Predictions for each position and depth
        """
        batch_size, seq_len = input_tokens.shape
        
        # Get token embeddings
        embeds = self.embeddings(input_tokens)  # (batch_size, seq_len, d_model)
        
        # Calculate maximum input position we can predict from
        # We need 'num_heads' future positions to exist
        max_i = seq_len - self.num_heads
        
        # Initialize output list to store predictions
        all_predictions = []
        
        # Outer loop: Iterate over input token positions (i = 0, 1, 2, ..., max_i-1)
        for i in range(max_i):
            
            # Initialize h_previous for this input position
            # h_previous starts as the embedding of the current token
            h_previous = embeds[:, i:i+1, :]  # (batch_size, 1, d_model)
            
            # Store predictions for this input position across all depths
            position_predictions = []
            
            # Inner loop: Iterate over prediction depths (k = 0, 1, 2, ..., num_heads-1)
            for k in range(self.num_heads):
                
                # Calculate future position we're predicting
                future_pos = i + k + 1  # +1 because Python is 0-indexed
                
                # Get input embedding at the future position
                token_embedding = embeds[:, future_pos:future_pos+1, :]  # (batch_size, 1, d_model)
                
                # === HEAD OPERATIONS ===
                
                # Step 1: RMS Normalization of both inputs
                h_norm = self.rms_norm(h_previous)      # (batch_size, 1, d_model)
                e_norm = self.rms_norm(token_embedding)  # (batch_size, 1, d_model)
                
                # Step 2: Merge (concatenate) normalized hidden state and embedding
                merged = torch.cat([h_norm, e_norm], dim=-1)  # (batch_size, 1, 2*d_model)
                
                # Step 3: Linear projection back to d_model
                projected = self.projections(merged)  # (batch_size, 1, d_model)
                
                # Step 4: Pass through transformer block
                h_current = self.transformer(projected)  # (batch_size, 1, d_model)
                
                # Step 5: Generate logits using shared unembedding matrix
                logits_k = self.unembedding(h_current)  # (batch_size, 1, vocab_size)
                
                # Store prediction for this depth
                position_predictions.append(logits_k)
                
                # Update h_previous for next iteration (causal chain)
                h_previous = h_current
            
            # Stack predictions for this position across all depths
            # From list of (batch_size, 1, vocab_size) to (batch_size, num_heads, vocab_size)
            position_logits = torch.cat(position_predictions, dim=1)  # (batch_size, num_heads, vocab_size)
            all_predictions.append(position_logits)
        
        # Stack all position predictions
        # From list of (batch_size, num_heads, vocab_size) to (batch_size, max_i, num_heads, vocab_size)
        final_logits = torch.stack(all_predictions, dim=1)  # (batch_size, max_i, num_heads, vocab_size)
        
        return final_logits

print("✅ SimpleMTP class defined successfully!")

In [None]:
# Step 3: Generate Next Tokens (Testing the Model)

# Define model hyperparameters
batch_size = 1          # Number of sequences to process
seq_len = 8             # Length of input sequence (T)
d_model = 8             # Embedding dimension
vocab_size = 5000       # Vocabulary size
num_heads = 3           # Prediction depth (number of future tokens to predict)

print("=== Model Configuration ===")
print(f"Batch size: {batch_size}")
print(f"Sequence length: {seq_len}")
print(f"Model dimension: {d_model}")
print(f"Vocabulary size: {vocab_size}")
print(f"Prediction depth: {num_heads}")
print(f"Predictable positions: {seq_len - num_heads} (positions 0 to {seq_len - num_heads - 1})")

# Create model instance
model = SimpleMTP(
    d_model=d_model,
    vocab_size=vocab_size,
    num_heads=num_heads,
    nhead=4,  # Number of attention heads in transformer
    dim_feedforward=32
)

# Move to device
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"\nModel parameters: {total_params:,}")

# Create batch of input tokens (randomly sampled)
input_tokens = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
print(f"\nInput tokens shape: {input_tokens.shape}")
print(f"Sample input tokens: {input_tokens[0].tolist()}")

# Forward pass through model
with torch.no_grad():
    logits = model(input_tokens)

print(f"\n=== Output Analysis ===")
print(f"Output logits shape: {logits.shape}")
print(f"Expected shape: (batch_size={batch_size}, max_i={seq_len-num_heads}, num_heads={num_heads}, vocab_size={vocab_size})")

# Analyze output dimensions
batch_dim, max_i_dim, heads_dim, vocab_dim = logits.shape
print(f"\nDimension breakdown:")
print(f"  Batch dimension: {batch_dim} (batch_size)")
print(f"  Position dimension: {max_i_dim} (T - D = {seq_len} - {num_heads} = {seq_len-num_heads})")
print(f"  Heads dimension: {heads_dim} (prediction depth)")
print(f"  Vocabulary dimension: {vocab_dim} (vocab_size)")

# Test specific predictions
print(f"\n=== Prediction Examples ===")

# Example 1: Predictions for input position i=0
print(f"1. Predictions for input position i=0:")
i0_logits = logits[0, 0, :, :]  # Shape: (num_heads, vocab_size)
i0_predictions = torch.argmax(i0_logits, dim=-1)  # Shape: (num_heads,)
print(f"   Predicted tokens: {i0_predictions.tolist()}")
print(f"   Shape: {i0_predictions.shape}")

# Example 2: Predictions for first head (k=0) across all positions
print(f"\n2. Predictions for head k=0 across all positions:")
k0_logits = logits[0, :, 0, :]  # Shape: (max_i, vocab_size)
k0_predictions = torch.argmax(k0_logits, dim=-1)  # Shape: (max_i,)
print(f"   Predicted tokens: {k0_predictions.tolist()}")
print(f"   Shape: {k0_predictions.shape}")

# Example 3: Single prediction (i=0, k=0)
print(f"\n3. Single prediction for i=0, k=0:")
single_logits = logits[0, 0, 0, :]  # Shape: (vocab_size,)
single_prediction = torch.argmax(single_logits)
print(f"   Predicted token: {single_prediction.item()}")
print(f"   Logits shape: {single_logits.shape}")

print("\n✅ Model forward pass completed successfully!"))

In [None]:
# Step 4: Loss Function Calculation

def compute_mtp_loss(logits, input_tokens):
    """
    Compute Multi-Token Prediction loss
    
    Args:
        logits: (batch_size, max_i, num_heads, vocab_size) - Model predictions
        input_tokens: (batch_size, seq_len) - Original input tokens
    
    Returns:
        loss: Scalar tensor representing average loss across all predictions
    """
    batch_size, max_i, num_heads, vocab_size = logits.shape
    total_loss = 0.0
    total_predictions = 0
    
    # Iterate over input positions
    for i in range(max_i):
        
        # Iterate over prediction depths for this position
        for k in range(num_heads):
            
            # Get predicted logits for position i, depth k
            predicted_logits = logits[:, i, k, :]  # (batch_size, vocab_size)
            
            # Get target token at position i + k + 1 (future position)
            target_pos = i + k + 1
            target_tokens = input_tokens[:, target_pos]  # (batch_size,)
            
            # Compute cross-entropy loss for this prediction
            loss_ik = F.cross_entropy(predicted_logits, target_tokens)
            
            # Add to total loss
            total_loss += loss_ik
            total_predictions += 1
    
    # Return average loss
    return total_loss / total_predictions

# Test loss calculation
print("=== Loss Function Calculation ===")

# Compute loss
loss = compute_mtp_loss(logits, input_tokens)
print(f"Total MTP Loss: {loss.item():.6f}")

# Break down loss calculation for understanding
batch_size, max_i, num_heads, vocab_size = logits.shape
print(f"\nLoss breakdown:")
print(f"  Total positions: {max_i}")
print(f"  Predictions per position: {num_heads}")
print(f"  Total predictions: {max_i * num_heads}")
print(f"  Average loss per prediction: {loss.item():.6f}")

# Detailed loss analysis for first few predictions
print(f"\n=== Detailed Loss Analysis ===")
individual_losses = []

for i in range(min(3, max_i)):  # Show first 3 positions
    print(f"\nPosition i={i}:")
    
    for k in range(num_heads):
        # Get prediction and target
        predicted_logits = logits[0, i, k, :]  # (vocab_size,)
        target_pos = i + k + 1
        target_token = input_tokens[0, target_pos].item()
        
        # Compute loss for this specific prediction
        loss_ik = F.cross_entropy(predicted_logits.unsqueeze(0), 
                                 input_tokens[0, target_pos:target_pos+1])
        individual_losses.append(loss_ik.item())
        
        # Get predicted token
        predicted_token = torch.argmax(predicted_logits).item()
        
        print(f"  Depth k={k}: Target={target_token}, Predicted={predicted_token}, Loss={loss_ik.item():.4f}")

print(f"\nIndividual losses: {[f'{l:.4f}' for l in individual_losses]}")
print(f"Mean of individual losses: {sum(individual_losses) / len(individual_losses):.6f}")

# Demonstrate backpropagation
print(f"\n=== Backpropagation Demo ===")
model.train()

# Create optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Forward pass
logits = model(input_tokens)
loss = compute_mtp_loss(logits, input_tokens)

print(f"Loss before backprop: {loss.item():.6f}")

# Backward pass
optimizer.zero_grad()
loss.backward()

# Check gradients
total_grad_norm = 0.0
for name, param in model.named_parameters():
    if param.grad is not None:
        grad_norm = param.grad.norm().item()
        total_grad_norm += grad_norm ** 2

total_grad_norm = total_grad_norm ** 0.5
print(f"Total gradient norm: {total_grad_norm:.6f}")

# Update parameters
optimizer.step()

# Forward pass after update
with torch.no_grad():
    new_logits = model(input_tokens)
    new_loss = compute_mtp_loss(new_logits, input_tokens)

print(f"Loss after one step: {new_loss.item():.6f}")
print(f"Loss change: {new_loss.item() - loss.item():.6f}")

print("\n✅ Loss calculation and backpropagation working correctly!")

In [None]:
# Visualization and Understanding

def visualize_mtp_predictions(input_tokens, logits, max_examples=3):
    """
    Visualize Multi-Token Prediction results in a clear format
    """
    print("=== Multi-Token Prediction Visualization ===")
    
    batch_size, max_i, num_heads, vocab_size = logits.shape
    seq_len = input_tokens.shape[1]
    
    print(f"Input sequence: {input_tokens[0].tolist()}")
    print(f"Sequence length: {seq_len}")
    print(f"Prediction depth: {num_heads}")
    print(f"Predictable positions: {max_i} (positions 0 to {max_i-1})")
    
    print(f"\n{'='*80}")
    print(f"{'Input Pos':<10} {'Predictions':<50} {'Targets':<20}")
    print(f"{'='*80}")
    
    for i in range(min(max_examples, max_i)):
        # Get predictions for this position
        position_logits = logits[0, i, :, :]  # (num_heads, vocab_size)
        predictions = torch.argmax(position_logits, dim=-1)  # (num_heads,)
        
        # Get target tokens
        targets = []
        for k in range(num_heads):
            target_pos = i + k + 1
            targets.append(input_tokens[0, target_pos].item())
        
        # Format output
        pred_str = f"{predictions.tolist()}"
        target_str = f"{targets}"
        
        print(f"i={i:<8} {pred_str:<50} {target_str:<20}")
        
        # Show which future positions these correspond to
        future_positions = [i + k + 1 for k in range(num_heads)]
        print(f"{'':>10} Future positions: {future_positions}")
        print()

# Run visualization
visualize_mtp_predictions(input_tokens, logits)

print("=== Understanding the Causal Chain ===")
print("""
The key innovation in DeepSeek's MTP is the causal chain between prediction heads:

1. **Head 1 (k=0)**: 
   - Input: h₀ (from transformer) + embedding at position i+1
   - Output: hidden_state₁ + prediction₁

2. **Head 2 (k=1)**:
   - Input: hidden_state₁ (from Head 1) + embedding at position i+2  
   - Output: hidden_state₂ + prediction₂

3. **Head 3 (k=2)**:
   - Input: hidden_state₂ (from Head 2) + embedding at position i+3
   - Output: hidden_state₃ + prediction₃

This creates dependencies: prediction₂ depends on prediction₁, prediction₃ depends on prediction₂.
This is different from the original Meta paper where predictions were independent.
""")

In [None]:
# Experimental Variations and Comparisons

# Comparison: Independent vs Causal MTP
class IndependentMTP(nn.Module):
    """
    Independent Multi-Token Prediction (like original Meta paper)
    Each head predicts independently without hidden state passing
    """
    def __init__(self, d_model, vocab_size, num_heads, nhead=8, dim_feedforward=2048):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.num_heads = num_heads
        
        self.rms_norm = RMSNorm(d_model)
        self.projections = nn.Linear(2 * d_model, d_model)
        self.transformer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
            batch_first=True, dropout=0.1
        )
        self.unembedding = nn.Linear(d_model, vocab_size)
        self.embeddings = nn.Embedding(vocab_size, d_model)
        
    def forward(self, input_tokens):
        batch_size, seq_len = input_tokens.shape
        embeds = self.embeddings(input_tokens)
        max_i = seq_len - self.num_heads
        all_predictions = []
        
        for i in range(max_i):
            # Use same initial hidden state for ALL heads (no causality)
            h_initial = embeds[:, i:i+1, :]
            position_predictions = []
            
            for k in range(self.num_heads):
                future_pos = i + k + 1
                token_embedding = embeds[:, future_pos:future_pos+1, :]
                
                # Always use initial hidden state (no causal chain)
                h_norm = self.rms_norm(h_initial)
                e_norm = self.rms_norm(token_embedding)
                merged = torch.cat([h_norm, e_norm], dim=-1)
                projected = self.projections(merged)
                h_current = self.transformer(projected)
                logits_k = self.unembedding(h_current)
                
                position_predictions.append(logits_k)
            
            position_logits = torch.cat(position_predictions, dim=1)
            all_predictions.append(position_logits)
        
        return torch.stack(all_predictions, dim=1)

# Compare the two approaches
print("=== Comparing Causal vs Independent MTP ===")

# Create both models
causal_model = SimpleMTP(d_model, vocab_size, num_heads, nhead=4, dim_feedforward=32).to(device)
independent_model = IndependentMTP(d_model, vocab_size, num_heads, nhead=4, dim_feedforward=32).to(device)

# Forward pass with same input
with torch.no_grad():
    causal_logits = causal_model(input_tokens)
    independent_logits = independent_model(input_tokens)

# Compare predictions
print(f"Input tokens: {input_tokens[0].tolist()}")
print(f"\nPredictions for position i=0:")

causal_preds = torch.argmax(causal_logits[0, 0, :, :], dim=-1)
independent_preds = torch.argmax(independent_logits[0, 0, :, :], dim=-1)

print(f"Causal MTP:      {causal_preds.tolist()}")
print(f"Independent MTP: {independent_preds.tolist()}")
print(f"Difference:      {(causal_preds != independent_preds).sum().item()} out of {num_heads} predictions differ")

# Compute losses
causal_loss = compute_mtp_loss(causal_logits, input_tokens)
independent_loss = compute_mtp_loss(independent_logits, input_tokens)

print(f"\nLoss comparison:")
print(f"Causal MTP loss:      {causal_loss.item():.6f}")
print(f"Independent MTP loss: {independent_loss.item():.6f}")
print(f"Difference:           {causal_loss.item() - independent_loss.item():.6f}")

print("\n✅ Comparison completed!")

In [None]:
# Training Simulation and Key Insights

def simulate_training_step(model, input_tokens, num_steps=5):
    """
    Simulate a few training steps to show loss improvement
    """
    print("=== Training Simulation ===")
    
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    
    initial_logits = None
    
    for step in range(num_steps):
        # Forward pass
        logits = model(input_tokens)
        loss = compute_mtp_loss(logits, input_tokens)
        
        if step == 0:
            initial_logits = logits.clone()
        
        print(f"Step {step + 1}: Loss = {loss.item():.6f}")
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # Final evaluation
    model.eval()
    with torch.no_grad():
        final_logits = model(input_tokens)
        final_loss = compute_mtp_loss(final_logits, input_tokens)
    
    print(f"Final loss: {final_loss.item():.6f}")
    
    # Compare predictions before and after training
    print(f"\n=== Prediction Changes ===")
    initial_preds = torch.argmax(initial_logits[0, 0, :, :], dim=-1)
    final_preds = torch.argmax(final_logits[0, 0, :, :], dim=-1)
    
    targets = [input_tokens[0, k + 1].item() for k in range(num_heads)]
    
    print(f"Targets:           {targets}")
    print(f"Initial predictions: {initial_preds.tolist()}")
    print(f"Final predictions:   {final_preds.tolist()}")
    
    # Check accuracy
    initial_accuracy = (initial_preds == torch.tensor(targets, device=device)).float().mean().item()
    final_accuracy = (final_preds == torch.tensor(targets, device=device)).float().mean().item()
    
    print(f"Initial accuracy: {initial_accuracy:.2%}")
    print(f"Final accuracy:   {final_accuracy:.2%}")
    print(f"Improvement:      {final_accuracy - initial_accuracy:.2%}")

# Run training simulation
simulate_training_step(causal_model, input_tokens)

print(f"\n{'='*60}")
print("🎯 KEY INSIGHTS FROM IMPLEMENTATION")
print(f"{'='*60}")

insights = [
    "1. **Causal Dependencies**: Each prediction head uses hidden state from previous head",
    "2. **RMS Normalization**: Applied before merging hidden state and input embedding", 
    "3. **Shared Unembedding**: Same vocabulary projection used across all heads",
    "4. **Sequence Boundaries**: Can only predict from positions with sufficient future context",
    "5. **Loss Aggregation**: Sum losses across all positions and depths, then average",
    "6. **Training Benefits**: Richer gradients from multiple prediction targets per token",
    "7. **Inference Strategy**: DeepSeek discards MTP modules during inference for simplicity"
]

for insight in insights:
    print(insight)

print(f"\n{'='*60}")
print("🚀 EXTENSIONS TO EXPERIMENT WITH")
print(f"{'='*60}")

extensions = [
    "• **Variable Depth**: Different prediction depths for different positions",
    "• **Attention in Heads**: Add attention mechanisms within each prediction head",
    "• **Learnable Weights**: Weighted combination of predictions from different depths",
    "• **Hierarchical MTP**: Multi-scale token prediction (characters, sub-words, words)",
    "• **Conditional MTP**: Prediction depth based on input token uncertainty",
    "• **Efficient Implementation**: Parallel processing of independent predictions"
]

for extension in extensions:
    print(extension)

print(f"\n✅ Multi-Token Prediction implementation completed successfully!")
print("This implementation captures the core concepts from DeepSeek's MTP architecture.")
print("You can now experiment with different configurations and extend the functionality!")