In [None]:
# üîß Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"‚úÖ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("‚ö†Ô∏è No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

print(f"\nüì¶ Python {sys.version.split()[0]}")
print(f"üî• PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"üé≤ Random seed set to {SEED}")

%matplotlib inline

In [None]:
#@title üéß Download Narration Audio & Play Introduction
import os as _os
if not _os.path.exists("/content/narration"):
    !pip install -q gdown
    import gdown
    gdown.download(id="1LnfxlgM23VTaqwxDE_vcyPrGp_tSwhZr", output="/content/narration.zip", quiet=False)
    !unzip -q /content/narration.zip -d /content/narration
    !rm /content/narration.zip
    print(f"Loaded {len(_os.listdir('/content/narration'))} narration segments")
else:
    print("Narration audio already loaded.")

from IPython.display import Audio, display
display(Audio("/content/narration/03_00_intro.mp3"))

# üöÄ Training a Diffusion LLM: The ELBO and Beyond

*Part 3 of the Vizuara series on Diffusion LLMs from Scratch*
*Estimated time: 40 minutes*

# ü§ñ AI Teaching Assistant

Need help with this notebook? Open the **AI Teaching Assistant** ‚Äî it has already read this entire notebook and can help with concepts, code, and exercises.

**[üëâ Open AI Teaching Assistant](https://course-creator-brown.vercel.app/courses/diffusion-llms-from-scratch/practice/3/assistant)**

*Tip: Open it in a separate tab and work through this notebook side-by-side.*


## 1. Why Does This Matter?

In Notebook 2, we built a bidirectional Transformer that predicts masked tokens. We trained it on synthetic patterns. But two big questions remain:

1. **Is this mathematically rigorous?** Is "train BERT at all masking ratios" really a valid diffusion model?
2. **Does it work on real text?** Can we train on actual language and get a model that learns the structure?

In this notebook, we answer both. We will derive the **Evidence Lower Bound (ELBO)**, show it simplifies to masked language modeling, and train a complete diffusion LLM on real text.

**By the end of this notebook, you will:**
- Understand why the ELBO guarantees our approach is a valid diffusion model
- Train a diffusion LLM on TinyShakespeare (character-level)
- See the model learn to predict masked characters with increasing accuracy
- Analyze how performance varies across masking ratios

In [None]:
#@title üéß Listen: Intuition
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_01_intuition.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 2. Building Intuition

### The Difficulty Curriculum

Think of training across all masking ratios as giving the model a **curriculum**:

- **Low $t$ (few masks):** Most tokens are visible. The masked ones can be inferred from abundant context. This is the "easy homework."
- **High $t$ (many masks):** Very few tokens are visible. The model must rely on global patterns and priors. This is the "hard exam."

Training uniformly over $t \in [0, 1]$ exposes the model to both easy and hard examples. The easy examples teach it local patterns (bigrams, common phrases). The hard examples teach it global structure (sentence templates, long-range dependencies).

### Why Bidirectional Attention Is the Secret Weapon

Consider a sentence where 90% of tokens are masked. An autoregressive model would see almost nothing ‚Äî just the first few tokens. Our model sees the **scattered 10% of tokens across the entire sequence**, giving it signal from both the beginning and the end.

### ü§î Think About This

If you were given a sentence where 90% of characters were hidden, but you could see random characters scattered throughout, could you guess the original? Probably yes ‚Äî because English has so much redundancy. The model exploits this same redundancy.

In [None]:
#@title üéß Listen: Elbo
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_02_elbo.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 3. The Mathematics

### The Evidence Lower Bound (ELBO)

We want to maximize the log-likelihood of our data:

$$\log p_\theta(x_0)$$

Directly computing this is intractable. So we optimize a lower bound called the **ELBO**:

$$\log p_\theta(x_0) \geq \text{ELBO} = \mathbb{E}_{q} \left[ \log p_\theta(x_0 \mid x_1) + \sum_{t=2}^{T} \log \frac{p_\theta(x_{t-1} \mid x_t)}{q(x_t \mid x_{t-1}, x_0)} \right]$$

**What this says computationally:** The ELBO measures how well our reverse process (unmasking) matches the true reverse of the forward process (masking). If our model perfectly predicts masked tokens, the ELBO equals the true log-likelihood.

### Worked Example

Suppose we have a 2-token vocabulary {A, B}, a 3-token sequence, and $T = 2$ steps. At each step, the model assigns probability 0.9 to the correct token:

$$\text{ELBO} \approx 3 \times \log(0.9) = 3 \times (-0.105) = -0.315$$

A perfect model gives ELBO = 0. As the model improves, ELBO increases towards 0.

### The Beautiful Simplification

Here is the key result from the MDLM paper (NeurIPS 2024): **when you work through the math for masked diffusion, the ELBO simplifies to a weighted mixture of masked language modeling losses at different masking ratios.**

$$\text{ELBO} \propto -\mathbb{E}_{t \sim U(0,1)} \left[ \frac{1}{t} \sum_{i : x_t^{(i)} = [\text{MASK}]} \log p_\theta(x_0^{(i)} \mid x_t) \right]$$

**What this means:** The theoretically rigorous diffusion training objective is essentially "train BERT at all masking ratios" with a $1/t$ importance weight. The theory confirms what the intuition suggested.

### üí° Key Insight

The $1/t$ weighting means low-masking-ratio examples (where few tokens are masked) get *more weight* per masked token. This makes sense ‚Äî when almost nothing is masked, each prediction is more informative about the model's understanding.

In practice, many implementations use uniform weighting (no $1/t$ factor) and it still works well.

In [None]:
#@title üéß Listen: Data Setup
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_03_data_setup.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 4. Let's Build It ‚Äî Component by Component

### 4.1 The TinyShakespeare Dataset

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import math
import urllib.request

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

torch.manual_seed(42)
np.random.seed(42)
%matplotlib inline

# Download TinyShakespeare
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
filepath = "shakespeare.txt"
try:
    with open(filepath, 'r') as f:
        text = f.read()
except FileNotFoundError:
    urllib.request.urlretrieve(url, filepath)
    with open(filepath, 'r') as f:
        text = f.read()

print(f"Dataset size: {len(text):,} characters")
print(f"First 200 chars:\n{text[:200]}")

In [None]:
# Character-level tokenizer
chars = sorted(list(set(text)))
VOCAB_SIZE = len(chars) + 1  # +1 for [MASK] token
MASK_TOKEN = 0

# Map characters to IDs (starting from 1; 0 is reserved for MASK)
char_to_id = {ch: i + 1 for i, ch in enumerate(chars)}
id_to_char = {i + 1: ch for i, ch in enumerate(chars)}
id_to_char[MASK_TOKEN] = '[M]'

def encode(s):
    return [char_to_id[c] for c in s]

def decode(ids):
    return ''.join(id_to_char.get(i, '?') for i in ids)

print(f"Vocabulary size: {VOCAB_SIZE} (including [MASK])")
print(f"Characters: {''.join(chars[:30])}...")
print(f"\nEncoding test: 'Hello' ‚Üí {encode('Hello')}")
print(f"Decoding test: {encode('Hello')} ‚Üí '{decode(encode('Hello'))}'")

In [None]:
# Create training sequences
SEQ_LEN = 64
BATCH_SIZE = 64
D_MODEL = 128
N_HEADS = 4
N_LAYERS = 4

# Encode entire text
data = torch.tensor(encode(text), dtype=torch.long)
print(f"Total tokens: {len(data):,}")

# Split into training sequences
n_sequences = len(data) // SEQ_LEN
sequences = data[:n_sequences * SEQ_LEN].reshape(n_sequences, SEQ_LEN)
print(f"Training sequences: {sequences.shape[0]:,} of length {SEQ_LEN}")

# Train/val split (90/10)
n_train = int(0.9 * len(sequences))
train_data = sequences[:n_train].to(device)
val_data = sequences[n_train:].to(device)
print(f"Train: {len(train_data):,} | Val: {len(val_data):,}")

In [None]:
#@title üéß Listen: Masking And Model
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_04_masking_and_model.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.2 The Forward Process

Same masking function as Notebook 2, but now on real text.

In [None]:
def mask_tokens(x_0, t):
    """Mask tokens with probability t."""
    random_vals = torch.rand_like(x_0.float())
    mask = random_vals < t
    x_t = x_0.clone()
    x_t[mask] = MASK_TOKEN
    return x_t, mask

In [None]:
# üìä Visualize masking on real Shakespeare text
sample = train_data[0:1]  # One sequence
print(f"Original: '{decode(sample[0].tolist())}'")
print()

for t_val in [0.2, 0.5, 0.8]:
    torch.manual_seed(0)
    t = torch.tensor([[t_val]], device=device)
    masked, _ = mask_tokens(sample, t)
    masked_str = decode(masked[0].tolist())
    print(f"t={t_val}: '{masked_str}'")

### 4.3 The Model

Same architecture as Notebook 2 but scaled up for real text.

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


class DiffusionLM(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_layers, max_len=512):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len)
        self.time_mlp = nn.Sequential(
            nn.Linear(1, d_model), nn.SiLU(), nn.Linear(d_model, d_model)
        )
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads,
            dim_feedforward=d_model * 4,
            dropout=0.1, batch_first=True, norm_first=True,
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, n_layers)
        self.output_head = nn.Linear(d_model, vocab_size)

    def forward(self, x_t, t):
        h = self.token_embed(x_t)
        h = self.pos_enc(h)
        h = h + self.time_mlp(t).unsqueeze(1)
        h = self.transformer(h)
        return self.output_head(h)


