# Notebook 14: N-gram Speculation & Lookahead Decoding

---

## Inference Engineering Course

Welcome to Notebook 14! In this notebook, we explore **n-gram speculation** and **lookahead decoding** -- two powerful techniques that accelerate text generation by predicting multiple future tokens at once.

### What You Will Learn

| Topic | Description |
|-------|-------------|
| **N-gram Dictionaries** | Build frequency tables from input text to predict next tokens |
| **Speculative Generation** | Use n-gram matches to propose candidate continuations |
| **Lookahead Decoding** | Verify multiple speculated tokens in a single forward pass |
| **Speedup Measurement** | Quantify how much faster we go with n-gram hits |
| **Visualization** | Analyze match rates across different text types |

### Why N-gram Speculation?

Standard autoregressive decoding generates **one token at a time**. Each token requires a full forward pass through the model. N-gram speculation exploits a simple insight:

> **If the same n-gram patterns appear repeatedly in text, we can predict what comes next without running the model.**

This is especially powerful for:
- Repetitive text (legal documents, code, structured data)
- Text with recurring phrases
- Long-context generation where earlier patterns repeat

---

## Part 1: Setup & Installations

We will use only standard Python libraries and matplotlib for visualization. No GPU required!

In [None]:
# Standard library imports
import time
import random
import collections
from collections import defaultdict, Counter
from typing import List, Dict, Tuple, Optional

# Visualization
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np

# Set plotting style
plt.style.use('default')
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 12
plt.rcParams['axes.grid'] = True
plt.rcParams['grid.alpha'] = 0.3

print("All imports successful!")
print("No GPU required for this notebook.")

## Part 2: Understanding N-grams

An **n-gram** is a contiguous sequence of `n` items from a given text. For our purposes, items are **words** (though they could be characters or subword tokens).

| N | Name | Example (from "the cat sat on the mat") |
|---|------|-------------------------------------------|
| 1 | Unigram | `the`, `cat`, `sat`, `on`, `the`, `mat` |
| 2 | Bigram | `the cat`, `cat sat`, `sat on`, `on the`, `the mat` |
| 3 | Trigram | `the cat sat`, `cat sat on`, `sat on the`, `on the mat` |
| 4 | 4-gram | `the cat sat on`, `cat sat on the`, `sat on the mat` |

### The Key Idea

If we see the bigram `"on the"` and our n-gram dictionary says it is frequently followed by `"mat"`, we can **speculate** that the next token is `"mat"` without running the language model.

In [None]:
def simple_tokenize(text: str) -> List[str]:
    """
    Simple whitespace tokenizer that also handles basic punctuation.
    In practice, you would use a proper tokenizer (BPE, SentencePiece, etc.)
    """
    # Lowercase and split on whitespace
    tokens = text.lower().split()
    # Clean up tokens (remove leading/trailing punctuation but keep it attached)
    cleaned = []
    for token in tokens:
        token = token.strip()
        if token:
            cleaned.append(token)
    return cleaned

# Test the tokenizer
sample = "The cat sat on the mat. The cat sat on the rug."
tokens = simple_tokenize(sample)
print(f"Input:  {sample}")
print(f"Tokens: {tokens}")
print(f"Count:  {len(tokens)} tokens")

## Part 3: Building an N-gram Dictionary

The n-gram dictionary maps each n-gram prefix to a **distribution of next tokens**. This tells us: given a particular context of `n-1` words, what word is most likely to follow?

### Data Structure

```
ngram_dict = {
    ("the", "cat"): Counter({"sat": 2, "ran": 1}),
    ("cat", "sat"): Counter({"on": 2}),
    ("sat", "on"): Counter({"the": 2}),
    ...
}
```

The **key** is a tuple of `n-1` tokens (the context), and the **value** is a Counter of next tokens with their frequencies.

