In [None]:
#@title üéß Download Narration Audio & Play Introduction
import os as _os
if not _os.path.exists("/content/narration"):
    !pip install -q gdown
    import gdown
    gdown.download(id="1Szlw2uIb0f2EpTdm_4p4cFRzB3zgMemA", output="/content/narration.zip", quiet=False)
    !unzip -q /content/narration.zip -d /content/narration
    !rm /content/narration.zip
    print(f"Loaded {len(_os.listdir('/content/narration'))} narration segments")
else:
    print("Narration audio already loaded.")

from IPython.display import Audio, display
display(Audio("/content/narration/02_00_intro.mp3"))

In [None]:
# üîß Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"‚úÖ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("‚ö†Ô∏è No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

print(f"\nüì¶ Python {sys.version.split()[0]}")
print(f"üî• PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"üé≤ Random seed set to {SEED}")

%matplotlib inline

# üöÄ Masked Diffusion for Text: The Forward Process

*Part 2 of the Vizuara series on Diffusion LLMs from Scratch*
*Estimated time: 35 minutes*

## 1. Why Does This Matter?

In Notebook 1, we built a working image diffusion model. But we hit a wall: **Gaussian noise is meaningless for discrete tokens.** You cannot add 0.3 units of noise to the word "cat."

In this notebook, we solve that problem with a beautifully simple idea: **replace Gaussian noise with masking.** Instead of corrupting images with static, we corrupt text by replacing tokens with [MASK].

This turns out to be equivalent to BERT's masked language modeling ‚Äî but generalized across all masking ratios. And it gives us a fully valid diffusion model for text.

**By the end of this notebook, you will:**
- Implement the masked forward process for text
- Build and train a bidirectional Transformer that predicts masked tokens
- See how this is mathematically connected to BERT
- Visualize the model's predictions at different masking levels

In [None]:
#@title üéß Listen: Intuition
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_01_intuition.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 2. Building Intuition

### Masking = Erasing Words from a Page

In image diffusion, the forward process *destroys* an image by adding noise. After enough noise, you have pure static ‚Äî no information remains.

For text, the equivalent of "destroying" a sentence is **masking out the words.** If every word is masked, no information remains ‚Äî just like pure noise for an image.

| Timestep $t$ | Masking Probability | What the Model Sees |
|---|---|---|
| $t = 0.0$ | 0% masked | The cat sat on the mat |
| $t = 0.2$ | 20% masked | The cat sat [M] the mat |
| $t = 0.5$ | 50% masked | [M] cat [M] on [M] mat |
| $t = 0.8$ | 80% masked | [M] [M] [M] on [M] [M] |
| $t = 1.0$ | 100% masked | [M] [M] [M] [M] [M] [M] |

### The BERT Connection

If you have studied BERT, this should look very familiar. BERT is trained by masking 15% of tokens and predicting what they should be. Diffusion LLMs do exactly the same thing ‚Äî but instead of always masking 15%, they train with **every possible masking ratio from 0% to 100%.**

This single change transforms BERT from a language *understanding* model into a full *generative* model.

### ü§î Think About This

If you saw "[M] cat [M] on [M] mat", what would you guess the masked words are? What clues helped you? Notice that you use **both left and right context** ‚Äî you know the first word is probably "The" because of "cat" to its right, not just because of what is to its left.

In [None]:
#@title üéß Listen: Math
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_02_math.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 3. The Mathematics

### The Forward Process

At timestep $t$, each token is independently masked with probability $t$:

$$q(x_t^{(i)} \mid x_0^{(i)}) = \begin{cases} x_0^{(i)} & \text{with probability } 1 - t \\ [\text{MASK}] & \text{with probability } t \end{cases}$$

**What this says computationally:** For each token, flip a coin with bias $t$. Heads ‚Üí replace with [MASK]. Tails ‚Üí keep the original token. Each token is treated independently.

### Worked Example

Suppose we have 10 tokens and $t = 0.4$. Each token has a 40% chance of being masked. The expected number of masked tokens is $10 \times 0.4 = 4$. On average, 4 out of 10 tokens are replaced with [MASK].

### The Training Objective

The loss is cross-entropy on masked positions ‚Äî identical to BERT:

$$\mathcal{L}(\theta) = \mathbb{E}_{t \sim U(0,1)} \, \mathbb{E}_{x_t \sim q(x_t \mid x_0)} \left[ -\sum_{i : x_t^{(i)} = [\text{MASK}]} \log p_\theta\!\left(x_0^{(i)} \mid x_t\right) \right]$$

Breaking this down term by term:
- $t \sim U(0,1)$: randomly sample a masking ratio between 0% and 100%
- $x_t \sim q(x_t \mid x_0)$: create a corrupted version by masking tokens with probability $t$
- The sum runs over all masked positions $i$
- $\log p_\theta(x_0^{(i)} \mid x_t)$: log probability the model assigns to the correct token

### Worked Loss Example

Sentence: "The cat sat" (3 tokens). Sample $t = 0.67$, which masks 2 tokens: "[M] cat [M]".

Model predicts:
- Position 1: P("The") = 0.6
- Position 3: P("sat") = 0.8

Loss: $\mathcal{L} = -\log(0.6) - \log(0.8) = 0.511 + 0.223 = 0.734$

As training progresses, these probabilities increase and the loss decreases.

In [None]:
#@title üéß Listen: Setup Data
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_03_setup_data.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 4. Let's Build It ‚Äî Component by Component

### 4.1 Setup and Dataset

We will use a character-level dataset with simple repeating patterns. This is small enough to train in minutes but rich enough to show the model learning real structure.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import math

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

torch.manual_seed(42)
np.random.seed(42)

# --- Configuration ---
VOCAB_SIZE = 16       # Token IDs 0-15, where 0 is [MASK]
SEQ_LEN = 32          # Sequence length
MASK_TOKEN = 0        # Token ID for [MASK]
BATCH_SIZE = 64
D_MODEL = 64
N_HEADS = 4
N_LAYERS = 3

%matplotlib inline

In [None]:
def generate_pattern_data(batch_size, seq_len, vocab_size):
    """Generate sequences with repeating patterns.

    Each sequence picks a random short pattern (length 2-4) and
    tiles it to fill the sequence. The model must learn these patterns.
    """
    sequences = []
    for _ in range(batch_size):
        pattern_len = np.random.randint(2, 5)
        # Use tokens 1 to vocab_size-1 (avoid 0 = MASK)
        pattern = np.random.randint(1, vocab_size, size=pattern_len)
        seq = np.tile(pattern, seq_len // pattern_len + 1)[:seq_len]
        sequences.append(seq)
    return torch.tensor(np.array(sequences), dtype=torch.long, device=device)

# Show some examples
examples = generate_pattern_data(5, SEQ_LEN, VOCAB_SIZE)
for i, seq in enumerate(examples):
    print(f"Pattern {i+1}: {seq.tolist()}")

In [None]:
#@title üéß Listen: Masking Function
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_04_masking_function.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.2 The Forward Process (Masking)

This is the heart of masked diffusion ‚Äî randomly masking tokens with probability $t$.

In [None]:
def mask_tokens(x_0, t):
    """Apply the forward masking process.

    Args:
        x_0: Clean token sequences, shape (B, L)
        t: Masking probability for each sample, shape (B, 1)

    Returns:
        x_t: Masked sequences, shape (B, L)
        mask: Boolean mask showing which positions were masked
    """
    # For each token, independently mask with probability t
    random_vals = torch.rand_like(x_0.float())
    mask = random_vals < t                  # True where masked
    x_t = x_0.clone()
    x_t[mask] = MASK_TOKEN
    return x_t, mask

In [None]:
# üìä Visualize the forward process at different timesteps
fig, axes = plt.subplots(1, 6, figsize=(20, 2.5))
sample = generate_pattern_data(1, SEQ_LEN, VOCAB_SIZE)

for ax, t_val in zip(axes, [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]):
    t = torch.tensor([[t_val]], device=device)
    masked, _ = mask_tokens(sample, t)
    display = masked[0].cpu().numpy()

    for pos in range(SEQ_LEN):
        if display[pos] == MASK_TOKEN:
            ax.add_patch(plt.Rectangle((pos, 0), 1, 1, color='#333333'))
            ax.text(pos + 0.5, 0.5, 'M', ha='center', va='center',
                    color='white', fontsize=7, fontweight='bold')
        else:
            color = plt.cm.Set2(display[pos] / VOCAB_SIZE)
            ax.add_patch(plt.Rectangle((pos, 0), 1, 1, color=color))
            ax.text(pos + 0.5, 0.5, str(display[pos]), ha='center',
                    va='center', fontsize=7)

    ax.set_xlim(0, SEQ_LEN)
    ax.set_ylim(0, 1)
    ax.set_title(f't = {t_val}', fontsize=11)
    ax.set_xticks([])
    ax.set_yticks([])

plt.suptitle('Forward Process: Gradually Masking Tokens', fontsize=14, y=1.08)
plt.tight_layout()
plt.show()
print("Dark cells = [MASK]. As t increases, more tokens are masked.")

In [None]:
#@title üéß Listen: Stats Bert
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_05_stats_bert.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.3 Masking Statistics

Let us verify that the implementation matches the math.

In [None]:
# üìä Run the forward process many times and check statistics
n_trials = 1000
t_values = torch.linspace(0.05, 0.95, 20)
actual_fractions = []

for t_val in t_values:
    fracs = []
    for _ in range(n_trials):
        sample = generate_pattern_data(1, SEQ_LEN, VOCAB_SIZE)
        t = torch.tensor([[t_val.item()]], device=device)
        _, mask = mask_tokens(sample, t)
        fracs.append(mask.float().mean().item())
    actual_fractions.append(np.mean(fracs))

plt.figure(figsize=(8, 5))
plt.plot(t_values.numpy(), actual_fractions, 'o-', color='#1565c0',
         label='Actual masked fraction', markersize=5)
plt.plot([0, 1], [0, 1], '--', color='#e53935', label='Expected (y = t)')
plt.xlabel('Masking probability t', fontsize=12)
plt.ylabel('Fraction of tokens masked', fontsize=12)
plt.title('Forward Process Verification', fontsize=13)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("‚úÖ The actual masking fraction matches the expected fraction perfectly.")

### 4.4 BERT vs Diffusion: Masking Ratio Comparison

In [None]:
# üìä BERT uses a fixed 15% masking rate. Diffusion trains at ALL rates.
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# BERT: spike at 0.15
x = np.linspace(0, 1, 100)
bert_dist = np.zeros_like(x)
bert_dist[np.argmin(np.abs(x - 0.15))] = 1.0
axes[0].bar([0.15], [1.0], width=0.03, color='#e53935', label='BERT (15% only)')
axes[0].set_xlabel('Masking ratio t', fontsize=11)
axes[0].set_ylabel('Training frequency', fontsize=11)
axes[0].set_title('BERT: Fixed Masking Ratio', fontsize=13)
axes[0].set_xlim(0, 1)
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Diffusion: uniform over [0, 1]
axes[1].fill_between([0, 1], [1, 1], color='#1565c0', alpha=0.5,
                      label='Diffusion LLM (all ratios)')
axes[1].set_xlabel('Masking ratio t', fontsize=11)
axes[1].set_ylabel('Training frequency', fontsize=11)
axes[1].set_title('Diffusion LLM: All Masking Ratios', fontsize=13)
axes[1].set_xlim(0, 1)
axes[1].set_ylim(0, 1.5)
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()
print("BERT trains at one fixed ratio. Diffusion trains at EVERY ratio.")
print("This is what makes diffusion a full generative model.")

In [None]:
#@title üéß Listen: Transformer
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_06_transformer.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.5 The Bidirectional Transformer

Our model is a standard Transformer **encoder** ‚Äî with **no causal mask**. Every token can attend to every other token, both left and right.

In [None]:
class PositionalEncoding(nn.Module):
    """Standard sinusoidal positional encoding."""
    def __init__(self, d_model, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

In [None]:
class DiffusionLM(nn.Module):
    """Bidirectional Transformer for masked diffusion language modeling."""

    def __init__(self, vocab_size, d_model, n_heads, n_layers, max_len=512):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len)
        self.time_mlp = nn.Sequential(
            nn.Linear(1, d_model),
            nn.SiLU(),
            nn.Linear(d_model, d_model),
        )
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads,
            dim_feedforward=d_model * 4,
            dropout=0.1, batch_first=True, norm_first=True,
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, n_layers)
        self.output_head = nn.Linear(d_model, vocab_size)

    def forward(self, x_t, t):
        """
        Args:
            x_t: Masked token IDs, shape (B, L)
            t: Masking ratio, shape (B, 1)
        Returns:
            Logits, shape (B, L, V)
        """
        h = self.token_embed(x_t)
        h = self.pos_enc(h)
        t_emb = self.time_mlp(t).unsqueeze(1)
        h = h + t_emb
        # Bidirectional ‚Äî NO causal mask!
        h = self.transformer(h)
        return self.output_head(h)


model = DiffusionLM(VOCAB_SIZE, D_MODEL, N_HEADS, N_LAYERS).to(device)
n_params = sum(p.numel() for p in model.parameters())
print(f"DiffusionLM parameters: {n_params:,}")

### üí° Key Insight: Why No Causal Mask?

In GPT, a causal mask prevents each position from seeing future tokens. This enforces left-to-right generation.

In our model, we **want** every position to see every other position. This bidirectional attention lets the model:
- Use tokens on the *right* to predict masked tokens on the *left*
- Fill in tokens in any order, not just left-to-right
- Overcome the "reversal curse" that plagues autoregressive models

In [None]:
#@title üéß Listen: Training
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_07_training.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.6 The Training Loop

In [None]:
def train_diffusion_lm(model, n_steps=3000, lr=3e-4):
    """Train the masked diffusion language model."""
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_steps)
    losses = []

    for step in range(n_steps):
        x_0 = generate_pattern_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)

        # Random masking ratio for each sample
        t = torch.rand(BATCH_SIZE, 1, device=device) * 0.98 + 0.02

        # Mask tokens
        x_t, mask = mask_tokens(x_0, t)

        # Predict original tokens
        logits = model(x_t, t)

        # Cross-entropy loss ONLY at masked positions
        logits_masked = logits[mask]
        targets_masked = x_0[mask]

        if logits_masked.shape[0] == 0:
            continue

        loss = F.cross_entropy(logits_masked, targets_masked)
        losses.append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        if (step + 1) % 500 == 0:
            print(f"Step {step+1}/{n_steps} | Loss: {loss.item():.4f}")

    return losses

