# Notebook 15: Speculative Decoding (Draft-Target)

---

## Inference Engineering Course

Welcome to Notebook 15! Here we implement **Speculative Decoding**, one of the most important inference optimization techniques for large language models.

### What You Will Learn

| Topic | Description |
|-------|-------------|
| **Speculative Decoding** | The core algorithm: draft model proposes, target model verifies |
| **Draft Model** | Small, fast model that generates candidate tokens |
| **Verification** | Target model checks all drafts in a single forward pass |
| **Acceptance Rate** | How often draft tokens are accepted by the target |
| **Speedup Analysis** | When and why speculative decoding helps |

### The Core Insight

> A small draft model generates K candidate tokens cheaply, then the large target model verifies all K tokens in a **single forward pass** (same cost as generating 1 token). If most tokens are accepted, we get up to Kx speedup!

### Key Property: Lossless Quality

Unlike pruning or quantization, speculative decoding produces **the exact same output distribution** as the target model alone. The draft model only affects speed, never quality.

---

## Part 1: Setup & Installations

We will first build understanding with simulations, then (optionally) run with real models using HuggingFace Transformers.

In [None]:
%%capture
!pip install transformers accelerate torch matplotlib numpy

In [None]:
import time
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from typing import List, Dict, Tuple, Optional
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

plt.style.use('default')
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 12
plt.rcParams['axes.grid'] = True
plt.rcParams['grid.alpha'] = 0.3

print("Imports complete!")

## Part 2: The Speculative Decoding Algorithm

### How It Works

```
Input: Target model M_target, Draft model M_draft, speculation length K

while not done:
    # Step 1: Draft phase (cheap)
    for i in 1..K:
        draft_token[i], draft_prob[i] = M_draft.sample(context)
        context = context + draft_token[i]
    
    # Step 2: Verification phase (one forward pass)
    target_probs[1..K+1] = M_target.forward(context_with_drafts)
    
    # Step 3: Accept/Reject
    for i in 1..K:
        if random() < min(1, target_prob[i] / draft_prob[i]):
            ACCEPT token[i]
        else:
            REJECT token[i] and all subsequent
            Sample correction token from adjusted distribution
            break
```

### Why It Works

The acceptance criterion `min(1, p_target / p_draft)` ensures:
- Tokens where target agrees with draft (high p_target, high p_draft) are **always accepted**
- Tokens where draft is overconfident (low p_target, high p_draft) are **often rejected**
- The final distribution **exactly matches** what the target model would produce alone