In [None]:
class NgramDictionary:
    """
    Builds and queries an n-gram dictionary from text.
    
    The dictionary maps n-gram prefixes (context) to a distribution
    of next tokens, enabling speculative prediction.
    """
    
    def __init__(self, n: int = 3):
        """
        Args:
            n: The size of the n-gram (e.g., 3 means trigram, 
               so context is 2 tokens)
        """
        self.n = n
        self.context_size = n - 1  # How many tokens form the context
        self.dictionary = defaultdict(Counter)
        self.total_entries = 0
    
    def build_from_tokens(self, tokens: List[str]):
        """
        Build the n-gram dictionary from a list of tokens.
        
        For each position i, we take tokens[i:i+n-1] as the context
        and tokens[i+n-1] as the next token.
        """
        self.dictionary.clear()
        self.total_entries = 0
        
        for i in range(len(tokens) - self.context_size):
            # Context: n-1 tokens
            context = tuple(tokens[i:i + self.context_size])
            # Next token
            next_token = tokens[i + self.context_size]
            # Record the observation
            self.dictionary[context][next_token] += 1
            self.total_entries += 1
    
    def predict_next(self, context: Tuple[str, ...]) -> Optional[str]:
        """
        Predict the most likely next token given a context.
        Returns None if context not found in dictionary.
        """
        if context in self.dictionary:
            # Return the most common next token
            return self.dictionary[context].most_common(1)[0][0]
        return None
    
    def predict_sequence(self, context: Tuple[str, ...], 
                         max_length: int = 5) -> List[str]:
        """
        Predict a sequence of tokens greedily using n-gram lookups.
        Keeps predicting until no match is found or max_length reached.
        """
        predictions = []
        current_context = list(context)
        
        for _ in range(max_length):
            ctx = tuple(current_context[-self.context_size:])
            next_token = self.predict_next(ctx)
            if next_token is None:
                break
            predictions.append(next_token)
            current_context.append(next_token)
        
        return predictions
    
    def get_stats(self) -> Dict:
        """Return statistics about the dictionary."""
        unique_contexts = len(self.dictionary)
        avg_next_tokens = np.mean([len(v) for v in self.dictionary.values()]) if self.dictionary else 0
        return {
            'n': self.n,
            'unique_contexts': unique_contexts,
            'total_entries': self.total_entries,
            'avg_next_tokens_per_context': round(avg_next_tokens, 2),
        }

print("NgramDictionary class defined successfully!")

In [None]:
# Build a trigram dictionary from our sample text
sample_text = """
The cat sat on the mat. The cat sat on the rug. The dog sat on the mat.
The cat ran to the door. The dog ran to the yard. The cat sat on the mat again.
A bird flew over the mat. The cat watched the bird. The cat sat on the mat.
"""

tokens = simple_tokenize(sample_text)
print(f"Total tokens: {len(tokens)}")
print(f"Unique tokens: {len(set(tokens))}")
print()

# Build dictionaries for different n values
for n in [2, 3, 4]:
    ngram = NgramDictionary(n=n)
    ngram.build_from_tokens(tokens)
    stats = ngram.get_stats()
    print(f"\n--- {n}-gram Dictionary ---")
    for key, value in stats.items():
        print(f"  {key}: {value}")

In [None]:
# Explore the trigram dictionary
trigram = NgramDictionary(n=3)
trigram.build_from_tokens(tokens)

print("=" * 60)
print("TRIGRAM DICTIONARY (showing top entries)")
print("=" * 60)

# Sort by total frequency
sorted_entries = sorted(
    trigram.dictionary.items(),
    key=lambda x: sum(x[1].values()),
    reverse=True
)

for context, next_tokens in sorted_entries[:15]:
    context_str = ' '.join(context)
    total = sum(next_tokens.values())
    predictions = ', '.join(f'"{tok}"({cnt})' for tok, cnt in next_tokens.most_common(3))
    print(f"  [{context_str:20s}] -> {predictions}  (total: {total})")

## Part 4: N-gram Speculative Token Generation

Now let's use our n-gram dictionary for **speculative generation**. The idea:

1. Look at the last `n-1` tokens as context
2. If there's a match in our dictionary, **speculate** the next token(s)
3. Verify the speculation against a "target model" (simulated)
4. If accepted, we skip the expensive forward pass!

### The Speculation Loop

```
while not done:
    context = last (n-1) tokens
    speculation = ngram_dict.predict_sequence(context, k)
    
    if speculation:  # We have a guess!
        # Verify against the target model
        accepted = verify(speculation, target_model)
        if accepted > 0:
            # Free tokens! No forward pass needed.
            append accepted tokens
        else:
            # Fall back to normal decoding
            token = target_model.generate_one()
    else:
        # No n-gram match, normal decoding
        token = target_model.generate_one()
```

