# Speculative Decoding: Accelerating Autoregressive Generation

## Overview

Speculative decoding is an innovative technique to speed up autoregressive text generation by using a smaller "draft" model to predict multiple tokens ahead, which are then verified by a larger "target" model. This approach can significantly reduce the number of forward passes through the large model while maintaining the same output quality.

### Key Concepts:
- **Draft Model**: Small, fast model that generates candidate tokens
- **Target Model**: Large, high-quality model that verifies candidates
- **Acceptance Rate**: Fraction of draft tokens accepted by target model
- **Lookahead Distance**: Number of tokens the draft model predicts ahead

### Trade-offs:
- **Pros**: Higher throughput, reduced latency for large models
- **Cons**: Increased complexity, memory usage, dependency on draft model quality

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Tuple, Optional
import time
from dataclasses import dataclass

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("Dependencies imported successfully!")

## Mathematical Foundation

### Standard Autoregressive Generation
In standard autoregressive generation, we sample tokens sequentially:

$$x_{t+1} \sim p_{\text{target}}(x_{t+1} | x_{1:t})$$

This requires one forward pass per token, making generation slow for large models.

### Speculative Decoding Algorithm

1. **Draft Phase**: Use small model to generate $k$ candidate tokens:
   $$\tilde{x}_{t+1}, \tilde{x}_{t+2}, ..., \tilde{x}_{t+k} \sim p_{\text{draft}}$$

2. **Verification Phase**: Target model computes probabilities for all candidates in parallel:
   $$p_{\text{target}}(\tilde{x}_{t+i} | x_{1:t}, \tilde{x}_{t+1:t+i-1})$$

3. **Acceptance/Rejection**: For each position $i$:
   - Accept $\tilde{x}_{t+i}$ with probability $\min(1, \frac{p_{\text{target}}(\tilde{x}_{t+i})}{p_{\text{draft}}(\tilde{x}_{t+i})})$
   - If rejected, resample from adjusted distribution and stop

### Key Insight
The algorithm maintains the exact same distribution as standard autoregressive sampling while potentially accepting multiple tokens per target model forward pass.

In [None]:
@dataclass
class SpeculativeDecodingConfig:
    """Configuration for speculative decoding."""
    lookahead_distance: int = 4  # Number of tokens to predict ahead
    temperature: float = 1.0     # Sampling temperature
    top_k: Optional[int] = None  # Top-k sampling
    top_p: Optional[float] = None # Top-p (nucleus) sampling

class SimpleLanguageModel(nn.Module):
    """Simple transformer-like model for demonstration."""
    
    def __init__(self, vocab_size: int, d_model: int, n_layers: int, n_heads: int, max_seq_len: int = 512):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.max_seq_len = max_seq_len
        
        # Embedding layers
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(max_seq_len, d_model)
        
        # Transformer layers
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_model * 4,
            dropout=0.1,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        
        # Output projection
        self.output_projection = nn.Linear(d_model, vocab_size)
        
        # Initialize weights
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        batch_size, seq_len = input_ids.shape
        
        # Create position ids
        position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(batch_size, -1)
        
        # Embeddings
        token_embeds = self.token_embedding(input_ids)
        pos_embeds = self.position_embedding(position_ids)
        hidden_states = token_embeds + pos_embeds
        
        # Create causal mask
        causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=input_ids.device), diagonal=1).bool()
        
        # Transformer
        hidden_states = self.transformer(hidden_states, mask=causal_mask)
        
        # Output projection
        logits = self.output_projection(hidden_states)
        
        return logits

# Create toy models for demonstration
vocab_size = 1000
draft_model = SimpleLanguageModel(vocab_size=vocab_size, d_model=128, n_layers=2, n_heads=4)
target_model = SimpleLanguageModel(vocab_size=vocab_size, d_model=256, n_layers=6, n_heads=8)

print(f"Draft model parameters: {sum(p.numel() for p in draft_model.parameters()):,}")
print(f"Target model parameters: {sum(p.numel() for p in target_model.parameters()):,}")
print(f"Target model is {sum(p.numel() for p in target_model.parameters()) / sum(p.numel() for p in draft_model.parameters()):.1f}x larger")

## Sampling Utilities

We need utilities for sampling from probability distributions with temperature scaling and top-k/top-p filtering.