model = DiffusionLM(VOCAB_SIZE, D_MODEL, N_HEADS, N_LAYERS).to(device)
n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,}")

In [None]:
#@title üéß Listen: Training
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_05_training.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.4 The Full Training Loop

In [None]:
NUM_EPOCHS = 20
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, NUM_EPOCHS * (n_train // BATCH_SIZE)
)

train_losses = []
val_losses = []

print("Training on TinyShakespeare...")
for epoch in range(NUM_EPOCHS):
    # Shuffle training data
    perm = torch.randperm(len(train_data))
    epoch_losses = []

    for i in range(0, len(train_data) - BATCH_SIZE, BATCH_SIZE):
        batch = train_data[perm[i:i + BATCH_SIZE]]

        # Random masking ratio per sample
        t = torch.rand(BATCH_SIZE, 1, device=device) * 0.98 + 0.02

        # Mask tokens
        x_t, mask = mask_tokens(batch, t)

        # Predict and compute loss on masked positions
        logits = model(x_t, t)
        logits_masked = logits[mask]
        targets_masked = batch[mask]

        if logits_masked.shape[0] == 0:
            continue

        loss = F.cross_entropy(logits_masked, targets_masked)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        epoch_losses.append(loss.item())

    # Validation loss
    with torch.no_grad():
        val_t = torch.full((len(val_data), 1), 0.5, device=device)
        val_masked, val_mask = mask_tokens(val_data, val_t)
        val_logits = model(val_masked, val_t)
        v_loss = F.cross_entropy(val_logits[val_mask], val_data[val_mask]).item()

    avg_train = np.mean(epoch_losses)
    train_losses.append(avg_train)
    val_losses.append(v_loss)

    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1}/{NUM_EPOCHS} | Train: {avg_train:.3f} | Val: {v_loss:.3f}")

