In [None]:
#@title üéß Download Narration Audio & Play Introduction
import os as _os
if not _os.path.exists("/content/narration"):
    !pip install -q gdown
    import gdown
    gdown.download(id="1RJjttCvltRK-j5XaI_Tp752cibGKRYMf", output="/content/narration.zip", quiet=False)
    !unzip -q /content/narration.zip -d /content/narration
    !rm /content/narration.zip
    print(f"Loaded {len(_os.listdir('/content/narration'))} narration segments")
else:
    print("Narration audio already loaded.")

from IPython.display import Audio, display
display(Audio("/content/narration/02_00_intro.mp3"))


In [None]:
#@title üéß Code Walkthrough: Setup Check
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_01_setup_check.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


In [None]:
# üîß Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"‚úÖ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("‚ö†Ô∏è No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

print(f"\nüì¶ Python {sys.version.split()[0]}")
print(f"üî• PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"üé≤ Random seed set to {SEED}")

%matplotlib inline

In [None]:
#@title üéß Listen: Why This Matters Concept
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_02_why_this_matters_concept.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


# Sampling Strategies for Text Generation

*Part 2 of the Vizuara series on Inference & Scaling*
*Estimated time: 50 minutes*

## 1. Why Does This Matter?

You have a language model that can generate the next token efficiently (thanks to the KV cache). But the model does not output a single token -- it outputs a **probability distribution** over the entire vocabulary. Fifty thousand numbers, each telling you how likely that token is to come next.

How do you pick one?

This is not a trivial question. Pick wrong, and your model either produces boring, repetitive text (always choosing the most likely token) or incoherent gibberish (sampling too randomly). The difference between a useful chatbot and a useless one often comes down to the sampling strategy.

In this notebook, we will:
1. Implement **greedy decoding** and see why it fails
2. Build **temperature scaling** and understand exactly how it reshapes distributions
3. Implement **top-k sampling** and see its limitations
4. Implement **top-p (nucleus) sampling** and understand why it dominates in practice
5. Compare all strategies head-to-head on real generated text

In [None]:
#@title üéß Code Walkthrough: Setup Imports
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_03_setup_imports.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


In [None]:
# Setup -- run this cell first
!pip install -q torch matplotlib numpy

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
#@title üéß Listen: Building Intuition
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_04_building_intuition.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


## 2. Building Intuition

Imagine you are completing this sentence: "The cat sat on the ___"

Your brain considers many possibilities: "mat", "couch", "floor", "table", "roof", "windowsill"... Some are much more likely than others. "Mat" and "couch" feel natural. "Roof" is unusual but creative. "Refrigerator" is possible but weird. "Quantum" makes no sense at all.

A language model faces this exact choice at every single token. The model's probability distribution might look like:

| Token | Probability |
|-------|------------|
| mat | 0.25 |
| couch | 0.18 |
| floor | 0.15 |
| table | 0.12 |
| bed | 0.08 |
| roof | 0.04 |
| ... | ... |
| quantum | 0.000001 |

**Greedy decoding** always picks "mat." Always. Every single time. This makes the model predictable and repetitive.

**Pure random sampling** would occasionally pick "quantum" -- which is almost certainly wrong.

The art of sampling is finding the sweet spot: enough randomness for creativity, enough constraint for coherence.

In [None]:
#@title üéß Listen: The Mathematics
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_05_the_mathematics.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


## 3. The Mathematics

### Temperature Scaling

Given raw logits $z_v$ for each token $v$ in the vocabulary, temperature scaling modifies the softmax:

$$P(v) = \frac{\exp(z_v / \tau)}{\sum_{v' \in V} \exp(z_{v'} / \tau)}$$

where $\tau$ is the temperature parameter.

**Why does this work?** Dividing by $\tau < 1$ amplifies the differences between logits (making the distribution sharper). Dividing by $\tau > 1$ dampens the differences (making it flatter).

Mathematically, as $\tau \to 0$:
$$P(v) \to \begin{cases} 1 & \text{if } v = \arg\max(z) \\ 0 & \text{otherwise} \end{cases}$$

This is greedy decoding. As $\tau \to \infty$, all probabilities become equal: $P(v) \to 1/|V|$. This is uniform random sampling.

