In [None]:
#@title üéß Download Narration Audio & Play Introduction
import os as _os
if not _os.path.exists("/content/narration"):
    !pip install -q gdown
    import gdown
    gdown.download(id="1rhbO-3WtvU0YjQYQrCebpE_IATSZyRCX", output="/content/narration.zip", quiet=False)
    !unzip -q /content/narration.zip -d /content/narration
    !rm /content/narration.zip
    print(f"Loaded {len(_os.listdir('/content/narration'))} narration segments")
else:
    print("Narration audio already loaded.")

from IPython.display import Audio, display
display(Audio("/content/narration/00_intro.mp3"))

In [None]:
# üîß Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"‚úÖ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("‚ö†Ô∏è No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

print(f"\nüì¶ Python {sys.version.split()[0]}")
print(f"üî• PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"üé≤ Random seed set to {SEED}")

%matplotlib inline

In [None]:
#@title üéß Listen: Motivation
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/01_motivation.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

# Sampling Strategies & Speed Tradeoffs in Diffusion LLMs

*Part 3 of the Vizuara series on Diffusion Language Models*
*Estimated time: 35 minutes*

# ü§ñ AI Teaching Assistant

Need help with this notebook? Open the **AI Teaching Assistant** ‚Äî it has already read this entire notebook and can help with concepts, code, and exercises.

**[üëâ Open AI Teaching Assistant](https://pods.vizuara.ai/courses/diffusion-llms/practice/3/assistant)**

*Tip: Open it in a separate tab and work through this notebook side-by-side.*


## 1. Why Does This Matter?

Autoregressive models have a hard speed limit: to generate $L$ tokens, you need $L$ sequential forward passes. Period.

Diffusion models break this limit. You choose the number of denoising steps $S$, and each step processes *all* tokens in parallel. But this raises critical questions:

- **How many steps do we actually need?** Is 5 enough? Do we need 50?
- **Does the unmasking order matter?** Should we unmask the most confident tokens first, or use random order?
- **How much faster is diffusion?** Can we quantify the speedup?

In this notebook, we will train a diffusion model and then systematically explore these tradeoffs. The punchline: you can get 80% of the quality with 20% of the steps.

**Teaser ‚Äî what you will build:**

A quality-vs-speed curve showing how generation quality degrades gracefully as we reduce the number of steps, and a side-by-side comparison of different remasking strategies.

In [None]:
# üîß Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import math
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

torch.manual_seed(42)
np.random.seed(42)

%matplotlib inline

In [None]:
#@title üéß Listen: Intuition
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_intuition.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 2. Building Intuition

### The Speed-Quality Tradeoff

Think of painting a wall. If you make one careful pass with the roller, you get good coverage. Two passes give you even coverage. By the third pass, the wall looks great.

But what if you only have time for **half a pass**? You would focus on the most visible areas first ‚Äî the spots at eye level, near the door. You would skip the corners and behind the furniture.

This is exactly how confidence-based remasking works. With limited steps, the model focuses on the "easy" tokens first (function words, common patterns) and resolves the harder tokens (content words, rare combinations) in later steps. If you give it fewer steps, it still gets the easy tokens right but may stumble on the hard ones.

### ü§î Think About This

With $S = 1$ step (single-shot generation), the diffusion model predicts all tokens independently and commits to all of them at once. With $S = \infty$ steps, it unmasks one token at a time, always with full context.

- Which extreme gives better quality?
- Which extreme is faster?
- Where is the sweet spot?

In [None]:
#@title üéß Listen: Math
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_math.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 3. The Mathematics

### Remasking Strategies

At each denoising step $s$, the model predicts tokens at all masked positions. The **remasking strategy** determines which predictions to keep and which to mask again:

**1. Confidence-based remasking:**
Keep the top $k$ predictions ranked by confidence $p_\theta(x^i \mid x_t)$, where:

$$k = \left\lfloor \frac{N_{\text{masked}}}{s} \right\rfloor$$

**Numerical example:** If 12 tokens are masked and we are at step $s = 4$, we unmask $\lfloor 12/4 \rfloor = 3$ tokens (the 3 most confident predictions). The other 9 get remasked.

**2. Random remasking:**
Keep $k$ predictions chosen uniformly at random (ignoring confidence). Same schedule for $k$.

**3. Linear schedule:**
Unmask a fixed fraction at each step: $k = N_{\text{total}} / S$ tokens per step, regardless of confidence.

**What this means computationally:** Confidence-based is like a student answering exam questions easiest-first. Random is like answering in random order. Linear is like answering one per minute, regardless of difficulty.

In [None]:
#@title üéß Listen: Build Model
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_build_model.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 4. Let's Build It ‚Äî Component by Component

### 4.1 Model and Data (Reused from Notebook 1)

In [None]:
VOCAB_SIZE = 16
SEQ_LEN = 24
MASK_TOKEN = 0
D_MODEL = 64
N_HEADS = 4
N_LAYERS = 3
BATCH_SIZE = 64

def generate_pattern_data(batch_size, seq_len, vocab_size):
    """Generate sequences with repeating patterns."""
    sequences = []
    for _ in range(batch_size):
        pattern_len = np.random.randint(2, 6)
        pattern = np.random.randint(1, vocab_size, size=pattern_len)
        seq = np.tile(pattern, seq_len // pattern_len + 1)[:seq_len]
        sequences.append(seq)
    return torch.tensor(np.array(sequences), dtype=torch.long, device=device)


class DiffusionLM(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_layers):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(SEQ_LEN, d_model)
        self.time_mlp = nn.Sequential(
            nn.Linear(1, d_model), nn.SiLU(), nn.Linear(d_model, d_model)
        )
        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=d_model*4,
            dropout=0.1, batch_first=True, norm_first=True
        )
        self.transformer = nn.TransformerEncoder(enc_layer, n_layers)
        self.head = nn.Linear(d_model, vocab_size)

    def forward(self, x_t, t):
        positions = torch.arange(x_t.size(1), device=x_t.device)
        h = self.embed(x_t) + self.pos_embed(positions)
        h = h + self.time_mlp(t).unsqueeze(1)
        h = self.transformer(h)
        return self.head(h)


def mask_tokens(x_0, t):
    mask = torch.rand_like(x_0.float()) < t
    x_t = x_0.clone()
    x_t[mask] = MASK_TOKEN
    return x_t, mask


# Train the model
model = DiffusionLM(VOCAB_SIZE, D_MODEL, N_HEADS, N_LAYERS).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

print("Training diffusion model...")
for step in range(2500):
    x_0 = generate_pattern_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
    t = torch.rand(BATCH_SIZE, 1, device=device) * 0.98 + 0.02
    x_t, mask = mask_tokens(x_0, t)
    logits = model(x_t, t)
    if mask.sum() == 0:
        continue
    loss = F.cross_entropy(logits[mask], x_0[mask])
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    if (step+1) % 500 == 0:
        print(f"  Step {step+1}/2500 | Loss: {loss.item():.4f}")

model.eval()
print("Done!")

In [None]:
#@title üéß Listen: Strategies
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/05_strategies.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.2 Three Remasking Strategies

In [None]:
@torch.no_grad()
def generate_confidence(model, seq_len=SEQ_LEN, n_steps=10):
    """Confidence-based remasking: unmask most confident predictions first."""
    x = torch.full((1, seq_len), MASK_TOKEN, dtype=torch.long, device=device)
    history = [x[0].cpu().clone()]

    for s in range(n_steps, 0, -1):
        t = torch.tensor([[s / n_steps]], device=device, dtype=torch.float)
        logits = model(x, t)
        probs = F.softmax(logits, dim=-1)

        sampled = torch.multinomial(probs.view(-1, VOCAB_SIZE), 1).view(1, -1)
        confidence = probs.gather(-1, sampled.unsqueeze(-1)).squeeze(-1)

        is_masked = (x == MASK_TOKEN)
        n_unmask = max(1, int(is_masked.sum().item() / s))

        conf = confidence.clone()
        conf[~is_masked] = -float('inf')
        _, top_idx = conf.topk(min(n_unmask, is_masked.sum().item()), dim=-1)
        x.scatter_(1, top_idx, sampled.gather(1, top_idx))
        history.append(x[0].cpu().clone())

    return x, history

In [None]:
@torch.no_grad()
def generate_random(model, seq_len=SEQ_LEN, n_steps=10):
    """Random remasking: unmask random masked positions each step."""
    x = torch.full((1, seq_len), MASK_TOKEN, dtype=torch.long, device=device)
    history = [x[0].cpu().clone()]

    for s in range(n_steps, 0, -1):
        t = torch.tensor([[s / n_steps]], device=device, dtype=torch.float)
        logits = model(x, t)
        probs = F.softmax(logits, dim=-1)
        sampled = torch.multinomial(probs.view(-1, VOCAB_SIZE), 1).view(1, -1)

        is_masked = (x == MASK_TOKEN)
        n_unmask = max(1, int(is_masked.sum().item() / s))

        # Randomly select among masked positions
        masked_indices = is_masked[0].nonzero(as_tuple=True)[0]
        if len(masked_indices) > 0:
            perm = torch.randperm(len(masked_indices), device=device)[:n_unmask]
            selected = masked_indices[perm]
            x[0, selected] = sampled[0, selected]

        history.append(x[0].cpu().clone())

    return x, history

In [None]:
@torch.no_grad()
def generate_single_shot(model, seq_len=SEQ_LEN):
    """Single-shot: predict all tokens at once (S=1)."""
    x = torch.full((1, seq_len), MASK_TOKEN, dtype=torch.long, device=device)
    t = torch.tensor([[1.0]], device=device, dtype=torch.float)
    logits = model(x, t)
    x = logits.argmax(dim=-1)
    return x, [torch.full((seq_len,), MASK_TOKEN), x[0].cpu().clone()]

In [None]:
#@title üéß Listen: Quality Metric
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_quality_metric.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.3 Quality Metric

In [None]:
def measure_pattern_quality(sequence):
    """Measure how well a generated sequence follows a repeating pattern.

    Returns a score from 0 to 1, where 1 means a perfect repeating pattern.
    We check all possible pattern lengths (2-6) and return the best score.
    """
    seq = sequence.cpu().numpy() if isinstance(sequence, torch.Tensor) else sequence
    best_score = 0

    for pat_len in range(2, 7):
        # Check if the sequence repeats with this period
        matches = 0
        total = 0
        for i in range(pat_len, len(seq)):
            if seq[i] == seq[i % pat_len]:
                matches += 1
            total += 1
        if total > 0:
            score = matches / total
            best_score = max(best_score, score)

    return best_score

In [None]:
# üìä Quick test: generate with each strategy and compare
print("Testing generation strategies...\n")

for name, gen_fn in [("Confidence-based", generate_confidence),
                      ("Random", generate_random),
                      ("Single-shot", generate_single_shot)]:
    scores = []
    for _ in range(20):
        if name == "Single-shot":
            seq, _ = gen_fn(model)
        else:
            seq, _ = gen_fn(model, n_steps=10)
        scores.append(measure_pattern_quality(seq[0]))

    print(f"{name:20s} | Quality: {np.mean(scores):.3f} ¬± {np.std(scores):.3f}")

In [None]:
#@title üéß Listen: Todo
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/07_todo.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 5. üîß Your Turn: Run the Step-Count Sweep

### TODO: Measure quality at different step counts

In [None]:
def sweep_step_counts(model, strategy_fn, step_counts, n_samples=30):
    """Measure generation quality at different numbers of denoising steps.

    Args:
        model: Trained DiffusionLM
        strategy_fn: Generation function (generate_confidence or generate_random)
        step_counts: List of step counts to test
        n_samples: Number of sequences to generate per step count

    Returns:
        mean_scores: List of mean quality scores
        std_scores: List of standard deviations
    """
    mean_scores = []
    std_scores = []

    for n_steps in step_counts:
        # ============ TODO ============
        # For each step count:
        # 1. Generate n_samples sequences using strategy_fn(model, n_steps=n_steps)
        # 2. Compute the pattern quality score for each
        # 3. Store the mean and std of scores

        scores = []
        for _ in range(n_samples):
            seq, _ = ???  # YOUR CODE: call strategy_fn
            score = ???   # YOUR CODE: measure quality
            scores.append(score)

        mean_scores.append(???)  # YOUR CODE
        std_scores.append(???)   # YOUR CODE
        # ==============================

        print(f"  Steps={n_steps:3d} | Quality: {mean_scores[-1]:.3f} ¬± {std_scores[-1]:.3f}")

    return mean_scores, std_scores

# Test
step_counts = [1, 2, 3, 5, 8, 12, 16, 24, 32]
print("Confidence-based remasking:")
# conf_means, conf_stds = sweep_step_counts(model, generate_confidence, step_counts)

In [None]:
# ‚úÖ Verification
# Uncomment and run after completing the TODO above:
# assert len(conf_means) == len(step_counts), "Wrong number of results"
# assert all(0 <= s <= 1 for s in conf_means), "Scores should be between 0 and 1"
# print("‚úÖ Sweep function works!")

In [None]:
#@title üéß Listen: Post Todo
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/08_post_todo.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### Solution

In [None]:
def sweep_step_counts(model, strategy_fn, step_counts, n_samples=30):
    """Measure generation quality at different step counts."""
    mean_scores = []
    std_scores = []

    for n_steps in step_counts:
        scores = []
        for _ in range(n_samples):
            seq, _ = strategy_fn(model, n_steps=n_steps)
            score = measure_pattern_quality(seq[0])
            scores.append(score)
        mean_scores.append(np.mean(scores))
        std_scores.append(np.std(scores))
        print(f"  Steps={n_steps:3d} | Quality: {mean_scores[-1]:.3f} ¬± {std_scores[-1]:.3f}")

    return mean_scores, std_scores


step_counts = [1, 2, 3, 5, 8, 12, 16, 24, 32]

print("Confidence-based remasking:")
conf_means, conf_stds = sweep_step_counts(model, generate_confidence, step_counts)

print("\nRandom remasking:")
rand_means, rand_stds = sweep_step_counts(model, generate_random, step_counts)

In [None]:
#@title üéß Listen: Quality Curve
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/09_quality_curve.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 6. Putting It All Together ‚Äî The Quality vs Speed Curve

In [None]:
# üìä Quality vs Number of Steps
fig, ax = plt.subplots(figsize=(10, 6))

ax.plot(step_counts, conf_means, 'o-', color='#1565c0', linewidth=2.5,
        markersize=8, label='Confidence-based', zorder=5)
ax.fill_between(step_counts,
                np.array(conf_means) - np.array(conf_stds),
                np.array(conf_means) + np.array(conf_stds),
                color='#1565c0', alpha=0.15)

ax.plot(step_counts, rand_means, 's--', color='#e65100', linewidth=2.5,
        markersize=8, label='Random', zorder=5)
ax.fill_between(step_counts,
                np.array(rand_means) - np.array(rand_stds),
                np.array(rand_means) + np.array(rand_stds),
                color='#e65100', alpha=0.15)

# Mark the sweet spot
best_idx = np.argmax(np.array(conf_means) > 0.9 * max(conf_means))
if best_idx > 0:
    ax.axvline(x=step_counts[best_idx], color='gray', linestyle=':', alpha=0.5)
    ax.annotate(f'Sweet spot\n~{step_counts[best_idx]} steps',
                xy=(step_counts[best_idx], conf_means[best_idx]),
                xytext=(step_counts[best_idx] + 5, conf_means[best_idx] - 0.1),
                fontsize=11, arrowprops=dict(arrowstyle='->', color='gray'))

ax.set_xlabel('Number of Denoising Steps', fontsize=13)
ax.set_ylabel('Pattern Quality Score', fontsize=13)
ax.set_title('Quality vs Speed: How Many Steps Do We Need?', fontsize=15)
ax.legend(fontsize=12)
ax.grid(True, alpha=0.3)
ax.set_ylim(0, 1.05)

plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Speed
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/10_speed.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### Speed Benchmark

In [None]:
def benchmark_speed(model, step_counts, n_runs=20):
    """Measure tokens/second at different step counts."""
    tokens_per_second = []

    for n_steps in step_counts:
        times = []
        for _ in range(n_runs):
            start = time.time()
            _ = generate_confidence(model, n_steps=n_steps)
            elapsed = time.time() - start
            times.append(elapsed)

        avg_time = np.mean(times)
        tps = SEQ_LEN / avg_time
        tokens_per_second.append(tps)

    return tokens_per_second


print("Benchmarking speed...")
speed_results = benchmark_speed(model, step_counts)

for steps, tps in zip(step_counts, speed_results):
    print(f"  Steps={steps:3d} | {tps:,.0f} tokens/sec")

In [None]:
# üìä Speed vs Quality combined plot
fig, ax1 = plt.subplots(figsize=(10, 6))
ax2 = ax1.twinx()

line1 = ax1.plot(step_counts, conf_means, 'o-', color='#1565c0',
                  linewidth=2.5, markersize=8, label='Quality')
line2 = ax2.plot(step_counts, speed_results, 's--', color='#2e7d32',
                  linewidth=2.5, markersize=8, label='Speed')

ax1.set_xlabel('Number of Denoising Steps', fontsize=13)
ax1.set_ylabel('Quality Score', fontsize=13, color='#1565c0')
ax2.set_ylabel('Tokens / Second', fontsize=13, color='#2e7d32')
ax1.tick_params(axis='y', labelcolor='#1565c0')
ax2.tick_params(axis='y', labelcolor='#2e7d32')

lines = line1 + line2
labels = [l.get_label() for l in lines]
ax1.legend(lines, labels, fontsize=12, loc='center right')
ax1.set_title('The Quality‚ÄìSpeed Tradeoff', fontsize=15)
ax1.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nüí° Key takeaway: Quality plateaus quickly, but speed drops linearly.")
print("   The sweet spot is usually 5-15 steps for this model.")

In [None]:
#@title üéß Listen: Trajectories
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/11_trajectories.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 7. üìä Visualizing the Unmasking Trajectory

In [None]:
def compare_trajectories(model, n_steps=10):
    """Show side-by-side unmasking trajectories for different strategies."""
    _, hist_conf = generate_confidence(model, n_steps=n_steps)
    _, hist_rand = generate_random(model, n_steps=n_steps)

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 6))

    for ax, history, title, color in [
        (ax1, hist_conf, 'Confidence-Based (Easy tokens first)', '#1565c0'),
        (ax2, hist_rand, 'Random Remasking', '#e65100')
    ]:
        # Create a 2D grid: rows = steps, columns = positions
        n_rows = len(history)
        grid = np.zeros((n_rows, SEQ_LEN, 3))

        for row, seq in enumerate(history):
            tokens = seq.numpy()
            for pos in range(SEQ_LEN):
                if tokens[pos] == MASK_TOKEN:
                    grid[row, pos] = [0.15, 0.15, 0.15]
                else:
                    c = plt.cm.Set2(tokens[pos] / VOCAB_SIZE)[:3]
                    grid[row, pos] = c

        ax.imshow(grid, aspect='auto', interpolation='nearest')
        ax.set_ylabel('Step', fontsize=11)
        ax.set_xlabel('Token Position', fontsize=11)
        ax.set_title(title, fontsize=13, color=color)
        ax.set_yticks(range(0, n_rows, max(1, n_rows//5)))

    plt.suptitle('Unmasking Trajectories: Which Tokens Appear First?', fontsize=15)
    plt.tight_layout()
    plt.show()

compare_trajectories(model, n_steps=12)

In [None]:
#@title üéß Listen: Final
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/12_final.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 8. üéØ Final Output: Generation Quality Grid

In [None]:
def quality_grid(model, step_counts_display=[1, 3, 5, 10, 20], n_sequences=6):
    """Show generated sequences at different step counts side by side."""
    fig, axes = plt.subplots(n_sequences, len(step_counts_display),
                              figsize=(3 * len(step_counts_display), n_sequences * 0.9))

    for col, n_steps in enumerate(step_counts_display):
        for row in range(n_sequences):
            ax = axes[row, col]
            seq, _ = generate_confidence(model, n_steps=n_steps)
            tokens = seq[0].cpu().numpy()
            quality = measure_pattern_quality(seq[0])

            img = np.zeros((1, SEQ_LEN, 3))
            for pos in range(SEQ_LEN):
                if tokens[pos] == MASK_TOKEN:
                    img[0, pos] = [0.15, 0.15, 0.15]
                else:
                    img[0, pos] = plt.cm.Set2(tokens[pos] / VOCAB_SIZE)[:3]

            ax.imshow(img, aspect='auto', interpolation='nearest')
            ax.set_xticks([])
            ax.set_yticks([])

            if row == 0:
                ax.set_title(f'{n_steps} steps', fontsize=12)
            if col == len(step_counts_display) - 1:
                ax.text(SEQ_LEN + 0.5, 0.5, f'{quality:.0%}',
                        va='center', fontsize=9, color='gray')

    plt.suptitle('Generation Quality at Different Step Counts\n'
                 '(Dark = leftover masks, Colors = tokens, Right = quality score)',
                 fontsize=14, y=1.04)
    plt.tight_layout()
    plt.show()

    print("üéâ Notice how even 5 steps produce recognizable patterns!")
    print("   This is why Mercury achieves 10x speed over autoregressive models.")

quality_grid(model)

In [None]:
#@title üéß Listen: Closing
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/13_closing.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 9. Reflection and Next Steps

### ü§î Reflection Questions

1. **Diminishing returns:** Why does quality plateau after a certain number of steps? What is the model doing in those extra steps that does not help much?

2. **Confidence vs random:** Why does confidence-based remasking outperform random? Think about what information the model has after each step.

3. **Real-world scaling:** Mercury achieves 1,109 tokens/sec on H100 with ~25 steps. If we reduce to 10 steps, we get ~2.5x more speed. What applications would benefit from this tradeoff?

### üèÜ Optional Challenges

1. **Cosine schedule:** Instead of unmasking $N/s$ tokens per step, try a cosine schedule: unmask more tokens in early steps (when predictions are coarser) and fewer in late steps (when fine details matter). Does this improve quality?

2. **Temperature annealing:** Start with high temperature (more randomness) in early steps and reduce it in later steps. Compare to fixed temperature.

3. **Semi-autoregressive:** Generate text in blocks of 8 tokens from left to right. Within each block, use diffusion. Implement this hybrid and measure quality vs pure diffusion.

**Next notebook:** We will build a **real diffusion language model on the TinyStories dataset** ‚Äî training it to generate coherent short stories through iterative unmasking!