In [None]:
class SimulatedLM:
    """
    A simulated language model that generates from a fixed text.
    
    This lets us test speculation without needing a real LLM.
    The 'model' simply returns the next token from the reference text.
    """
    
    def __init__(self, reference_tokens: List[str], latency_ms: float = 50.0):
        """
        Args:
            reference_tokens: The "ground truth" text the model would generate
            latency_ms: Simulated per-token latency in milliseconds
        """
        self.reference = reference_tokens
        self.latency_ms = latency_ms
        self.position = 0
        self.forward_passes = 0
    
    def generate_one(self) -> Optional[str]:
        """Generate one token (simulated with latency)."""
        if self.position >= len(self.reference):
            return None
        # Simulate model latency
        time.sleep(self.latency_ms / 1000.0)
        self.forward_passes += 1
        token = self.reference[self.position]
        self.position += 1
        return token
    
    def verify_tokens(self, tokens: List[str]) -> int:
        """
        Verify a batch of speculated tokens against the reference.
        Returns the number of tokens accepted (from the start).
        
        This simulates a single forward pass that checks all tokens.
        """
        time.sleep(self.latency_ms / 1000.0)  # One forward pass for all
        self.forward_passes += 1
        
        accepted = 0
        for i, token in enumerate(tokens):
            if self.position + i >= len(self.reference):
                break
            if token == self.reference[self.position + i]:
                accepted += 1
            else:
                break  # First mismatch stops acceptance
        
        self.position += max(accepted, 1)  # Always advance at least 1
        return accepted
    
    def reset(self):
        """Reset the model state."""
        self.position = 0
        self.forward_passes = 0

print("SimulatedLM class defined!")

In [None]:
def standard_decoding(model: SimulatedLM, num_tokens: int) -> Dict:
    """
    Standard autoregressive decoding: one token at a time.
    """
    model.reset()
    generated = []
    start_time = time.time()
    
    for _ in range(num_tokens):
        token = model.generate_one()
        if token is None:
            break
        generated.append(token)
    
    elapsed = time.time() - start_time
    return {
        'method': 'Standard',
        'tokens_generated': len(generated),
        'forward_passes': model.forward_passes,
        'time_seconds': round(elapsed, 3),
        'tokens_per_second': round(len(generated) / elapsed, 1) if elapsed > 0 else 0,
        'generated_text': ' '.join(generated)
    }


def ngram_speculative_decoding(
    model: SimulatedLM, 
    ngram_dict: NgramDictionary,
    num_tokens: int,
    speculation_length: int = 4
) -> Dict:
    """
    N-gram speculative decoding: use n-gram predictions to skip forward passes.
    """
    model.reset()
    generated = []
    start_time = time.time()
    
    total_speculations = 0
    total_accepted = 0
    speculation_attempts = 0
    
    while len(generated) < num_tokens:
        # Try n-gram speculation
        speculation = []
        if len(generated) >= ngram_dict.context_size:
            context = tuple(generated[-ngram_dict.context_size:])
            speculation = ngram_dict.predict_sequence(context, speculation_length)
        
        if speculation:
            # We have speculated tokens -- verify them
            speculation_attempts += 1
            total_speculations += len(speculation)
            accepted = model.verify_tokens(speculation)
            total_accepted += accepted
            
            if accepted > 0:
                generated.extend(speculation[:accepted])
            else:
                # Speculation failed, generate one token
                # The verify call already advanced position by 1
                if model.position <= len(model.reference):
                    generated.append(model.reference[model.position - 1])
        else:
            # No n-gram match, standard decoding
            token = model.generate_one()
            if token is None:
                break
            generated.append(token)
    
    elapsed = time.time() - start_time
    acceptance_rate = (total_accepted / total_speculations * 100) if total_speculations > 0 else 0
    
    return {
        'method': 'N-gram Speculative',
        'tokens_generated': len(generated),
        'forward_passes': model.forward_passes,
        'time_seconds': round(elapsed, 3),
        'tokens_per_second': round(len(generated) / elapsed, 1) if elapsed > 0 else 0,
        'speculation_attempts': speculation_attempts,
        'total_speculated': total_speculations,
        'total_accepted': total_accepted,
        'acceptance_rate': round(acceptance_rate, 1),
        'generated_text': ' '.join(generated)
    }

print("Decoding functions defined!")

In [None]:
# Create a repetitive text for testing
# (Repetitive text benefits most from n-gram speculation)
repetitive_text = """
The quick brown fox jumps over the lazy dog. The quick brown fox jumps over the lazy cat.
The quick brown fox jumps over the lazy dog. The slow brown fox jumps over the lazy dog.
The quick brown fox jumps over the lazy dog. The quick brown fox runs past the lazy dog.
The quick brown fox jumps over the lazy dog. The quick brown fox jumps over the lazy dog.
""" * 3  # Repeat for more data

tokens = simple_tokenize(repetitive_text)
print(f"Text has {len(tokens)} tokens, {len(set(tokens))} unique")

# Build trigram dictionary from the text
ngram = NgramDictionary(n=3)
ngram.build_from_tokens(tokens)
print(f"Dictionary has {len(ngram.dictionary)} unique contexts")

# Create simulated model (with 20ms latency per forward pass)
model = SimulatedLM(tokens, latency_ms=20.0)