print("Training complete!")

In [None]:
# üìä Training and validation loss curves
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(train_losses, label='Train', color='#1565c0', linewidth=2)
ax.plot(val_losses, label='Validation', color='#e53935', linewidth=2, linestyle='--')
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Cross-Entropy Loss', fontsize=12)
ax.set_title('Diffusion LM Training on TinyShakespeare', fontsize=14)
ax.legend(fontsize=12)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Accuracy
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_06_accuracy.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.5 Evaluating Across Masking Ratios

In [None]:
# üìä How does accuracy change with masking ratio?
@torch.no_grad()
def evaluate_at_t(model, data, t_val, n_samples=200):
    """Compute accuracy at a specific masking ratio."""
    indices = torch.randperm(len(data))[:n_samples]
    batch = data[indices]
    t = torch.full((len(batch), 1), t_val, device=device)
    x_t, mask = mask_tokens(batch, t)
    logits = model(x_t, t)
    preds = logits[mask].argmax(dim=-1)
    targets = batch[mask]
    accuracy = (preds == targets).float().mean().item()
    return accuracy

t_vals = np.linspace(0.1, 0.95, 15)
accuracies = [evaluate_at_t(model, val_data, t) * 100 for t in t_vals]

plt.figure(figsize=(10, 5))
plt.plot(t_vals, accuracies, 'o-', color='#1565c0', linewidth=2, markersize=6)
plt.xlabel('Masking ratio t', fontsize=12)
plt.ylabel('Accuracy (%)', fontsize=12)
plt.title('Prediction Accuracy vs Masking Ratio (Shakespeare)', fontsize=14)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("Low masking ‚Üí high accuracy (lots of context).")
print("High masking ‚Üí lower accuracy (very little context).")

