In [None]:
#@title üéß Download Narration Audio & Play Introduction
import os as _os
if not _os.path.exists("/content/narration"):
    !pip install -q gdown
    import gdown
    gdown.download(id="1b5CC6G-W_RszA4l3kZ1pQHVSrFdbh7aI", output="/content/narration.zip", quiet=False)
    !unzip -q /content/narration.zip -d /content/narration
    !rm /content/narration.zip
    print(f"Loaded {len(_os.listdir('/content/narration'))} narration segments")
else:
    print("Narration audio already loaded.")

from IPython.display import Audio, display
display(Audio("/content/narration/00_intro.mp3"))

In [None]:
# üîß Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"‚úÖ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("‚ö†Ô∏è No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

print(f"\nüì¶ Python {sys.version.split()[0]}")
print(f"üî• PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"üé≤ Random seed set to {SEED}")

%matplotlib inline

In [None]:
#@title üéß Listen: Motivation
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/01_motivation.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

# Building a Diffusion LLM on TinyStories

*Part 4 of the Vizuara series on Diffusion Language Models*
*Estimated time: 50 minutes (includes ~10 min training)*

# ü§ñ AI Teaching Assistant

Need help with this notebook? Open the **AI Teaching Assistant** ‚Äî it has already read this entire notebook and can help with concepts, code, and exercises.

**[üëâ Open AI Teaching Assistant](https://pods.vizuara.ai/courses/diffusion-llms/practice/4/assistant)**

*Tip: Open it in a separate tab and work through this notebook side-by-side.*


## 1. Why Does This Matter?

In the previous notebooks, we built diffusion language models on synthetic patterns. That was great for understanding the mechanics. But the real test of any language model is: **can it generate coherent, readable text?**

In this notebook, we will train a masked diffusion language model on the **TinyStories** dataset ‚Äî a collection of short stories written in simple English that a 3-4 year old could understand. By the end, our model will generate short stories through iterative unmasking, and you will watch words materialize from a sea of [MASK] tokens into a readable narrative.

**Teaser ‚Äî what our trained model will produce:**

```
Step 1:  [M] [M] [M] [M] [M] [M] [M] [M] [M] [M] [M] [M] ...
Step 3:  [M] was [M] [M] [M] . [M] [M] a [M] [M] the [M] ...
Step 6:  she was [M] happy [M] . [M] had a [M] big the dog ...
Step 10: she was very happy today . she had a new big the dog .
```

Not Shakespeare ‚Äî but a real diffusion model generating real English, trained in under 10 minutes on a single GPU.

In [None]:
# üîß Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import math
import time
from collections import Counter

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

torch.manual_seed(42)
np.random.seed(42)

%matplotlib inline

In [None]:
# Install datasets library if needed
try:
    from datasets import load_dataset
    print("datasets library ready")
except ImportError:
    !pip install datasets -q
    from datasets import load_dataset
    print("datasets library installed")

In [None]:
#@title üéß Listen: Intuition
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_intuition.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 2. Building Intuition

### Why TinyStories?

Training a language model that generates coherent text usually requires billions of parameters and weeks of compute. But the TinyStories dataset (Eldan & Li, 2023) is special:

- Stories use only vocabulary a **3-4 year old** would understand
- Each story is short (50-200 words)
- Grammar is simple but correct
- Stories have narrative structure (beginning, middle, end)

This means a **small model** (a few million parameters) can learn meaningful patterns in **minutes**. Perfect for our Colab notebook.

### How Diffusion Generation Differs from Autoregressive

When GPT generates a story, it writes one word at a time, left to right. It commits to "Once" before knowing the story is about a dog. It commits to "Once upon" before knowing the ending.

Our diffusion model will generate stories differently:
1. Start: `[M] [M] [M] [M] [M] [M] [M] [M] [M] [M]`
2. Structure words appear first: `[M] was [M] [M] . [M] had [M] [M] .`
3. Content fills in: `She was very happy . She had a dog .`

The model "plans" the whole story simultaneously, committing to easy structural tokens first and content words later.

### ü§î Think About This

In a TinyStory like "Once upon a time, there was a little girl named Lily. She liked to play in the park."

- Which words would be **easiest** for the model to predict (appear first during generation)?
- Which words would be **hardest** (appear last)?

*Hint: Think about which words are most predictable from context vs which carry the most unique information.*

In [None]:
#@title üéß Listen: Math
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_math.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 3. The Mathematics ‚Äî Quick Recap

We use the same masked diffusion framework from Notebook 1:

**Forward process:** Mask each token independently with probability $t$:
$$q(x_t^i \mid x_0^i) = (1 - t) \cdot \mathbb{1}[x_t^i = x_0^i] + t \cdot \mathbb{1}[x_t^i = \texttt{[MASK]}]$$

**Training loss:** Cross-entropy at masked positions, weighted by $1/(t \cdot L)$:
$$\mathcal{L} = -\mathbb{E}_{t} \left[ \frac{1}{t \cdot L} \sum_{i:\, x_t^i = \texttt{[MASK]}} \log p_\theta(x_0^i \mid x_t) \right]$$

**What this says computationally:** At each training step, we randomly mask some fraction of the story, then ask the Transformer to predict the missing words. The model sees the whole partially-masked story bidirectionally and must fill in the blanks.

In [None]:
#@title üéß Listen: Data Loading
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_data_loading.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 4. Let's Build It ‚Äî Component by Component

### 4.1 Load and Prepare TinyStories

In [None]:
print("Loading TinyStories dataset...")
dataset = load_dataset("roneneldan/TinyStories", split="train")
print(f"Total stories: {len(dataset):,}")

# Peek at a few stories
for i in range(3):
    story = dataset[i]['text']
    print(f"\n--- Story {i+1} (first 200 chars) ---")
    print(story[:200])

In [None]:
#@title üéß Listen: Tokenizer
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/05_tokenizer.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.2 Build a Word-Level Tokenizer

We build a simple word-level tokenizer from the data. This keeps the vocabulary small and each token meaningful.

In [None]:
class SimpleTokenizer:
    """Word-level tokenizer with a fixed vocabulary."""

    def __init__(self, vocab_size=2000):
        self.vocab_size = vocab_size
        self.word2idx = {}
        self.idx2word = {}
        # Reserve special tokens
        self.mask_token = "[MASK]"
        self.pad_token = "[PAD]"
        self.unk_token = "[UNK]"

    def build_vocab(self, texts, max_texts=50000):
        """Build vocabulary from the most common words."""
        word_counts = Counter()
        for i, text in enumerate(texts):
            if i >= max_texts:
                break
            words = text.lower().split()
            word_counts.update(words)

        # Special tokens get indices 0, 1, 2
        self.word2idx = {
            self.mask_token: 0,
            self.pad_token: 1,
            self.unk_token: 2,
        }

        # Add most common words
        for word, _ in word_counts.most_common(self.vocab_size - 3):
            self.word2idx[word] = len(self.word2idx)

        self.idx2word = {v: k for k, v in self.word2idx.items()}
        print(f"Vocabulary size: {len(self.word2idx)}")
        print(f"Top 20 words: {[w for w, _ in word_counts.most_common(20)]}")

    def encode(self, text, max_len=64):
        """Convert text to token IDs, with padding/truncation."""
        words = text.lower().split()[:max_len]
        ids = [self.word2idx.get(w, 2) for w in words]  # 2 = UNK
        # Pad to max_len
        ids = ids + [1] * (max_len - len(ids))  # 1 = PAD
        return ids

    def decode(self, ids):
        """Convert token IDs back to text."""
        words = []
        for idx in ids:
            if idx == 1:  # PAD
                break
            if idx == 0:  # MASK
                words.append("[M]")
            else:
                words.append(self.idx2word.get(idx, "[?]"))
        return " ".join(words)


# Build tokenizer
VOCAB_SIZE = 2000
SEQ_LEN = 64
tokenizer = SimpleTokenizer(vocab_size=VOCAB_SIZE)
tokenizer.build_vocab([d['text'] for d in dataset])

In [None]:
# üìä Test the tokenizer
test_story = dataset[0]['text']
encoded = tokenizer.encode(test_story, max_len=SEQ_LEN)
decoded = tokenizer.decode(encoded)
print(f"Original: {test_story[:200]}")
print(f"\nEncoded (first 20): {encoded[:20]}")
print(f"\nDecoded: {decoded[:200]}")

In [None]:
#@title üéß Listen: Data Prep
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_data_prep.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.3 Prepare the Training Data

In [None]:
def prepare_dataset(dataset, tokenizer, max_len, n_samples=30000):
    """Convert stories to tensor of token IDs."""
    all_ids = []
    for i in range(min(n_samples, len(dataset))):
        ids = tokenizer.encode(dataset[i]['text'], max_len=max_len)
        # Skip very short stories (mostly padding)
        if sum(1 for x in ids if x > 1) >= 10:  # at least 10 real tokens
            all_ids.append(ids)

    data = torch.tensor(all_ids, dtype=torch.long)
    print(f"Prepared {len(data):,} stories, shape: {data.shape}")
    return data


train_data = prepare_dataset(dataset, tokenizer, SEQ_LEN, n_samples=30000)

In [None]:
# üìä Dataset statistics
real_token_counts = (train_data > 1).sum(dim=1).float()
print(f"Average tokens per story: {real_token_counts.mean():.1f}")
print(f"Min / Max tokens: {real_token_counts.min():.0f} / {real_token_counts.max():.0f}")

plt.figure(figsize=(8, 3))
plt.hist(real_token_counts.numpy(), bins=30, color='#1565c0', alpha=0.7, edgecolor='white')
plt.xlabel('Number of Real Tokens per Story')
plt.ylabel('Count')
plt.title('Story Length Distribution')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Model
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/07_model.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.4 The Diffusion Transformer

This is a proper implementation with layer normalization, positional encoding, and time conditioning ‚Äî scaled up from our toy model.

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


class StoryDiffusionLM(nn.Module):
    """Bidirectional Transformer for story generation via masked diffusion."""

    def __init__(self, vocab_size, d_model=256, n_heads=4, n_layers=6,
                 max_len=512, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len)
        self.time_mlp = nn.Sequential(
            nn.Linear(1, d_model),
            nn.SiLU(),
            nn.Linear(d_model, d_model),
        )
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_model * 4,
            dropout=dropout,
            batch_first=True,
            norm_first=True,
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, n_layers)
        self.ln_final = nn.LayerNorm(d_model)
        self.output_head = nn.Linear(d_model, vocab_size)

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, x_t, t):
        """
        x_t: (B, L) masked token IDs
        t: (B, 1) masking ratio
        """
        h = self.token_embed(x_t) * math.sqrt(self.d_model)
        h = self.pos_enc(h)
        h = h + self.time_mlp(t).unsqueeze(1)  # broadcast time to all positions
        h = self.transformer(h)  # bidirectional!
        h = self.ln_final(h)
        return self.output_head(h)  # (B, L, V)