### Top-k Sampling

Let $V_k$ be the set of $k$ tokens with the highest probabilities. Top-k sampling sets:

$$P_{\text{top-k}}(v) = \begin{cases} \frac{P(v)}{\sum_{v' \in V_k} P(v')} & \text{if } v \in V_k \\ 0 & \text{otherwise} \end{cases}$$

### Top-p (Nucleus) Sampling

Let $V_p$ be the smallest set of tokens such that $\sum_{v \in V_p} P(v) \geq p$. Top-p sampling sets:

$$P_{\text{top-p}}(v) = \begin{cases} \frac{P(v)}{\sum_{v' \in V_p} P(v')} & \text{if } v \in V_p \\ 0 & \text{otherwise} \end{cases}$$

The key difference: $|V_k| = k$ is always the same, but $|V_p|$ varies. When the model is confident, $|V_p|$ is small. When uncertain, $|V_p|$ is large. This adaptivity is why top-p is preferred in practice.

In [None]:
#@title üéß Transition: Lets Build It
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_06_lets_build_it.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


## 4. Let's Build It -- Component by Component

### 4.1 Implementing Each Sampling Strategy

In [None]:
#@title üéß Code Walkthrough: Implementing Strategies
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_07_implementing_strategies.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


In [None]:
def greedy_decode(logits):
    """Always pick the most probable token."""
    return torch.argmax(logits, dim=-1, keepdim=True)


def temperature_sample(logits, temperature=1.0):
    """Sample with temperature scaling."""
    scaled_logits = logits / temperature
    probs = F.softmax(scaled_logits, dim=-1)
    return torch.multinomial(probs, num_samples=1)


def top_k_sample(logits, k=10, temperature=1.0):
    """Sample from the top-k most probable tokens."""
    scaled_logits = logits / temperature

    # Find the k-th largest value
    top_k_values, _ = torch.topk(scaled_logits, k, dim=-1)
    min_top_k = top_k_values[:, -1:]  # The k-th largest value

    # Mask everything below the threshold
    filtered_logits = scaled_logits.clone()
    filtered_logits[filtered_logits < min_top_k] = float('-inf')

    probs = F.softmax(filtered_logits, dim=-1)
    return torch.multinomial(probs, num_samples=1)


def top_p_sample(logits, p=0.9, temperature=1.0):
    """Sample from the nucleus (smallest set of tokens with cumulative prob >= p)."""
    scaled_logits = logits / temperature

    # Sort in descending order
    sorted_logits, sorted_indices = torch.sort(scaled_logits, descending=True, dim=-1)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    # Remove tokens with cumulative probability above the threshold
    # Shift right so we keep the first token that crosses the threshold
    sorted_mask = cumulative_probs - F.softmax(sorted_logits, dim=-1) >= p
    sorted_logits[sorted_mask] = float('-inf')

    # Sample from filtered distribution
    probs = F.softmax(sorted_logits, dim=-1)
    sampled_index = torch.multinomial(probs, num_samples=1)

    # Map back to original indices
    original_index = sorted_indices.gather(-1, sampled_index)
    return original_index


# Quick test with synthetic logits
test_logits = torch.tensor([[2.0, 1.5, 1.0, 0.5, 0.1, -0.5, -1.0, -2.0, -3.0, -5.0]])

print("Greedy:", greedy_decode(test_logits).item())
print("Temperature=0.5:", temperature_sample(test_logits, temperature=0.5).item())
print("Temperature=1.0:", temperature_sample(test_logits, temperature=1.0).item())
print("Temperature=2.0:", temperature_sample(test_logits, temperature=2.0).item())
print("Top-k (k=3):", top_k_sample(test_logits, k=3).item())
print("Top-p (p=0.9):", top_p_sample(test_logits, p=0.9).item())

In [None]:
#@title üéß What to Look For: Viz Temp Intro
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_08_viz_temp_intro.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


### Visualization Checkpoint 1: How Temperature Reshapes the Distribution

In [None]:
#@title üéß What to Look For: Viz Temp Explain
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_09_viz_temp_explain.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