In [None]:
def apply_temperature(logits: torch.Tensor, temperature: float) -> torch.Tensor:
    """Apply temperature scaling to logits."""
    return logits / temperature

def top_k_filtering(logits: torch.Tensor, top_k: int) -> torch.Tensor:
    """Apply top-k filtering to logits."""
    if top_k <= 0:
        return logits
    
    top_k = min(top_k, logits.size(-1))
    values, _ = torch.topk(logits, top_k, dim=-1)
    min_values = values[..., -1:]
    return torch.where(logits < min_values, torch.full_like(logits, float('-inf')), logits)

def top_p_filtering(logits: torch.Tensor, top_p: float) -> torch.Tensor:
    """Apply top-p (nucleus) filtering to logits."""
    if top_p >= 1.0:
        return logits
    
    sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
    
    # Remove tokens with cumulative probability above the threshold
    sorted_indices_to_remove = cumulative_probs > top_p
    # Shift the indices to the right to keep also the first token above the threshold
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0
    
    # Scatter sorted indices back to original indexing
    indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)
    logits = logits.masked_fill(indices_to_remove, float('-inf'))
    
    return logits

def sample_from_logits(logits: torch.Tensor, config: SpeculativeDecodingConfig) -> Tuple[torch.Tensor, torch.Tensor]:
    """Sample tokens from logits with temperature and filtering."""
    # Apply temperature
    logits = apply_temperature(logits, config.temperature)
    
    # Apply top-k filtering
    if config.top_k is not None:
        logits = top_k_filtering(logits, config.top_k)
    
    # Apply top-p filtering
    if config.top_p is not None:
        logits = top_p_filtering(logits, config.top_p)
    
    # Convert to probabilities
    probs = F.softmax(logits, dim=-1)
    
    # Sample
    next_token = torch.multinomial(probs, num_samples=1)
    
    return next_token, probs

print("Sampling utilities defined!")

## Speculative Decoding Implementation

Now let's implement the core speculative decoding algorithm. The key insight is that we can verify multiple draft tokens in parallel using the target model, and accept/reject them based on the probability ratio.