# Create model
D_MODEL = 256
N_HEADS = 4
N_LAYERS = 6
model = StoryDiffusionLM(
    vocab_size=VOCAB_SIZE, d_model=D_MODEL,
    n_heads=N_HEADS, n_layers=N_LAYERS, max_len=SEQ_LEN
).to(device)

n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,}")
print(f"Architecture: {N_LAYERS} layers, d_model={D_MODEL}, {N_HEADS} heads")

### üí° Key Design Choices

Why these hyperparameters?
- **d_model=256, 6 layers:** ~4M parameters ‚Äî small enough to train in minutes on a T4, large enough to learn meaningful patterns in simple stories.
- **vocab_size=2000:** Covers the most common words in TinyStories. Rare words become [UNK], but that is fine for learning structure.
- **seq_len=64:** Long enough for a short story (2-4 sentences), short enough for fast training.

In [None]:
#@title üéß Listen: Masking
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/08_masking.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.5 Forward Masking Process

In [None]:
MASK_TOKEN_ID = 0
PAD_TOKEN_ID = 1

def mask_tokens(x_0, t):
    """Apply forward masking, but NEVER mask PAD tokens."""
    mask = torch.rand_like(x_0.float()) < t
    # Don't mask padding tokens (they stay as PAD)
    is_pad = (x_0 == PAD_TOKEN_ID)
    mask = mask & ~is_pad

    x_t = x_0.clone()
    x_t[mask] = MASK_TOKEN_ID
    return x_t, mask