# Run standard decoding
num_tokens = min(40, len(tokens))
print(f"\nGenerating {num_tokens} tokens...")
print("\n" + "=" * 60)

result_standard = standard_decoding(model, num_tokens)
print(f"STANDARD DECODING:")
for k, v in result_standard.items():
    if k != 'generated_text':
        print(f"  {k}: {v}")

print("\n" + "-" * 60)

result_speculative = ngram_speculative_decoding(model, ngram, num_tokens)
print(f"N-GRAM SPECULATIVE DECODING:")
for k, v in result_speculative.items():
    if k != 'generated_text':
        print(f"  {k}: {v}")

# Calculate speedup
if result_standard['time_seconds'] > 0:
    speedup = result_standard['time_seconds'] / result_speculative['time_seconds']
    print(f"\nSPEEDUP: {speedup:.2f}x faster with n-gram speculation!")

## Part 5: Implementing Lookahead Decoding

**Lookahead decoding** extends the speculation idea by maintaining a **window** of potential future tokens. Instead of just predicting one step ahead, we:

1. Build n-gram candidates from the generated text so far
2. Maintain multiple parallel speculation branches
3. Verify the best branch in a single forward pass

### Lookahead vs Simple Speculation

| Feature | Simple N-gram | Lookahead |
|---------|--------------|----------|
| Speculation depth | Fixed | Dynamic window |
| Branches | Single best | Multiple candidates |
| N-gram source | Pre-built dictionary | Dynamic from generation |
| Adaptiveness | Static | Adapts to new text |

In [None]:
class LookaheadDecoder:
    """
    Implements a simplified version of Lookahead Decoding.
    
    Key idea: as we generate text, we dynamically build an n-gram
    dictionary from the generated text itself. This means the
    dictionary grows and improves as generation proceeds.
    """
    
    def __init__(self, n: int = 3, window_size: int = 5, max_branches: int = 3):
        self.n = n
        self.window_size = window_size  # How far ahead to speculate
        self.max_branches = max_branches  # Max parallel speculation branches
        self.dynamic_dict = NgramDictionary(n=n)
        
    def decode(self, model: SimulatedLM, num_tokens: int, 
               seed_text: List[str] = None) -> Dict:
        """
        Perform lookahead decoding.
        
        The n-gram dictionary is built dynamically from already-generated text.
        """
        model.reset()
        generated = list(seed_text) if seed_text else []
        
        # Skip model position to account for seed
        model.position = len(generated)
        
        start_time = time.time()
        total_speculated = 0
        total_accepted = 0
        speculation_rounds = 0
        
        while len(generated) < num_tokens + len(seed_text or []):
            # Dynamically rebuild n-gram dict from generated text so far
            if len(generated) >= self.n:
                self.dynamic_dict.build_from_tokens(generated)
            
            # Try to speculate
            speculation = []
            if len(generated) >= self.dynamic_dict.context_size:
                context = tuple(generated[-self.dynamic_dict.context_size:])
                speculation = self.dynamic_dict.predict_sequence(
                    context, self.window_size
                )
            
            if speculation:
                speculation_rounds += 1
                total_speculated += len(speculation)
                accepted = model.verify_tokens(speculation)
                total_accepted += accepted
                
                if accepted > 0:
                    generated.extend(speculation[:accepted])
                else:
                    if model.position <= len(model.reference):
                        generated.append(model.reference[model.position - 1])
            else:
                token = model.generate_one()
                if token is None:
                    break
                generated.append(token)
        
        elapsed = time.time() - start_time
        acceptance_rate = (total_accepted / total_speculated * 100) if total_speculated > 0 else 0
        
        return {
            'method': 'Lookahead',
            'tokens_generated': len(generated),
            'forward_passes': model.forward_passes,
            'time_seconds': round(elapsed, 3),
            'tokens_per_second': round(len(generated) / elapsed, 1) if elapsed > 0 else 0,
            'speculation_rounds': speculation_rounds,
            'total_speculated': total_speculated,
            'total_accepted': total_accepted,
            'acceptance_rate': round(acceptance_rate, 1),
        }

print("LookaheadDecoder class defined!")

In [None]:
# Compare all three methods
repetitive_tokens = simple_tokenize(repetitive_text)
num_gen = min(30, len(repetitive_tokens) - 5)

# Standard decoding
model = SimulatedLM(repetitive_tokens, latency_ms=15.0)
r_standard = standard_decoding(model, num_gen)

# N-gram speculative
model = SimulatedLM(repetitive_tokens, latency_ms=15.0)
ngram = NgramDictionary(n=3)
ngram.build_from_tokens(repetitive_tokens)
r_ngram = ngram_speculative_decoding(model, ngram, num_gen, speculation_length=4)

