In [None]:
# üîß Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"‚úÖ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("‚ö†Ô∏è No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

print(f"\nüì¶ Python {sys.version.split()[0]}")
print(f"üî• PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"üé≤ Random seed set to {SEED}")

%matplotlib inline

In [None]:
#@title üéß Download Narration Audio & Play Introduction
import os as _os
if not _os.path.exists("/content/narration"):
    !pip install -q gdown
    import gdown
    gdown.download(id="1N8Z5j5PJiwMRZO-F3b_vfz0oQi19MNIs", output="/content/narration.zip", quiet=False)
    !unzip -q /content/narration.zip -d /content/narration
    !rm /content/narration.zip
    print(f"Loaded {len(_os.listdir('/content/narration'))} narration segments")
else:
    print("Narration audio already loaded.")

from IPython.display import Audio, display
display(Audio("/content/narration/04_00_intro.mp3"))

# üöÄ Generation: Iterative Unmasking ‚Äî The Grand Finale

*Part 4 of the Vizuara series on Diffusion LLMs from Scratch*
*Estimated time: 35 minutes*

## 1. Why Does This Matter?

We have trained a model that predicts masked tokens. Now comes the payoff: **generating brand new text from scratch.**

The process is called **iterative unmasking**: start with a fully masked sequence, predict all tokens, keep the most confident ones, re-mask the rest, and repeat. In just a handful of steps, coherent text emerges from pure [MASK] tokens.

The remarkable thing? **Tokens appear in order of confidence, not left-to-right.** The model might fill in "the" at position 1 and "." at the end before it fills in the middle. It generates like an artist ‚Äî broad strokes first, details last.

**By the end of this notebook, you will:**
- Implement the full generation pipeline
- Watch text materialize step-by-step from pure masks
- Experiment with temperature, number of steps, and sampling strategies
- Demonstrate infilling ‚Äî the killer feature of diffusion LLMs

In [None]:
#@title üéß Listen: Artist Analogy
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_01_artist_analogy.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 2. Building Intuition

### The Artist Analogy ‚Äî Revisited

An artist does not paint left to right. They:
1. Start with a blank canvas
2. Sketch the broadest composition (big shapes, layout)
3. Add medium details (forms, proportions)
4. Refine fine details (textures, edges)

Our generation process works the same way:
1. Start with all [MASK] tokens
2. Fill in the most confident tokens first (common words, structural elements)
3. Use that context to fill in less obvious tokens
4. The last tokens are the trickiest ‚Äî subtle word choices that depend on everything else

### ü§î Think About This

What are the advantages of generating in confidence order vs left-to-right?

- The ending can inform the beginning (bidirectional context)
- Easy tokens settle first, creating scaffolding for harder tokens
- The model can effectively "change its mind" through the iterative process

In [None]:
#@title üéß Listen: Algorithm
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_02_algorithm.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 3. The Generation Algorithm

### Step-by-Step Walkthrough

Using the article's example ‚Äî generating a 6-token sentence in 4 steps:

**Step 1:** Input: [M] [M] [M] [M] [M] [M]
Model is most confident about positions 1 and 6 ‚Üí unmask "The" and "mat"

**Step 2:** Input: The [M] [M] [M] [M] mat
Now has context from both ends ‚Üí unmask "sat" and "the"

**Step 3:** Input: The [M] sat [M] the mat
Strong bidirectional signal ‚Üí unmask "cat"

**Step 4:** Input: The cat sat [M] the mat
Only one mask left ‚Üí unmask "on"

**Result: "The cat sat on the mat"**

Notice: "mat" at position 6 was filled before "cat" at position 2. The model generated in **confidence order, not positional order.**

### The Key Formula

At each step, unmask the top-$k$ most confident predictions among currently masked positions. A simple schedule: $k = \text{remaining\_masks} / \text{remaining\_steps}$.

In [None]:
#@title üéß Listen: Setup Training
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_03_setup_training.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 4. Let's Build It ‚Äî Component by Component

### 4.1 Rebuild and Train the Model

We need a self-contained notebook, so we retrain on TinyShakespeare. This takes about 3-5 minutes on a T4 GPU.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import math
import urllib.request

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
torch.manual_seed(42)
np.random.seed(42)
%matplotlib inline

# Download TinyShakespeare
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
filepath = "shakespeare.txt"
try:
    with open(filepath, 'r') as f:
        text = f.read()
except FileNotFoundError:
    urllib.request.urlretrieve(url, filepath)
    with open(filepath, 'r') as f:
        text = f.read()

chars = sorted(list(set(text)))
VOCAB_SIZE = len(chars) + 1
MASK_TOKEN = 0
SEQ_LEN = 64
BATCH_SIZE = 64
D_MODEL = 128
N_HEADS = 4
N_LAYERS = 4

char_to_id = {ch: i + 1 for i, ch in enumerate(chars)}
id_to_char = {i + 1: ch for i, ch in enumerate(chars)}
id_to_char[MASK_TOKEN] = '‚ñà'

def encode(s):
    return [char_to_id[c] for c in s]
def decode(ids):
    return ''.join(id_to_char.get(i, '?') for i in ids)

data = torch.tensor(encode(text), dtype=torch.long)
n_seq = len(data) // SEQ_LEN
sequences = data[:n_seq * SEQ_LEN].reshape(n_seq, SEQ_LEN).to(device)
n_train = int(0.9 * len(sequences))
train_data = sequences[:n_train]

def mask_tokens(x_0, t):
    mask = torch.rand_like(x_0.float()) < t
    x_t = x_0.clone()
    x_t[mask] = MASK_TOKEN
    return x_t, mask

print(f"Vocab: {VOCAB_SIZE} | Sequences: {len(sequences):,} | Training: {n_train:,}")

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class DiffusionLM(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_layers):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model)
        self.time_mlp = nn.Sequential(nn.Linear(1, d_model), nn.SiLU(), nn.Linear(d_model, d_model))
        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=d_model*4,
            dropout=0.1, batch_first=True, norm_first=True)
        self.transformer = nn.TransformerEncoder(enc_layer, n_layers)
        self.output_head = nn.Linear(d_model, vocab_size)

    def forward(self, x_t, t):
        h = self.token_embed(x_t)
        h = self.pos_enc(h)
        h = h + self.time_mlp(t).unsqueeze(1)
        h = self.transformer(h)
        return self.output_head(h)