In [None]:
# üìä Visualize masking on a real story
sample_idx = 0
sample = train_data[sample_idx:sample_idx+1].to(device)
print(f"Original: {tokenizer.decode(sample[0].tolist())}\n")

for t_val in [0.2, 0.5, 0.8, 1.0]:
    t = torch.tensor([[t_val]], device=device)
    masked, _ = mask_tokens(sample, t)
    print(f"t={t_val}: {tokenizer.decode(masked[0].tolist())}")

In [None]:
#@title üéß Listen: Training
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/09_training.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.6 Training Loop

In [None]:
def train_story_model(model, train_data, n_epochs=3, batch_size=64, lr=3e-4):
    """Train the story diffusion model."""
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    n_batches = len(train_data) // batch_size
    total_steps = n_epochs * n_batches
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, total_steps)

    losses = []
    step = 0

    for epoch in range(n_epochs):
        # Shuffle data
        perm = torch.randperm(len(train_data))
        train_data_shuffled = train_data[perm]

        epoch_losses = []
        for i in range(0, len(train_data_shuffled) - batch_size, batch_size):
            x_0 = train_data_shuffled[i:i+batch_size].to(device)

            # Sample masking ratio
            t = torch.rand(batch_size, 1, device=device) * 0.98 + 0.02

            # Forward masking
            x_t, mask = mask_tokens(x_0, t)

            # Predict original tokens
            logits = model(x_t, t)

            if mask.sum() == 0:
                continue

            # Loss at masked positions only
            loss = F.cross_entropy(logits[mask], x_0[mask])

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

            losses.append(loss.item())
            epoch_losses.append(loss.item())
            step += 1

            if step % 100 == 0:
                avg_loss = np.mean(epoch_losses[-100:])
                print(f"  Epoch {epoch+1}/{n_epochs} | Step {step}/{total_steps} | "
                      f"Loss: {avg_loss:.4f}")

        print(f"Epoch {epoch+1} complete | Avg loss: {np.mean(epoch_losses):.4f}\n")

    return losses