In [None]:
class SpeculativeDecoder:
    """Speculative decoding implementation."""
    
    def __init__(self, draft_model: nn.Module, target_model: nn.Module, config: SpeculativeDecodingConfig):
        self.draft_model = draft_model
        self.target_model = target_model
        self.config = config
        
        # Statistics tracking
        self.stats = {
            'total_draft_tokens': 0,
            'accepted_tokens': 0,
            'rejection_positions': [],
            'target_model_calls': 0
        }
    
    def reset_stats(self):
        """Reset statistics tracking."""
        self.stats = {
            'total_draft_tokens': 0,
            'accepted_tokens': 0,
            'rejection_positions': [],
            'target_model_calls': 0
        }
    
    def draft_tokens(self, input_ids: torch.Tensor) -> Tuple[List[int], List[torch.Tensor]]:
        """Generate draft tokens using the small model."""
        draft_tokens = []
        draft_probs = []
        
        current_input = input_ids.clone()
        
        with torch.no_grad():
            for _ in range(self.config.lookahead_distance):
                # Get logits from draft model
                logits = self.draft_model(current_input)
                next_token_logits = logits[:, -1, :]  # Last position
                
                # Sample next token
                next_token, probs = sample_from_logits(next_token_logits, self.config)
                
                draft_tokens.append(next_token.item())
                draft_probs.append(probs)
                
                # Append to input for next iteration
                current_input = torch.cat([current_input, next_token], dim=1)
        
        return draft_tokens, draft_probs
    
    def verify_tokens(self, input_ids: torch.Tensor, draft_tokens: List[int], draft_probs: List[torch.Tensor]) -> Tuple[List[int], int]:
        """Verify draft tokens using the target model."""
        # Create extended input with all draft tokens
        draft_tensor = torch.tensor(draft_tokens, device=input_ids.device).unsqueeze(0)
        extended_input = torch.cat([input_ids, draft_tensor], dim=1)
        
        # Get target model predictions for all positions
        with torch.no_grad():
            target_logits = self.target_model(extended_input)
            self.stats['target_model_calls'] += 1
        
        accepted_tokens = []
        
        # Verify each draft token
        for i, (draft_token, draft_prob) in enumerate(zip(draft_tokens, draft_probs)):
            # Get target probability for this position
            position_idx = input_ids.size(1) + i - 1  # -1 because we predict next token
            target_logits_pos = target_logits[:, position_idx, :]
            
            # Apply same sampling configuration
            target_logits_pos = apply_temperature(target_logits_pos, self.config.temperature)
            if self.config.top_k is not None:
                target_logits_pos = top_k_filtering(target_logits_pos, self.config.top_k)
            if self.config.top_p is not None:
                target_logits_pos = top_p_filtering(target_logits_pos, self.config.top_p)
            
            target_probs = F.softmax(target_logits_pos, dim=-1)
            
            # Calculate acceptance probability
            draft_prob_token = draft_prob[0, draft_token]
            target_prob_token = target_probs[0, draft_token]
            
            acceptance_prob = min(1.0, (target_prob_token / (draft_prob_token + 1e-10)).item())
            
            # Accept or reject
            if torch.rand(1).item() < acceptance_prob:
                accepted_tokens.append(draft_token)
                self.stats['accepted_tokens'] += 1
            else:
                # Rejection - resample from adjusted distribution
                adjusted_probs = torch.clamp(target_probs - draft_prob, min=0.0)
                adjusted_probs = adjusted_probs / (adjusted_probs.sum() + 1e-10)
                
                resampled_token = torch.multinomial(adjusted_probs, num_samples=1)
                accepted_tokens.append(resampled_token.item())
                self.stats['accepted_tokens'] += 1
                self.stats['rejection_positions'].append(i)
                break  # Stop at first rejection
            
            self.stats['total_draft_tokens'] += 1
        
        return accepted_tokens, len(accepted_tokens)
    
    def generate_step(self, input_ids: torch.Tensor) -> Tuple[torch.Tensor, int]:
        """Perform one speculative decoding step."""
        # Draft phase
        draft_tokens, draft_probs = self.draft_tokens(input_ids)
        
        # Verification phase
        accepted_tokens, num_accepted = self.verify_tokens(input_ids, draft_tokens, draft_probs)
        
        # Update input
        if accepted_tokens:
            accepted_tensor = torch.tensor(accepted_tokens, device=input_ids.device).unsqueeze(0)
            new_input = torch.cat([input_ids, accepted_tensor], dim=1)
        else:
            new_input = input_ids
        
        return new_input, num_accepted
    
    def generate(self, input_ids: torch.Tensor, max_new_tokens: int) -> torch.Tensor:
        """Generate text using speculative decoding."""
        current_input = input_ids.clone()
        generated_tokens = 0
        
        while generated_tokens < max_new_tokens:
            current_input, num_accepted = self.generate_step(current_input)
            generated_tokens += num_accepted
            
            if num_accepted == 0:
                break  # Fallback to standard generation if needed
        
        return current_input
    
    def get_acceptance_rate(self) -> float:
        """Calculate current acceptance rate."""
        if self.stats['total_draft_tokens'] == 0:
            return 0.0
        return self.stats['accepted_tokens'] / self.stats['total_draft_tokens']
    
    def get_throughput_multiplier(self) -> float:
        """Calculate throughput multiplier vs standard generation."""
        if self.stats['target_model_calls'] == 0:
            return 1.0
        return self.stats['accepted_tokens'] / self.stats['target_model_calls']

print("Speculative decoder implemented!")

## Standard Autoregressive Generation (Baseline)

Let's implement standard autoregressive generation for comparison.

In [None]:
class StandardGenerator:
    """Standard autoregressive generation for comparison."""
    
    def __init__(self, model: nn.Module, config: SpeculativeDecodingConfig):
        self.model = model
        self.config = config
        self.stats = {'model_calls': 0}
    
    def reset_stats(self):
        self.stats = {'model_calls': 0}
    
    def generate(self, input_ids: torch.Tensor, max_new_tokens: int) -> torch.Tensor:
        """Generate text using standard autoregressive sampling."""
        current_input = input_ids.clone()
        
        with torch.no_grad():
            for _ in range(max_new_tokens):
                # Get next token logits
                logits = self.model(current_input)
                next_token_logits = logits[:, -1, :]
                
                self.stats['model_calls'] += 1
                
                # Sample next token
                next_token, _ = sample_from_logits(next_token_logits, self.config)
                
                # Append to sequence
                current_input = torch.cat([current_input, next_token], dim=1)
        
        return current_input