In [None]:
# Visualize temperature effects on a real distribution
logits = torch.tensor([3.0, 2.5, 2.0, 1.5, 1.0, 0.5, 0.0, -0.5, -1.0, -2.0])
tokens = [f"tok_{i}" for i in range(len(logits))]
temperatures = [0.3, 0.5, 1.0, 2.0, 5.0]

fig, axes = plt.subplots(1, len(temperatures), figsize=(20, 4), sharey=True)

for ax, temp in zip(axes, temperatures):
    probs = F.softmax(logits / temp, dim=-1).numpy()
    colors = plt.cm.RdYlBu_r(np.linspace(0.2, 0.8, len(probs)))
    ax.bar(range(len(probs)), probs, color=colors, edgecolor='white', linewidth=0.5)
    ax.set_title(f'T = {temp}', fontsize=12, fontweight='bold')
    ax.set_xlabel('Token index')
    if ax == axes[0]:
        ax.set_ylabel('Probability')
    ax.set_ylim(0, 1.0)
    # Show entropy
    entropy = -np.sum(probs * np.log(probs + 1e-10))
    ax.text(0.95, 0.95, f'H={entropy:.2f}', transform=ax.transAxes,
            ha='right', va='top', fontsize=9, style='italic')

plt.suptitle('Temperature Scaling: From Greedy (low T) to Uniform (high T)', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Topk Vs Topp Intro
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_10_topk_vs_topp_intro.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


### 4.2 Top-k vs Top-p: The Adaptivity Advantage

In [None]:
#@title üéß What to Look For: Topk Vs Topp Explain
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_11_topk_vs_topp_explain.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


In [None]:
# Demonstrate the key difference between top-k and top-p

# Case 1: Model is CONFIDENT (one token dominates)
confident_logits = torch.tensor([[5.0, 1.0, 0.5, 0.1, -0.5, -1.0, -2.0, -3.0, -4.0, -5.0]])
confident_probs = F.softmax(confident_logits, dim=-1).squeeze()

# Case 2: Model is UNCERTAIN (many tokens are likely)
uncertain_logits = torch.tensor([[1.0, 0.9, 0.85, 0.8, 0.7, 0.6, 0.5, 0.3, 0.1, -0.1]])
uncertain_probs = F.softmax(uncertain_logits, dim=-1).squeeze()

fig, axes = plt.subplots(2, 3, figsize=(16, 9))

for row, (probs, title) in enumerate([(confident_probs, 'Confident'), (uncertain_probs, 'Uncertain')]):
    # Original distribution
    axes[row, 0].bar(range(10), probs.numpy(), color='steelblue', alpha=0.8)
    axes[row, 0].set_title(f'{title}: Original Distribution')
    axes[row, 0].set_ylim(0, max(probs.numpy()) * 1.2)

    # Top-k (k=3) -- always selects exactly 3
    top_k_mask = torch.zeros_like(probs, dtype=torch.bool)
    _, top_k_idx = torch.topk(probs, 3)
    top_k_mask[top_k_idx] = True
    top_k_probs = probs.clone()
    top_k_probs[~top_k_mask] = 0
    top_k_probs = top_k_probs / top_k_probs.sum()

    colors_k = ['#e74c3c' if m else '#ddd' for m in top_k_mask]
    axes[row, 1].bar(range(10), top_k_probs.numpy(), color=colors_k, edgecolor='white')
    axes[row, 1].set_title(f'{title}: Top-k (k=3) -- {top_k_mask.sum().item()} tokens')
    axes[row, 1].set_ylim(0, max(top_k_probs.numpy()) * 1.2)

    # Top-p (p=0.9) -- adapts!
    sorted_probs, sorted_idx = torch.sort(probs, descending=True)
    cumsum = torch.cumsum(sorted_probs, dim=0)
    # Find cutoff: first index where cumsum >= 0.9
    n_included = (cumsum < 0.9).sum().item() + 1
    top_p_mask = torch.zeros_like(probs, dtype=torch.bool)
    top_p_mask[sorted_idx[:n_included]] = True
    top_p_probs = probs.clone()
    top_p_probs[~top_p_mask] = 0
    top_p_probs = top_p_probs / top_p_probs.sum()

    colors_p = ['#2ecc71' if m else '#ddd' for m in top_p_mask]
    axes[row, 2].bar(range(10), top_p_probs.numpy(), color=colors_p, edgecolor='white')
    axes[row, 2].set_title(f'{title}: Top-p (p=0.9) -- {top_p_mask.sum().item()} tokens')
    axes[row, 2].set_ylim(0, max(top_p_probs.numpy()) * 1.2)

for ax in axes.flat:
    ax.set_xlabel('Token index')
    ax.set_ylabel('Probability')

plt.suptitle('Top-k uses fixed count; Top-p adapts to model confidence', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

print(f"Confident model: top-k selects 3 tokens, top-p selects {(cumsum < 0.9).sum().item() + 1} tokens")

In [None]:
#@title üéß Before You Start: Todo1 Intro
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_12_todo1_intro.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


## 5. Your Turn

### TODO 1: Implement Combined Sampling

In practice, top-p and temperature are used **together**. Implement a function that applies temperature first, then top-p filtering, then samples.

In [None]:
#@title üéß Before You Start: Todo1 Impl
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_13_todo1_impl.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


In [None]:
# TODO: Implement combined temperature + top-p sampling
def sample_combined(logits, temperature=0.8, top_p=0.9):
    """
    Apply temperature scaling, then top-p filtering, then sample.

    Steps:
    1. Divide logits by temperature
    2. Sort by probability (descending)
    3. Compute cumulative probabilities
    4. Mask tokens beyond the top-p threshold
    5. Sample from the filtered distribution

    Args:
        logits: (batch_size, vocab_size) raw logits
        temperature: temperature parameter (lower = more deterministic)
        top_p: nucleus probability threshold
    Returns:
        next_token: (batch_size, 1) sampled token indices
    """
    # YOUR CODE HERE
    pass


# Test your implementation
# test_logits = torch.tensor([[3.0, 2.0, 1.5, 1.0, 0.5, 0.0, -1.0, -2.0]])
# for _ in range(10):
#     token = sample_combined(test_logits, temperature=0.8, top_p=0.9)
#     print(f"Sampled token: {token.item()}")

In [None]:
#@title üéß Before You Start: Todo2 Intro
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_14_todo2_intro.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


### TODO 2: Measure Diversity vs. Quality

Sample 50 completions from the same prompt using different strategies. Measure the **diversity** (number of unique completions) and visualize the distribution of generated tokens.

In [None]:
#@title üéß Before You Start: Todo2 Impl
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_15_todo2_impl.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


In [None]:
# TODO: Generate multiple completions with different strategies
# and compare their diversity.
#
# Strategies to compare:
# 1. Greedy (baseline -- 0 diversity)
# 2. Temperature = 0.3 (low diversity)
# 3. Temperature = 1.0 (medium diversity)
# 4. Top-p = 0.9, Temperature = 0.8 (production setting)
#
# For each strategy:
# - Generate 50 single-token samples from the same logits
# - Count the number of unique tokens sampled
# - Plot a histogram of which tokens were selected
#
# YOUR CODE HERE
# test_logits = torch.tensor([[3.0, 2.5, 2.0, 1.5, 1.0, 0.5, 0.0, -0.5, -1.0, -2.0]])
# strategies = {
#     "Greedy": lambda l: greedy_decode(l),
#     "Temp=0.3": lambda l: temperature_sample(l, 0.3),
#     "Temp=1.0": lambda l: temperature_sample(l, 1.0),
#     "Top-p=0.9": lambda l: top_p_sample(l, 0.9, 0.8),
# }
#
# n_samples = 50
# fig, axes = plt.subplots(1, 4, figsize=(20, 4), sharey=True)
# ...

In [None]:
#@title üéß What to Look For: Viz Repetition Intro
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_16_viz_repetition_intro.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


### Visualization Checkpoint 2: Repetition Penalty Effect

In [None]:
#@title üéß What to Look For: Viz Repetition Explain
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_17_viz_repetition_explain.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


In [None]:
# Show how different strategies affect repetition
# We simulate a sequence of token probabilities and show how
# greedy always picks the same token while others explore

logits_sequence = [
    torch.tensor([[3.0, 2.5, 2.0, 1.0, 0.5]]),
    torch.tensor([[2.8, 2.6, 2.1, 1.2, 0.3]]),
    torch.tensor([[3.2, 2.4, 1.9, 1.1, 0.4]]),
    torch.tensor([[2.9, 2.7, 2.0, 0.9, 0.5]]),
    torch.tensor([[3.1, 2.3, 2.2, 1.3, 0.1]]),
] * 4  # Repeat to get 20 steps

strategies = {
    "Greedy": lambda l: greedy_decode(l).item(),
    "Temp=0.5": lambda l: temperature_sample(l, 0.5).item(),
    "Temp=1.0": lambda l: temperature_sample(l, 1.0).item(),
    "Top-p=0.9": lambda l: top_p_sample(l, 0.9, 0.8).item(),
}

fig, axes = plt.subplots(1, 4, figsize=(18, 4), sharey=True)

for ax, (name, strategy) in zip(axes, strategies.items()):
    tokens = [strategy(l) for l in logits_sequence]
    ax.plot(tokens, 'o-', markersize=6, linewidth=1.5)
    ax.set_title(f'{name}\n{len(set(tokens))} unique out of {len(tokens)}')
    ax.set_xlabel('Step')
    ax.set_ylabel('Token ID')
    ax.set_ylim(-0.5, 4.5)
    ax.set_yticks(range(5))

plt.suptitle('Greedy is Repetitive; Sampling Explores', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Transition: Putting It Together Intro
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_18_putting_it_together_intro.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


## 6. Putting It All Together

Let us train a tiny character-level model and see how each sampling strategy produces different text.

In [None]:
#@title üéß Code Walkthrough: Model Training
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_19_model_training.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


In [None]:
# Build a small model (reuse the architecture from Notebook 1 or define inline)
class TinyTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=64, n_heads=4, n_layers=2, max_len=256):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=4*d_model,
            dropout=0.1, activation='gelu', batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.head = nn.Linear(d_model, vocab_size)
        self.d_model = d_model

    def forward(self, x):
        B, T = x.shape
        pos = torch.arange(T, device=x.device)
        mask = nn.Transformer.generate_square_subsequent_mask(T, device=x.device)
        h = self.tok_emb(x) + self.pos_emb(pos)
        h = self.transformer(h, mask=mask, is_causal=True)
        return self.head(h)


# Prepare a small text dataset
text = ("To be or not to be that is the question. "
        "Whether tis nobler in the mind to suffer the slings and arrows "
        "of outrageous fortune or to take arms against a sea of troubles "
        "and by opposing end them. To die to sleep no more and by a sleep "
        "to say we end the heartache. ") * 10

chars = sorted(set(text))
c2i = {c: i for i, c in enumerate(chars)}
i2c = {i: c for c, i in c2i.items()}
data = torch.tensor([c2i[c] for c in text], device=device)

# Train
model = TinyTransformer(len(chars)).to(device)
opt = torch.optim.Adam(model.parameters(), lr=3e-3)
seq_len = 48

model.train()
for epoch in range(150):
    total_loss, n = 0, 0
    for i in range(0, len(data) - seq_len - 1, seq_len):
        x = data[i:i+seq_len].unsqueeze(0)
        y = data[i+1:i+seq_len+1].unsqueeze(0)
        logits = model(x)
        loss = F.cross_entropy(logits.view(-1, len(chars)), y.view(-1))
        opt.zero_grad()
        loss.backward()
        opt.step()
        total_loss += loss.item()
        n += 1
    if (epoch + 1) % 50 == 0:
        print(f"Epoch {epoch+1}: loss = {total_loss/n:.4f}")

print("Training complete!")

In [None]:
#@title üéß Transition: Training Results Intro
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_20_training_results_intro.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


## 7. Training and Results

In [None]:
#@title üéß Code Walkthrough: Generate Text
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_21_generate_text.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


In [None]:
# Generate text with different sampling strategies
model.eval()

prompt_text = "To be or "
prompt_ids = torch.tensor([[c2i[c] for c in prompt_text]], device=device)

def generate_text(model, prompt, n_tokens, sample_fn):
    """Generate text using a given sampling function."""
    tokens = prompt.clone()
    for _ in range(n_tokens):
        logits = model(tokens)
        next_logits = logits[:, -1, :]
        next_token = sample_fn(next_logits)
        tokens = torch.cat([tokens, next_token], dim=1)
    return ''.join([i2c[t.item()] for t in tokens[0]])

print("=" * 60)
strategies = [
    ("Greedy", lambda l: greedy_decode(l)),
    ("Temp=0.3", lambda l: temperature_sample(l, 0.3)),
    ("Temp=0.8", lambda l: temperature_sample(l, 0.8)),
    ("Temp=1.5", lambda l: temperature_sample(l, 1.5)),
    ("Top-k=5", lambda l: top_k_sample(l, k=5, temperature=0.8)),
    ("Top-p=0.9", lambda l: top_p_sample(l, p=0.9, temperature=0.8)),
]

for name, fn in strategies:
    text_out = generate_text(model, prompt_ids, 80, fn)
    print(f"\n{name}:")
    print(f"  {text_out}")

print("\n" + "=" * 60)

In [None]:
#@title üéß What to Look For: Final Output Intro
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_22_final_output_intro.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


## 8. Final Output

In [None]:
#@title üéß What to Look For: Final Output Viz
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_23_final_output_viz.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


In [None]:
# Summary visualization: token diversity across strategies
model.eval()

n_runs = 30
gen_len = 40

diversity_data = {}
for name, fn in strategies:
    all_tokens = []
    for _ in range(n_runs):
        tokens = prompt_ids.clone()
        for _ in range(gen_len):
            logits = model(tokens)
            next_logits = logits[:, -1, :]
            next_token = fn(next_logits)
            tokens = torch.cat([tokens, next_token], dim=1)
        generated = tokens[0, prompt_ids.shape[1]:].tolist()
        all_tokens.append(tuple(generated))

    unique_sequences = len(set(all_tokens))
    avg_unique_tokens = np.mean([len(set(seq)) for seq in all_tokens])
    diversity_data[name] = {
        'unique_seqs': unique_sequences,
        'avg_unique_tokens': avg_unique_tokens,
    }

# Plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

names = list(diversity_data.keys())
unique_seqs = [diversity_data[n]['unique_seqs'] for n in names]
unique_toks = [diversity_data[n]['avg_unique_tokens'] for n in names]

ax1.barh(names, unique_seqs, color='steelblue', alpha=0.8)
ax1.set_xlabel(f'Unique Sequences (out of {n_runs} runs)')
ax1.set_title('Sequence Diversity')

ax2.barh(names, unique_toks, color='coral', alpha=0.8)
ax2.set_xlabel(f'Avg Unique Tokens per Sequence (out of {gen_len})')
ax2.set_title('Token Diversity')

plt.suptitle('Sampling Strategy Comparison', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Wrap-Up: Reflection And Next Steps
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_24_reflection_and_next_steps.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


## 9. Reflection and Next Steps

### What We Learned

1. **Greedy decoding** is deterministic and repetitive. It always picks the most probable token, leading to degenerate loops in longer sequences.

2. **Temperature** controls the sharpness of the distribution. Low temperature approaches greedy decoding; high temperature approaches uniform random sampling. It is a single knob for the diversity-quality tradeoff.

3. **Top-k sampling** restricts sampling to a fixed number of tokens. This works but is inflexible -- the optimal $k$ depends on how confident the model is at each step.

4. **Top-p (nucleus) sampling** adapts the candidate set dynamically based on the model's confidence. When the model is sure, few tokens are considered. When uncertain, more options are available. This adaptivity makes it the standard choice in production.

5. In practice, **temperature + top-p** are used together: temperature controls overall randomness, and top-p prevents sampling from the dangerous tail.

### Production Defaults

| System | Temperature | Top-p | Notes |
|--------|------------|-------|-------|
| OpenAI API | 1.0 | 1.0 | Adjustable per request |
| Claude | ~0.7 | ~0.9 | Tuned internally |
| LLaMA | 0.6 | 0.9 | Recommended defaults |

### What is Next

In the next notebook, we will combine the KV cache (Notebook 1) with sampling strategies (this notebook) into a **complete inference engine** and benchmark it end-to-end.