In [None]:
#@title üéß Listen: Predictions
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_07_predictions.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.6 Visualizing Predictions on Real Text

In [None]:
# Show predictions on actual Shakespeare text
@torch.no_grad()
def show_predictions(model, text_str, t_val=0.5):
    """Mask a text string and show model predictions."""
    ids = encode(text_str[:SEQ_LEN])
    if len(ids) < SEQ_LEN:
        ids = ids + [char_to_id[' ']] * (SEQ_LEN - len(ids))
    x = torch.tensor([ids], device=device)

    t = torch.tensor([[t_val]], device=device)
    x_t, mask = mask_tokens(x, t)
    logits = model(x_t, t)
    probs = F.softmax(logits, dim=-1)
    preds = logits[0].argmax(dim=-1)

    print(f"Original:  '{decode(ids)}'")
    masked_str = decode(x_t[0].tolist())
    print(f"Masked:    '{masked_str}'")

    # Show predictions at masked positions
    pred_ids = x_t[0].clone()
    for i in range(SEQ_LEN):
        if mask[0, i]:
            pred_ids[i] = preds[i]
    print(f"Predicted: '{decode(pred_ids.tolist())}'")

    # Count correct
    n_correct = sum(1 for i in range(SEQ_LEN) if mask[0, i] and preds[i] == x[0, i])
    n_total = mask[0].sum().item()
    print(f"Accuracy:  {n_correct}/{n_total} = {n_correct/max(n_total,1)*100:.0f}%")

print("=" * 60)
show_predictions(model, "First Citizen:\nBefore we proceed any further, hear me speak.")
print()
show_predictions(model, "To be, or not to be, that is the question:")
print()
show_predictions(model, "All that glitters is not gold; often have you heard that told")

In [None]:
#@title üéß Listen: Todo1
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_08_todo1.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 5. üîß Your Turn

### TODO 1: Implement a Learning Rate Scheduler Comparison

In [None]:
def train_with_schedule(schedule_type='cosine', n_steps=500, lr=3e-4):
    """Train a small model with different LR schedules and compare.

    Args:
        schedule_type: 'cosine', 'constant', or 'linear'
        n_steps: Number of training steps
        lr: Base learning rate

    Returns:
        List of losses
    """
    small_model = DiffusionLM(VOCAB_SIZE, 64, 2, 2).to(device)
    optimizer = torch.optim.AdamW(small_model.parameters(), lr=lr)

    # ============ TODO ============
    # Create the appropriate scheduler based on schedule_type
    if schedule_type == 'cosine':
        scheduler = ???  # YOUR CODE: CosineAnnealingLR
    elif schedule_type == 'linear':
        scheduler = ???  # YOUR CODE: LinearLR with end_factor=0.01
    else:  # constant
        scheduler = ???  # YOUR CODE: ConstantLR or no scheduler
    # ==============================

    losses = []
    for step in range(n_steps):
        batch_idx = torch.randint(0, len(train_data), (32,))
        batch = train_data[batch_idx]
        t = torch.rand(32, 1, device=device) * 0.98 + 0.02
        x_t, mask = mask_tokens(batch, t)
        logits = small_model(x_t, t)
        if mask.sum() == 0:
            continue
        loss = F.cross_entropy(logits[mask], batch[mask])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        losses.append(loss.item())

    return losses