model = DiffusionLM(VOCAB_SIZE, D_MODEL, N_HEADS, N_LAYERS).to(device)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Quick training (~3-5 min on T4)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
NUM_EPOCHS = 15

print("Training...")
for epoch in range(NUM_EPOCHS):
    perm = torch.randperm(len(train_data))
    epoch_loss = []
    for i in range(0, len(train_data) - BATCH_SIZE, BATCH_SIZE):
        batch = train_data[perm[i:i+BATCH_SIZE]]
        t = torch.rand(BATCH_SIZE, 1, device=device) * 0.98 + 0.02
        x_t, mask = mask_tokens(batch, t)
        logits = model(x_t, t)
        if mask.sum() == 0: continue
        loss = F.cross_entropy(logits[mask], batch[mask])
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        epoch_loss.append(loss.item())
    if (epoch+1) % 5 == 0:
        print(f"  Epoch {epoch+1}/{NUM_EPOCHS} | Loss: {np.mean(epoch_loss):.3f}")

print("Training complete!")

In [None]:
#@title üéß Listen: Generate Function
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_04_generate_function.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.2 The Basic Generation Function

In [None]:
@torch.no_grad()
def generate(model, seq_len=SEQ_LEN, num_steps=10, temperature=1.0):
    """Generate text via iterative confidence-based unmasking.

    Args:
        model: Trained DiffusionLM
        seq_len: Length of sequence to generate
        num_steps: Number of unmasking steps
        temperature: Sampling temperature (lower = more deterministic)

    Returns:
        Generated token IDs, shape (1, seq_len)
    """
    model.eval()
    x = torch.full((1, seq_len), MASK_TOKEN, dtype=torch.long, device=device)

    for step in range(num_steps):
        # Current noise level (decreasing from 1 to 0)
        t = torch.tensor([[1.0 - step / num_steps]], device=device)

        # Model predicts all tokens
        logits = model(x, t)
        probs = F.softmax(logits / temperature, dim=-1)

        # Sample tokens from predicted distribution
        predicted = torch.multinomial(
            probs.view(-1, VOCAB_SIZE), 1
        ).view(1, seq_len)

        # Confidence = max probability
        confidence = probs.max(dim=-1).values

        # How many to unmask this step
        is_masked = (x == MASK_TOKEN)
        remaining = is_masked.sum().item()
        remaining_steps = max(1, num_steps - step)
        n_to_unmask = max(1, int(remaining / remaining_steps))

        # Unmask the most confident predictions among masked positions
        masked_confidence = confidence.clone()
        masked_confidence[~is_masked] = -float('inf')
        _, top_idx = masked_confidence.view(-1).topk(min(n_to_unmask, remaining))
        x.view(-1)[top_idx] = predicted.view(-1)[top_idx]

    return x


# Generate some text!
print("Generated samples:")
print("=" * 65)
for i in range(5):
    gen = generate(model, num_steps=12)
    print(f"  {decode(gen[0].tolist())}")
print("=" * 65)

In [None]:
#@title üéß Listen: Visualization
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_05_visualization.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.3 Visualizing the Generation Process

This is the centerpiece visualization ‚Äî watching tokens appear step by step.