print("Training...")
losses = train_diffusion_lm(model, n_steps=3000)
print("Done!")

In [None]:
# üìä Training loss curve
plt.figure(figsize=(10, 4))
window = 50
smoothed = np.convolve(losses, np.ones(window)/window, mode='valid')
plt.plot(smoothed, color='#1565c0', linewidth=2)
plt.xlabel('Training Step', fontsize=11)
plt.ylabel('Cross-Entropy Loss', fontsize=11)
plt.title('Diffusion LM Training Loss', fontsize=13)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Todo1
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_08_todo1.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 5. üîß Your Turn

### TODO 1: Compute the Loss for a Single Example

In [None]:
def compute_single_loss(model, x_0, t_val):
    """Compute the masked diffusion loss for one sequence.

    Args:
        model: Trained DiffusionLM
        x_0: Clean tokens, shape (1, L)
        t_val: Float masking probability

    Returns:
        loss: Scalar loss value
        n_masked: Number of masked tokens
    """
    t = torch.tensor([[t_val]], device=device)
    x_t, mask = mask_tokens(x_0, t)

    # ============ TODO ============
    # Step 1: Get model predictions (logits)
    logits = ???  # YOUR CODE HERE

    # Step 2: Extract logits and targets at masked positions
    logits_masked = ???  # YOUR CODE HERE
    targets_masked = ???  # YOUR CODE HERE

    # Step 3: Compute cross-entropy loss
    loss = ???  # YOUR CODE HERE
    # ==============================

    return loss.item(), mask.sum().item()