# Lookahead decoding
model = SimulatedLM(repetitive_tokens, latency_ms=15.0)
lookahead = LookaheadDecoder(n=3, window_size=4)
# Give it a seed of first 3 tokens
seed = repetitive_tokens[:3]
r_lookahead = lookahead.decode(model, num_gen, seed_text=seed)

# Display results
print("=" * 70)
print(f"{'Method':<25} {'Tokens':<10} {'FW Passes':<12} {'Time (s)':<10} {'Tok/s':<10}")
print("=" * 70)
for r in [r_standard, r_ngram, r_lookahead]:
    print(f"{r['method']:<25} {r['tokens_generated']:<10} {r['forward_passes']:<12} {r['time_seconds']:<10} {r['tokens_per_second']:<10}")

## Part 6: Measuring Speedup from N-gram Hits

The speedup from n-gram speculation depends critically on the **hit rate** -- how often our n-gram predictions are correct. Let's analyze this across different scenarios.

In [None]:
def measure_ngram_hit_rate(text: str, n: int = 3) -> Dict:
    """
    Measure how often n-gram predictions match the actual next token.
    
    We use leave-one-out: for each position, we build the dictionary
    from all OTHER positions and predict the current one.
    
    For simplicity, we build from the full text and measure on the same text.
    """
    tokens = simple_tokenize(text)
    ngram_dict = NgramDictionary(n=n)
    ngram_dict.build_from_tokens(tokens)
    
    hits = 0
    misses = 0
    no_match = 0
    context_size = n - 1
    
    for i in range(context_size, len(tokens)):
        context = tuple(tokens[i - context_size:i])
        prediction = ngram_dict.predict_next(context)
        actual = tokens[i]
        
        if prediction is None:
            no_match += 1
        elif prediction == actual:
            hits += 1
        else:
            misses += 1
    
    total = hits + misses + no_match
    return {
        'n': n,
        'total_predictions': total,
        'hits': hits,
        'misses': misses,
        'no_match': no_match,
        'hit_rate': round(hits / total * 100, 1) if total > 0 else 0,
        'match_rate': round((hits + misses) / total * 100, 1) if total > 0 else 0,
    }

# Test texts with different repetition levels
texts = {
    'Highly Repetitive': """
        The server returns a response. The client sends a request.
        The server returns a response. The client sends a request.
        The server returns a response. The client sends a request.
        The server returns a response. The client sends a request.
    """ * 3,
    
    'Code-like': """
        for i in range of n do begin
        if x is greater than y then return x end
        for j in range of m do begin
        if a is greater than b then return a end
        for i in range of n do begin
        if x is greater than y then return y end
    """ * 3,
    
    'Semi-Repetitive': """
        The cat sat on the mat near the window.
        The dog lay on the rug by the fireplace.
        The bird perched on the branch above the garden.
        The cat sat on the mat watching the bird.
        The dog lay on the rug chewing a bone.
    """ * 2,
    
    'Diverse (low repeat)': """
        Quantum mechanics describes nature at the smallest scales.
        Einstein proposed general relativity in nineteen fifteen.
        The double slit experiment reveals wave particle duality.
        Heisenberg uncertainty principle limits simultaneous measurements.
        Schrodinger equation governs quantum state evolution.
    """,
}

print(f"{'Text Type':<25} {'N':>3} {'Hit Rate':>10} {'Match Rate':>12} {'Hits':>6} {'Total':>7}")
print("=" * 70)

results_by_type = {}
for text_type, text in texts.items():
    results_by_type[text_type] = []
    for n in [2, 3, 4, 5]:
        result = measure_ngram_hit_rate(text, n=n)
        results_by_type[text_type].append(result)
        print(f"{text_type:<25} {n:>3} {result['hit_rate']:>9}% {result['match_rate']:>11}% {result['hits']:>6} {result['total_predictions']:>7}")
    print()

## Part 7: Visualizations

Let's create rich visualizations to understand n-gram speculation behavior.

In [None]:
# Visualization 1: N-gram hit rates across text types
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Left: Grouped bar chart of hit rates
ax = axes[0]
text_types = list(results_by_type.keys())
n_values = [2, 3, 4, 5]
x = np.arange(len(text_types))
width = 0.18
colors = ['#2196F3', '#4CAF50', '#FF9800', '#F44336']

for i, n in enumerate(n_values):
    hit_rates = [results_by_type[tt][i]['hit_rate'] for tt in text_types]
    bars = ax.bar(x + i * width, hit_rates, width, label=f'n={n}', color=colors[i], alpha=0.85)
    # Add value labels on bars
    for bar, val in zip(bars, hit_rates):
        if val > 0:
            ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 1,
                   f'{val:.0f}%', ha='center', va='bottom', fontsize=8)