In [None]:
@torch.no_grad()
def generate_with_history(model, seq_len=SEQ_LEN, num_steps=10, temperature=0.8):
    """Generate text and record the state at every step."""
    model.eval()
    x = torch.full((1, seq_len), MASK_TOKEN, dtype=torch.long, device=device)
    history = [(x[0].cpu().clone(), 'Start (all masked)')]
    confidence_history = []

    for step in range(num_steps):
        t = torch.tensor([[1.0 - step / num_steps]], device=device)
        logits = model(x, t)
        probs = F.softmax(logits / temperature, dim=-1)
        predicted = torch.multinomial(probs.view(-1, VOCAB_SIZE), 1).view(1, seq_len)
        confidence = probs.max(dim=-1).values

        is_masked = (x == MASK_TOKEN)
        remaining = is_masked.sum().item()
        if remaining == 0:
            break
        n_to_unmask = max(1, int(remaining / max(1, num_steps - step)))

        masked_conf = confidence.clone()
        masked_conf[~is_masked] = -float('inf')
        _, top_idx = masked_conf.view(-1).topk(min(n_to_unmask, remaining))
        x.view(-1)[top_idx] = predicted.view(-1)[top_idx]

        history.append((x[0].cpu().clone(), f'Step {step+1}'))
        confidence_history.append(confidence[0].cpu().clone())

    return x, history, confidence_history

In [None]:
# üìä Step-by-step generation visualization
gen, history, conf_hist = generate_with_history(model, num_steps=10)

fig, axes = plt.subplots(len(history), 1, figsize=(18, len(history) * 0.9))

for ax, (seq, label) in zip(axes, history):
    tokens = seq.numpy()
    for pos in range(SEQ_LEN):
        if tokens[pos] == MASK_TOKEN:
            ax.add_patch(plt.Rectangle((pos, 0), 1, 1, color='#333333', alpha=0.85))
            ax.text(pos + 0.5, 0.5, '‚ñà', ha='center', va='center',
                    color='#666', fontsize=7)
        else:
            ax.add_patch(plt.Rectangle((pos, 0), 1, 1, color='#e8f5e9', alpha=0.9))
            char = id_to_char.get(tokens[pos], '?')
            ax.text(pos + 0.5, 0.5, char, ha='center', va='center',
                    fontsize=7, color='#1b5e20', fontweight='bold')
    ax.set_xlim(0, SEQ_LEN)
    ax.set_ylim(0, 1)
    ax.set_ylabel(label, fontsize=9, rotation=0, ha='right', va='center',
                  labelpad=80)
    ax.set_xticks([])
    ax.set_yticks([])

plt.suptitle('Iterative Unmasking: Text Emerging from Pure Masks',
             fontsize=14, y=1.02)
plt.tight_layout()
plt.show()
print("Dark = still masked. Green = revealed token.")
print("Notice: tokens appear in confidence order, not left-to-right!")

In [None]:
#@title üéß Listen: Temperature
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_06_temperature.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.4 Temperature and Sampling

In [None]:
# üìä Effect of temperature on generation diversity
print("Temperature Comparison:")
print("=" * 65)
for temp in [0.3, 0.7, 1.0, 1.5]:
    print(f"\nTemperature = {temp}:")
    for _ in range(3):
        gen = generate(model, num_steps=12, temperature=temp)
        print(f"  {decode(gen[0].tolist())}")
print("=" * 65)
print("\nLow temperature ‚Üí repetitive but coherent")
print("High temperature ‚Üí diverse but potentially noisy")

In [None]:
#@title üéß Listen: Num Steps
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_07_num_steps.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.5 Number of Steps

In [None]:
# üìä Quality vs number of unmasking steps
print("Effect of Number of Steps:")
print("=" * 65)
for n_steps in [1, 2, 5, 10, 20]:
    torch.manual_seed(42)
    gen = generate(model, num_steps=n_steps, temperature=0.7)
    text_out = decode(gen[0].tolist())
    print(f"  steps={n_steps:2d}: {text_out}")
print("=" * 65)
print("\n1 step = unmask everything at once (worst quality)")
print("More steps = iterative refinement (better quality)")
print("This is why diffusion LLMs are fast ‚Äî even a few steps work!")

In [None]:
#@title üéß Listen: Infilling
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_08_infilling.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.6 Infilling ‚Äî The Killer Feature

This is something autoregressive models **cannot** do naturally. We fix some tokens and let the model fill in the blanks.