In [None]:
# ‚úÖ Verification
try:
    test_seq = generate_pattern_data(1, SEQ_LEN, VOCAB_SIZE)
    loss_val, n_masked = compute_single_loss(model, test_seq, 0.5)
    assert isinstance(loss_val, float), "Loss should be a float"
    assert n_masked > 0, "Should have masked some tokens"
    print(f"‚úÖ Loss = {loss_val:.4f}, masked {n_masked}/{SEQ_LEN} tokens")
except NameError:
    print("‚ùå Replace the ??? placeholders.")
except Exception as e:
    print(f"‚ùå Error: {e}")

In [None]:
#@title üéß Listen: Todo2
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_09_todo2.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### TODO 2: Predict Masked Tokens

In [None]:
@torch.no_grad()
def predict_masked(model, x_0, t_val, top_k=3):
    """Mask a sequence and show the model's top-k predictions.

    Args:
        model: Trained DiffusionLM
        x_0: Clean tokens, shape (1, L)
        t_val: Masking probability
        top_k: Number of top predictions to show

    Returns:
        x_t: Masked sequence
        predictions: List of (position, true_token, [(pred_token, prob), ...])
    """
    model.eval()
    t = torch.tensor([[t_val]], device=device)
    x_t, mask = mask_tokens(x_0, t)

    # ============ TODO ============
    # Step 1: Get model logits
    logits = ???  # YOUR CODE HERE

    # Step 2: Convert to probabilities
    probs = ???  # YOUR CODE HERE: softmax over vocabulary dimension

    # Step 3: For each masked position, get top-k predictions
    predictions = []
    for pos in range(SEQ_LEN):
        if mask[0, pos]:
            top_probs, top_tokens = ???  # YOUR CODE: topk on probs[0, pos]
            preds = [(t.item(), p.item()) for t, p in zip(top_tokens, top_probs)]
            predictions.append((pos, x_0[0, pos].item(), preds))
    # ==============================

    model.train()
    return x_t, predictions