print("Standard generator implemented!")

## Experimental Analysis

Now let's run experiments to analyze the performance of speculative decoding vs standard generation.

In [None]:
# Set models to evaluation mode
draft_model.eval()
target_model.eval()

# Experiment configuration
config = SpeculativeDecodingConfig(
    lookahead_distance=4,
    temperature=0.8,
    top_k=50
)

# Create generators
spec_decoder = SpeculativeDecoder(draft_model, target_model, config)
standard_generator = StandardGenerator(target_model, config)

# Test input
input_text = torch.randint(0, vocab_size, (1, 10))  # Random starting sequence
max_new_tokens = 50

print(f"Input shape: {input_text.shape}")
print(f"Generating {max_new_tokens} new tokens...")
print(f"Lookahead distance: {config.lookahead_distance}")

In [None]:
# Run speculative decoding
spec_decoder.reset_stats()
start_time = time.time()
spec_output = spec_decoder.generate(input_text, max_new_tokens)
spec_time = time.time() - start_time

print("Speculative Decoding Results:")
print(f"Generated sequence length: {spec_output.shape[1]}")
print(f"New tokens generated: {spec_output.shape[1] - input_text.shape[1]}")
print(f"Time taken: {spec_time:.3f} seconds")
print(f"Target model calls: {spec_decoder.stats['target_model_calls']}")
print(f"Acceptance rate: {spec_decoder.get_acceptance_rate():.3f}")
print(f"Throughput multiplier: {spec_decoder.get_throughput_multiplier():.2f}x")
print(f"Rejection positions: {spec_decoder.stats['rejection_positions'][:10]}...")  # First 10

In [None]:
# Run standard generation
standard_generator.reset_stats()
start_time = time.time()
standard_output = standard_generator.generate(input_text, max_new_tokens)
standard_time = time.time() - start_time

print("Standard Generation Results:")
print(f"Generated sequence length: {standard_output.shape[1]}")
print(f"New tokens generated: {standard_output.shape[1] - input_text.shape[1]}")
print(f"Time taken: {standard_time:.3f} seconds")
print(f"Target model calls: {standard_generator.stats['model_calls']}")
print(f"Speedup: {standard_time / spec_time:.2f}x")

## Analysis of Acceptance Rates

Let's analyze how acceptance rates vary with different parameters and model similarities.

In [None]:
def analyze_acceptance_rates(lookahead_distances: List[int], temperatures: List[float], num_trials: int = 5):
    """Analyze acceptance rates across different parameters."""
    results = []
    
    for lookahead in lookahead_distances:
        for temperature in temperatures:
            config = SpeculativeDecodingConfig(
                lookahead_distance=lookahead,
                temperature=temperature,
                top_k=50
            )
            
            decoder = SpeculativeDecoder(draft_model, target_model, config)
            
            acceptance_rates = []
            throughput_multipliers = []
            
            for trial in range(num_trials):
                decoder.reset_stats()
                test_input = torch.randint(0, vocab_size, (1, 8))
                decoder.generate(test_input, 30)
                
                acceptance_rates.append(decoder.get_acceptance_rate())
                throughput_multipliers.append(decoder.get_throughput_multiplier())
            
            results.append({
                'lookahead': lookahead,
                'temperature': temperature,
                'acceptance_rate_mean': np.mean(acceptance_rates),
                'acceptance_rate_std': np.std(acceptance_rates),
                'throughput_multiplier_mean': np.mean(throughput_multipliers),
                'throughput_multiplier_std': np.std(throughput_multipliers)
            })
    
    return results

# Run analysis
lookahead_distances = [2, 3, 4, 5, 6]
temperatures = [0.5, 0.8, 1.0, 1.2]

print("Running acceptance rate analysis...")
analysis_results = analyze_acceptance_rates(lookahead_distances, temperatures, num_trials=3)
print(f"Completed {len(analysis_results)} experiments")

In [None]:
# Create visualizations
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Convert results to arrays for plotting
import pandas as pd
df = pd.DataFrame(analysis_results)