print("Training the story diffusion model...")
print("(This will take ~5-10 minutes on a T4 GPU)\n")
start_time = time.time()
losses = train_story_model(model, train_data, n_epochs=3, batch_size=64)
elapsed = time.time() - start_time
print(f"\nTraining complete in {elapsed/60:.1f} minutes!")

In [None]:
# üìä Training curve
plt.figure(figsize=(10, 4))
window = 50
smoothed = np.convolve(losses, np.ones(window)/window, mode='valid')
plt.plot(smoothed, color='#1565c0', linewidth=1.5)
plt.xlabel('Training Step', fontsize=12)
plt.ylabel('Cross-Entropy Loss', fontsize=12)
plt.title('Story Diffusion LM Training', fontsize=14)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Todo
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/10_todo.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 5. üîß Your Turn: Implement Story Generation

### TODO: Complete the story generation function

In [None]:
@torch.no_grad()
def generate_story(model, tokenizer, seq_len=SEQ_LEN, n_steps=15,
                   temperature=0.8):
    """Generate a story using iterative confidence-based unmasking.

    Args:
        model: Trained StoryDiffusionLM
        tokenizer: SimpleTokenizer
        seq_len: Length of sequence to generate
        n_steps: Number of denoising steps
        temperature: Sampling temperature (lower = more conservative)

    Returns:
        story_text: The generated story as a string
        history: List of (text, step_label) tuples for visualization
    """
    model.eval()
    x = torch.full((1, seq_len), MASK_TOKEN_ID, dtype=torch.long, device=device)
    history = [(tokenizer.decode(x[0].tolist()), 'Start')]

    for s in range(n_steps, 0, -1):
        t = torch.tensor([[s / n_steps]], device=device, dtype=torch.float)
        logits = model(x, t)

        # ============ TODO ============
        # Step 1: Apply temperature scaling to logits
        #         Divide logits by the temperature parameter
        scaled_logits = ???  # YOUR CODE

        # Step 2: Compute probabilities
        probs = ???  # YOUR CODE: softmax of scaled_logits

        # Step 3: Sample tokens from the distributions
        sampled = torch.multinomial(
            probs.view(-1, VOCAB_SIZE), num_samples=1
        ).view(1, -1)

        # Step 4: Compute confidence for each sampled token
        confidence = ???  # YOUR CODE: probability of the sampled token

        # Step 5: Only consider masked (non-PAD, non-already-unmasked) positions
        is_masked = (x == MASK_TOKEN_ID)

        # Step 6: Compute how many tokens to unmask
        n_unmask = max(1, int(is_masked.sum().item() / s))

        # Step 7: Select the most confident masked positions
        conf = confidence.clone()
        conf[~is_masked] = -float('inf')
        _, top_idx = conf.topk(min(n_unmask, max(1, is_masked.sum().item())), dim=-1)

        # Step 8: Unmask those positions
        x.scatter_(1, top_idx, sampled.gather(1, top_idx))
        # ==============================

        if s % max(1, n_steps // 5) == 0 or s == 1:
            history.append((tokenizer.decode(x[0].tolist()), f'Step {n_steps - s + 1}'))

    story = tokenizer.decode(x[0].tolist())
    return story, history

In [None]:
# ‚úÖ Verification
try:
    story, history = generate_story(model, tokenizer, n_steps=15)
    assert isinstance(story, str), "Should return a string"
    assert "[M]" not in story, "All masks should be resolved"
    print("‚úÖ Story generation works!")
    print(f"\nGenerated story:\n{story}")
except Exception as e:
    print(f"‚ùå Error: {e}")

In [None]:
#@title üéß Listen: Post Todo
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/11_post_todo.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

---
### ‚úã Stop and Think
Before seeing the solution, consider:
1. What effect does temperature have? (Low = safe/repetitive, High = creative/chaotic)
2. Why do we divide logits by temperature before softmax?

---

### Solution

In [None]:
@torch.no_grad()
def generate_story(model, tokenizer, seq_len=SEQ_LEN, n_steps=15,
                   temperature=0.8):
    """Generate a story using iterative confidence-based unmasking."""
    model.eval()
    x = torch.full((1, seq_len), MASK_TOKEN_ID, dtype=torch.long, device=device)
    history = [(tokenizer.decode(x[0].tolist()), 'Start')]

    for s in range(n_steps, 0, -1):
        t = torch.tensor([[s / n_steps]], device=device, dtype=torch.float)
        logits = model(x, t)

        # Temperature scaling
        scaled_logits = logits / temperature
        probs = F.softmax(scaled_logits, dim=-1)

        # Sample tokens
        sampled = torch.multinomial(
            probs.view(-1, VOCAB_SIZE), num_samples=1
        ).view(1, -1)

        # Confidence = probability of sampled token
        confidence = probs.gather(-1, sampled.unsqueeze(-1)).squeeze(-1)

        # Only unmask among currently masked positions
        is_masked = (x == MASK_TOKEN_ID)
        n_unmask = max(1, int(is_masked.sum().item() / s))

        # Pick most confident masked positions
        conf = confidence.clone()
        conf[~is_masked] = -float('inf')
        k = min(n_unmask, max(1, is_masked.sum().item()))
        _, top_idx = conf.topk(k, dim=-1)
        x.scatter_(1, top_idx, sampled.gather(1, top_idx))

        if s % max(1, n_steps // 5) == 0 or s == 1:
            history.append((tokenizer.decode(x[0].tolist()), f'Step {n_steps - s + 1}'))

    story = tokenizer.decode(x[0].tolist())
    return story, history

In [None]:
#@title üéß Listen: Stories
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/12_stories.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 6. Putting It All Together

In [None]:
# Generate and display multiple stories
print("=" * 60)
print("GENERATED STORIES")
print("=" * 60)

for i in range(5):
    story, _ = generate_story(model, tokenizer, n_steps=20, temperature=0.8)
    print(f"\nStory {i+1}:")
    print(f"  {story}")
    print(f"  {'‚îÄ' * 50}")

In [None]:
#@title üéß Listen: Visualization
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/13_visualization.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 7. üìä Visualizing the Unmasking Process

In [None]:
def visualize_story_generation(history):
    """Show how a story materializes from masks step by step."""
    fig, ax = plt.subplots(figsize=(16, len(history) * 0.6 + 1))

    for row, (text, label) in enumerate(history):
        words = text.split()
        x_offset = 0
        for word in words:
            if word == "[M]":
                color = '#e0e0e0'
                text_color = '#999999'
            else:
                # Color real words by their first character
                hue = (ord(word[0]) % 10) / 10
                color = plt.cm.Pastel1(hue)
                text_color = '#333333'

            ax.text(x_offset, -row, word + " ",
                    fontsize=10, fontfamily='monospace',
                    color=text_color,
                    bbox=dict(boxstyle='round,pad=0.15', facecolor=color,
                              edgecolor='none', alpha=0.7))
            x_offset += len(word) + 1.5

        ax.text(-3, -row, label, fontsize=10, fontweight='bold',
                ha='right', va='center', color='#666666')

    ax.set_xlim(-5, 80)
    ax.set_ylim(-len(history) + 0.5, 1)
    ax.axis('off')
    ax.set_title('Story Generation: Watch Words Materialize',
                 fontsize=15, pad=20)
    plt.tight_layout()
    plt.show()


# Generate one story with full history
story, history = generate_story(model, tokenizer, n_steps=15, temperature=0.8)
visualize_story_generation(history)

In [None]:
#@title üéß Listen: Word Order
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/14_word_order.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### Which Words Appear First?

In [None]:
@torch.no_grad()
def track_unmasking_order(model, tokenizer, n_steps=20):
    """Track the order in which tokens are unmasked during generation."""
    model.eval()
    x = torch.full((1, SEQ_LEN), MASK_TOKEN_ID, dtype=torch.long, device=device)
    unmask_order = [None] * SEQ_LEN  # When each position was unmasked

    for step_num, s in enumerate(range(n_steps, 0, -1)):
        t = torch.tensor([[s / n_steps]], device=device, dtype=torch.float)
        logits = model(x, t)
        probs = F.softmax(logits / 0.8, dim=-1)
        sampled = torch.multinomial(probs.view(-1, VOCAB_SIZE), 1).view(1, -1)
        confidence = probs.gather(-1, sampled.unsqueeze(-1)).squeeze(-1)

        is_masked = (x == MASK_TOKEN_ID)
        n_unmask = max(1, int(is_masked.sum().item() / s))

        conf = confidence.clone()
        conf[~is_masked] = -float('inf')
        k = min(n_unmask, max(1, is_masked.sum().item()))
        _, top_idx = conf.topk(k, dim=-1)

        # Record which step each position was unmasked
        for idx in top_idx[0].tolist():
            if unmask_order[idx] is None:
                unmask_order[idx] = step_num

        x.scatter_(1, top_idx, sampled.gather(1, top_idx))

    final_tokens = [tokenizer.idx2word.get(x[0, i].item(), '?') for i in range(SEQ_LEN)]
    return final_tokens, unmask_order


# Run multiple times and aggregate
early_words = Counter()
late_words = Counter()

for _ in range(50):
    tokens, order = track_unmasking_order(model, tokenizer)
    n_steps_used = max(o for o in order if o is not None) + 1
    mid = n_steps_used // 2

    for tok, step in zip(tokens, order):
        if tok in ('[PAD]', '[MASK]', '[UNK]'):
            continue
        if step is not None:
            if step <= mid // 2:
                early_words[tok] += 1
            elif step >= n_steps_used - mid // 2:
                late_words[tok] += 1

# Show results
print("Words that appear FIRST (most confident, easiest):")
for word, count in early_words.most_common(15):
    print(f"  '{word}': {count} times")

print("\nWords that appear LAST (least confident, hardest):")
for word, count in late_words.most_common(15):
    print(f"  '{word}': {count} times")

In [None]:
# üìä Visualize early vs late words
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

top_early = early_words.most_common(12)
top_late = late_words.most_common(12)

ax1.barh([w for w, _ in top_early], [c for _, c in top_early],
         color='#66bb6a', edgecolor='white')
ax1.set_title('Appear FIRST (Structure Words)', fontsize=13, color='#2e7d32')
ax1.set_xlabel('Frequency')
ax1.invert_yaxis()

ax2.barh([w for w, _ in top_late], [c for _, c in top_late],
         color='#ef5350', edgecolor='white')
ax2.set_title('Appear LAST (Content Words)', fontsize=13, color='#c62828')
ax2.set_xlabel('Frequency')
ax2.invert_yaxis()

plt.suptitle('Unmasking Order: Easy Words First, Hard Words Last', fontsize=15)
plt.tight_layout()
plt.show()

print("\nüí° Notice: function words (the, a, was, is) appear first.")
print("   Content words (names, adjectives, specific nouns) appear last.")
print("   The model plans structure before content ‚Äî like a painter!")

In [None]:
#@title üéß Listen: Gallery
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/15_gallery.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 8. üéØ Final Output: Story Generation Gallery

In [None]:
def story_gallery(model, tokenizer, n_stories=8, n_steps=20):
    """Generate a gallery of stories with different temperatures."""
    temps = [0.5, 0.7, 0.9, 1.1]

    print("=" * 70)
    print("STORY GALLERY ‚Äî Diffusion LLM on TinyStories")
    print("=" * 70)

    for temp in temps:
        print(f"\n{'‚îÄ' * 70}")
        print(f"Temperature = {temp}")
        print(f"{'‚îÄ' * 70}")
        for i in range(n_stories // len(temps)):
            story, _ = generate_story(model, tokenizer,
                                       n_steps=n_steps, temperature=temp)
            print(f"  {story}")
        print()

story_gallery(model, tokenizer)

In [None]:
#@title üéß Listen: Final Viz
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/16_final_viz.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

In [None]:
# üìä One final beautiful visualization: generation grid
fig, axes = plt.subplots(4, 1, figsize=(18, 8))
step_snapshots = [1, 5, 10, 20]

# Generate one story and capture at specific steps
model.eval()
x = torch.full((1, SEQ_LEN), MASK_TOKEN_ID, dtype=torch.long, device=device)
snapshots = {}
n_total_steps = 20

for s in range(n_total_steps, 0, -1):
    t = torch.tensor([[s / n_total_steps]], device=device, dtype=torch.float)
    logits = model(x, t)
    probs = F.softmax(logits / 0.8, dim=-1)
    sampled = torch.multinomial(probs.view(-1, VOCAB_SIZE), 1).view(1, -1)
    confidence = probs.gather(-1, sampled.unsqueeze(-1)).squeeze(-1)
    is_masked = (x == MASK_TOKEN_ID)
    n_unmask = max(1, int(is_masked.sum().item() / s))
    conf = confidence.clone()
    conf[~is_masked] = -float('inf')
    k = min(n_unmask, max(1, is_masked.sum().item()))
    _, top_idx = conf.topk(k, dim=-1)
    x.scatter_(1, top_idx, sampled.gather(1, top_idx))

    step_num = n_total_steps - s + 1
    if step_num in step_snapshots:
        snapshots[step_num] = x[0].cpu().clone()

for ax_idx, step_num in enumerate(step_snapshots):
    ax = axes[ax_idx]
    seq = snapshots[step_num]
    text = tokenizer.decode(seq.tolist())
    words = text.split()

    x_pos = 0
    for word in words[:20]:  # Show first 20 words
        if word == "[M]":
            color = '#bdbdbd'
            ax.text(x_pos, 0.5, word, fontsize=11, fontfamily='monospace',
                    color='#888888', va='center',
                    bbox=dict(boxstyle='round', facecolor=color, alpha=0.5,
                              edgecolor='none'))
        else:
            hue = (hash(word) % 8) / 8
            color = plt.cm.Set3(hue)
            ax.text(x_pos, 0.5, word, fontsize=11, fontfamily='monospace',
                    color='#222222', va='center', fontweight='bold',
                    bbox=dict(boxstyle='round', facecolor=color, alpha=0.7,
                              edgecolor='none'))
        x_pos += len(word) + 1.2

    ax.set_xlim(-1, 65)
    ax.set_ylim(0, 1)
    ax.set_ylabel(f'Step {step_num}', fontsize=12, fontweight='bold',
                  rotation=0, ha='right', va='center')
    ax.axis('off')

plt.suptitle('A Story Materializing Through Diffusion\n'
             '(Gray = [MASK], Colored = revealed words)',
             fontsize=15, y=1.02)
plt.tight_layout()
plt.show()

print("\nüéâ Congratulations! You have built a diffusion language model from scratch")
print("   that generates coherent short stories through iterative unmasking!")
print("   The model learned English structure, grammar, and basic narrative flow")
print("   from the TinyStories dataset ‚Äî all in under 10 minutes of training.")

In [None]:
#@title üéß Listen: Closing
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/17_closing.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 9. Reflection and Next Steps

### ü§î Reflection Questions

1. **Quality observations:** What kinds of grammatical patterns did the model learn well? Where does it struggle? (Compare to what a GPT-style model of the same size would produce.)

2. **Temperature effects:** How did temperature affect the stories? At $T = 0.5$, stories are repetitive but grammatical. At $T = 1.1$, stories are more creative but sometimes nonsensical. Why?

3. **Scaling intuition:** LLaDA used 8B parameters and 2.3T tokens. We used ~4M parameters and ~30K stories. What do you think would happen if we 10x our data? 10x our model?

### üèÜ Optional Challenges

1. **Prompted generation:** Modify `generate_story` to accept a text prompt. Encode the prompt, set those positions as fixed (never mask them), and let the model fill in the rest.

2. **Longer stories:** Increase `SEQ_LEN` to 128. You may need to reduce `batch_size` to fit in GPU memory. Do longer stories maintain coherence?

3. **Character-level model:** Replace the word tokenizer with a character-level tokenizer. How does this change what the model learns? (Hint: it should learn to spell words correctly through diffusion.)

4. **Infilling:** Given a story with a gap in the middle, let the model fill it in. This is something autoregressive models cannot do natively!