In [None]:
# ‚úÖ Verification
try:
    test_seq = generate_pattern_data(1, SEQ_LEN, VOCAB_SIZE)
    masked_seq, preds = predict_masked(model, test_seq, 0.5, top_k=3)
    assert len(preds) > 0, "Should have predictions"
    print("‚úÖ Model predictions for masked positions:")
    for pos, true_tok, top_preds in preds[:5]:
        pred_str = ", ".join([f"tok {t}({p:.2f})" for t, p in top_preds])
        correct = "‚úÖ" if top_preds[0][0] == true_tok else "‚ùå"
        print(f"  Pos {pos:2d}: true={true_tok:2d} | preds: {pred_str} {correct}")
except NameError:
    print("‚ùå Replace the ??? placeholders.")

In [None]:
#@title üéß Listen: Results
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_10_results.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 6. Putting It All Together

In [None]:
# Show model predictions at different masking ratios
test_seq = generate_pattern_data(1, SEQ_LEN, VOCAB_SIZE)
print(f"Original sequence: {test_seq[0].tolist()}\n")

for t_val in [0.2, 0.5, 0.8]:
    masked_seq, preds = predict_masked(model, test_seq, t_val)
    n_correct = sum(1 for _, true, top in preds if top[0][0] == true)
    n_total = len(preds)
    acc = n_correct / max(n_total, 1) * 100

    masked_display = masked_seq[0].cpu().tolist()
    display_str = " ".join(
        f"[M]" if tok == MASK_TOKEN else f" {tok:2d}" for tok in masked_display
    )
    print(f"t={t_val}: {display_str}")
    print(f"  Accuracy: {n_correct}/{n_total} = {acc:.0f}%\n")