In [None]:
@torch.no_grad()
def infill(model, template_str, num_steps=15, temperature=0.7):
    """Fill in masked positions while keeping fixed tokens.

    Use '‚ñà' in the template to indicate positions to fill.
    """
    model.eval()
    template = template_str[:SEQ_LEN]
    if len(template) < SEQ_LEN:
        template = template + '‚ñà' * (SEQ_LEN - len(template))

    ids = []
    fixed_mask = []
    for ch in template:
        if ch == '‚ñà':
            ids.append(MASK_TOKEN)
            fixed_mask.append(False)
        else:
            ids.append(char_to_id.get(ch, MASK_TOKEN))
            fixed_mask.append(True)

    x = torch.tensor([ids], dtype=torch.long, device=device)
    fixed = torch.tensor([fixed_mask], device=device)

    for step in range(num_steps):
        t = torch.tensor([[1.0 - step / num_steps]], device=device)
        logits = model(x, t)
        probs = F.softmax(logits / temperature, dim=-1)
        predicted = torch.multinomial(probs.view(-1, VOCAB_SIZE), 1).view(1, SEQ_LEN)
        confidence = probs.max(dim=-1).values

        is_masked = (x == MASK_TOKEN) & ~fixed
        remaining = is_masked.sum().item()
        if remaining == 0:
            break
        n_to_unmask = max(1, int(remaining / max(1, num_steps - step)))

        masked_conf = confidence.clone()
        masked_conf[~is_masked] = -float('inf')
        _, top_idx = masked_conf.view(-1).topk(min(n_to_unmask, remaining))
        x.view(-1)[top_idx] = predicted.view(-1)[top_idx]

    return decode(x[0].tolist())[:len(template_str)]

In [None]:
# üìä Infilling demonstrations
print("Infilling Demonstrations:")
print("=" * 65)

templates = [
    "To be‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñàthat is the question",
    "ROMEO:‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà",
    "‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñàlight‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà",
    "First Citizen:‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà",
]

for template in templates:
    result = infill(model, template, num_steps=20, temperature=0.7)
    print(f"  Template:  {template[:60]}")
    print(f"  Filled:    {result[:60]}")
    print()

print("The model fills in the blanks using BIDIRECTIONAL context!")
print("It sees text on BOTH sides of each blank ‚Äî AR models cannot do this.")

In [None]:
#@title üéß Listen: Lr Comparison
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_09_lr_comparison.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.7 Comparison with Left-to-Right Generation

In [None]:
@torch.no_grad()
def generate_left_to_right(model, seq_len=SEQ_LEN, temperature=0.8):
    """Generate by unmasking strictly left-to-right."""
    model.eval()
    x = torch.full((1, seq_len), MASK_TOKEN, dtype=torch.long, device=device)

    for pos in range(seq_len):
        t_val = 1.0 - pos / seq_len
        t = torch.tensor([[t_val]], device=device)
        logits = model(x, t)
        probs = F.softmax(logits[0, pos] / temperature, dim=-1)
        x[0, pos] = torch.multinomial(probs, 1)

    return x


print("Generation Comparison:")
print("=" * 65)
print("Confidence-order (diffusion):")
for _ in range(3):
    gen = generate(model, num_steps=12, temperature=0.7)
    print(f"  {decode(gen[0].tolist())}")

print("\nLeft-to-right (forced):")
for _ in range(3):
    gen = generate_left_to_right(model, temperature=0.7)
    print(f"  {decode(gen[0].tolist())}")

print("=" * 65)
print("Diffusion can use bidirectional context; L‚ÜíR generation cannot.")

In [None]:
#@title üéß Listen: Todo Topp
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_10_todo_topp.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 5. üîß Your Turn

### TODO 1: Implement Top-p (Nucleus) Sampling

In [None]:
def top_p_sample(probs, p=0.9):
    """Sample from the top-p (nucleus) of the distribution.

    Args:
        probs: Probability distribution, shape (V,)
        p: Cumulative probability threshold

    Returns:
        Sampled token ID
    """
    # ============ TODO ============
    # Step 1: Sort probabilities descending
    sorted_probs, sorted_indices = ???  # YOUR CODE

    # Step 2: Compute cumulative sum
    cumulative = ???  # YOUR CODE

    # Step 3: Create mask ‚Äî keep tokens up to cumulative prob p
    mask = cumulative - sorted_probs > p  # True for tokens past threshold

    # Step 4: Zero out and renormalize
    sorted_probs[mask] = 0.0
    sorted_probs = sorted_probs / sorted_probs.sum()

    # Step 5: Sample
    sampled_idx = torch.multinomial(sorted_probs, 1)
    token = sorted_indices[sampled_idx]
    # ==============================
    return token.item()

In [None]:
# ‚úÖ Verification
try:
    test_probs = torch.tensor([0.4, 0.3, 0.15, 0.1, 0.05])
    samples = [top_p_sample(test_probs, p=0.9) for _ in range(100)]
    from collections import Counter
    counts = Counter(samples)
    print("‚úÖ Top-p sampling works!")
    print(f"   Distribution: {dict(sorted(counts.items()))}")
    print(f"   Token 4 (prob=0.05) appeared {counts.get(4, 0)} times")
except NameError:
    print("‚ùå Replace the ??? placeholders.")

In [None]:
#@title üéß Listen: Todo Prompted
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_11_todo_prompted.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### TODO 2: Implement Prompted Generation