In [None]:
# ‚úÖ Verification: compare schedules
try:
    losses_cos = train_with_schedule('cosine')
    losses_const = train_with_schedule('constant')
    losses_lin = train_with_schedule('linear')

    fig, ax = plt.subplots(figsize=(10, 5))
    w = 20
    for losses, label, color in [
        (losses_cos, 'Cosine', '#1565c0'),
        (losses_const, 'Constant', '#e53935'),
        (losses_lin, 'Linear', '#2e7d32'),
    ]:
        smoothed = np.convolve(losses, np.ones(w)/w, mode='valid')
        ax.plot(smoothed, label=label, color=color, linewidth=2)
    ax.set_xlabel('Step')
    ax.set_ylabel('Loss')
    ax.set_title('LR Schedule Comparison')
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    print("‚úÖ Compare which schedule converges fastest!")
except NameError:
    print("‚ùå Replace the ??? placeholders.")

In [None]:
#@title üéß Listen: Todo2
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_09_todo2.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### TODO 2: Per-Position Confidence Analysis

In [None]:
@torch.no_grad()
def confidence_analysis(model, data, t_val=0.5, n_samples=100):
    """Compute the model's confidence at each position.

    Args:
        model: Trained DiffusionLM
        data: Validation data
        t_val: Masking ratio to evaluate at
        n_samples: Number of samples

    Returns:
        position_confidences: Average confidence per position, shape (SEQ_LEN,)
    """
    indices = torch.randperm(len(data))[:n_samples]
    batch = data[indices]
    t = torch.full((n_samples, 1), t_val, device=device)
    x_t, mask = mask_tokens(batch, t)

    # ============ TODO ============
    # Step 1: Get model predictions
    logits = ???  # YOUR CODE

    # Step 2: Convert to probabilities
    probs = ???  # YOUR CODE: softmax over vocab dim

    # Step 3: Get max probability at each position (confidence)
    confidence = ???  # YOUR CODE: max prob at each position, shape (n_samples, SEQ_LEN)

    # Step 4: Average confidence per position (only at masked positions)
    #         Set unmasked positions to NaN so they don't affect the mean
    confidence[~mask] = float('nan')
    position_confidences = torch.nanmean(confidence, dim=0)
    # ==============================

    return position_confidences.cpu().numpy()

In [None]:
# ‚úÖ Verification
try:
    conf = confidence_analysis(model, val_data, t_val=0.5)
    plt.figure(figsize=(12, 4))
    plt.bar(range(SEQ_LEN), conf, color='#1565c0', alpha=0.7)
    plt.xlabel('Position', fontsize=11)
    plt.ylabel('Avg Confidence', fontsize=11)
    plt.title('Model Confidence by Position (t=0.5)', fontsize=13)
    plt.grid(True, alpha=0.3, axis='y')
    plt.tight_layout()
    plt.show()
    print("‚úÖ Do you see any positional patterns in confidence?")
except NameError:
    print("‚ùå Replace the ??? placeholders.")

In [None]:
#@title üéß Listen: Results
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_10_results.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 6. Putting It All Together

In [None]:
# Full pipeline demo: take Shakespeare, mask it, predict, show results
@torch.no_grad()
def full_demo(model, text_str, t_val=0.5):
    """Complete pipeline demonstration."""
    ids = encode(text_str[:SEQ_LEN])
    if len(ids) < SEQ_LEN:
        ids = ids + [char_to_id[' ']] * (SEQ_LEN - len(ids))
    x = torch.tensor([ids], device=device)
    t = torch.tensor([[t_val]], device=device)
    x_t, mask = mask_tokens(x, t)
    logits = model(x_t, t)
    probs = F.softmax(logits, dim=-1)
    confidences = probs.max(dim=-1).values[0]

    return x[0], x_t[0], logits[0].argmax(dim=-1), confidences, mask[0]


sample_text = "ROMEO:\nBut, soft! what light through yonder window breaks?"
orig, masked, predicted, conf, mask = full_demo(model, sample_text)

print("Original: ", decode(orig.tolist()))
print("Masked:   ", decode(masked.tolist()))
print("Predicted:", decode(predicted.tolist()))
print()