## 7. üéØ Final Output

In [None]:
# üìä Beautiful visualization: masked sequence with model's predictions
torch.manual_seed(99)
test_seq = generate_pattern_data(1, SEQ_LEN, VOCAB_SIZE)
masked_seq, preds = predict_masked(model, test_seq, 0.5, top_k=3)

fig, ax = plt.subplots(figsize=(18, 4))
display = masked_seq[0].cpu().numpy()
original = test_seq[0].cpu().numpy()

pred_dict = {pos: (true, top) for pos, true, top in preds}

for pos in range(SEQ_LEN):
    if display[pos] == MASK_TOKEN:
        # Masked position ‚Äî show prediction
        ax.add_patch(plt.Rectangle((pos, 0.5), 1, 0.5, color='#333333'))
        ax.text(pos + 0.5, 0.75, '[M]', ha='center', va='center',
                color='white', fontsize=8, fontweight='bold')

        if pos in pred_dict:
            true_tok, top_preds = pred_dict[pos]
            for j, (tok, prob) in enumerate(top_preds[:3]):
                color = '#2e7d32' if tok == true_tok else '#e53935'
                ax.text(pos + 0.5, 0.35 - j * 0.12,
                        f'{tok}({prob:.1f})', ha='center', fontsize=7,
                        color=color, fontweight='bold' if j == 0 else 'normal')
    else:
        color = plt.cm.Set2(display[pos] / VOCAB_SIZE)
        ax.add_patch(plt.Rectangle((pos, 0.5), 1, 0.5, color=color))
        ax.text(pos + 0.5, 0.75, str(display[pos]), ha='center',
                va='center', fontsize=9)

ax.set_xlim(0, SEQ_LEN)
ax.set_ylim(-0.1, 1.1)
ax.set_xticks([])
ax.set_yticks([])
ax.set_title('Model Predictions for Masked Tokens (green = correct, red = wrong)',
             fontsize=13)