In [None]:
class SpeculativeDecodingSimulator:
    """
    Simulates speculative decoding with configurable parameters.
    
    This uses probability distributions rather than real models,
    allowing us to study the algorithm's behavior in detail.
    """
    
    def __init__(
        self,
        vocab_size: int = 100,
        draft_latency_ms: float = 5.0,    # Time for one draft token
        target_latency_ms: float = 50.0,   # Time for one target forward pass
        agreement_level: float = 0.7,       # How similar draft & target are
    ):
        self.vocab_size = vocab_size
        self.draft_latency = draft_latency_ms / 1000.0
        self.target_latency = target_latency_ms / 1000.0
        self.agreement_level = agreement_level
        
        # Generate base distributions for the "language"
        # Using Zipf's law to simulate natural language token frequencies
        self.base_dist = np.array([1.0 / (i + 1) for i in range(vocab_size)])
        self.base_dist /= self.base_dist.sum()
    
    def _get_target_distribution(self, position: int) -> np.ndarray:
        """Get the target model's distribution at a given position."""
        np.random.seed(position * 7 + 13)
        noise = np.random.dirichlet(np.ones(self.vocab_size) * 0.5)
        dist = 0.7 * self.base_dist + 0.3 * noise
        return dist / dist.sum()
    
    def _get_draft_distribution(self, position: int) -> np.ndarray:
        """Get the draft model's distribution (correlated with target)."""
        target_dist = self._get_target_distribution(position)
        np.random.seed(position * 11 + 29)
        noise = np.random.dirichlet(np.ones(self.vocab_size) * 0.5)
        
        # Mix target distribution with noise (agreement_level controls similarity)
        draft_dist = self.agreement_level * target_dist + (1 - self.agreement_level) * noise
        return draft_dist / draft_dist.sum()
    
    def standard_decoding(self, num_tokens: int) -> Dict:
        """Standard autoregressive decoding."""
        start_time = time.time()
        tokens = []
        forward_passes = 0
        
        for pos in range(num_tokens):
            time.sleep(self.target_latency)
            target_dist = self._get_target_distribution(pos)
            token = np.random.choice(self.vocab_size, p=target_dist)
            tokens.append(token)
            forward_passes += 1
        
        elapsed = time.time() - start_time
        return {
            'method': 'Standard',
            'tokens': len(tokens),
            'forward_passes': forward_passes,
            'time_s': round(elapsed, 3),
            'tok_per_s': round(len(tokens) / elapsed, 1),
        }
    
    def speculative_decoding(self, num_tokens: int, K: int = 4) -> Dict:
        """
        Speculative decoding with draft-target verification.
        
        Args:
            num_tokens: Number of tokens to generate
            K: Number of draft tokens per speculation round
        """
        start_time = time.time()
        tokens = []
        draft_forward_passes = 0
        target_forward_passes = 0
        total_drafted = 0
        total_accepted = 0
        acceptance_history = []  # Track per-round acceptance
        
        pos = 0
        while len(tokens) < num_tokens:
            # === DRAFT PHASE ===
            draft_tokens = []
            draft_probs = []
            
            for i in range(K):
                time.sleep(self.draft_latency)  # Draft is fast
                draft_dist = self._get_draft_distribution(pos + i)
                token = np.random.choice(self.vocab_size, p=draft_dist)
                draft_tokens.append(token)
                draft_probs.append(draft_dist[token])
                draft_forward_passes += 1
            
            # === VERIFICATION PHASE (single forward pass) ===
            time.sleep(self.target_latency)  # One target forward pass
            target_forward_passes += 1
            
            accepted_this_round = 0
            for i in range(K):
                target_dist = self._get_target_distribution(pos + i)
                target_prob = target_dist[draft_tokens[i]]
                draft_prob = draft_probs[i]
                
                # Acceptance criterion
                acceptance_ratio = min(1.0, target_prob / max(draft_prob, 1e-10))
                
                if np.random.random() < acceptance_ratio:
                    # ACCEPT
                    tokens.append(draft_tokens[i])
                    accepted_this_round += 1
                    total_accepted += 1
                else:
                    # REJECT - sample from corrected distribution
                    correction = np.maximum(target_dist - self._get_draft_distribution(pos + i), 0)
                    if correction.sum() > 0:
                        correction /= correction.sum()
                        corrected_token = np.random.choice(self.vocab_size, p=correction)
                    else:
                        corrected_token = np.random.choice(self.vocab_size, p=target_dist)
                    tokens.append(corrected_token)
                    accepted_this_round += 0  # Don't count correction as acceptance
                    break
            
            # If all K were accepted, sample one more from target
            if accepted_this_round == K and len(tokens) < num_tokens:
                target_dist = self._get_target_distribution(pos + K)
                bonus_token = np.random.choice(self.vocab_size, p=target_dist)
                tokens.append(bonus_token)
            
            total_drafted += K
            acceptance_history.append(accepted_this_round / K)
            pos += accepted_this_round + 1
        
        elapsed = time.time() - start_time
        tokens = tokens[:num_tokens]  # Trim to exact count
        
        return {
            'method': f'Speculative (K={K})',
            'tokens': len(tokens),
            'draft_forward_passes': draft_forward_passes,
            'target_forward_passes': target_forward_passes,
            'total_forward_passes': draft_forward_passes + target_forward_passes,
            'time_s': round(elapsed, 3),
            'tok_per_s': round(len(tokens) / elapsed, 1),
            'total_drafted': total_drafted,
            'total_accepted': total_accepted,
            'acceptance_rate': round(total_accepted / total_drafted * 100, 1) if total_drafted > 0 else 0,
            'acceptance_history': acceptance_history,
        }

print("SpeculativeDecodingSimulator defined!")

In [None]:
# Run a basic comparison
sim = SpeculativeDecodingSimulator(
    vocab_size=100,
    draft_latency_ms=5.0,
    target_latency_ms=40.0,
    agreement_level=0.75,
)

NUM_TOKENS = 30
print(f"Generating {NUM_TOKENS} tokens...")
print("=" * 70)

# Standard decoding
r_std = sim.standard_decoding(NUM_TOKENS)
print(f"\nSTANDARD DECODING:")
print(f"  Time: {r_std['time_s']}s")
print(f"  Forward passes: {r_std['forward_passes']}")
print(f"  Throughput: {r_std['tok_per_s']} tok/s")

# Speculative decoding with different K values
for K in [2, 4, 6]:
    r_spec = sim.speculative_decoding(NUM_TOKENS, K=K)
    speedup = r_std['time_s'] / r_spec['time_s'] if r_spec['time_s'] > 0 else 0
    print(f"\nSPECULATIVE (K={K}):")
    print(f"  Time: {r_spec['time_s']}s")
    print(f"  Target FW passes: {r_spec['target_forward_passes']}")
    print(f"  Draft FW passes: {r_spec['draft_forward_passes']}")
    print(f"  Acceptance rate: {r_spec['acceptance_rate']}%")
    print(f"  Throughput: {r_spec['tok_per_s']} tok/s")
    print(f"  SPEEDUP: {speedup:.2f}x")

## Part 3: Understanding the Acceptance Criterion

The acceptance probability is:

$$P(\text{accept}) = \min\left(1, \frac{p_{\text{target}}(x)}{p_{\text{draft}}(x)}\right)$$

