# Notebook 5: Temperature, Top-k, and Top-p Sampling

## Inference Engineering Course

---

### What You'll Learn

When a language model generates text, it doesn't simply "choose" the next word. Instead, it produces a **probability distribution** over its entire vocabulary, and we must decide *how* to select from that distribution. This decision process is called **sampling**, and it has a profound impact on the quality, creativity, and reliability of generated text.

In this notebook, we will:

1. **Understand logits and softmax** - how raw model outputs become probabilities
2. **Implement temperature scaling** - controlling randomness in generation
3. **Implement top-k filtering** - limiting the candidate pool
4. **Implement top-p (nucleus) sampling** - adaptive candidate selection
5. **Visualize** how each strategy reshapes probability distributions
6. **Compare outputs** from different strategies on a real model

### Prerequisites
- Basic Python and NumPy
- Understanding of probability distributions
- Familiarity with neural network basics

### Runtime
- **No GPU required** for most cells
- GPU recommended for the real model generation section at the end

---

## 1. Setup and Imports

Let's install and import everything we need. We keep dependencies minimal so this runs on free Colab.

In [None]:
!pip install tiktoken matplotlib numpy torch transformers -q

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import torch
import torch.nn.functional as F
import tiktoken
from collections import Counter

# Reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Plotting style
plt.rcParams['figure.figsize'] = (12, 5)
plt.rcParams['font.size'] = 12
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False

print("All imports successful!")

## 2. From Logits to Probabilities: The Softmax Function

### What Are Logits?

A language model's final layer outputs a vector of **logits** - one value per token in the vocabulary. These are raw, unnormalized scores that can be any real number (positive, negative, or zero).

For example, if the vocabulary has 50,000 tokens and the model has processed the prompt *"The capital of France is"*, it will output 50,000 logits. The token for *"Paris"* will likely have a high logit, while *"banana"* will have a low one.

### The Softmax Transformation

To convert logits into a valid probability distribution, we use the **softmax function**:

$$P(x_i) = \frac{e^{x_i}}{\sum_{j} e^{x_j}}$$

This guarantees:
- All probabilities are between 0 and 1
- All probabilities sum to 1
- Higher logits get higher probabilities
- The relative ordering is preserved

Let's implement this from scratch and see it in action.

In [None]:
def softmax(logits):
    """Numerically stable softmax implementation.
    
    We subtract the max for numerical stability - this prevents
    overflow when computing exp() of large numbers, without
    changing the result (since it cancels out).
    """
    # Subtract max for numerical stability
    shifted = logits - np.max(logits)
    exp_values = np.exp(shifted)
    return exp_values / np.sum(exp_values)

# Simulate logits for a small vocabulary
vocab = ["Paris", "London", "Berlin", "the", "a", "is", "Rome", "Madrid", "banana", "cat"]
logits = np.array([8.5, 5.2, 4.1, 2.0, 1.5, 1.0, 3.8, 3.5, -2.0, -3.5])

probs = softmax(logits)

print("Token Probabilities after Softmax:")
print("=" * 45)
for token, logit, prob in sorted(zip(vocab, logits, probs), key=lambda x: -x[2]):
    bar = '█' * int(prob * 100)
    print(f"{token:>8s} | logit={logit:>6.1f} | prob={prob:.4f} | {bar}")
print(f"\nSum of probabilities: {probs.sum():.6f}")

Notice how softmax **amplifies differences** - "Paris" with a logit of 8.5 gets a much larger share of probability than you might expect from the raw numbers. This is because the exponential function grows very quickly.

Let's visualize this transformation.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Sort by logit value for better visualization
sorted_indices = np.argsort(logits)[::-1]
sorted_vocab = [vocab[i] for i in sorted_indices]
sorted_logits = logits[sorted_indices]
sorted_probs = probs[sorted_indices]

colors = plt.cm.RdYlGn(np.linspace(0.2, 0.9, len(vocab)))

# Plot logits
axes[0].barh(range(len(vocab)), sorted_logits, color=colors)
axes[0].set_yticks(range(len(vocab)))
axes[0].set_yticklabels(sorted_vocab)
axes[0].set_xlabel('Logit Value')
axes[0].set_title('Raw Logits (Model Output)')
axes[0].invert_yaxis()

# Plot probabilities
axes[1].barh(range(len(vocab)), sorted_probs, color=colors)
axes[1].set_yticks(range(len(vocab)))
axes[1].set_yticklabels(sorted_vocab)
axes[1].set_xlabel('Probability')
axes[1].set_title('After Softmax (Probabilities)')
axes[1].invert_yaxis()

plt.tight_layout()
plt.suptitle('Logits → Softmax → Probabilities', fontsize=14, y=1.02)
plt.show()

## 3. Temperature Scaling

### The Core Idea

**Temperature** is the simplest and most important sampling parameter. It controls the "sharpness" of the probability distribution by scaling the logits before softmax:

$$P(x_i) = \frac{e^{x_i / T}}{\sum_{j} e^{x_j / T}}$$

where $T$ is the temperature.

**Intuition:**
- **T = 1.0**: Standard softmax, no change
- **T → 0**: Distribution becomes sharper (more deterministic), approaching greedy decoding
- **T > 1**: Distribution becomes flatter (more random), more diverse outputs
- **T → ∞**: Uniform distribution (every token equally likely)

Think of it like this: **low temperature = confident and repetitive**, **high temperature = creative but risky**.

### Why Does This Work?

Dividing logits by T before softmax:
- When T < 1: magnifies the differences between logits → winner takes more
- When T > 1: shrinks the differences → more tokens get a fair chance

In [None]:
def softmax_with_temperature(logits, temperature=1.0):
    """Apply temperature scaling before softmax.
    
    Args:
        logits: Raw logit values (numpy array)
        temperature: Scaling factor. Must be > 0.
                     T < 1 = sharper, T > 1 = flatter
    Returns:
        Probability distribution (numpy array)
    """
    if temperature <= 0:
        raise ValueError("Temperature must be > 0. Use greedy decoding for T→0.")
    
    # Scale logits by temperature
    scaled_logits = logits / temperature
    
    # Apply softmax
    return softmax(scaled_logits)

# Demonstrate with different temperatures
temperatures = [0.1, 0.5, 1.0, 1.5, 3.0]

print(f"{'Token':>8s}", end=" | ")
for t in temperatures:
    print(f"T={t:<4.1f}", end=" | ")
print()
print("-" * 65)

for i, token in enumerate(vocab):
    print(f"{token:>8s}", end=" | ")
    for t in temperatures:
        p = softmax_with_temperature(logits, t)[i]
        print(f"{p:.4f}", end=" | ")
    print()

### Visualizing Temperature's Effect

Let's see how temperature reshapes the probability distribution. Watch how the distribution goes from nearly deterministic (T=0.1) to nearly uniform (T=3.0).

In [None]:
fig, axes = plt.subplots(1, 5, figsize=(20, 5), sharey=True)

temperatures = [0.1, 0.5, 1.0, 1.5, 3.0]
temp_colors = ['#d62728', '#ff7f0e', '#2ca02c', '#1f77b4', '#9467bd']