# 1. Acceptance rate vs lookahead distance
for temp in temperatures:
    temp_data = df[df['temperature'] == temp]
    axes[0, 0].plot(temp_data['lookahead'], temp_data['acceptance_rate_mean'], 
                   marker='o', label=f'T={temp}')
    axes[0, 0].fill_between(temp_data['lookahead'], 
                           temp_data['acceptance_rate_mean'] - temp_data['acceptance_rate_std'],
                           temp_data['acceptance_rate_mean'] + temp_data['acceptance_rate_std'],
                           alpha=0.3)

axes[0, 0].set_xlabel('Lookahead Distance')
axes[0, 0].set_ylabel('Acceptance Rate')
axes[0, 0].set_title('Acceptance Rate vs Lookahead Distance')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# 2. Throughput multiplier vs lookahead distance
for temp in temperatures:
    temp_data = df[df['temperature'] == temp]
    axes[0, 1].plot(temp_data['lookahead'], temp_data['throughput_multiplier_mean'], 
                   marker='s', label=f'T={temp}')
    axes[0, 1].fill_between(temp_data['lookahead'], 
                           temp_data['throughput_multiplier_mean'] - temp_data['throughput_multiplier_std'],
                           temp_data['throughput_multiplier_mean'] + temp_data['throughput_multiplier_std'],
                           alpha=0.3)

axes[0, 1].set_xlabel('Lookahead Distance')
axes[0, 1].set_ylabel('Throughput Multiplier')
axes[0, 1].set_title('Throughput Multiplier vs Lookahead Distance')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# 3. Acceptance rate vs temperature
for lookahead in lookahead_distances:
    lookahead_data = df[df['lookahead'] == lookahead]
    axes[1, 0].plot(lookahead_data['temperature'], lookahead_data['acceptance_rate_mean'], 
                   marker='o', label=f'L={lookahead}')

axes[1, 0].set_xlabel('Temperature')
axes[1, 0].set_ylabel('Acceptance Rate')
axes[1, 0].set_title('Acceptance Rate vs Temperature')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# 4. Heatmap of acceptance rates
pivot_data = df.pivot(index='temperature', columns='lookahead', values='acceptance_rate_mean')
sns.heatmap(pivot_data, annot=True, fmt='.3f', cmap='viridis', ax=axes[1, 1])
axes[1, 1].set_title('Acceptance Rate Heatmap')
axes[1, 1].set_xlabel('Lookahead Distance')
axes[1, 1].set_ylabel('Temperature')

plt.tight_layout()
plt.show()

# Print summary statistics
print("\nSummary Statistics:")
print(f"Best acceptance rate: {df['acceptance_rate_mean'].max():.3f} (T={df.loc[df['acceptance_rate_mean'].idxmax(), 'temperature']}, L={df.loc[df['acceptance_rate_mean'].idxmax(), 'lookahead']})")
print(f"Best throughput multiplier: {df['throughput_multiplier_mean'].max():.2f}x (T={df.loc[df['throughput_multiplier_mean'].idxmax(), 'temperature']}, L={df.loc[df['throughput_multiplier_mean'].idxmax(), 'lookahead']})")
print(f"Mean acceptance rate: {df['acceptance_rate_mean'].mean():.3f} ± {df['acceptance_rate_mean'].std():.3f}")
print(f"Mean throughput multiplier: {df['throughput_multiplier_mean'].mean():.2f}x ± {df['throughput_multiplier_mean'].std():.2f}x")

## Analysis of Rejection Patterns

Let's analyze where rejections typically occur in the lookahead sequence.

In [None]:
def analyze_rejection_patterns(lookahead_distance: int = 4, num_samples: int = 100):
    """Analyze rejection patterns in speculative decoding."""
    config = SpeculativeDecodingConfig(
        lookahead_distance=lookahead_distance,
        temperature=1.0,
        top_k=50
    )
    
    decoder = SpeculativeDecoder(draft_model, target_model, config)
    decoder.reset_stats()
    
    # Run multiple generation steps
    for _ in range(num_samples):
        test_input = torch.randint(0, vocab_size, (1, 10))
        decoder.generate_step(test_input)
    
    return decoder.stats['rejection_positions']