ax.set_xlabel('Text Type', fontsize=12)
ax.set_ylabel('Hit Rate (%)', fontsize=12)
ax.set_title('N-gram Hit Rate by Text Type & N', fontsize=14, fontweight='bold')
ax.set_xticks(x + width * 1.5)
ax.set_xticklabels(text_types, rotation=15, ha='right', fontsize=10)
ax.legend(title='N-gram Size', fontsize=10)
ax.set_ylim(0, 110)

# Right: Theoretical speedup vs hit rate
ax = axes[1]
hit_rates = np.linspace(0, 1, 100)
spec_lengths = [2, 3, 4, 5]

for spec_len in spec_lengths:
    # Speedup formula: expected tokens per forward pass
    # With hit rate p and speculation length k:
    # E[tokens per step] = 1 + p + p^2 + ... + p^(k-1) = (1 - p^k) / (1 - p)
    speedups = []
    for p in hit_rates:
        if p < 0.999:
            expected = (1 - p**spec_len) / (1 - p)
        else:
            expected = spec_len
        speedups.append(expected)
    ax.plot(hit_rates * 100, speedups, linewidth=2.5, label=f'k={spec_len}')

ax.set_xlabel('Acceptance Rate (%)', fontsize=12)
ax.set_ylabel('Expected Speedup (x)', fontsize=12)
ax.set_title('Theoretical Speedup vs Acceptance Rate', fontsize=14, fontweight='bold')
ax.legend(title='Speculation Length k', fontsize=10)
ax.set_xlim(0, 100)

plt.tight_layout()
plt.savefig('/tmp/ngram_hit_rates.png', dpi=150, bbox_inches='tight')
plt.show()
print("Plot saved!")

In [None]:
# Visualization 2: N-gram dictionary coverage
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