plt.tight_layout()
plt.show()
print("Top row: the masked sequence. Below each [M]: model's top-3 predictions.")

In [None]:
#@title üéß Listen: Accuracy
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_11_accuracy.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

In [None]:
# üìä Accuracy vs masking ratio
t_values = np.linspace(0.1, 0.95, 15)
accuracies = []

for t_val in t_values:
    correct = 0
    total = 0
    for _ in range(50):
        test_seq = generate_pattern_data(1, SEQ_LEN, VOCAB_SIZE)
        _, preds = predict_masked(model, test_seq, t_val, top_k=1)
        correct += sum(1 for _, true, top in preds if top[0][0] == true)
        total += len(preds)
    accuracies.append(correct / max(total, 1) * 100)

plt.figure(figsize=(10, 5))
plt.plot(t_values, accuracies, 'o-', color='#1565c0', linewidth=2, markersize=6)
plt.xlabel('Masking ratio t', fontsize=12)
plt.ylabel('Prediction accuracy (%)', fontsize=12)
plt.title('Model Accuracy vs Masking Ratio', fontsize=14)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("Higher masking = harder task (less context). Accuracy decreases as expected.")

In [None]:
#@title üéß Listen: Closing
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_12_closing.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 8. Reflection and Next Steps

### ü§î Reflection Questions

1. **Why is bidirectional attention crucial?** With causal (left-to-right) attention, the model could not use "mat" at position 6 to predict "[M]" at position 1. Bidirectional attention gives full context from both directions.

2. **Why train with all masking ratios instead of just 15% like BERT?** BERT at 15% learns to fill in occasional blanks. But for *generation*, we start from 100% masked and progressively unmask. The model needs to handle every ratio from 0% to 100%.

3. **At what masking ratio is the task hardest?** At very high $t$ (e.g., 90%), the model sees very few tokens. At very low $t$, most tokens are visible and the few masked ones are easy to infer. The hardest regime is high masking with limited context.

### üèÜ Optional Challenges

1. Try different timestep embedding strategies (learned vs sinusoidal)
2. Compare model performance at different masking ratios ‚Äî plot a loss breakdown
3. Visualize the attention patterns ‚Äî are they different from BERT?

---

**Up Next ‚Äî Notebook 3:** *Training a Diffusion LLM.* We will dive into the ELBO, train a complete masked diffusion language model on real text, and see how the mathematical theory confirms that "train BERT at all masking ratios" is a rigorous diffusion process.

In [None]:
#@title üí¨ AI Teaching Assistant ‚Äî Click ‚ñ∂ to start
#@markdown This AI chatbot reads your notebook and can answer questions about any concept, code, or exercise.

import json as _json
import requests as _requests
from google.colab import output as _output
from IPython.display import display, HTML as _HTML, Markdown as _Markdown

# --- Read notebook content for context ---
def _get_notebook_context():
    try:
        from google.colab import _message
        nb = _message.blocking_request("get_ipynb", request="", timeout_sec=10)
        cells = nb.get("ipynb", {}).get("cells", [])
        parts = []
        for cell in cells:
            src = "".join(cell.get("source", []))
            tags = cell.get("metadata", {}).get("tags", [])
            if "chatbot" in tags:
                continue
            if src.strip():
                ct = cell.get("cell_type", "unknown")
                parts.append(f"[{ct.upper()}]\n{src}")
        return "\n\n---\n\n".join(parts)
    except Exception:
        return "Notebook content unavailable."

_NOTEBOOK_CONTEXT = _get_notebook_context()
_CHAT_HISTORY = []
_API_URL = "https://course-creator-brown.vercel.app/api/chat"

def _notebook_chat(question):
    global _CHAT_HISTORY
    try:
        resp = _requests.post(_API_URL, json={
            'question': question,
            'context': _NOTEBOOK_CONTEXT[:100000],
            'history': _CHAT_HISTORY[-10:],
        }, timeout=60)
        data = resp.json()
        answer = data.get('answer', 'Sorry, I could not generate a response.')
        _CHAT_HISTORY.append({'role': 'user', 'content': question})
        _CHAT_HISTORY.append({'role': 'assistant', 'content': answer})
        return answer
    except Exception as e:
        return f'Error connecting to teaching assistant: {str(e)}'

_output.register_callback('notebook_chat', _notebook_chat)