for ax, temp, color in zip(axes, temperatures, temp_colors):
    probs_t = softmax_with_temperature(logits, temp)
    sorted_idx = np.argsort(probs_t)[::-1]
    
    bars = ax.bar(range(len(vocab)), probs_t[sorted_idx], color=color, alpha=0.8)
    ax.set_xticks(range(len(vocab)))
    ax.set_xticklabels([vocab[i] for i in sorted_idx], rotation=45, ha='right', fontsize=8)
    ax.set_title(f'T = {temp}', fontsize=14, fontweight='bold')
    ax.set_ylim(0, 1.05)
    
    # Annotate top probability
    top_prob = probs_t[sorted_idx[0]]
    ax.annotate(f'{top_prob:.2f}', xy=(0, top_prob), fontsize=10,
                ha='center', va='bottom', fontweight='bold')
    
    # Show entropy
    entropy = -np.sum(probs_t * np.log(probs_t + 1e-10))
    ax.text(0.95, 0.95, f'H={entropy:.2f}', transform=ax.transAxes,
            ha='right', va='top', fontsize=10,
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

axes[0].set_ylabel('Probability', fontsize=12)
plt.suptitle('Effect of Temperature on Token Probability Distribution', fontsize=16, y=1.02)
plt.tight_layout()
plt.show()

### Entropy as a Measure of Randomness

The **H** value shown on each plot is the **Shannon entropy** of the distribution. Higher entropy means more randomness:
- Low entropy (T=0.1): Almost all probability on one token → predictable
- High entropy (T=3.0): Probability spread across many tokens → unpredictable

Let's plot entropy as a continuous function of temperature.

In [None]:
temp_range = np.linspace(0.01, 5.0, 200)
entropies = []
top1_probs = []
top3_probs = []

for t in temp_range:
    p = softmax_with_temperature(logits, t)
    entropy = -np.sum(p * np.log(p + 1e-10))
    entropies.append(entropy)
    sorted_p = np.sort(p)[::-1]
    top1_probs.append(sorted_p[0])
    top3_probs.append(sorted_p[:3].sum())

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Entropy vs Temperature
ax1.plot(temp_range, entropies, 'b-', linewidth=2)
max_entropy = np.log(len(vocab))
ax1.axhline(y=max_entropy, color='r', linestyle='--', alpha=0.5, label=f'Max entropy (uniform) = {max_entropy:.2f}')
ax1.axvline(x=1.0, color='gray', linestyle=':', alpha=0.5, label='T=1.0 (default)')
ax1.set_xlabel('Temperature')
ax1.set_ylabel('Shannon Entropy')
ax1.set_title('Entropy vs Temperature')
ax1.legend()

# Top-k cumulative probability vs Temperature
ax2.plot(temp_range, top1_probs, 'r-', linewidth=2, label='Top-1 token')
ax2.plot(temp_range, top3_probs, 'g-', linewidth=2, label='Top-3 tokens')
ax2.axvline(x=1.0, color='gray', linestyle=':', alpha=0.5, label='T=1.0 (default)')
ax2.set_xlabel('Temperature')
ax2.set_ylabel('Cumulative Probability')
ax2.set_title('Probability Concentration vs Temperature')
ax2.legend()

plt.tight_layout()
plt.show()

## 4. Top-k Sampling

### The Problem with Pure Temperature Sampling

Even with moderate temperature, the model might occasionally sample very unlikely tokens. If "banana" has a probability of 0.001, that means roughly 1 in 1000 generated tokens will be "banana" - which could lead to nonsensical text.

### The Top-k Solution

**Top-k sampling** restricts the candidate pool to the **k most probable tokens**, then renormalizes the probabilities among just those k tokens.

**Algorithm:**
1. Compute probabilities (with optional temperature)
2. Sort tokens by probability
3. Keep only the top k tokens
4. Set all other probabilities to 0
5. Renormalize so probabilities sum to 1
6. Sample from this filtered distribution

In [None]:
def top_k_filtering(logits, k, temperature=1.0):
    """Apply top-k filtering to logits.
    
    Args:
        logits: Raw logit values
        k: Number of top tokens to keep
        temperature: Temperature for softmax
    
    Returns:
        filtered_probs: Probability distribution with only top-k tokens
    """
    # Apply temperature
    scaled_logits = logits / temperature
    
    # Find the k-th largest logit value (threshold)
    if k < len(logits):
        # Get indices of top-k logits
        top_k_indices = np.argsort(scaled_logits)[-k:]
        
        # Create a mask: -inf for non-top-k tokens
        filtered_logits = np.full_like(scaled_logits, -np.inf)
        filtered_logits[top_k_indices] = scaled_logits[top_k_indices]
    else:
        filtered_logits = scaled_logits
    
    # Softmax over filtered logits (exp(-inf) = 0)
    probs = softmax(filtered_logits)
    return probs

# Demonstrate top-k filtering
print("Top-k Filtering Demonstration")
print("=" * 55)

for k in [1, 3, 5, 10]:
    probs_k = top_k_filtering(logits, k=k)
    nonzero = np.sum(probs_k > 0)
    print(f"\nk={k} ({nonzero} tokens with non-zero probability):")
    for token, prob in sorted(zip(vocab, probs_k), key=lambda x: -x[1]):
        if prob > 0:
            print(f"  {token:>8s}: {prob:.4f}")

In [None]:
# Visualize top-k filtering
fig, axes = plt.subplots(1, 4, figsize=(18, 5), sharey=True)

k_values = [1, 3, 5, 10]
k_colors = ['#d62728', '#ff7f0e', '#2ca02c', '#1f77b4']

for ax, k, color in zip(axes, k_values, k_colors):
    probs_k = top_k_filtering(logits, k=k)
    sorted_idx = np.argsort(probs_k)[::-1]
    
    bar_colors = [color if probs_k[i] > 0 else '#cccccc' for i in sorted_idx]
    ax.bar(range(len(vocab)), probs_k[sorted_idx], color=bar_colors, alpha=0.8)
    ax.set_xticks(range(len(vocab)))
    ax.set_xticklabels([vocab[i] for i in sorted_idx], rotation=45, ha='right', fontsize=8)
    ax.set_title(f'Top-k = {k}', fontsize=14, fontweight='bold')
    ax.set_ylim(0, 1.05)

axes[0].set_ylabel('Probability', fontsize=12)
plt.suptitle('Top-k Sampling: Keeping Only the k Most Likely Tokens', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

### The Limitation of Top-k

The problem with top-k is that **k is fixed**. Consider two scenarios:

1. **High-confidence prediction**: The model is 95% sure the next word is "Paris". With k=10, we're including 9 mostly-irrelevant tokens.

2. **Low-confidence prediction**: The model is uncertain, with 50 tokens each having ~2% probability. With k=10, we're missing 80% of the reasonable probability mass.

We need something **adaptive** - enter **top-p sampling**.

## 5. Top-p (Nucleus) Sampling

### The Key Insight

**Top-p sampling** (also called **nucleus sampling**, from the [Holtzman et al. 2019](https://arxiv.org/abs/1904.09751) paper) adapts the number of candidate tokens based on the model's confidence.

Instead of keeping a fixed number of tokens, we keep the **smallest set of tokens** whose cumulative probability exceeds a threshold p:

$$\text{nucleus} = \min \left\{ k : \sum_{i=1}^{k} P(x_{(i)}) \geq p \right\}$$

where $x_{(1)}, x_{(2)}, \ldots$ are tokens sorted by decreasing probability.

**Algorithm:**
1. Compute probabilities (with optional temperature)
2. Sort tokens by probability (descending)
3. Compute cumulative sum
4. Find where cumulative sum first exceeds p
5. Keep only tokens up to that point
6. Renormalize and sample

In [None]:
def top_p_filtering(logits, p, temperature=1.0):
    """Apply top-p (nucleus) filtering to logits.
    
    Args:
        logits: Raw logit values
        p: Cumulative probability threshold (0 < p <= 1)
        temperature: Temperature for initial softmax
    
    Returns:
        filtered_probs: Probability distribution with nucleus tokens only
    """
    # Apply temperature and get probabilities
    scaled_logits = logits / temperature
    probs = softmax(scaled_logits)
    
    # Sort probabilities in descending order
    sorted_indices = np.argsort(probs)[::-1]
    sorted_probs = probs[sorted_indices]
    
    # Compute cumulative sum
    cumulative_probs = np.cumsum(sorted_probs)
    
    # Find cutoff: first index where cumulative prob exceeds p
    # We include the token that pushes us over the threshold
    cutoff_idx = np.searchsorted(cumulative_probs, p) + 1
    cutoff_idx = min(cutoff_idx, len(probs))  # Safety bound
    
    # Keep only nucleus tokens
    nucleus_indices = sorted_indices[:cutoff_idx]
    
    # Create filtered distribution
    filtered_probs = np.zeros_like(probs)
    filtered_probs[nucleus_indices] = probs[nucleus_indices]
    
    # Renormalize
    filtered_probs = filtered_probs / filtered_probs.sum()
    
    return filtered_probs

# Demonstrate top-p filtering
print("Top-p Filtering Demonstration")
print("=" * 55)

for p in [0.5, 0.8, 0.9, 0.95, 1.0]:
    probs_p = top_p_filtering(logits, p=p)
    nonzero = np.sum(probs_p > 1e-8)
    print(f"\np={p} ({nonzero} tokens in nucleus):")
    for token, prob in sorted(zip(vocab, probs_p), key=lambda x: -x[1]):
        if prob > 1e-8:
            print(f"  {token:>8s}: {prob:.4f}")

In [None]:
# Visualize top-p filtering with cumulative probability
fig, axes = plt.subplots(2, 3, figsize=(18, 10))

p_values = [0.5, 0.8, 0.9, 0.95, 0.99, 1.0]

probs_base = softmax(logits)
sorted_idx = np.argsort(probs_base)[::-1]
sorted_probs = probs_base[sorted_idx]
cumulative = np.cumsum(sorted_probs)

for ax, p_val in zip(axes.flat, p_values):
    probs_p = top_p_filtering(logits, p=p_val)
    
    # Color bars: green if in nucleus, gray if not
    bar_colors = ['#2ca02c' if probs_p[i] > 1e-8 else '#cccccc' for i in sorted_idx]
    
    ax.bar(range(len(vocab)), probs_base[sorted_idx], color=bar_colors, alpha=0.7)
    
    # Overlay cumulative probability line
    ax2 = ax.twinx()
    ax2.plot(range(len(vocab)), cumulative, 'r--', linewidth=1.5, alpha=0.6)
    ax2.axhline(y=p_val, color='red', linestyle=':', alpha=0.4)
    ax2.set_ylim(0, 1.05)
    ax2.set_ylabel('Cumulative Prob', color='red', fontsize=8)
    
    n_tokens = np.sum(probs_p > 1e-8)
    ax.set_xticks(range(len(vocab)))
    ax.set_xticklabels([vocab[i] for i in sorted_idx], rotation=45, ha='right', fontsize=7)
    ax.set_title(f'p={p_val} ({n_tokens} tokens)', fontsize=12, fontweight='bold')
    ax.set_ylim(0, 1.0)

plt.suptitle('Top-p (Nucleus) Sampling: Adaptive Token Selection', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

## 6. Comparing All Sampling Strategies

Let's simulate many samples from each strategy and see how the token frequencies differ. This reveals the **diversity** of each method.

In [None]:
def sample_token(probs, n_samples=10000):
    """Sample tokens from a probability distribution."""
    indices = np.random.choice(len(probs), size=n_samples, p=probs)
    return indices

def greedy_decode(logits):
    """Greedy decoding: always pick the highest probability token."""
    probs = np.zeros_like(logits, dtype=float)
    probs[np.argmax(logits)] = 1.0
    return probs

n_samples = 50000

strategies = {
    'Greedy': greedy_decode(logits),
    'T=0.5': softmax_with_temperature(logits, 0.5),
    'T=1.0': softmax_with_temperature(logits, 1.0),
    'T=1.5': softmax_with_temperature(logits, 1.5),
    'Top-k=3': top_k_filtering(logits, k=3),
    'Top-p=0.9': top_p_filtering(logits, p=0.9),
    'T=0.7 + Top-p=0.9': top_p_filtering(logits, p=0.9, temperature=0.7),
}

fig, axes = plt.subplots(2, 4, figsize=(20, 9))

for idx, (name, probs) in enumerate(strategies.items()):
    if idx >= 7:
        break
    ax = axes[idx // 4][idx % 4]
    
    # Sample
    samples = sample_token(probs, n_samples)
    counts = Counter(samples)
    
    # Plot frequencies
    sorted_idx = np.argsort(probs)[::-1]
    freqs = [counts.get(i, 0) / n_samples for i in sorted_idx]
    expected = [probs[i] for i in sorted_idx]
    
    x = range(len(vocab))
    ax.bar(x, freqs, alpha=0.6, color='steelblue', label='Sampled freq')
    ax.bar(x, expected, alpha=0.4, color='orange', label='Expected prob')
    ax.set_xticks(x)
    ax.set_xticklabels([vocab[i] for i in sorted_idx], rotation=45, ha='right', fontsize=7)
    ax.set_title(name, fontsize=11, fontweight='bold')
    ax.set_ylim(0, 1.05)
    
    # Count unique tokens sampled
    n_unique = len(counts)
    ax.text(0.95, 0.95, f'{n_unique} unique', transform=ax.transAxes,
            ha='right', va='top', fontsize=9,
            bbox=dict(boxstyle='round', facecolor='lightyellow'))

# Remove empty subplot
axes[1][3].set_visible(False)
axes[0][0].legend(fontsize=8)

plt.suptitle(f'Sampling Strategy Comparison ({n_samples:,} samples each)', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

## 7. Combining Temperature with Top-p: The Industry Standard

In practice, most LLM APIs (OpenAI, Anthropic, Google) use **temperature + top-p together**. The typical flow is:

1. Compute logits from the model
2. Apply **temperature** scaling
3. Apply **top-p** filtering
4. Sample from the filtered distribution

Common default settings:
- **Factual/deterministic tasks**: T=0.0 (greedy) or T=0.3, top-p=0.9
- **Balanced generation**: T=0.7, top-p=0.9
- **Creative writing**: T=1.0, top-p=0.95
- **Brainstorming**: T=1.2, top-p=0.95

Let's implement the complete sampling pipeline.

In [None]:
def full_sampling_pipeline(logits, temperature=1.0, top_k=None, top_p=None):
    """Complete sampling pipeline combining temperature, top-k, and top-p.
    
    The order matters:
    1. Temperature scaling (applied to logits before softmax)
    2. Top-k filtering (if specified)
    3. Top-p filtering (if specified)
    4. Sample from the final distribution
    
    Args:
        logits: Raw logit values
        temperature: Temperature parameter (> 0)
        top_k: If set, keep only top k tokens
        top_p: If set, keep nucleus of tokens with cumulative prob >= p
    
    Returns:
        sampled_index: Index of the sampled token
        final_probs: The probability distribution used for sampling
    """
    # Step 1: Temperature scaling
    scaled_logits = logits / temperature
    
    # Step 2: Top-k filtering (on logits)
    if top_k is not None and top_k < len(scaled_logits):
        top_k_indices = np.argsort(scaled_logits)[-top_k:]
        mask = np.full_like(scaled_logits, -np.inf)
        mask[top_k_indices] = scaled_logits[top_k_indices]
        scaled_logits = mask
    
    # Convert to probabilities
    probs = softmax(scaled_logits)
    
    # Step 3: Top-p filtering (on probabilities)
    if top_p is not None and top_p < 1.0:
        sorted_indices = np.argsort(probs)[::-1]
        sorted_probs = probs[sorted_indices]
        cumulative = np.cumsum(sorted_probs)
        
        cutoff = np.searchsorted(cumulative, top_p) + 1
        nucleus_indices = sorted_indices[:cutoff]
        
        filtered = np.zeros_like(probs)
        filtered[nucleus_indices] = probs[nucleus_indices]
        probs = filtered / filtered.sum()
    
    # Step 4: Sample
    sampled_index = np.random.choice(len(probs), p=probs)
    
    return sampled_index, probs

# Run the pipeline multiple times and collect statistics
configs = [
    {'temperature': 0.001, 'top_k': None, 'top_p': None, 'name': 'Greedy (T≈0)'},
    {'temperature': 0.7, 'top_k': None, 'top_p': None, 'name': 'T=0.7'},
    {'temperature': 0.7, 'top_k': None, 'top_p': 0.9, 'name': 'T=0.7, p=0.9'},
    {'temperature': 1.0, 'top_k': 5, 'top_p': None, 'name': 'T=1.0, k=5'},
    {'temperature': 0.7, 'top_k': 5, 'top_p': 0.9, 'name': 'T=0.7, k=5, p=0.9'},
]

print("Sampling Pipeline Results (1000 samples each)")
print("=" * 70)

for config in configs:
    name = config.pop('name')
    counts = Counter()
    for _ in range(1000):
        idx, _ = full_sampling_pipeline(logits, **config)
        counts[vocab[idx]] += 1
    config['name'] = name  # Restore
    
    print(f"\n{name}:")
    for token, count in counts.most_common():
        bar = '█' * (count // 10)
        print(f"  {token:>8s}: {count:>4d}/1000 ({count/10:.1f}%) {bar}")

## 8. Real Model Generation with Different Sampling Strategies

Now let's use a real language model to see how sampling strategies affect actual text generation. We'll use GPT-2 (small) which runs fine on CPU.

> **Note**: This section loads a ~500MB model. It works on free Colab but takes a moment to download.

In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Load GPT-2 small
print("Loading GPT-2 model and tokenizer...")
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.eval()
print(f"Model loaded! Vocabulary size: {tokenizer.vocab_size:,}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
def generate_text(prompt, max_tokens=50, temperature=1.0, top_k=None, top_p=None):
    """Generate text using our custom sampling pipeline."""
    input_ids = tokenizer.encode(prompt, return_tensors='pt')
    generated = list(input_ids[0].numpy())
    
    with torch.no_grad():
        for _ in range(max_tokens):
            inputs = torch.tensor([generated])
            outputs = model(inputs)
            next_token_logits = outputs.logits[0, -1, :].numpy()
            
            # Use our sampling pipeline
            next_token, _ = full_sampling_pipeline(
                next_token_logits,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p
            )
            generated.append(next_token)
            
            # Stop at end of sentence (period followed by space or newline)
            decoded = tokenizer.decode([next_token])
            if next_token == tokenizer.eos_token_id:
                break
    
    return tokenizer.decode(generated)

# Test different strategies
prompt = "The future of artificial intelligence is"

print(f"Prompt: \"{prompt}\"")
print("=" * 80)

configs = [
    ("Greedy (T=0.001)", {'temperature': 0.001}),
    ("T=0.7", {'temperature': 0.7}),
    ("T=0.7, top-p=0.9", {'temperature': 0.7, 'top_p': 0.9}),
    ("T=1.0, top-k=50", {'temperature': 1.0, 'top_k': 50}),
    ("T=1.2, top-p=0.95", {'temperature': 1.2, 'top_p': 0.95}),
]

for name, kwargs in configs:
    text = generate_text(prompt, max_tokens=40, **kwargs)
    print(f"\n[{name}]")
    print(text)
    print("-" * 80)

## 9. Visualizing Token Probabilities During Generation

Let's look at how the model's probability distribution evolves as it generates text. This gives insight into *when* the model is confident vs uncertain.

In [None]:
def generate_with_tracking(prompt, max_tokens=15, temperature=0.7, top_p=0.9):
    """Generate text and track probability distributions at each step."""
    input_ids = tokenizer.encode(prompt, return_tensors='pt')
    generated = list(input_ids[0].numpy())
    
    steps = []  # Track info at each generation step
    
    with torch.no_grad():
        for step in range(max_tokens):
            inputs = torch.tensor([generated])
            outputs = model(inputs)
            next_token_logits = outputs.logits[0, -1, :].numpy()
            
            # Get full probability distribution
            full_probs = softmax_with_temperature(next_token_logits, temperature)
            
            # Sample using our pipeline
            next_token, filtered_probs = full_sampling_pipeline(
                next_token_logits, temperature=temperature, top_p=top_p
            )
            
            # Track top-5 tokens and their probabilities
            top5_idx = np.argsort(full_probs)[-5:][::-1]
            top5_tokens = [tokenizer.decode([i]).strip() for i in top5_idx]
            top5_probs = full_probs[top5_idx]
            
            steps.append({
                'chosen_token': tokenizer.decode([next_token]).strip(),
                'chosen_prob': full_probs[next_token],
                'top5_tokens': top5_tokens,
                'top5_probs': top5_probs,
                'entropy': -np.sum(full_probs * np.log(full_probs + 1e-10)),
                'nucleus_size': np.sum(filtered_probs > 1e-8),
            })
            
            generated.append(next_token)
    
    return tokenizer.decode(generated), steps

text, steps = generate_with_tracking("The best way to learn programming is", max_tokens=15)
print(f"Generated: {text}\n")

# Visualize
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 8), sharex=True)

x = range(len(steps))
chosen_tokens = [s['chosen_token'] for s in steps]
chosen_probs = [s['chosen_prob'] for s in steps]
entropies = [s['entropy'] for s in steps]
nucleus_sizes = [s['nucleus_size'] for s in steps]

# Plot 1: Chosen token probability + entropy
bars = ax1.bar(x, chosen_probs, color='steelblue', alpha=0.7, label='Chosen token prob')
ax1.set_ylabel('Probability', color='steelblue')
ax1_twin = ax1.twinx()
ax1_twin.plot(x, entropies, 'r-o', markersize=5, label='Entropy')
ax1_twin.set_ylabel('Entropy', color='red')
ax1.set_title('Token Probability and Entropy at Each Generation Step')
ax1.legend(loc='upper left')
ax1_twin.legend(loc='upper right')

# Plot 2: Nucleus size
ax2.bar(x, nucleus_sizes, color='green', alpha=0.7)
ax2.set_ylabel('Nucleus Size (# tokens)')
ax2.set_xlabel('Generation Step')
ax2.set_xticks(x)
ax2.set_xticklabels(chosen_tokens, rotation=45, ha='right')
ax2.set_title('Number of Tokens in Nucleus (top-p=0.9) at Each Step')

plt.tight_layout()
plt.show()

## 10. Repetition Penalty

One common issue with language model generation is **repetition** - the model gets stuck in loops. A simple fix is to apply a **repetition penalty** that reduces the logits of tokens that have already appeared.

The formula: for each previously generated token $t$:
- If $\text{logit}(t) > 0$: divide by penalty $\alpha$
- If $\text{logit}(t) < 0$: multiply by penalty $\alpha$

This makes already-seen tokens less likely.

In [None]:
def apply_repetition_penalty(logits, generated_tokens, penalty=1.2):
    """Apply repetition penalty to logits.
    
    Penalizes tokens that have already been generated,
    making them less likely to be repeated.
    
    Args:
        logits: Raw logit values
        generated_tokens: List of token indices already generated
        penalty: Penalty factor (1.0 = no penalty, >1.0 = penalize repeats)
    """
    penalized_logits = logits.copy()
    
    for token_id in set(generated_tokens):
        if penalized_logits[token_id] > 0:
            penalized_logits[token_id] /= penalty
        else:
            penalized_logits[token_id] *= penalty
    
    return penalized_logits

# Demonstrate the effect
demo_logits = np.array([5.0, 3.0, 2.0, 1.0, 0.5])
demo_tokens = ['hello', 'world', 'how', 'are', 'you']
already_generated = [0, 1]  # 'hello' and 'world' already appeared

original_probs = softmax(demo_logits)
penalized = apply_repetition_penalty(demo_logits, already_generated, penalty=1.5)
penalized_probs = softmax(penalized)

print("Effect of Repetition Penalty (alpha=1.5):")
print(f"{'Token':>8s} | {'Original':>10s} | {'Penalized':>10s} | {'Change':>10s}")
print("-" * 50)
for i, token in enumerate(demo_tokens):
    marker = " ← penalized" if i in already_generated else ""
    change = penalized_probs[i] - original_probs[i]
    print(f"{token:>8s} | {original_probs[i]:>10.4f} | {penalized_probs[i]:>10.4f} | {change:>+10.4f}{marker}")

## 11. Interactive Summary: The Complete Picture

Let's create a comprehensive comparison showing how all sampling parameters interact.

In [None]:
# Create a large vocabulary scenario (more realistic)
np.random.seed(42)
n_vocab = 1000
# Zipf-like distribution of logits (few high, many low)
large_logits = np.random.randn(n_vocab) * 2
large_logits[0] = 8.0   # One dominant token
large_logits[1:5] = np.array([5.0, 4.5, 4.0, 3.5])  # A few strong competitors
large_logits[5:20] = np.random.uniform(1.0, 3.0, 15)  # Moderate tokens

fig, axes = plt.subplots(2, 3, figsize=(18, 10))

scenarios = [
    ('Greedy (T→0)', {'temperature': 0.01}),
    ('T=0.7', {'temperature': 0.7}),
    ('T=1.0', {'temperature': 1.0}),
    ('T=0.7, top-k=10', {'temperature': 0.7, 'top_k': 10}),
    ('T=0.7, top-p=0.9', {'temperature': 0.7, 'top_p': 0.9}),
    ('T=0.7, k=50, p=0.9', {'temperature': 0.7, 'top_k': 50, 'top_p': 0.9}),
]

for ax, (name, kwargs) in zip(axes.flat, scenarios):
    temp = kwargs.get('temperature', 1.0)
    top_k = kwargs.get('top_k', None)
    top_p = kwargs.get('top_p', None)
    
    # Compute probabilities
    probs = softmax_with_temperature(large_logits, temp)
    
    if top_k:
        probs_filtered = top_k_filtering(large_logits, k=top_k, temperature=temp)
    elif top_p:
        probs_filtered = top_p_filtering(large_logits, p=top_p, temperature=temp)
    else:
        probs_filtered = probs
    
    if top_k and top_p:
        # Apply both: first top-k on logits, then top-p on resulting probs
        _, probs_filtered = full_sampling_pipeline(large_logits, temperature=temp, top_k=top_k, top_p=top_p)
        # Re-run to just get the probs
        scaled = large_logits / temp
        top_k_idx = np.argsort(scaled)[-top_k:]
        mask = np.full_like(scaled, -np.inf)
        mask[top_k_idx] = scaled[top_k_idx]
        p_temp = softmax(mask)
        sorted_i = np.argsort(p_temp)[::-1]
        cum = np.cumsum(p_temp[sorted_i])
        cutoff = np.searchsorted(cum, top_p) + 1
        nuc = sorted_i[:cutoff]
        probs_filtered = np.zeros_like(p_temp)
        probs_filtered[nuc] = p_temp[nuc]
        probs_filtered /= probs_filtered.sum()
    
    # Sort and plot top 30
    sorted_idx = np.argsort(probs_filtered)[::-1][:30]
    
    colors = ['steelblue' if probs_filtered[i] > 1e-8 else '#cccccc' for i in sorted_idx]
    ax.bar(range(30), probs_filtered[sorted_idx], color=colors, alpha=0.8)
    ax.set_title(name, fontsize=12, fontweight='bold')
    ax.set_xlabel('Token rank')
    
    n_active = np.sum(probs_filtered > 1e-8)
    entropy = -np.sum(probs_filtered[probs_filtered > 0] * np.log(probs_filtered[probs_filtered > 0]))
    ax.text(0.95, 0.95, f'{n_active} active\nH={entropy:.2f}',
            transform=ax.transAxes, ha='right', va='top', fontsize=9,
            bbox=dict(boxstyle='round', facecolor='lightyellow'))

axes[0][0].set_ylabel('Probability')
axes[1][0].set_ylabel('Probability')
plt.suptitle('Sampling Strategies on a 1000-Token Vocabulary (showing top 30)', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

## 12. Key Takeaways

| Strategy | What It Does | Pros | Cons |
|----------|-------------|------|------|
| **Greedy** | Always picks highest probability | Deterministic, fast | Repetitive, boring |
| **Temperature < 1** | Sharpens distribution | More focused | Can still sample bad tokens |
| **Temperature > 1** | Flattens distribution | More diverse/creative | Can produce nonsense |
| **Top-k** | Keeps only top k tokens | Eliminates bad tokens | Fixed k is inflexible |
| **Top-p** | Keeps tokens summing to p | Adapts to model confidence | Slightly more compute |
| **Combined** | Temperature + top-p (+ optional top-k) | Best of all worlds | More hyperparameters |

### Practical Guidelines

1. **Start with T=0.7, top-p=0.9** as a sensible default
2. **Lower temperature** (0.1-0.5) for factual, deterministic tasks
3. **Higher temperature** (0.8-1.2) for creative tasks
4. **Top-p** is generally preferred over top-k because it adapts
5. **Never use T > 2.0** in practice - outputs become random noise
6. **Repetition penalty** (~1.1-1.3) helps prevent loops

---

## Exercises

### Exercise 1: Implement Min-p Sampling
**Min-p** is a newer sampling strategy where you keep all tokens with probability >= min_p * max_probability. Implement it and compare with top-p.

In [None]:
def min_p_filtering(logits, min_p=0.1, temperature=1.0):
    """Implement min-p sampling.
    
    Keep all tokens with probability >= min_p * max_probability.
    
    Args:
        logits: Raw logit values
        min_p: Minimum probability ratio threshold
        temperature: Temperature for softmax
    
    Returns:
        Filtered and renormalized probability distribution
    """
    # TODO: Implement min-p filtering
    # 1. Apply temperature and compute probabilities
    # 2. Find the maximum probability
    # 3. Compute threshold = min_p * max_prob
    # 4. Keep only tokens with prob >= threshold
    # 5. Renormalize
    pass

# Test your implementation:
# result = min_p_filtering(logits, min_p=0.1)
# print(result)

### Exercise 2: Temperature Sweep
Generate the same prompt 5 times each at temperatures [0.3, 0.5, 0.7, 1.0, 1.5] using the GPT-2 model. Rate the outputs on coherence and creativity.

In [None]:
# TODO: Generate text at different temperatures and compare
# prompt = "Once upon a time in a land far away,"
# temperatures = [0.3, 0.5, 0.7, 1.0, 1.5]
# for each temperature, generate 5 completions and observe patterns

### Exercise 3: Adaptive Temperature
Implement a strategy where temperature adapts based on the model's confidence (entropy of the distribution). When the model is confident, use lower temperature; when uncertain, use higher temperature.

In [None]:
def adaptive_temperature(logits, base_temp=1.0, min_temp=0.3, max_temp=1.5):
    """Compute an adaptive temperature based on logit entropy.
    
    When the model is confident (low entropy), use lower temperature.
    When uncertain (high entropy), use higher temperature.
    
    TODO: Implement this function
    Hint: 
    1. Compute base probabilities with temp=1.0
    2. Compute entropy
    3. Map entropy to a temperature in [min_temp, max_temp]
    """
    pass

---

**Next up: Notebook 06 - KV Cache Mechanics** where we'll explore how language models speed up generation by caching key-value pairs from attention computations.