This means:
- If `p_target >= p_draft`: **always accept** (target agrees or likes the token even more)
- If `p_target < p_draft`: accept with probability `p_target / p_draft` (draft is overconfident)

When a token is rejected, we sample from the **residual distribution**:

$$p_{\text{correction}}(x) \propto \max(0, p_{\text{target}}(x) - p_{\text{draft}}(x))$$

This ensures the combined process produces the exact target distribution.

In [None]:
# Visualize the acceptance criterion
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

vocab = ['the', 'cat', 'sat', 'on', 'mat', 'dog', 'ran', 'big', 'red', 'and']

# Scenario 1: High agreement (good draft model)
target_probs_1 = np.array([0.25, 0.20, 0.15, 0.12, 0.10, 0.06, 0.05, 0.03, 0.02, 0.02])
draft_probs_1 = np.array([0.22, 0.18, 0.16, 0.11, 0.11, 0.07, 0.06, 0.04, 0.03, 0.02])

# Scenario 2: Medium agreement
target_probs_2 = np.array([0.25, 0.20, 0.15, 0.12, 0.10, 0.06, 0.05, 0.03, 0.02, 0.02])
draft_probs_2 = np.array([0.10, 0.10, 0.10, 0.10, 0.10, 0.10, 0.10, 0.10, 0.10, 0.10])

# Scenario 3: Poor agreement (bad draft model)
target_probs_3 = np.array([0.25, 0.20, 0.15, 0.12, 0.10, 0.06, 0.05, 0.03, 0.02, 0.02])
draft_probs_3 = np.array([0.02, 0.03, 0.05, 0.07, 0.10, 0.13, 0.15, 0.17, 0.14, 0.14])

scenarios = [
    (target_probs_1, draft_probs_1, 'High Agreement\n(Good Draft Model)'),
    (target_probs_2, draft_probs_2, 'Medium Agreement\n(Uniform Draft)'),
    (target_probs_3, draft_probs_3, 'Low Agreement\n(Mismatched Draft)'),
]