for idx, (text_type, text) in enumerate(texts.items()):
    ax = axes[idx // 2][idx % 2]
    tokens = simple_tokenize(text)
    
    # Build n-gram dict and analyze prediction certainty
    ngram = NgramDictionary(n=3)
    ngram.build_from_tokens(tokens)
    
    # For each context, measure the entropy of predictions
    certainties = []
    for context, next_tokens in ngram.dictionary.items():
        total = sum(next_tokens.values())
        if total > 0:
            # Certainty = probability of most likely next token
            max_count = next_tokens.most_common(1)[0][1]
            certainty = max_count / total
            certainties.append(certainty)
    
    if certainties:
        ax.hist(certainties, bins=20, color=colors[idx], alpha=0.7, edgecolor='black')
        ax.axvline(np.mean(certainties), color='red', linestyle='--', linewidth=2,
                   label=f'Mean: {np.mean(certainties):.2f}')
    
    ax.set_title(f'{text_type}', fontsize=13, fontweight='bold')
    ax.set_xlabel('Prediction Certainty', fontsize=11)
    ax.set_ylabel('Number of Contexts', fontsize=11)
    ax.legend(fontsize=10)
    ax.set_xlim(0, 1.05)

plt.suptitle('N-gram Prediction Certainty Distribution (Trigrams)',
             fontsize=15, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

In [None]:
# Visualization 3: Speculation timeline -- showing accepted vs rejected tokens
fig, ax = plt.subplots(figsize=(16, 5))

# Simulate a speculation timeline
np.random.seed(42)
timeline = []
pos = 0
step = 0

while pos < 50:
    # Randomly decide if n-gram match exists
    has_match = np.random.random() < 0.6  # 60% match rate
    
    if has_match:
        spec_len = np.random.randint(2, 6)
        # How many get accepted?
        accepted = np.random.randint(0, spec_len + 1)
        if accepted > 0:
            for i in range(accepted):
                timeline.append(('accepted', step, pos + i))
            for i in range(accepted, spec_len):
                timeline.append(('rejected', step, pos + i))
            pos += accepted
        else:
            for i in range(spec_len):
                timeline.append(('rejected', step, pos + i))
            timeline.append(('standard', step, pos))
            pos += 1
    else:
        timeline.append(('standard', step, pos))
        pos += 1
    step += 1

# Plot
for token_type, step_num, token_pos in timeline:
    if token_type == 'accepted':
        ax.scatter(token_pos, step_num, c='#4CAF50', s=120, zorder=5, marker='s')
    elif token_type == 'rejected':
        ax.scatter(token_pos, step_num, c='#F44336', s=80, zorder=4, marker='x', linewidths=2)
    else:
        ax.scatter(token_pos, step_num, c='#2196F3', s=100, zorder=5, marker='o')

# Add legend
legend_elements = [
    plt.scatter([], [], c='#4CAF50', s=120, marker='s', label='Accepted (free!)'),
    plt.scatter([], [], c='#F44336', s=80, marker='x', linewidths=2, label='Rejected'),
    plt.scatter([], [], c='#2196F3', s=100, marker='o', label='Standard (1 fwd pass)'),
]
ax.legend(handles=legend_elements, loc='upper left', fontsize=11)

ax.set_xlabel('Token Position', fontsize=12)
ax.set_ylabel('Decoding Step', fontsize=12)
ax.set_title('N-gram Speculation Timeline: Tokens Generated per Step',
             fontsize=14, fontweight='bold')
ax.invert_yaxis()

plt.tight_layout()
plt.show()
print("Green squares = accepted speculations (no forward pass needed!)")
print("Red X = rejected speculations (wasted)")
print("Blue circles = standard decoding (one forward pass each)")

In [None]:
# Visualization 4: Speedup benchmark across text types
fig, ax = plt.subplots(figsize=(12, 6))

text_types_list = list(texts.keys())
standard_times = []
speculative_times = []
speedups = []

for text_type in text_types_list:
    text = texts[text_type]
    toks = simple_tokenize(text)
    num_gen = min(25, len(toks) - 3)
    
    # Standard
    model = SimulatedLM(toks, latency_ms=10.0)
    r_std = standard_decoding(model, num_gen)
    standard_times.append(r_std['time_seconds'])
    
    # Speculative
    model = SimulatedLM(toks, latency_ms=10.0)
    ngram = NgramDictionary(n=3)
    ngram.build_from_tokens(toks)
    r_spec = ngram_speculative_decoding(model, ngram, num_gen)
    speculative_times.append(r_spec['time_seconds'])
    
    sp = r_std['time_seconds'] / r_spec['time_seconds'] if r_spec['time_seconds'] > 0 else 1.0
    speedups.append(sp)

x = np.arange(len(text_types_list))
width = 0.35

bars1 = ax.bar(x - width/2, standard_times, width, label='Standard Decoding',
               color='#F44336', alpha=0.8)
bars2 = ax.bar(x + width/2, speculative_times, width, label='N-gram Speculative',
               color='#4CAF50', alpha=0.8)

# Add speedup annotations
for i, (s, sp) in enumerate(zip(standard_times, speedups)):
    ax.annotate(f'{sp:.1f}x', xy=(i, max(standard_times[i], speculative_times[i])),
               xytext=(0, 10), textcoords='offset points',
               ha='center', fontsize=12, fontweight='bold', color='#1565C0')

ax.set_xlabel('Text Type', fontsize=12)
ax.set_ylabel('Time (seconds)', fontsize=12)
ax.set_title('Decoding Speed Comparison by Text Type', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(text_types_list, rotation=15, ha='right', fontsize=10)
ax.legend(fontsize=11)

plt.tight_layout()
plt.show()

## Part 8: Effect of N-gram Size on Speculation Quality

There is a fundamental tradeoff with n-gram size:

| Larger N | Smaller N |
|----------|----------|
| More specific contexts | More general contexts |
| Higher precision when matched | Lower precision |
| Fewer matches (sparse) | More matches (dense) |
| Better for long repeated phrases | Better for short patterns |

In [None]:
# Analyze the precision-coverage tradeoff for different n values
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

test_text = (repetitive_text + 
             texts['Code-like'] + 
             texts['Semi-Repetitive']) * 2

n_values = range(2, 8)
precisions = []
coverages = []
dict_sizes = []

for n in n_values:
    result = measure_ngram_hit_rate(test_text, n=n)
    precision = result['hits'] / (result['hits'] + result['misses']) * 100 if (result['hits'] + result['misses']) > 0 else 0
    coverage = result['match_rate']
    precisions.append(precision)
    coverages.append(coverage)
    
    toks = simple_tokenize(test_text)
    ngram_d = NgramDictionary(n=n)
    ngram_d.build_from_tokens(toks)
    dict_sizes.append(len(ngram_d.dictionary))

# Left: Precision vs Coverage
ax = axes[0]
ax.plot(list(n_values), precisions, 'o-', color='#4CAF50', linewidth=2.5, 
        markersize=10, label='Precision')
ax.plot(list(n_values), coverages, 's-', color='#2196F3', linewidth=2.5, 
        markersize=10, label='Coverage')
ax.set_xlabel('N-gram Size (N)', fontsize=12)
ax.set_ylabel('Percentage (%)', fontsize=12)
ax.set_title('Precision vs Coverage by N-gram Size', fontsize=14, fontweight='bold')
ax.legend(fontsize=12)
ax.set_xticks(list(n_values))

# Right: Dictionary size
ax = axes[1]
ax.bar(list(n_values), dict_sizes, color='#FF9800', alpha=0.8, edgecolor='black')
ax.set_xlabel('N-gram Size (N)', fontsize=12)
ax.set_ylabel('Unique Contexts', fontsize=12)
ax.set_title('Dictionary Size by N-gram Size', fontsize=14, fontweight='bold')
ax.set_xticks(list(n_values))

for i, (n, size) in enumerate(zip(n_values, dict_sizes)):
    ax.text(n, size + max(dict_sizes)*0.02, str(size), ha='center', fontsize=10)

plt.tight_layout()
plt.show()

## Part 9: Key Takeaways

### What We Learned

1. **N-gram dictionaries** can predict next tokens without running an LLM, providing "free" tokens when predictions are correct.

2. **Hit rate depends on text repetitiveness**: highly repetitive text (code, legal docs, structured data) benefits enormously; diverse creative text benefits less.

3. **N-gram size tradeoff**: larger n-grams are more precise but match less often. Trigrams (n=3) often offer a good balance.

4. **Lookahead decoding** improves on static n-grams by dynamically building the dictionary from already-generated text.

5. **Verification is cheap**: checking multiple speculated tokens in one forward pass costs the same as generating one token.

### Connection to Real Systems

- **Google's LLMA** uses n-gram speculation from prompt text for faster generation
- **Lookahead Decoding** (Stern et al.) uses this approach with real Transformer models
- **REST** (Retrieval-based Speculative Decoding) retrieves draft tokens from a datastore

### Next Up: Notebook 15

In the next notebook, we move from simple n-gram prediction to **Speculative Decoding with Draft Models**, where a small neural network proposes tokens and a large model verifies them.

## Exercises

Try these exercises to deepen your understanding:

### Exercise 1: Variable N-gram Speculation
Modify the `NgramDictionary` to try multiple n-gram sizes (5, 4, 3, 2) in order, using the longest matching context first.

In [None]:
# Exercise 1: Implement a multi-level n-gram dictionary
# Hint: Build dictionaries for n=2,3,4,5 and try the largest n first

class MultiLevelNgramDict:
    def __init__(self, max_n: int = 5):
        self.max_n = max_n
        self.dictionaries = {}
        # TODO: Create NgramDictionary for each n from 2 to max_n
        pass
    
    def build_from_tokens(self, tokens: List[str]):
        # TODO: Build all dictionaries
        pass
    
    def predict_next(self, tokens: List[str]) -> Optional[str]:
        # TODO: Try largest n first, fall back to smaller n
        pass

# Your code here...
print("Exercise 1: Implement the MultiLevelNgramDict class above!")

### Exercise 2: Measure Real-World Text
Try the n-gram analysis on different types of real text. What hit rates do you observe?

In [None]:
# Exercise 2: Analyze n-gram hit rates on different text types
# Try: Python code, JSON data, news articles, poetry, etc.

your_text = """
# Paste some text here and analyze the n-gram hit rates!
# Try different types of text:
# - Python code
# - JSON/XML structured data  
# - Legal document text
# - News article
# - Poetry or creative writing
"""

# Uncomment and run:
# for n in [2, 3, 4, 5]:
#     result = measure_ngram_hit_rate(your_text, n=n)
#     print(f"n={n}: hit_rate={result['hit_rate']}%, coverage={result['match_rate']}%")

print("Exercise 2: Paste your own text above and analyze!")

### Exercise 3: Adaptive Speculation Length
Instead of a fixed speculation length, adapt it based on recent acceptance rates.

In [None]:
# Exercise 3: Implement adaptive speculation length
# If recent acceptance rate is high -> speculate more tokens
# If recent acceptance rate is low -> speculate fewer tokens

def adaptive_ngram_decoding(model, ngram_dict, num_tokens,
                            min_spec=1, max_spec=8, window=10):
    """
    TODO: Implement adaptive speculation that adjusts k based on
    a rolling window of recent acceptance rates.
    
    Hint: Keep a deque of last `window` acceptance rates.
    If mean acceptance > 0.7, increase spec length.
    If mean acceptance < 0.3, decrease spec length.
    """
    pass

print("Exercise 3: Implement the adaptive_ngram_decoding function!")

---

**End of Notebook 14: N-gram Speculation & Lookahead Decoding**

Next: [Notebook 15 - Speculative Decoding (Draft-Target)](./15_speculative_decoding.ipynb)