n_correct = ((predicted == orig) & mask).sum().item()
n_total = mask.sum().item()
print(f"Accuracy on masked tokens: {n_correct}/{n_total} = {n_correct/max(n_total,1)*100:.0f}%")

## 7. üéØ Final Output

In [None]:
# üìä Beautiful visualization: predictions with confidence coloring
fig, axes = plt.subplots(3, 1, figsize=(18, 6))

sample_texts = [
    "First Citizen:\nBefore we proceed any further, hear me speak.",
    "ROMEO:\nBut, soft! what light through yonder window breaks?",
    "To be, or not to be, that is the question: whether 'tis",
]

for ax, sample_text in zip(axes, sample_texts):
    orig, masked, predicted, conf, mask_bool = full_demo(model, sample_text, t_val=0.4)

    for pos in range(SEQ_LEN):
        if mask_bool[pos]:
            correct = predicted[pos] == orig[pos]
            color = '#2e7d32' if correct else '#e53935'
            alpha = min(1.0, conf[pos].item() + 0.3)
            ax.add_patch(plt.Rectangle((pos, 0), 1, 1, color=color, alpha=alpha))
            char = id_to_char.get(predicted[pos].item(), '?')
            ax.text(pos + 0.5, 0.5, char, ha='center', va='center',
                    fontsize=7, fontweight='bold', color='white')
        else:
            ax.add_patch(plt.Rectangle((pos, 0), 1, 1, color='#e3f2fd'))
            char = id_to_char.get(orig[pos].item(), '?')
            ax.text(pos + 0.5, 0.5, char, ha='center', va='center',
                    fontsize=7, color='#333333')

    ax.set_xlim(0, SEQ_LEN)
    ax.set_ylim(0, 1)
    ax.set_xticks([])
    ax.set_yticks([])

plt.suptitle('Predictions at t=0.4 (green=correct, red=wrong, blue=unmasked)',
             fontsize=13, y=1.02)
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Landscape
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_11_landscape.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 8. The Landscape of Diffusion LLMs

Our tiny character-level model is conceptually identical to state-of-the-art diffusion LLMs ‚Äî just scaled up.

| Model | Year | Approach | Key Result |
|---|---|---|---|
| MDLM | 2024 | Masked diffusion | Within 14% of GPT-2 perplexity |
| SEDD | 2024 | Score-based discrete | ICML Best Paper |
| LLaDA | 2025 | Masked diffusion (8B) | Competitive with LLaMA 3 |
| Mercury | 2025 | Diffusion (commercial) | 1,000+ tok/s |
| Gemini Diffusion | 2025 | Diffusion (commercial) | 1,479 tok/s |

The standout is **LLaDA** ‚Äî an 8B-parameter diffusion LLM that matches LLaMA 3. It even solves the **reversal curse**: when trained that "A is B," it can infer "B is A" ‚Äî something GPT-4o cannot do, because it only sees left-to-right context.

In [None]:
#@title üéß Listen: Closing
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_12_closing.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 9. Reflection and Next Steps

### ü§î Reflection Questions

1. **Why is the ELBO a lower bound?** It is tight when the reverse process perfectly matches the forward process. Any imperfection in the model's predictions loosens the bound.

2. **Should we weight masking ratios non-uniformly?** The theory says $1/t$ weighting is optimal. In practice, uniform works well too. Some papers use cosine schedules over $t$.

3. **How would you handle very long sequences?** The Transformer's attention is $O(L^2)$. For long sequences, you would need efficient attention (sparse, linear, or local+global patterns).

### üèÜ Optional Challenges

1. Try a larger model (more layers, bigger d_model) and see if accuracy improves
2. Implement weight tying (embedding and output head share parameters)
3. Try word-level tokenization instead of character-level
4. Compute the actual ELBO on the validation set and compare to cross-entropy loss

---

**Up Next ‚Äî Notebook 4:** *Generation ‚Äî Iterative Unmasking.* We will use our trained model to generate text from scratch, starting with all [MASK] tokens and progressively revealing them in order of confidence ‚Äî the grand finale!