for ax, (t_probs, d_probs, title) in zip(axes, scenarios):
    x = np.arange(len(vocab))
    width = 0.35
    
    ax.bar(x - width/2, t_probs, width, label='Target', color='#2196F3', alpha=0.8)
    ax.bar(x + width/2, d_probs, width, label='Draft', color='#FF9800', alpha=0.8)
    
    # Calculate acceptance probability for the most likely draft token
    draft_choice = np.argmax(d_probs)
    accept_prob = min(1.0, t_probs[draft_choice] / d_probs[draft_choice])
    
    # Calculate expected acceptance rate
    expected_accept = sum(d_probs[i] * min(1.0, t_probs[i] / d_probs[i]) for i in range(len(vocab)))
    
    ax.set_title(f'{title}\nE[accept] = {expected_accept:.1%}', fontsize=12, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels(vocab, rotation=45, fontsize=9)
    ax.set_ylabel('Probability', fontsize=11)
    ax.legend(fontsize=10)
    ax.set_ylim(0, 0.35)

plt.suptitle('Draft vs Target Distributions: Impact on Acceptance Rate',
             fontsize=14, fontweight='bold', y=1.05)
plt.tight_layout()
plt.show()

## Part 4: Speedup vs Acceptance Rate Tradeoff

The speedup of speculative decoding depends on:

1. **Acceptance rate (alpha)**: Higher is better
2. **Speculation length (K)**: More tokens per round
3. **Draft-to-target cost ratio (c)**: `cost_draft / cost_target`

### Theoretical Speedup Formula

For acceptance rate `alpha` and speculation length `K`:

$$\text{Speedup} = \frac{\text{Expected tokens per round}}{\text{Cost per round}} = \frac{1 - \alpha^{K+1}}{(1 - \alpha)(K \cdot c + 1)}$$

where `c` is the ratio of draft cost to target cost.

In [None]:
def theoretical_speedup(alpha: float, K: int, cost_ratio: float) -> float:
    """
    Calculate theoretical speedup of speculative decoding.
    
    Args:
        alpha: Acceptance rate (0 to 1)
        K: Number of draft tokens per round
        cost_ratio: draft_cost / target_cost (typically 0.05-0.2)
    """
    if alpha >= 0.9999:
        return (K + 1) / (K * cost_ratio + 1)
    
    expected_tokens = (1 - alpha**(K + 1)) / (1 - alpha)
    cost_per_round = K * cost_ratio + 1  # K draft passes + 1 target pass
    return expected_tokens / cost_per_round

# Plot speedup surface
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

alphas = np.linspace(0.01, 0.99, 100)

# Plot 1: Speedup vs alpha for different K
ax = axes[0]
cost_ratio = 0.1
for K in [1, 2, 4, 6, 8]:
    speedups = [theoretical_speedup(a, K, cost_ratio) for a in alphas]
    ax.plot(alphas * 100, speedups, linewidth=2.5, label=f'K={K}')

ax.axhline(y=1, color='gray', linestyle='--', alpha=0.5, label='Baseline (1x)')
ax.set_xlabel('Acceptance Rate (%)', fontsize=12)
ax.set_ylabel('Speedup (x)', fontsize=12)
ax.set_title(f'Speedup vs Acceptance Rate\n(cost ratio = {cost_ratio})', fontsize=13, fontweight='bold')
ax.legend(fontsize=10)
ax.set_ylim(0, 8)

# Plot 2: Speedup vs K for different acceptance rates
ax = axes[1]
K_values = range(1, 13)
cost_ratio = 0.1
for alpha in [0.3, 0.5, 0.7, 0.85, 0.95]:
    speedups = [theoretical_speedup(alpha, K, cost_ratio) for K in K_values]
    ax.plot(list(K_values), speedups, 'o-', linewidth=2, markersize=6, 
            label=f'alpha={alpha:.0%}')

ax.axhline(y=1, color='gray', linestyle='--', alpha=0.5)
ax.set_xlabel('Speculation Length (K)', fontsize=12)
ax.set_ylabel('Speedup (x)', fontsize=12)
ax.set_title(f'Speedup vs K\n(cost ratio = {cost_ratio})', fontsize=13, fontweight='bold')
ax.legend(fontsize=10)

# Plot 3: Speedup vs cost ratio
ax = axes[2]
cost_ratios = np.linspace(0.01, 0.5, 100)
K = 4
for alpha in [0.5, 0.7, 0.85, 0.95]:
    speedups = [theoretical_speedup(alpha, K, c) for c in cost_ratios]
    ax.plot(cost_ratios, speedups, linewidth=2.5, label=f'alpha={alpha:.0%}')

ax.axhline(y=1, color='gray', linestyle='--', alpha=0.5)
ax.set_xlabel('Cost Ratio (draft/target)', fontsize=12)
ax.set_ylabel('Speedup (x)', fontsize=12)
ax.set_title(f'Speedup vs Cost Ratio\n(K={K})', fontsize=13, fontweight='bold')
ax.legend(fontsize=10)

plt.tight_layout()
plt.show()

print("Key insights:")
print("1. Higher acceptance rate -> higher speedup (obvious but important)")
print("2. Optimal K depends on acceptance rate (higher alpha allows larger K)")
print("3. Smaller draft model (lower cost ratio) enables more speculation")

## Part 5: Detailed Simulation -- Step by Step

Let's trace through speculative decoding step by step to see exactly what happens at each round.

In [None]:
def detailed_speculative_trace(num_rounds: int = 8, K: int = 4, 
                                agreement: float = 0.75):
    """
    Run speculative decoding with detailed per-round tracing.
    """
    np.random.seed(42)
    vocab_size = 50
    
    rounds = []
    total_tokens = 0
    total_target_passes = 0
    
    for round_idx in range(num_rounds):
        # Generate target and draft distributions for each position
        accepted_in_round = 0
        draft_tokens = []
        statuses = []  # 'accepted', 'rejected', 'correction'
        
        for i in range(K):
            pos = total_tokens + i
            
            # Create correlated distributions
            base = np.random.dirichlet(np.ones(vocab_size) * 0.3)
            target_dist = base.copy()
            noise = np.random.dirichlet(np.ones(vocab_size) * 0.5)
            draft_dist = agreement * target_dist + (1 - agreement) * noise
            draft_dist /= draft_dist.sum()
            
            # Draft samples
            token = np.random.choice(vocab_size, p=draft_dist)
            draft_tokens.append(token)
            
            # Acceptance check
            accept_prob = min(1.0, target_dist[token] / max(draft_dist[token], 1e-10))
            
            if np.random.random() < accept_prob:
                statuses.append('accepted')
                accepted_in_round += 1
            else:
                statuses.append('rejected')
                # All subsequent are not evaluated
                for j in range(i + 1, K):
                    statuses.append('skipped')
                    draft_tokens.append(-1)
                statuses.append('correction')  # Correction token
                break
        
        if accepted_in_round == K:
            statuses.append('bonus')  # Bonus token from target
        
        tokens_this_round = accepted_in_round + 1  # +1 for correction or bonus
        total_tokens += tokens_this_round
        total_target_passes += 1
        
        rounds.append({
            'round': round_idx + 1,
            'drafted': K,
            'accepted': accepted_in_round,
            'tokens_generated': tokens_this_round,
            'statuses': statuses,
        })
    
    return rounds, total_tokens, total_target_passes

# Run trace
rounds, total_tok, total_passes = detailed_speculative_trace(num_rounds=10, K=4)

print("=" * 80)
print(f"{'Round':>6} {'Drafted':>8} {'Accepted':>9} {'Generated':>10}  Status Sequence")
print("=" * 80)

for r in rounds:
    status_str = ' '.join([s[0].upper() for s in r['statuses']])
    print(f"{r['round']:>6} {r['drafted']:>8} {r['accepted']:>9} {r['tokens_generated']:>10}  [{status_str}]")

print("=" * 80)
print(f"Total tokens: {total_tok}")
print(f"Total target forward passes: {total_passes}")
print(f"Effective tokens per target pass: {total_tok / total_passes:.1f}")
print(f"\nLegend: A=Accepted, R=Rejected, S=Skipped, C=Correction, B=Bonus")

In [None]:
# Visualize the speculation trace
fig, ax = plt.subplots(figsize=(16, 6))

color_map = {
    'accepted': '#4CAF50',
    'rejected': '#F44336',
    'skipped': '#BDBDBD',
    'correction': '#FF9800',
    'bonus': '#2196F3',
}

token_pos = 0
for r in rounds:
    round_idx = r['round'] - 1
    for i, status in enumerate(r['statuses']):
        color = color_map.get(status, '#BDBDBD')
        marker = 's' if status == 'accepted' else ('x' if status in ('rejected', 'skipped') else 'D')
        size = 150 if status != 'skipped' else 80
        ax.scatter(token_pos + i, round_idx, c=color, s=size, marker=marker,
                  edgecolors='black' if status != 'skipped' else 'gray',
                  linewidths=1.5, zorder=5)
    token_pos += r['tokens_generated']

# Legend
legend_elements = [
    mpatches.Patch(color='#4CAF50', label='Accepted (free!)'),
    mpatches.Patch(color='#F44336', label='Rejected'),
    mpatches.Patch(color='#BDBDBD', label='Skipped'),
    mpatches.Patch(color='#FF9800', label='Correction token'),
    mpatches.Patch(color='#2196F3', label='Bonus token'),
]
ax.legend(handles=legend_elements, loc='upper right', fontsize=10)

ax.set_xlabel('Token Position in Sequence', fontsize=12)
ax.set_ylabel('Speculation Round', fontsize=12)
ax.set_title('Speculative Decoding Trace: Token Status per Round',
             fontsize=14, fontweight='bold')
ax.invert_yaxis()
ax.set_yticks(range(len(rounds)))
ax.set_yticklabels([f'Round {r["round"]}' for r in rounds])

plt.tight_layout()
plt.show()

## Part 6: Comprehensive Benchmarking

Let's benchmark speculative decoding across different configurations to find optimal settings.

In [None]:
# Benchmark across different agreement levels and K values
agreement_levels = [0.3, 0.5, 0.7, 0.85, 0.95]
K_values = [1, 2, 3, 4, 5, 6, 8]
NUM_TOKENS = 20

results_grid = np.zeros((len(agreement_levels), len(K_values)))
acceptance_grid = np.zeros((len(agreement_levels), len(K_values)))

print("Running benchmark (this takes ~1-2 minutes)...")
for i, agreement in enumerate(agreement_levels):
    sim = SpeculativeDecodingSimulator(
        vocab_size=100,
        draft_latency_ms=3.0,
        target_latency_ms=30.0,
        agreement_level=agreement,
    )
    
    # Baseline
    r_std = sim.standard_decoding(NUM_TOKENS)
    baseline_time = r_std['time_s']
    
    for j, K in enumerate(K_values):
        r_spec = sim.speculative_decoding(NUM_TOKENS, K=K)
        speedup = baseline_time / r_spec['time_s'] if r_spec['time_s'] > 0 else 1.0
        results_grid[i, j] = speedup
        acceptance_grid[i, j] = r_spec['acceptance_rate']
    
    print(f"  Agreement {agreement:.0%} done.")

print("Benchmark complete!")

In [None]:
# Heatmap of speedups
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Speedup heatmap
ax = axes[0]
im = ax.imshow(results_grid, cmap='RdYlGn', aspect='auto', vmin=0.5, vmax=4.0)
ax.set_xticks(range(len(K_values)))
ax.set_xticklabels([str(k) for k in K_values])
ax.set_yticks(range(len(agreement_levels)))
ax.set_yticklabels([f'{a:.0%}' for a in agreement_levels])
ax.set_xlabel('Speculation Length (K)', fontsize=12)
ax.set_ylabel('Draft-Target Agreement', fontsize=12)
ax.set_title('Speedup (x)', fontsize=14, fontweight='bold')

# Add text annotations
for i in range(len(agreement_levels)):
    for j in range(len(K_values)):
        text = f'{results_grid[i, j]:.1f}x'
        color = 'white' if results_grid[i, j] > 2.5 or results_grid[i, j] < 1.0 else 'black'
        ax.text(j, i, text, ha='center', va='center', fontsize=9, 
                fontweight='bold', color=color)

plt.colorbar(im, ax=ax, shrink=0.8)

# Acceptance rate heatmap
ax = axes[1]
im2 = ax.imshow(acceptance_grid, cmap='Blues', aspect='auto', vmin=0, vmax=100)
ax.set_xticks(range(len(K_values)))
ax.set_xticklabels([str(k) for k in K_values])
ax.set_yticks(range(len(agreement_levels)))
ax.set_yticklabels([f'{a:.0%}' for a in agreement_levels])
ax.set_xlabel('Speculation Length (K)', fontsize=12)
ax.set_ylabel('Draft-Target Agreement', fontsize=12)
ax.set_title('Acceptance Rate (%)', fontsize=14, fontweight='bold')

for i in range(len(agreement_levels)):
    for j in range(len(K_values)):
        text = f'{acceptance_grid[i, j]:.0f}%'
        color = 'white' if acceptance_grid[i, j] > 60 else 'black'
        ax.text(j, i, text, ha='center', va='center', fontsize=9, 
                fontweight='bold', color=color)

plt.colorbar(im2, ax=ax, shrink=0.8)

plt.suptitle('Speculative Decoding: Speedup & Acceptance Rate Grid',
             fontsize=15, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

## Part 7: Optimal K Selection

A critical question: **how many draft tokens (K) should we generate per round?**

- Too few: not enough speculation benefit
- Too many: wasting draft compute on tokens that will be rejected

The optimal K depends on the acceptance rate and the cost ratio.

In [None]:
# Find optimal K for different scenarios
fig, ax = plt.subplots(figsize=(12, 7))

cost_ratios = [0.05, 0.1, 0.15, 0.2]
alpha_range = np.linspace(0.1, 0.98, 50)
K_range = range(1, 16)

for c_idx, cost_ratio in enumerate(cost_ratios):
    optimal_Ks = []
    optimal_speedups = []
    
    for alpha in alpha_range:
        best_K = 1
        best_speedup = 0
        for K in K_range:
            sp = theoretical_speedup(alpha, K, cost_ratio)
            if sp > best_speedup:
                best_speedup = sp
                best_K = K
        optimal_Ks.append(best_K)
        optimal_speedups.append(best_speedup)
    
    ax.plot(alpha_range * 100, optimal_Ks, 'o-', markersize=4, linewidth=2,
            label=f'cost ratio = {cost_ratio}')

ax.set_xlabel('Acceptance Rate (%)', fontsize=12)
ax.set_ylabel('Optimal K', fontsize=12)
ax.set_title('Optimal Speculation Length (K) vs Acceptance Rate',
             fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.set_ylim(0, 16)

plt.tight_layout()
plt.show()

print("Key insight: As acceptance rate increases, it becomes profitable")
print("to speculate more tokens per round. A cheaper draft model (lower")
print("cost ratio) also allows more aggressive speculation.")

## Part 8: Real Model Speculative Decoding (Optional)

If you have a GPU available on Colab (Runtime -> Change runtime type -> T4 GPU), you can run speculative decoding with real HuggingFace models.

We use GPT-2 small as the draft model and GPT-2 medium as the target model.

In [None]:
# Check GPU availability
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
else:
    print("No GPU available. The real model section will be skipped.")
    print("To enable GPU: Runtime -> Change runtime type -> T4 GPU")

In [None]:
# Real model speculative decoding (only runs if GPU available)
if device == 'cuda':
    from transformers import AutoModelForCausalLM, AutoTokenizer
    import torch.nn.functional as F
    
    print("Loading models...")
    # Draft model: GPT-2 small (117M params)
    draft_tokenizer = AutoTokenizer.from_pretrained('gpt2')
    draft_model = AutoModelForCausalLM.from_pretrained('gpt2').to(device)
    draft_model.eval()
    
    # Target model: GPT-2 medium (345M params)
    target_tokenizer = AutoTokenizer.from_pretrained('gpt2-medium')
    target_model = AutoModelForCausalLM.from_pretrained('gpt2-medium').to(device)
    target_model.eval()
    
    print(f"Draft model: GPT-2 ({sum(p.numel() for p in draft_model.parameters()) / 1e6:.0f}M params)")
    print(f"Target model: GPT-2 Medium ({sum(p.numel() for p in target_model.parameters()) / 1e6:.0f}M params)")
    print("Models loaded!")
else:
    print("Skipping real model loading (no GPU).")
    print("The simulation results above demonstrate the same concepts.")

In [None]:
if device == 'cuda':
    @torch.no_grad()
    def real_standard_decoding(prompt: str, max_new_tokens: int = 50) -> Dict:
        """Standard autoregressive decoding with GPT-2 Medium."""
        input_ids = target_tokenizer.encode(prompt, return_tensors='pt').to(device)
        
        start_time = time.time()
        generated = target_model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            do_sample=False,  # Greedy for fair comparison
        )
        elapsed = time.time() - start_time
        
        new_tokens = generated.shape[1] - input_ids.shape[1]
        text = target_tokenizer.decode(generated[0], skip_special_tokens=True)
        
        return {
            'method': 'Standard (GPT-2 Medium)',
            'new_tokens': new_tokens,
            'time_s': round(elapsed, 3),
            'tok_per_s': round(new_tokens / elapsed, 1),
            'text': text,
        }
    
    @torch.no_grad()
    def real_speculative_decoding(prompt: str, max_new_tokens: int = 50, K: int = 4) -> Dict:
        """Simple speculative decoding with GPT-2 small (draft) and medium (target)."""
        input_ids = target_tokenizer.encode(prompt, return_tensors='pt').to(device)
        current_ids = input_ids.clone()
        
        start_time = time.time()
        total_accepted = 0
        total_drafted = 0
        new_tokens = 0
        
        while new_tokens < max_new_tokens:
            # Draft phase: generate K tokens with small model
            draft_ids = current_ids.clone()
            draft_tokens = []
            draft_probs_list = []
            
            for _ in range(K):
                outputs = draft_model(draft_ids)
                logits = outputs.logits[:, -1, :]
                probs = F.softmax(logits, dim=-1)
                token = torch.argmax(probs, dim=-1)  # Greedy
                draft_tokens.append(token.item())
                draft_probs_list.append(probs[0, token.item()].item())
                draft_ids = torch.cat([draft_ids, token.unsqueeze(0)], dim=1)
            
            # Verification: single forward pass through target model
            target_outputs = target_model(draft_ids)
            target_logits = target_outputs.logits
            
            # Check each draft token
            accepted = 0
            for i in range(K):
                pos = current_ids.shape[1] + i - 1  # Position in sequence
                target_probs = F.softmax(target_logits[:, current_ids.shape[1] - 1 + i, :], dim=-1)
                target_token = torch.argmax(target_probs, dim=-1).item()
                
                if target_token == draft_tokens[i]:
                    accepted += 1
                else:
                    # Use target's token instead
                    draft_tokens[i] = target_token
                    accepted += 1  # We still get a token from this position
                    break
            
            total_drafted += K
            total_accepted += accepted
            
            # Add accepted tokens
            new_token_ids = torch.tensor([draft_tokens[:accepted]], device=device)
            current_ids = torch.cat([current_ids, new_token_ids], dim=1)
            new_tokens += accepted
        
        elapsed = time.time() - start_time
        text = target_tokenizer.decode(current_ids[0], skip_special_tokens=True)
        
        return {
            'method': f'Speculative (K={K})',
            'new_tokens': new_tokens,
            'time_s': round(elapsed, 3),
            'tok_per_s': round(new_tokens / elapsed, 1),
            'acceptance_rate': round(total_accepted / total_drafted * 100, 1),
            'text': text,
        }
    
    # Run comparison
    prompt = "The future of artificial intelligence is"
    print(f"Prompt: '{prompt}'\n")
    
    r_std = real_standard_decoding(prompt, max_new_tokens=40)
    print(f"STANDARD: {r_std['tok_per_s']} tok/s ({r_std['time_s']}s)")
    print(f"  Text: {r_std['text'][:200]}...\n")
    
    for K in [3, 5]:
        r_spec = real_speculative_decoding(prompt, max_new_tokens=40, K=K)
        speedup = r_std['time_s'] / r_spec['time_s'] if r_spec['time_s'] > 0 else 0
        print(f"SPECULATIVE (K={K}): {r_spec['tok_per_s']} tok/s ({r_spec['time_s']}s) | Accept: {r_spec['acceptance_rate']}% | Speedup: {speedup:.2f}x")
        print(f"  Text: {r_spec['text'][:200]}...\n")
else:
    print("GPU not available -- real model benchmarks skipped.")
    print("The simulation results above demonstrate the same principles.")

## Part 9: Visualizing Speedup vs Acceptance Rate Tradeoff

Let's create a comprehensive visualization showing the relationship between acceptance rate and achieved speedup.

In [None]:
# Run many simulations to get empirical data points
np.random.seed(42)
empirical_data = []

for agreement in np.linspace(0.2, 0.98, 15):
    for K in [2, 3, 4, 5, 6]:
        for cost_ratio_mult in [0.5, 1.0, 1.5]:
            sim = SpeculativeDecodingSimulator(
                vocab_size=100,
                draft_latency_ms=3.0 * cost_ratio_mult,
                target_latency_ms=30.0,
                agreement_level=agreement,
            )
            r_std = sim.standard_decoding(15)
            r_spec = sim.speculative_decoding(15, K=K)
            speedup = r_std['time_s'] / r_spec['time_s'] if r_spec['time_s'] > 0 else 1.0
            empirical_data.append({
                'acceptance_rate': r_spec['acceptance_rate'],
                'speedup': speedup,
                'K': K,
                'cost_ratio': 3.0 * cost_ratio_mult / 30.0,
            })

# Plot
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Left: Scatter plot colored by K
ax = axes[0]
for K in [2, 3, 4, 5, 6]:
    points = [d for d in empirical_data if d['K'] == K]
    rates = [p['acceptance_rate'] for p in points]
    speedups = [p['speedup'] for p in points]
    ax.scatter(rates, speedups, s=60, alpha=0.6, label=f'K={K}')

ax.axhline(y=1, color='gray', linestyle='--', alpha=0.5, label='No speedup')
ax.set_xlabel('Acceptance Rate (%)', fontsize=12)
ax.set_ylabel('Speedup (x)', fontsize=12)
ax.set_title('Empirical: Speedup vs Acceptance Rate', fontsize=14, fontweight='bold')
ax.legend(fontsize=10)

# Right: Show the "sweet spot"
ax = axes[1]
acceptance_bins = np.arange(0, 105, 10)
for K in [2, 4, 6]:
    points = [d for d in empirical_data if d['K'] == K]
    bin_speedups = []
    bin_centers = []
    for b_start in acceptance_bins[:-1]:
        b_end = b_start + 10
        bin_pts = [p['speedup'] for p in points if b_start <= p['acceptance_rate'] < b_end]
        if bin_pts:
            bin_speedups.append(np.mean(bin_pts))
            bin_centers.append(b_start + 5)
    
    ax.plot(bin_centers, bin_speedups, 'o-', linewidth=2.5, markersize=8, label=f'K={K}')

ax.axhline(y=1, color='gray', linestyle='--', alpha=0.5)
ax.fill_between([50, 90], [0, 0], [5, 5], alpha=0.1, color='green')
ax.text(70, 0.3, 'Sweet Spot', ha='center', fontsize=12, color='green', fontweight='bold')
ax.set_xlabel('Acceptance Rate (%)', fontsize=12)
ax.set_ylabel('Average Speedup (x)', fontsize=12)
ax.set_title('Average Speedup by Acceptance Rate Bin', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.set_ylim(0, 4)

plt.tight_layout()
plt.show()

## Part 10: Key Takeaways

### Summary

1. **Speculative decoding is lossless**: The output distribution exactly matches the target model. Only speed changes, not quality.

2. **Acceptance rate is king**: The speedup is primarily determined by how well the draft model matches the target. Typical rates: 50-80% for well-matched model pairs.

3. **Draft model choice matters**: The draft model should be:
   - Much cheaper than the target (low cost ratio)
   - Highly correlated with the target (high acceptance rate)
   - Common choice: same architecture family, smaller size (e.g., Llama-7B draft for Llama-70B target)

4. **Optimal K varies**: Higher acceptance rates allow larger K. Typical values: K=3-5 for most scenarios.

5. **Best use cases**:
   - Memory-bound inference (large models on few GPUs)
   - Batch size = 1 (latency-sensitive applications)
   - Tasks where draft model aligns well with target (continuation, code, etc.)

### Real-World Examples

| System | Draft Model | Target Model | Reported Speedup |
|--------|-------------|-------------|------------------|
| Google (2023) | T5-small | PaLM-540B | 2-3x |
| DeepMind (2023) | Chinchilla-1B | Chinchilla-70B | 2.5x |
| Medusa | Multi-head drafts | Any LLM | 2-3x |

---

## Exercises

### Exercise 1: Adaptive K
Implement a version that adjusts K based on recent acceptance rates.

In [None]:
# Exercise 1: Implement adaptive K
# If acceptance rate is high (>80%), increase K
# If acceptance rate is low (<40%), decrease K

def adaptive_speculative_decoding(sim, num_tokens, initial_K=4):
    """
    TODO: Implement speculative decoding with adaptive K.
    Track a rolling window of acceptance rates and adjust K accordingly.
    """
    pass

print("Exercise 1: Implement adaptive K selection!")

### Exercise 2: Multi-Draft Speculation
Instead of one draft sequence, generate multiple candidate sequences and verify the best one.

In [None]:
# Exercise 2: Multi-draft speculation
# Generate N different draft sequences (using sampling)
# Verify all N in parallel
# Accept the one with the most matching tokens

def multi_draft_speculative(sim, num_tokens, K=4, num_drafts=3):
    """
    TODO: Generate multiple draft sequences and pick the best.
    """
    pass

print("Exercise 2: Implement multi-draft speculation!")

### Exercise 3: Break-even Analysis
For a given draft model cost, find the minimum acceptance rate needed for speculative decoding to be faster than standard decoding.

In [None]:
# Exercise 3: Break-even analysis
# For a given cost_ratio and K, find the minimum alpha where speedup > 1

def find_breakeven_alpha(cost_ratio: float, K: int) -> float:
    """
    TODO: Find the minimum acceptance rate where speculative
    decoding breaks even with standard decoding.
    
    Hint: Use binary search on alpha, checking theoretical_speedup > 1
    """
    pass

# Test: find_breakeven_alpha(0.1, 4) should return ~0.2-0.3
print("Exercise 3: Implement break-even analysis!")

---

**End of Notebook 15: Speculative Decoding (Draft-Target)**

Next: [Notebook 16 - Embedding Models & Cosine Similarity](./16_embedding_models.ipynb)