In [None]:
@torch.no_grad()
def generate_with_prompt(model, prompt_str, total_len=SEQ_LEN,
                          num_steps=15, temperature=0.7):
    """Generate text that continues from a given prompt.

    Args:
        model: Trained DiffusionLM
        prompt_str: Starting text
        total_len: Total output length
        num_steps: Unmasking steps
        temperature: Sampling temperature

    Returns:
        Generated string (prompt + completion)
    """
    model.eval()

    # ============ TODO ============
    # Step 1: Encode the prompt
    prompt_ids = ???  # YOUR CODE: encode(prompt_str)

    # Step 2: Create sequence ‚Äî prompt tokens + MASK for the rest
    x = torch.full((1, total_len), MASK_TOKEN, dtype=torch.long, device=device)
    prompt_len = min(len(prompt_ids), total_len)
    x[0, :prompt_len] = ???  # YOUR CODE: fill in prompt

    # Step 3: Create fixed mask ‚Äî True for prompt positions
    fixed = torch.zeros(1, total_len, dtype=torch.bool, device=device)
    fixed[0, :prompt_len] = ???  # YOUR CODE: True

    # Step 4: Iterative unmasking (skip fixed positions)
    for step in range(num_steps):
        t = torch.tensor([[1.0 - step / num_steps]], device=device)
        logits = model(x, t)
        probs = F.softmax(logits / temperature, dim=-1)
        predicted = torch.multinomial(probs.view(-1, VOCAB_SIZE), 1).view(1, total_len)
        confidence = probs.max(dim=-1).values

        is_masked = (x == MASK_TOKEN) & ~fixed
        remaining = is_masked.sum().item()
        if remaining == 0: break
        n_to_unmask = max(1, int(remaining / max(1, num_steps - step)))

        masked_conf = confidence.clone()
        masked_conf[~is_masked] = -float('inf')
        _, top_idx = masked_conf.view(-1).topk(min(n_to_unmask, remaining))
        x.view(-1)[top_idx] = predicted.view(-1)[top_idx]
    # ==============================

    return decode(x[0].tolist())

In [None]:
# ‚úÖ Verification
try:
    prompts = ["ROMEO:\n", "To be, ", "First C"]
    for p in prompts:
        result = generate_with_prompt(model, p, num_steps=15, temperature=0.7)
        print(f"  Prompt: '{p}'")
        print(f"  Output: '{result[:60]}'")
        print()
    print("‚úÖ Prompted generation works!")
except NameError:
    print("‚ùå Replace the ??? placeholders.")

In [None]:
#@title üéß Listen: Grand Viz
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_12_grand_viz.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 6. Putting It All Together

In [None]:
# Generate a batch of text samples
print("Final Generation Showcase:")
print("=" * 65)
for i in range(8):
    gen = generate(model, num_steps=15, temperature=0.7)
    text_out = decode(gen[0].tolist())
    print(f"  Sample {i+1}: {text_out}")
print("=" * 65)

## 7. üéØ Final Output

In [None]:
# üìä The Grand Visualization: step-by-step unmasking with highlights
torch.manual_seed(7)
gen_final, history_final, _ = generate_with_history(model, num_steps=12, temperature=0.7)

fig, axes = plt.subplots(len(history_final), 1, figsize=(20, len(history_final) * 1.0))

prev_seq = history_final[0][0].numpy()

for idx, (ax, (seq, label)) in enumerate(zip(axes, history_final)):
    tokens = seq.numpy()
    for pos in range(SEQ_LEN):
        if tokens[pos] == MASK_TOKEN:
            ax.add_patch(plt.Rectangle((pos, 0), 1, 1, color='#424242'))
            ax.text(pos + 0.5, 0.5, '‚ñà', ha='center', va='center',
                    color='#616161', fontsize=8)
        elif idx > 0 and prev_seq[pos] == MASK_TOKEN:
            # NEWLY revealed ‚Äî highlight in yellow!
            ax.add_patch(plt.Rectangle((pos, 0), 1, 1, color='#ffeb3b', alpha=0.9))
            char = id_to_char.get(tokens[pos], '?')
            ax.text(pos + 0.5, 0.5, char, ha='center', va='center',
                    fontsize=8, fontweight='bold', color='#e65100')
        else:
            # Previously revealed
            ax.add_patch(plt.Rectangle((pos, 0), 1, 1, color='#c8e6c9'))
            char = id_to_char.get(tokens[pos], '?')
            ax.text(pos + 0.5, 0.5, char, ha='center', va='center',
                    fontsize=8, color='#2e7d32')

    ax.set_xlim(0, SEQ_LEN)
    ax.set_ylim(0, 1)
    ax.set_ylabel(label, fontsize=9, rotation=0, ha='right', va='center', labelpad=100)
    ax.set_xticks([])
    ax.set_yticks([])
    prev_seq = tokens.copy()