# Analyze rejection patterns
rejection_positions = analyze_rejection_patterns(lookahead_distance=5, num_samples=200)

# Plot rejection position distribution
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Histogram of rejection positions
if rejection_positions:
    ax1.hist(rejection_positions, bins=range(6), alpha=0.7, edgecolor='black')
    ax1.set_xlabel('Rejection Position in Lookahead Sequence')
    ax1.set_ylabel('Frequency')
    ax1.set_title('Distribution of Rejection Positions')
    ax1.set_xticks(range(5))
    ax1.grid(True, alpha=0.3)
    
    # Calculate statistics
    mean_rejection_pos = np.mean(rejection_positions)
    median_rejection_pos = np.median(rejection_positions)
    
    ax1.axvline(mean_rejection_pos, color='red', linestyle='--', label=f'Mean: {mean_rejection_pos:.2f}')
    ax1.axvline(median_rejection_pos, color='orange', linestyle='--', label=f'Median: {median_rejection_pos:.2f}')
    ax1.legend()
    
    print(f"Rejection statistics (out of {len(rejection_positions)} rejections):")
    print(f"Mean rejection position: {mean_rejection_pos:.2f}")
    print(f"Median rejection position: {median_rejection_pos:.2f}")
    print(f"Most common rejection position: {max(set(rejection_positions), key=rejection_positions.count)}")
else:
    ax1.text(0.5, 0.5, 'No rejections occurred', transform=ax1.transAxes, 
             ha='center', va='center', fontsize=14)
    ax1.set_title('Distribution of Rejection Positions (No Rejections)')

# Cumulative acceptance probability by position
total_positions = len(rejection_positions) + 200 * 5  # Approximate total positions tried
position_counts = [rejection_positions.count(i) for i in range(5)]
cumulative_acceptance = []
cumulative_positions = 0

for i in range(5):
    cumulative_positions += 200  # Each generation step tries position i
    cumulative_rejections = sum(position_counts[:i+1])
    acceptance_rate = 1 - (cumulative_rejections / cumulative_positions)
    cumulative_acceptance.append(acceptance_rate)

ax2.plot(range(5), cumulative_acceptance, marker='o', linewidth=2, markersize=8)
ax2.set_xlabel('Position in Lookahead Sequence')
ax2.set_ylabel('Cumulative Acceptance Rate')
ax2.set_title('Acceptance Rate by Position')
ax2.set_xticks(range(5))
ax2.grid(True, alpha=0.3)
ax2.set_ylim([0, 1])

plt.tight_layout()
plt.show()

## Model Similarity Analysis

The effectiveness of speculative decoding depends heavily on how similar the draft and target models are. Let's create models with different levels of similarity and analyze the impact.