def ask(question):
    """Ask the AI teaching assistant a question about this notebook."""
    answer = _notebook_chat(question)
    display(_Markdown(answer))

print("\u2705 AI Teaching Assistant is ready!")
print("\U0001f4a1 Use the chat below, or call ask(\'your question\') in any cell.")

# --- Display chat widget ---
display(_HTML('''<style>
  .vc-wrap{font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Roboto,sans-serif;max-width:100%;border-radius:16px;overflow:hidden;box-shadow:0 4px 24px rgba(0,0,0,.12);background:#fff;border:1px solid #e5e7eb}
  .vc-hdr{background:linear-gradient(135deg,#667eea 0%,#764ba2 100%);color:#fff;padding:16px 20px;display:flex;align-items:center;gap:12px}
  .vc-avatar{width:42px;height:42px;background:rgba(255,255,255,.2);border-radius:50%;display:flex;align-items:center;justify-content:center;font-size:22px}
  .vc-hdr h3{font-size:16px;font-weight:600;margin:0}
  .vc-hdr p{font-size:12px;opacity:.85;margin:2px 0 0}
  .vc-msgs{height:420px;overflow-y:auto;padding:16px;background:#f8f9fb;display:flex;flex-direction:column;gap:10px}
  .vc-msg{display:flex;flex-direction:column;animation:vc-fade .25s ease}
  .vc-msg.user{align-items:flex-end}
  .vc-msg.bot{align-items:flex-start}
  .vc-bbl{max-width:85%;padding:10px 14px;border-radius:16px;font-size:14px;line-height:1.55;word-wrap:break-word}
  .vc-msg.user .vc-bbl{background:linear-gradient(135deg,#667eea 0%,#764ba2 100%);color:#fff;border-bottom-right-radius:4px}
  .vc-msg.bot .vc-bbl{background:#fff;color:#1a1a2e;border:1px solid #e8e8e8;border-bottom-left-radius:4px}
  .vc-bbl code{background:rgba(0,0,0,.07);padding:2px 6px;border-radius:4px;font-size:13px;font-family:'Fira Code',monospace}
  .vc-bbl pre{background:#1e1e2e;color:#cdd6f4;padding:12px;border-radius:8px;overflow-x:auto;margin:8px 0;font-size:13px}
  .vc-bbl pre code{background:none;padding:0;color:inherit}
  .vc-bbl h3,.vc-bbl h4{margin:10px 0 4px;font-size:15px}
  .vc-bbl ul,.vc-bbl ol{margin:4px 0;padding-left:20px}
  .vc-bbl li{margin:2px 0}
  .vc-chips{display:flex;flex-wrap:wrap;gap:8px;padding:0 16px 12px;background:#f8f9fb}
  .vc-chip{background:#fff;border:1px solid #d1d5db;border-radius:20px;padding:6px 14px;font-size:12px;cursor:pointer;transition:all .15s;color:#4b5563}
  .vc-chip:hover{border-color:#667eea;color:#667eea;background:#f0f0ff}
  .vc-input{display:flex;padding:12px 16px;background:#fff;border-top:1px solid #eee;gap:8px}
  .vc-input input{flex:1;padding:10px 16px;border:2px solid #e8e8e8;border-radius:24px;font-size:14px;outline:none;transition:border-color .2s}
  .vc-input input:focus{border-color:#667eea}
  .vc-input button{background:linear-gradient(135deg,#667eea 0%,#764ba2 100%);color:#fff;border:none;border-radius:50%;width:42px;height:42px;cursor:pointer;display:flex;align-items:center;justify-content:center;font-size:18px;transition:transform .1s}
  .vc-input button:hover{transform:scale(1.05)}
  .vc-input button:disabled{opacity:.5;cursor:not-allowed;transform:none}
  .vc-typing{display:flex;gap:5px;padding:4px 0}
  .vc-typing span{width:8px;height:8px;background:#667eea;border-radius:50%;animation:vc-bounce 1.4s infinite ease-in-out}
  .vc-typing span:nth-child(2){animation-delay:.2s}
  .vc-typing span:nth-child(3){animation-delay:.4s}
  @keyframes vc-bounce{0%,80%,100%{transform:scale(0)}40%{transform:scale(1)}}
  @keyframes vc-fade{from{opacity:0;transform:translateY(8px)}to{opacity:1;transform:translateY(0)}}
  .vc-note{text-align:center;font-size:11px;color:#9ca3af;padding:8px 16px 12px;background:#fff}
</style>
<div class="vc-wrap">
  <div class="vc-hdr">
    <div class="vc-avatar">&#129302;</div>
    <div>
      <h3>Vizuara Teaching Assistant</h3>
      <p>Ask me anything about this notebook</p>
    </div>
  </div>
  <div class="vc-msgs" id="vcMsgs">
    <div class="vc-msg bot">
      <div class="vc-bbl">&#128075; Hi! I've read through this entire notebook. Ask me about any concept, code block, or exercise &mdash; I'm here to help you learn!</div>
    </div>
  </div>
  <div class="vc-chips" id="vcChips">
    <span class="vc-chip" onclick="vcAsk(this.textContent)">Explain the main concept</span>
    <span class="vc-chip" onclick="vcAsk(this.textContent)">Help with the TODO exercise</span>
    <span class="vc-chip" onclick="vcAsk(this.textContent)">Summarize what I learned</span>
  </div>
  <div class="vc-input">
    <input type="text" id="vcIn" placeholder="Ask about concepts, code, exercises..." />
    <button id="vcSend" onclick="vcSendMsg()">&#10148;</button>
  </div>
  <div class="vc-note">AI-generated &middot; Verify important information &middot; <a href="#" onclick="vcClear();return false" style="color:#667eea">Clear chat</a></div>
</div>
<script>
(function(){
  var msgs=document.getElementById('vcMsgs'),inp=document.getElementById('vcIn'),
      btn=document.getElementById('vcSend'),chips=document.getElementById('vcChips');

  function esc(s){var d=document.createElement('div');d.textContent=s;return d.innerHTML}

  function md(t){
    return t
      .replace(/```(\w*)\n([\s\S]*?)```/g,function(_,l,c){return '<pre><code>'+esc(c)+'</code></pre>'})
      .replace(/`([^`]+)`/g,'<code>$1</code>')
      .replace(/\*\*([^*]+)\*\*/g,'<strong>$1</strong>')
      .replace(/\*([^*]+)\*/g,'<em>$1</em>')
      .replace(/^#### (.+)$/gm,'<h4>$1</h4>')
      .replace(/^### (.+)$/gm,'<h4>$1</h4>')
      .replace(/^## (.+)$/gm,'<h3>$1</h3>')
      .replace(/^\d+\. (.+)$/gm,'<li>$1</li>')
      .replace(/^- (.+)$/gm,'<li>$1</li>')
      .replace(/\n\n/g,'<br><br>')
      .replace(/\n/g,'<br>');
  }

  function addMsg(text,isUser){
    var m=document.createElement('div');m.className='vc-msg '+(isUser?'user':'bot');
    var b=document.createElement('div');b.className='vc-bbl';
    b.innerHTML=isUser?esc(text):md(text);
    m.appendChild(b);msgs.appendChild(m);msgs.scrollTop=msgs.scrollHeight;
  }

  function showTyping(){
    var m=document.createElement('div');m.className='vc-msg bot';m.id='vcTyping';
    m.innerHTML='<div class="vc-bbl"><div class="vc-typing"><span></span><span></span><span></span></div></div>';
    msgs.appendChild(m);msgs.scrollTop=msgs.scrollHeight;
  }

  function hideTyping(){var e=document.getElementById('vcTyping');if(e)e.remove()}

  window.vcSendMsg=function(){
    var q=inp.value.trim();if(!q)return;
    inp.value='';chips.style.display='none';
    addMsg(q,true);showTyping();btn.disabled=true;
    google.colab.kernel.invokeFunction('notebook_chat',[q],{})
      .then(function(r){
        hideTyping();
        var a=r.data['application/json'];
        addMsg(typeof a==='string'?a:JSON.stringify(a),false);
      })
      .catch(function(){
        hideTyping();
        addMsg('Sorry, I encountered an error. Please check your internet connection and try again.',false);
      })
      .finally(function(){btn.disabled=false;inp.focus()});
  };

  window.vcAsk=function(q){inp.value=q;vcSendMsg()};
  window.vcClear=function(){
    msgs.innerHTML='<div class="vc-msg bot"><div class="vc-bbl">&#128075; Chat cleared. Ask me anything!</div></div>';
    chips.style.display='flex';
  };

  inp.addEventListener('keypress',function(e){if(e.key==='Enter')vcSendMsg()});
  inp.focus();
})();
</script>'''))