plt.suptitle('üéØ Text Generation via Iterative Unmasking\n'
             '(dark=masked, yellow=newly revealed, green=previously revealed)',
             fontsize=13, y=1.04)
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Congrats Recap
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_13_congrats_recap.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

In [None]:
print("üéâ Congratulations! You've Built a Complete Diffusion LLM from Scratch!")
print("=" * 65)
print()
print("Over these 4 notebooks, you built:")
print("  1. Image diffusion foundations (forward + reverse process)")
print("  2. Masked diffusion for text (masking = noise for tokens)")
print("  3. A trained bidirectional Transformer on Shakespeare")
print("  4. Generation via iterative unmasking (this notebook!)")
print()
print("Key insights:")
print("  ‚Ä¢ Masking replaces Gaussian noise for discrete tokens")
print("  ‚Ä¢ Training = BERT at all masking ratios (ELBO confirms it)")
print("  ‚Ä¢ Generation = iterative unmasking by confidence")
print("  ‚Ä¢ Bidirectional context ‚Üí better than left-to-right")
print("  ‚Ä¢ Infilling is natural (impossible for autoregressive models)")
print()
print("This exact architecture, scaled to 8B parameters, gives LLaDA ‚Äî")
print("competitive with LLaMA 3. Mercury generates 1,000+ tok/s.")
print()
print("The era of one-token-at-a-time may be coming to an end.")

In [None]:
#@title üéß Listen: Bigger Picture
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_14_bigger_picture.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 8. The Bigger Picture

Our tiny character-level model is conceptually identical to the largest diffusion LLMs. The architecture, training, and generation are the same ‚Äî only the scale differs:

| | Our Model | LLaDA | Mercury |
|---|---|---|---|
| Parameters | ~500K | 8B | Undisclosed |
| Vocabulary | ~66 chars | 32K tokens | 32K+ tokens |
| Training data | 1MB Shakespeare | Trillions of tokens | Trillions |
| Speed | Instant | Fast | 1,000+ tok/s |

### Why Diffusion LLMs Might Be the Future

1. **Speed** ‚Äî Parallel prediction means 5-10x faster generation
2. **Bidirectional context** ‚Äî Every token sees past AND future
3. **Natural infilling** ‚Äî Just mask the middle, no special tricks
4. **Error correction** ‚Äî Iterative refinement lets the model "change its mind"

### Open Challenges

- Variable-length generation (how long should the output be?)
- Sequential reasoning (counting, arithmetic)
- Training compute (currently needs more than AR models of same size)

In [None]:
#@title üéß Listen: Closing
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_15_closing.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 9. Reflection and Next Steps

### ü§î Reflection Questions

1. **Why does confidence-order generation work better than left-to-right?** Easy tokens settle first, creating scaffolding for harder predictions. Bidirectional context means the ending informs the beginning.

2. **What is the tradeoff between steps and quality?** More steps = more refinement = higher quality, but diminishing returns. Even 5-10 steps work ‚Äî this is why diffusion LLMs are fast.

3. **How would you generate variable-length text?** Include an [END] token and let the model place it. Or generate to max length and truncate.

4. **Could you combine autoregressive and diffusion?** Yes ‚Äî generate blocks via diffusion, process blocks left-to-right. Best of both worlds.

### üèÜ Optional Challenges

1. Implement remasking ‚Äî occasionally re-mask some tokens and re-predict
2. Try word-level (BPE) tokenization
3. Scale up to a larger model and dataset
4. Implement classifier-free guidance for conditional generation
5. Compare generation quality between AR and diffusion approaches

---

**üéâ You have completed the Diffusion LLMs from Scratch series!**

You now understand how diffusion models work for language ‚Äî from the math through training to generation. These ideas power commercial systems generating 1,000+ tokens per second. The future of language generation may not be one token at a time.

In [None]:
#@title üí¨ AI Teaching Assistant ‚Äî Click ‚ñ∂ to start
#@markdown This AI chatbot reads your notebook and can answer questions about any concept, code, or exercise.

import json as _json
import requests as _requests
from google.colab import output as _output
from IPython.display import display, HTML as _HTML, Markdown as _Markdown

# --- Read notebook content for context ---
def _get_notebook_context():
    try:
        from google.colab import _message
        nb = _message.blocking_request("get_ipynb", request="", timeout_sec=10)
        cells = nb.get("ipynb", {}).get("cells", [])
        parts = []
        for cell in cells:
            src = "".join(cell.get("source", []))
            tags = cell.get("metadata", {}).get("tags", [])
            if "chatbot" in tags:
                continue
            if src.strip():
                ct = cell.get("cell_type", "unknown")
                parts.append(f"[{ct.upper()}]\n{src}")
        return "\n\n---\n\n".join(parts)
    except Exception:
        return "Notebook content unavailable."