In [None]:
def create_similar_model(base_model: nn.Module, similarity_factor: float) -> nn.Module:
    """Create a model similar to the base model by copying and adding noise to weights."""
    similar_model = SimpleLanguageModel(
        vocab_size=base_model.vocab_size,
        d_model=base_model.d_model // 2,  # Smaller model
        n_layers=max(1, base_model.transformer.num_layers // 2),
        n_heads=max(1, base_model.transformer.layers[0].self_attn.num_heads // 2)
    )
    
    # Initialize with noise based on similarity factor
    with torch.no_grad():
        for param in similar_model.parameters():
            # Start with random initialization
            noise = torch.randn_like(param) * (1 - similarity_factor)
            # Add some structured similarity (simplified)
            base_influence = torch.randn_like(param) * similarity_factor * 0.1
            param.data = noise + base_influence
    
    return similar_model

def test_model_similarity_impact(similarity_factors: List[float]):
    """Test how model similarity affects speculative decoding performance."""
    results = []
    
    config = SpeculativeDecodingConfig(
        lookahead_distance=4,
        temperature=1.0,
        top_k=50
    )
    
    for similarity in similarity_factors:
        # Create draft model with specified similarity
        draft = create_similar_model(target_model, similarity)
        draft.eval()
        
        decoder = SpeculativeDecoder(draft, target_model, config)
        
        # Test multiple times
        acceptance_rates = []
        throughput_multipliers = []
        
        for _ in range(5):
            decoder.reset_stats()
            test_input = torch.randint(0, vocab_size, (1, 8))
            decoder.generate(test_input, 20)
            
            acceptance_rates.append(decoder.get_acceptance_rate())
            throughput_multipliers.append(decoder.get_throughput_multiplier())
        
        results.append({
            'similarity': similarity,
            'acceptance_rate': np.mean(acceptance_rates),
            'acceptance_rate_std': np.std(acceptance_rates),
            'throughput_multiplier': np.mean(throughput_multipliers),
            'throughput_multiplier_std': np.std(throughput_multipliers)
        })
        
        print(f"Similarity {similarity:.2f}: Acceptance rate = {np.mean(acceptance_rates):.3f} ± {np.std(acceptance_rates):.3f}")
    
    return results

# Test different similarity levels
similarity_factors = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
print("Testing impact of model similarity on speculative decoding...")
similarity_results = test_model_similarity_impact(similarity_factors)

In [None]:
# Plot similarity analysis results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

similarities = [r['similarity'] for r in similarity_results]
acceptance_rates = [r['acceptance_rate'] for r in similarity_results]
acceptance_stds = [r['acceptance_rate_std'] for r in similarity_results]
throughput_multipliers = [r['throughput_multiplier'] for r in similarity_results]
throughput_stds = [r['throughput_multiplier_std'] for r in similarity_results]

# Acceptance rate vs similarity
ax1.errorbar(similarities, acceptance_rates, yerr=acceptance_stds, 
             marker='o', linewidth=2, markersize=8, capsize=5)
ax1.set_xlabel('Model Similarity Factor')
ax1.set_ylabel('Acceptance Rate')
ax1.set_title('Acceptance Rate vs Model Similarity')
ax1.grid(True, alpha=0.3)
ax1.set_xlim([-0.05, 1.05])

# Throughput multiplier vs similarity
ax2.errorbar(similarities, throughput_multipliers, yerr=throughput_stds, 
             marker='s', linewidth=2, markersize=8, capsize=5, color='orange')
ax2.set_xlabel('Model Similarity Factor')
ax2.set_ylabel('Throughput Multiplier')
ax2.set_title('Throughput Multiplier vs Model Similarity')
ax2.grid(True, alpha=0.3)
ax2.set_xlim([-0.05, 1.05])

plt.tight_layout()
plt.show()

print("\nKey Insights:")
print(f"- Highest acceptance rate: {max(acceptance_rates):.3f} at similarity {similarities[acceptance_rates.index(max(acceptance_rates))]:.2f}")
print(f"- Highest throughput: {max(throughput_multipliers):.2f}x at similarity {similarities[throughput_multipliers.index(max(throughput_multipliers))]:.2f}")
print(f"- Acceptance rate range: {min(acceptance_rates):.3f} - {max(acceptance_rates):.3f}")
print(f"- Throughput range: {min(throughput_multipliers):.2f}x - {max(throughput_multipliers):.2f}x")

## Trade-offs and Practical Considerations

Let's analyze the key trade-offs in speculative decoding implementation.

In [None]:
def analyze_tradeoffs():
    """Analyze key trade-offs in speculative decoding."""
    
    print("=" * 60)
    print("SPECULATIVE DECODING TRADE-OFFS ANALYSIS")
    print("=" * 60)
    
    print("\n1. COMPUTATIONAL OVERHEAD:")
    print("   Pros:")
    print("   + Reduced target model forward passes")
    print("   + Parallel verification of multiple tokens")
    print("   + Potential for significant speedup (1.5-3x typical)")
    print("   \n   Cons:")
    print("   - Draft model overhead (additional compute)")
    print("   - Memory overhead for storing draft sequences")
    print("   - Implementation complexity")
    
    print("\n2. QUALITY GUARANTEES:")
    print("   Pros:")
    print("   + Mathematically equivalent to standard sampling")
    print("   + No quality degradation when properly implemented")
    print("   + Same distribution as target model")
    print("   \n   Cons:")
    print("   - Requires careful implementation of rejection sampling")
    print("   - Numerical precision considerations")
    
    print("\n3. PERFORMANCE FACTORS:")
    print("   Critical Dependencies:")
    print("   - Draft model quality (higher similarity = better performance)")
    print("   - Lookahead distance (optimal range: 3-6 tokens)")
    print("   - Temperature settings (lower temp often better)")
    print("   - Hardware characteristics (memory bandwidth, compute ratio)")
    
    print("\n4. PRACTICAL DEPLOYMENT:")
    print("   Considerations:")
    print("   - Draft model training/distillation requirements")
    print("   - Memory usage (storing multiple model states)")
    print("   - Batching strategies (more complex with variable acceptance)")
    print("   - Hardware optimization opportunities")
    
    print("\n5. WHEN TO USE:")
    print("   Recommended:")
    print("   - Latency-critical applications")
    print("   - Large target models with available smaller variants")
    print("   - Interactive applications (chatbots, code completion)")
    print("   \n   Not Recommended:")
    print("   - Very small target models (overhead may dominate)")
    print("   - When draft model quality is poor")
    print("   - Memory-constrained environments")

analyze_tradeoffs()

## Summary and Conclusions

Based on our implementation and analysis, here are the key findings about speculative decoding:

In [None]:
def generate_summary_report():
    """Generate a comprehensive summary of findings."""
    
    print("=" * 70)
    print("SPECULATIVE DECODING: SUMMARY REPORT")
    print("=" * 70)
    
    print("\n🎯 KEY ALGORITHM INSIGHTS:")
    print("\n1. Probabilistic Acceptance: The algorithm maintains exact equivalence")
    print("   to standard autoregressive sampling through careful rejection sampling.")
    
    print("\n2. Parallel Verification: Multiple draft tokens can be verified in a")
    print("   single target model forward pass, enabling significant speedups.")
    
    print("\n3. Graceful Degradation: Even with poor draft models, the algorithm")
    print("   maintains correctness and often still provides some speedup.")
    
    print("\n📊 PERFORMANCE CHARACTERISTICS:")
    print("\n• Typical speedups: 1.5-3x for well-matched draft/target pairs")
    print("• Acceptance rates: 30-80% depending on model similarity")
    print("• Memory overhead: ~20-50% increase for storing draft states")
    print("• Optimal lookahead: 3-6 tokens for most scenarios")
    
    print("\n🔧 IMPLEMENTATION BEST PRACTICES:")
    print("\n1. Draft Model Selection:")
    print("   - Use models trained on similar data as target")
    print("   - 4-10x smaller parameter count typically optimal")
    print("   - Consider knowledge distillation for better alignment")
    
    print("\n2. Hyperparameter Tuning:")
    print("   - Lower temperatures generally improve acceptance rates")
    print("   - Lookahead distance should be tuned based on hardware")
    print("   - Top-k/top-p filtering affects both models equally")
    
    print("\n3. System Optimization:")
    print("   - Batch processing requires careful sequence alignment")
    print("   - Memory-efficient implementations crucial for deployment")
    print("   - Hardware-specific optimizations can provide additional gains")
    
    print("\n🚀 FUTURE DIRECTIONS:")
    print("\n• Adaptive lookahead based on real-time acceptance rates")
    print("• Multi-level speculative decoding with cascaded draft models")
    print("• Integration with other acceleration techniques (quantization, etc.)")
    print("• Specialized hardware designs for speculative execution")
    
    print("\n✅ CONCLUSION:")
    print("\nSpeculative decoding represents a clever probabilistic approach to")
    print("accelerating autoregressive generation. When properly implemented with")
    print("well-matched draft models, it can provide substantial speedups while")
    print("maintaining perfect output quality. The technique is particularly")
    print("valuable for interactive applications where latency is critical.")
    
    print("\n" + "=" * 70)

generate_summary_report()

## Exercise: Implementing Improvements

Try implementing these extensions to deepen your understanding:

1. **Adaptive Lookahead**: Modify the decoder to adjust lookahead distance based on recent acceptance rates

2. **Batch Processing**: Extend the implementation to handle multiple sequences simultaneously

3. **Multiple Draft Models**: Implement a version that uses multiple draft models of different sizes

4. **Performance Profiling**: Add detailed timing and memory usage tracking

5. **Real Model Integration**: Adapt the code to work with actual pre-trained models (e.g., GPT-2 variants)

6. **Tree-based Speculation**: Implement tree-based speculative decoding for handling multiple possible futures

Each of these extensions would provide deeper insights into the practical challenges and opportunities in deploying speculative decoding in production systems.