_NOTEBOOK_CONTEXT = _get_notebook_context()
_CHAT_HISTORY = []
_API_URL = "https://course-creator-brown.vercel.app/api/chat"

def _notebook_chat(question):
    global _CHAT_HISTORY
    try:
        resp = _requests.post(_API_URL, json={
            'question': question,
            'context': _NOTEBOOK_CONTEXT[:100000],
            'history': _CHAT_HISTORY[-10:],
        }, timeout=60)
        data = resp.json()
        answer = data.get('answer', 'Sorry, I could not generate a response.')
        _CHAT_HISTORY.append({'role': 'user', 'content': question})
        _CHAT_HISTORY.append({'role': 'assistant', 'content': answer})
        return answer
    except Exception as e:
        return f'Error connecting to teaching assistant: {str(e)}'

_output.register_callback('notebook_chat', _notebook_chat)

def ask(question):
    """Ask the AI teaching assistant a question about this notebook."""
    answer = _notebook_chat(question)
    display(_Markdown(answer))

print("\u2705 AI Teaching Assistant is ready!")
print("\U0001f4a1 Use the chat below, or call ask(\'your question\') in any cell.")

# --- Display chat widget ---
display(_HTML('''<style>
  .vc-wrap{font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Roboto,sans-serif;max-width:100%;border-radius:16px;overflow:hidden;box-shadow:0 4px 24px rgba(0,0,0,.12);background:#fff;border:1px solid #e5e7eb}
  .vc-hdr{background:linear-gradient(135deg,#667eea 0%,#764ba2 100%);color:#fff;padding:16px 20px;display:flex;align-items:center;gap:12px}
  .vc-avatar{width:42px;height:42px;background:rgba(255,255,255,.2);border-radius:50%;display:flex;align-items:center;justify-content:center;font-size:22px}
  .vc-hdr h3{font-size:16px;font-weight:600;margin:0}
  .vc-hdr p{font-size:12px;opacity:.85;margin:2px 0 0}
  .vc-msgs{height:420px;overflow-y:auto;padding:16px;background:#f8f9fb;display:flex;flex-direction:column;gap:10px}
  .vc-msg{display:flex;flex-direction:column;animation:vc-fade .25s ease}
  .vc-msg.user{align-items:flex-end}
  .vc-msg.bot{align-items:flex-start}
  .vc-bbl{max-width:85%;padding:10px 14px;border-radius:16px;font-size:14px;line-height:1.55;word-wrap:break-word}
  .vc-msg.user .vc-bbl{background:linear-gradient(135deg,#667eea 0%,#764ba2 100%);color:#fff;border-bottom-right-radius:4px}
  .vc-msg.bot .vc-bbl{background:#fff;color:#1a1a2e;border:1px solid #e8e8e8;border-bottom-left-radius:4px}
  .vc-bbl code{background:rgba(0,0,0,.07);padding:2px 6px;border-radius:4px;font-size:13px;font-family:'Fira Code',monospace}
  .vc-bbl pre{background:#1e1e2e;color:#cdd6f4;padding:12px;border-radius:8px;overflow-x:auto;margin:8px 0;font-size:13px}
  .vc-bbl pre code{background:none;padding:0;color:inherit}
  .vc-bbl h3,.vc-bbl h4{margin:10px 0 4px;font-size:15px}
  .vc-bbl ul,.vc-bbl ol{margin:4px 0;padding-left:20px}
  .vc-bbl li{margin:2px 0}
  .vc-chips{display:flex;flex-wrap:wrap;gap:8px;padding:0 16px 12px;background:#f8f9fb}
  .vc-chip{background:#fff;border:1px solid #d1d5db;border-radius:20px;padding:6px 14px;font-size:12px;cursor:pointer;transition:all .15s;color:#4b5563}
  .vc-chip:hover{border-color:#667eea;color:#667eea;background:#f0f0ff}
  .vc-input{display:flex;padding:12px 16px;background:#fff;border-top:1px solid #eee;gap:8px}
  .vc-input input{flex:1;padding:10px 16px;border:2px solid #e8e8e8;border-radius:24px;font-size:14px;outline:none;transition:border-color .2s}
  .vc-input input:focus{border-color:#667eea}
  .vc-input button{background:linear-gradient(135deg,#667eea 0%,#764ba2 100%);color:#fff;border:none;border-radius:50%;width:42px;height:42px;cursor:pointer;display:flex;align-items:center;justify-content:center;font-size:18px;transition:transform .1s}
  .vc-input button:hover{transform:scale(1.05)}
  .vc-input button:disabled{opacity:.5;cursor:not-allowed;transform:none}
  .vc-typing{display:flex;gap:5px;padding:4px 0}
  .vc-typing span{width:8px;height:8px;background:#667eea;border-radius:50%;animation:vc-bounce 1.4s infinite ease-in-out}
  .vc-typing span:nth-child(2){animation-delay:.2s}
  .vc-typing span:nth-child(3){animation-delay:.4s}
  @keyframes vc-bounce{0%,80%,100%{transform:scale(0)}40%{transform:scale(1)}}
  @keyframes vc-fade{from{opacity:0;transform:translateY(8px)}to{opacity:1;transform:translateY(0)}}
  .vc-note{text-align:center;font-size:11px;color:#9ca3af;padding:8px 16px 12px;background:#fff}
</style>
<div class="vc-wrap">
  <div class="vc-hdr">
    <div class="vc-avatar">&#129302;</div>
    <div>
      <h3>Vizuara Teaching Assistant</h3>
      <p>Ask me anything about this notebook</p>
    </div>
  </div>
  <div class="vc-msgs" id="vcMsgs">
    <div class="vc-msg bot">
      <div class="vc-bbl">&#128075; Hi! I've read through this entire notebook. Ask me about any concept, code block, or exercise &mdash; I'm here to help you learn!</div>
    </div>
  </div>
  <div class="vc-chips" id="vcChips">
    <span class="vc-chip" onclick="vcAsk(this.textContent)">Explain the main concept</span>
    <span class="vc-chip" onclick="vcAsk(this.textContent)">Help with the TODO exercise</span>
    <span class="vc-chip" onclick="vcAsk(this.textContent)">Summarize what I learned</span>
  </div>
  <div class="vc-input">
    <input type="text" id="vcIn" placeholder="Ask about concepts, code, exercises..." />
    <button id="vcSend" onclick="vcSendMsg()">&#10148;</button>
  </div>
  <div class="vc-note">AI-generated &middot; Verify important information &middot; <a href="#" onclick="vcClear();return false" style="color:#667eea">Clear chat</a></div>
</div>
<script>
(function(){
  var msgs=document.getElementById('vcMsgs'),inp=document.getElementById('vcIn'),
      btn=document.getElementById('vcSend'),chips=document.getElementById('vcChips');

  function esc(s){var d=document.createElement('div');d.textContent=s;return d.innerHTML}

  function md(t){
    return t
      .replace(/```(\w*)\n([\s\S]*?)```/g,function(_,l,c){return '<pre><code>'+esc(c)+'</code></pre>'})
      .replace(/`([^`]+)`/g,'<code>$1</code>')
      .replace(/\*\*([^*]+)\*\*/g,'<strong>$1</strong>')
      .replace(/\*([^*]+)\*/g,'<em>$1</em>')
      .replace(/^#### (.+)$/gm,'<h4>$1</h4>')
      .replace(/^### (.+)$/gm,'<h4>$1</h4>')
      .replace(/^## (.+)$/gm,'<h3>$1</h3>')
      .replace(/^\d+\. (.+)$/gm,'<li>$1</li>')
      .replace(/^- (.+)$/gm,'<li>$1</li>')
      .replace(/\n\n/g,'<br><br>')
      .replace(/\n/g,'<br>');
  }

  function addMsg(text,isUser){
    var m=document.createElement('div');m.className='vc-msg '+(isUser?'user':'bot');
    var b=document.createElement('div');b.className='vc-bbl';
    b.innerHTML=isUser?esc(text):md(text);
    m.appendChild(b);msgs.appendChild(m);msgs.scrollTop=msgs.scrollHeight;
  }

  function showTyping(){
    var m=document.createElement('div');m.className='vc-msg bot';m.id='vcTyping';
    m.innerHTML='<div class="vc-bbl"><div class="vc-typing"><span></span><span></span><span></span></div></div>';
    msgs.appendChild(m);msgs.scrollTop=msgs.scrollHeight;
  }

  function hideTyping(){var e=document.getElementById('vcTyping');if(e)e.remove()}

  window.vcSendMsg=function(){
    var q=inp.value.trim();if(!q)return;
    inp.value='';chips.style.display='none';
    addMsg(q,true);showTyping();btn.disabled=true;
    google.colab.kernel.invokeFunction('notebook_chat',[q],{})
      .then(function(r){
        hideTyping();
        var a=r.data['application/json'];
        addMsg(typeof a==='string'?a:JSON.stringify(a),false);
      })
      .catch(function(){
        hideTyping();
        addMsg('Sorry, I encountered an error. Please check your internet connection and try again.',false);
      })
      .finally(function(){btn.disabled=false;inp.focus()});
  };

  window.vcAsk=function(q){inp.value=q;vcSendMsg()};
  window.vcClear=function(){
    msgs.innerHTML='<div class="vc-msg bot"><div class="vc-bbl">&#128075; Chat cleared. Ask me anything!</div></div>';
    chips.style.display='flex';
  };

  inp.addEventListener('keypress',function(e){if(e.key==='Enter')vcSendMsg()});
  inp.focus();
})();
</script>'''))