In [None]:
# üîß Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"‚úÖ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("‚ö†Ô∏è No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

print(f"\nüì¶ Python {sys.version.split()[0]}")
print(f"üî• PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"üé≤ Random seed set to {SEED}")

%matplotlib inline

# Case Study: Real-Time Clinical Documentation with Diffusion Language Models
## Implementation Notebook

*MedScribe AI -- Building a masked diffusion language model for clinical note generation*
*Estimated time: 60-75 minutes (including ~10 minutes of training)*

# ü§ñ AI Teaching Assistant

Need help with this notebook? Open the **AI Teaching Assistant** ‚Äî it has already read this entire notebook and can help with concepts, code, and exercises.

**[üëâ Open AI Teaching Assistant](https://pods.vizuara.ai/courses/diffusion-llms/practice/0/assistant)**

*Tip: Open it in a separate tab and work through this notebook side-by-side.*


## 1. Setup and Environment

Before we begin, let us set up our environment and install the necessary dependencies. This notebook is designed to run on Google Colab with a T4 GPU.

In [None]:
# Check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Install required packages
!pip install -q datasets rouge-score numpy matplotlib seaborn tqdm

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import json
import re
import time
from collections import Counter
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Plot styling
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 12
sns.set_style("whitegrid")
print("All imports ready.")

## 2. Industry Context: The Clinical Documentation Problem

MedScribe AI is a health-tech startup building an AI-powered clinical documentation assistant. Their current autoregressive model generates structured SOAP notes (Subjective, Objective, Assessment, Plan) from patient-physician conversations, but it suffers from three critical limitations:

1. **Latency**: 12-20 seconds per note, too slow for real-time use during encounters
2. **No bidirectional editing**: Edits in one section cannot propagate to other sections
3. **Brittle infilling**: Completing partially-filled templates yields poor quality (52% vs 68% acceptance rate)

In this notebook, we will build a **masked diffusion language model** that addresses all three problems by generating clinical notes through iterative unmasking rather than left-to-right token generation.

Our goal: generate a complete SOAP note in under 500ms by predicting all tokens in parallel across 15-25 denoising steps, while maintaining physician acceptance quality.

## 3. Data Acquisition and Preprocessing

We use a synthetic clinical notes dataset that mimics the structure of MIMIC-III discharge summaries. Each note follows the SOAP format with clearly delimited sections.

In a production setting, MedScribe would train on de-identified MIMIC-III data under a PhysioNet credential. For this notebook, we generate realistic synthetic data that captures the key structural properties: section headers, clinical vocabulary, and variable note lengths.

In [None]:
# Synthetic clinical note generator
# In production, this would be replaced by MIMIC-III data loading

SECTION_HEADERS = ["HPI:", "ROS:", "EXAM:", "ASSESSMENT:", "PLAN:"]
SECTION_LABELS = {"HPI": 0, "ROS": 1, "EXAM": 2, "ASSESSMENT": 3, "PLAN": 4, "OTHER": 5}

# Clinical vocabulary pools for synthetic data generation
HPI_PHRASES = [
    "patient presents with", "chief complaint of", "reports onset of",
    "denies any", "history of", "states that symptoms began",
    "worsening over the past", "associated with", "no prior episodes of",
    "previously treated with", "medication compliance has been",
    "symptoms include", "pain described as", "located in the",
    "radiating to", "aggravated by", "relieved by", "duration of",
    "frequency of episodes", "last seen by physician on",
]

ROS_PHRASES = [
    "negative for fever", "denies chest pain", "no shortness of breath",
    "reports mild fatigue", "denies nausea or vomiting", "no weight changes",
    "positive for headache", "denies dizziness", "no vision changes",
    "reports occasional", "constitutional symptoms absent",
    "cardiovascular review unremarkable", "respiratory review negative",
    "gastrointestinal symptoms denied", "musculoskeletal pain noted",
]

EXAM_PHRASES = [
    "vitals BP", "HR", "RR", "temp", "SpO2", "BMI",
    "general appearance alert and oriented", "no acute distress",
    "lungs clear to auscultation bilaterally", "heart regular rate and rhythm",
    "abdomen soft nontender", "extremities no edema",
    "neurological exam grossly intact", "skin warm and dry",
    "HEENT normocephalic atraumatic", "neck supple no lymphadenopathy",
]

ASSESSMENT_PHRASES = [
    "type 2 diabetes mellitus", "essential hypertension",
    "hyperlipidemia", "chronic kidney disease stage",
    "osteoarthritis of the", "major depressive disorder",
    "generalized anxiety disorder", "chronic low back pain",
    "well controlled", "poorly controlled", "stable",
    "acute exacerbation of", "new diagnosis of", "suspected",
    "differential includes", "consistent with", "likely secondary to",
]

PLAN_PHRASES = [
    "continue current medications", "increase dose of", "add",
    "refer to specialist", "follow up in", "weeks",
    "obtain labs including", "CBC CMP lipid panel",
    "schedule imaging", "CT scan", "MRI of",
    "lifestyle modifications discussed", "diet and exercise counseling",
    "return if symptoms worsen", "patient educated on",
    "prescription sent to pharmacy", "prior authorization submitted",
]


def generate_section(phrases, min_phrases=3, max_phrases=8):
    """Generate a synthetic clinical section from phrase pools."""
    n = np.random.randint(min_phrases, max_phrases + 1)
    selected = np.random.choice(phrases, size=n, replace=False)
    # Add some random numbers for vitals/dates
    text = " ".join(selected)
    text = text.replace("BP", f"BP {np.random.randint(110,160)}/{np.random.randint(60,95)}")
    text = text.replace("HR", f"HR {np.random.randint(60,100)}")
    text = text.replace("temp", f"temp {np.random.uniform(97.5, 99.5):.1f}")
    text = text.replace("SpO2", f"SpO2 {np.random.randint(94,100)}%")
    text = text.replace("stage", f"stage {np.random.randint(1,5)}")
    return text


def generate_synthetic_note():
    """Generate a complete synthetic SOAP note."""
    sections = {
        "HPI": generate_section(HPI_PHRASES, 4, 8),
        "ROS": generate_section(ROS_PHRASES, 3, 6),
        "EXAM": generate_section(EXAM_PHRASES, 4, 7),
        "ASSESSMENT": generate_section(ASSESSMENT_PHRASES, 2, 5),
        "PLAN": generate_section(PLAN_PHRASES, 3, 7),
    }

    note = ""
    for header in SECTION_HEADERS:
        section_name = header.replace(":", "")
        note += f"{header} {sections[section_name]} "
    return note.strip(), sections


# Generate dataset
NUM_NOTES = 5000
notes_data = []
for i in range(NUM_NOTES):
    note_text, sections = generate_synthetic_note()
    notes_data.append({"text": note_text, "sections": sections})

print(f"Generated {len(notes_data)} synthetic clinical notes")
print(f"\nSample note:\n{notes_data[0]['text'][:500]}...")

### 3.1 Section Parsing

A critical preprocessing step is identifying where each SOAP section begins and ends in the tokenized sequence. This is needed for: (a) the section structure loss term, and (b) evaluating whether the model maintains correct note structure.

In [None]:
def parse_section_labels(note_text, tokenizer):
    """
    Parse a clinical note and return per-token section labels.

    Given a note with section headers (HPI:, ROS:, EXAM:, ASSESSMENT:, PLAN:),
    return a list of integer labels the same length as the tokenized note,
    where each label indicates which SOAP section that token belongs to.

    Args:
        note_text: Raw clinical note text with section headers
        tokenizer: Tokenizer with encode() method

    Returns:
        List[int]: Section label for each token (0=HPI, 1=ROS, 2=EXAM,
                   3=ASSESSMENT, 4=PLAN, 5=OTHER)
    """
    tokens = tokenizer.encode(note_text)
    labels = []
    current_section = SECTION_LABELS["OTHER"]

    # Build a character-to-section mapping
    char_sections = [SECTION_LABELS["OTHER"]] * len(note_text)
    for header in SECTION_HEADERS:
        section_name = header.replace(":", "")
        # Find all occurrences of this header
        start = 0
        while True:
            idx = note_text.find(header, start)
            if idx == -1:
                break
            # Find the next header or end of string
            next_header_pos = len(note_text)
            for other_header in SECTION_HEADERS:
                pos = note_text.find(other_header, idx + len(header))
                if pos != -1 and pos < next_header_pos:
                    next_header_pos = pos
            # Label all characters in this section
            for i in range(idx, next_header_pos):
                char_sections[i] = SECTION_LABELS[section_name]
            start = idx + len(header)

    return tokens, char_sections

### TODO: Implement Section Filtering

Students must implement the filtering function that ensures each note in the dataset meets quality requirements for training.

In [None]:
def filter_notes(notes_data, tokenizer, min_tokens=50, max_tokens=256,
                 min_sections=3):
    """
    Filter clinical notes based on quality criteria.

    A note passes the filter if:
    1. Its tokenized length is between min_tokens and max_tokens
    2. It contains at least min_sections distinct SOAP sections
    3. No single section exceeds 60% of the total note length

    Args:
        notes_data: List of dicts with 'text' and 'sections' keys
        tokenizer: Tokenizer with encode() method
        min_tokens: Minimum number of tokens per note
        max_tokens: Maximum number of tokens per note
        min_sections: Minimum number of distinct sections required

    Returns:
        filtered: List of dicts that pass all criteria
        stats: Dict with counts of notes filtered by each criterion

    Hints:
        - Use tokenizer.encode() to get the token count
        - Count distinct sections by checking which SECTION_HEADERS
          appear in the note text
        - For the 60% rule, tokenize each section individually and
          compare to total length
    """
    # ============ TODO ============
    # Step 1: Initialize filtered list and stats counters
    # Step 2: For each note, check the three criteria
    # Step 3: Track why notes were filtered (too short, too long,
    #         too few sections, section imbalance)
    # Step 4: Return the filtered list and stats dict
    # ==============================

    filtered = ???  # YOUR CODE HERE
    stats = ???     # YOUR CODE HERE

    return filtered, stats

In [None]:
# Verification cell -- run after implementing filter_notes

# Simple tokenizer for testing
class SimpleTokenizer:
    def __init__(self, vocab_size=2000):
        self.vocab_size = vocab_size
        self.word2idx = {"[PAD]": 0, "[MASK]": 1, "[UNK]": 2}
        self.idx2word = {0: "[PAD]", 1: "[MASK]", 2: "[UNK]"}
        self._next_idx = 3

    def build_vocab(self, texts, max_vocab=2000):
        word_counts = Counter()
        for text in texts:
            word_counts.update(text.lower().split())
        for word, _ in word_counts.most_common(max_vocab - 3):
            if word not in self.word2idx:
                self.word2idx[word] = self._next_idx
                self.idx2word[self._next_idx] = word
                self._next_idx += 1
        self.vocab_size = len(self.word2idx)

    def encode(self, text):
        return [self.word2idx.get(w, 2) for w in text.lower().split()]

    def decode(self, ids):
        return " ".join(self.idx2word.get(i, "[UNK]") for i in ids)


tokenizer = SimpleTokenizer()
tokenizer.build_vocab([n["text"] for n in notes_data])
print(f"Vocabulary size: {tokenizer.vocab_size}")

filtered, stats = filter_notes(notes_data, tokenizer)
print(f"\nFiltering results:")
print(f"  Original notes: {len(notes_data)}")
print(f"  Filtered notes: {len(filtered)}")
print(f"  Filter stats: {stats}")
assert len(filtered) > 0, "No notes passed filtering -- check your implementation"
assert len(filtered) < len(notes_data), "All notes passed -- filter is too lenient"
print("\nFilter implementation looks correct!")

### Thought Questions: Data Preprocessing

1. Why do we set a maximum token length? What happens to a diffusion model if training sequences have highly variable lengths?
2. Why is the "no single section exceeds 60% of the note" rule important for training a balanced model?
3. In production, MedScribe would use MIMIC-III data. What additional preprocessing challenges would real clinical text introduce compared to our synthetic data?

## 4. Exploratory Data Analysis

Before building the model, we need to understand the structure and distribution of our clinical notes dataset.

In [None]:
# Note length distribution
note_lengths = [len(tokenizer.encode(n["text"])) for n in filtered]

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Token length distribution
axes[0].hist(note_lengths, bins=30, edgecolor='black', alpha=0.7, color='steelblue')
axes[0].set_xlabel("Tokens per Note")
axes[0].set_ylabel("Count")
axes[0].set_title("Distribution of Note Lengths")
axes[0].axvline(np.mean(note_lengths), color='red', linestyle='--',
                label=f'Mean: {np.mean(note_lengths):.0f}')
axes[0].legend()

# Section length distribution
section_lengths = {s: [] for s in SECTION_HEADERS}
for note in filtered:
    for header in SECTION_HEADERS:
        section_name = header.replace(":", "")
        if section_name in note["sections"]:
            section_lengths[header].append(
                len(tokenizer.encode(note["sections"][section_name]))
            )

section_means = [np.mean(section_lengths[h]) for h in SECTION_HEADERS]
section_stds = [np.std(section_lengths[h]) for h in SECTION_HEADERS]
axes[1].bar(range(len(SECTION_HEADERS)), section_means, yerr=section_stds,
            capsize=5, color=['#2196F3', '#4CAF50', '#FF9800', '#F44336', '#9C27B0'],
            edgecolor='black', alpha=0.8)
axes[1].set_xticks(range(len(SECTION_HEADERS)))
axes[1].set_xticklabels([h.replace(":", "") for h in SECTION_HEADERS])
axes[1].set_ylabel("Tokens per Section")
axes[1].set_title("Section Length Distribution")

# Top vocabulary
all_words = []
for note in filtered:
    all_words.extend(note["text"].lower().split())
word_counts = Counter(all_words)
top_20 = word_counts.most_common(20)
words, counts = zip(*top_20)
axes[2].barh(range(len(words)), counts, color='steelblue', edgecolor='black', alpha=0.7)
axes[2].set_yticks(range(len(words)))
axes[2].set_yticklabels(words)
axes[2].invert_yaxis()
axes[2].set_xlabel("Frequency")
axes[2].set_title("Top 20 Words in Clinical Notes")

plt.tight_layout()
plt.show()

print(f"\nDataset summary:")
print(f"  Total notes: {len(filtered)}")
print(f"  Mean note length: {np.mean(note_lengths):.1f} tokens")
print(f"  Min/Max length: {min(note_lengths)}/{max(note_lengths)} tokens")

### TODO: EDA -- Section Ordering Analysis

Analyze how consistent the SOAP section ordering is across the dataset. In real clinical notes, sections sometimes appear out of order or are missing entirely.

In [None]:
def analyze_section_ordering(notes_data):
    """
    Analyze the section ordering patterns in the clinical notes dataset.

    For each note, determine:
    1. Which sections are present
    2. Whether sections appear in the standard SOAP order
    3. Which sections are most frequently missing

    Args:
        notes_data: List of dicts with 'text' key

    Returns:
        analysis: Dict with keys:
            - 'correct_order_pct': float, percentage of notes with correct ordering
            - 'missing_sections': Dict[str, int], count of missing sections by name
            - 'section_presence': Dict[str, float], fraction of notes containing each section

    Hints:
        - Standard order is: HPI, ROS, EXAM, ASSESSMENT, PLAN
        - Use str.find() to locate each section header in the note text
        - A section is in correct order if its position is after all
          preceding sections' positions
    """
    # ============ TODO ============
    # Step 1: For each note, find the character position of each section header
    # Step 2: Check if positions are in ascending order (standard SOAP order)
    # Step 3: Count which sections are missing from each note
    # Step 4: Compute summary statistics
    # ==============================

    analysis = ???  # YOUR CODE HERE

    return analysis

In [None]:
# Verification cell
analysis = analyze_section_ordering(filtered)
print("Section ordering analysis:")
print(f"  Notes with correct SOAP order: {analysis['correct_order_pct']:.1f}%")
print(f"\n  Section presence rates:")
for section, rate in analysis['section_presence'].items():
    print(f"    {section}: {rate:.1%}")
print(f"\n  Missing section counts:")
for section, count in analysis['missing_sections'].items():
    print(f"    {section}: {count}")
assert 'correct_order_pct' in analysis, "Missing 'correct_order_pct' key"
assert 'section_presence' in analysis, "Missing 'section_presence' key"
print("\nEDA implementation looks correct!")

### Thought Questions: EDA

1. Which SOAP sections tend to be longest? Why might the HPI section typically be longer than the PLAN section?
2. If 15% of notes are missing the ROS section, how might this affect the diffusion model's ability to generate complete notes?
3. Looking at the top vocabulary words, what do you notice about the ratio of clinical terms to function words? What does this imply about token difficulty during unmasking?

## 5. Baseline: Autoregressive Clinical Note Generator

Before building our diffusion model, we implement a simple autoregressive baseline. This establishes the quality floor and latency ceiling that the diffusion approach must beat.

In [None]:
class CausalTransformer(nn.Module):
    """Small causal (left-to-right) Transformer for autoregressive note generation."""

    def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=4,
                 max_len=256, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = nn.Embedding(max_len, d_model)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4,
            dropout=dropout, batch_first=True, activation='gelu'
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.output_head = nn.Linear(d_model, vocab_size)
        self.max_len = max_len

    def forward(self, x):
        B, L = x.shape
        positions = torch.arange(L, device=x.device).unsqueeze(0)
        h = self.embedding(x) + self.pos_encoding(positions)
        # Causal mask: each position can only attend to itself and earlier positions
        causal_mask = torch.triu(
            torch.ones(L, L, device=x.device) * float('-inf'), diagonal=1
        )
        h = self.transformer(h, mask=causal_mask)
        return self.output_head(h)

In [None]:
class ClinicalNoteDataset(Dataset):
    """Dataset for clinical notes."""

    def __init__(self, notes, tokenizer, max_len=128):
        self.data = []
        for note in notes:
            ids = tokenizer.encode(note["text"])[:max_len]
            # Pad to max_len
            ids = ids + [tokenizer.word2idx["[PAD]"]] * (max_len - len(ids))
            self.data.append(torch.tensor(ids, dtype=torch.long))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


MAX_LEN = 128
PAD_ID = tokenizer.word2idx["[PAD]"]
MASK_ID = tokenizer.word2idx["[MASK]"]

dataset = ClinicalNoteDataset(filtered, tokenizer, max_len=MAX_LEN)
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_ds, val_ds, test_ds = torch.utils.data.random_split(
    dataset, [train_size, val_size, test_size]
)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=64)
test_loader = DataLoader(test_ds, batch_size=64)

print(f"Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}")
print(f"Sequence length: {MAX_LEN}")
print(f"Vocabulary size: {tokenizer.vocab_size}")

In [None]:
# Train the autoregressive baseline
ar_model = CausalTransformer(tokenizer.vocab_size, d_model=256, nhead=4,
                              num_layers=4, max_len=MAX_LEN).to(device)
optimizer = torch.optim.AdamW(ar_model.parameters(), lr=3e-4, weight_decay=0.01)

num_params = sum(p.numel() for p in ar_model.parameters())
print(f"Autoregressive model parameters: {num_params:,}")

ar_losses = []
ar_model.train()
for epoch in range(5):
    epoch_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        # Input: tokens 0..L-2, Target: tokens 1..L-1
        logits = ar_model(batch[:, :-1])
        target = batch[:, 1:]
        # Mask out padding from loss
        loss_mask = (target != PAD_ID).float()
        loss = F.cross_entropy(logits.reshape(-1, tokenizer.vocab_size),
                                target.reshape(-1), reduction='none')
        loss = (loss * loss_mask.reshape(-1)).sum() / loss_mask.sum()
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(ar_model.parameters(), 1.0)
        optimizer.step()
        epoch_loss += loss.item()
    epoch_loss /= len(train_loader)
    ar_losses.append(epoch_loss)
    print(f"Epoch {epoch+1}: loss = {epoch_loss:.4f}")

print("\nAutoregressive baseline training complete!")

In [None]:
# Evaluate baseline: generation latency
ar_model.eval()

def generate_ar(model, tokenizer, max_len=128, temperature=0.8):
    """Generate a clinical note autoregressively."""
    # Start with a random first token from common clinical words
    start_tokens = ["patient", "hpi:", "the"]
    start_id = tokenizer.word2idx.get(start_tokens[0], 3)
    ids = [start_id]

    start_time = time.time()
    with torch.no_grad():
        for _ in range(max_len - 1):
            x = torch.tensor([ids], device=device)
            logits = model(x)[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, 1).item()
            if next_id == PAD_ID:
                break
            ids.append(next_id)
    elapsed = time.time() - start_time
    return tokenizer.decode(ids), elapsed

# Generate 10 notes and measure latency
latencies = []
for i in range(10):
    text, elapsed = generate_ar(ar_model, tokenizer)
    latencies.append(elapsed)
    if i == 0:
        print(f"Sample generated note:\n{text[:300]}...\n")

print(f"Autoregressive generation latency:")
print(f"  Mean: {np.mean(latencies)*1000:.1f} ms")
print(f"  Std:  {np.std(latencies)*1000:.1f} ms")
print(f"  This is the latency ceiling our diffusion model must beat.")

### Thought Questions: Baseline

1. Why does autoregressive generation latency scale linearly with sequence length?
2. If MedScribe generates a 400-token SOAP note, how many forward passes does the autoregressive model require? How does this compare to a 20-step diffusion model?
3. What is the fundamental architectural difference that prevents the autoregressive model from doing bidirectional edits?

## 6. Diffusion Transformer Model Design

Now we build the core model: a bidirectional Transformer with time conditioning for masked diffusion generation. The key difference from the autoregressive model is the absence of a causal mask -- every position attends to every other position.

### 6.1 Time Conditioning

The masking ratio $t$ is a critical input to the model. It tells the model how much of the note is currently masked, which calibrates its prediction confidence. We convert the scalar $t$ into a $d$-dimensional vector using a small MLP with sinusoidal features.

In [None]:
class TimeConditioningMLP(nn.Module):
    """Convert scalar masking ratio t into a d-dimensional conditioning vector."""

    def __init__(self, d_model, max_period=10000):
        super().__init__()
        self.d_model = d_model
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model),
        )
        # Precompute sinusoidal frequency bands
        half = d_model // 2
        freqs = torch.exp(
            -np.log(max_period) * torch.arange(half, dtype=torch.float32) / half
        )
        self.register_buffer("freqs", freqs)

    def forward(self, t):
        """
        Args:
            t: (B,) tensor of masking ratios in [0, 1]
        Returns:
            (B, d_model) time conditioning vector
        """
        # Sinusoidal embedding of t
        t_emb = t[:, None] * self.freqs[None, :]  # (B, d_model//2)
        t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)  # (B, d_model)
        return self.mlp(t_emb)

### 6.2 The Full Diffusion Transformer

In [None]:
class ClinicalDiffusionLM(nn.Module):
    """
    Bidirectional Transformer for masked diffusion clinical note generation.

    Architecture:
    - Token embedding + sinusoidal positional encoding
    - Time conditioning via addition
    - Full (non-causal) Transformer encoder
    - Linear output head over vocabulary

    The model sees all positions bidirectionally, which enables:
    1. Parallel prediction of all masked tokens
    2. Bidirectional context propagation for edits
    3. Native infilling of partially-masked notes
    """

    def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=6,
                 max_len=256, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size

        # Token embedding
        self.embedding = nn.Embedding(vocab_size, d_model)

        # Sinusoidal positional encoding
        self.pos_encoding = nn.Embedding(max_len, d_model)

        # Time conditioning
        self.time_mlp = TimeConditioningMLP(d_model)

        # Bidirectional Transformer (NO causal mask)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4,
            dropout=dropout, batch_first=True, activation='gelu',
            norm_first=True  # Pre-norm for training stability
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Output head
        self.output_head = nn.Linear(d_model, vocab_size)

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, x, t, pad_mask=None):
        """
        Args:
            x: (B, L) token IDs (with [MASK] at masked positions)
            t: (B,) masking ratios in [0, 1]
            pad_mask: (B, L) boolean mask, True for padding positions

        Returns:
            logits: (B, L, V) unnormalized log-probabilities over vocabulary
        """
        B, L = x.shape
        positions = torch.arange(L, device=x.device).unsqueeze(0)

        # Embed tokens and add positional encoding
        h = self.embedding(x) + self.pos_encoding(positions)

        # Add time conditioning (broadcast across sequence positions)
        t_emb = self.time_mlp(t)  # (B, d_model)
        h = h + t_emb.unsqueeze(1)  # (B, L, d_model)

        # Bidirectional Transformer (no causal mask)
        if pad_mask is not None:
            h = self.transformer(h, src_key_padding_mask=pad_mask)
        else:
            h = self.transformer(h)

        # Project to vocabulary
        logits = self.output_head(h)
        return logits


# Instantiate model
diff_model = ClinicalDiffusionLM(
    vocab_size=tokenizer.vocab_size,
    d_model=256,
    nhead=4,
    num_layers=6,
    max_len=MAX_LEN,
).to(device)

num_params = sum(p.numel() for p in diff_model.parameters())
print(f"Diffusion model parameters: {num_params:,}")
print(f"Autoregressive baseline had: {sum(p.numel() for p in ar_model.parameters()):,}")

### 6.3 Forward Masking Process

The forward process takes a clean note and masks each token independently with probability $t$. Padding tokens are never masked -- they stay as padding regardless of $t$.

In [None]:
def forward_mask(x, t, pad_id=PAD_ID, mask_id=MASK_ID):
    """
    Apply the forward masking process to a batch of notes.

    Each non-padding token is independently replaced with [MASK]
    with probability t.

    Args:
        x: (B, L) clean token IDs
        t: (B,) masking ratios in [0, 1]
        pad_id: token ID for padding
        mask_id: token ID for [MASK]

    Returns:
        x_t: (B, L) masked token IDs
        mask: (B, L) boolean tensor, True at positions that were masked
    """
    B, L = x.shape
    # Sample masking decisions: Bernoulli with probability t per batch element
    rand = torch.rand(B, L, device=x.device)
    t_expanded = t[:, None].expand(B, L)
    should_mask = rand < t_expanded

    # Never mask padding tokens
    is_pad = (x == pad_id)
    should_mask = should_mask & ~is_pad

    # Apply masking
    x_t = x.clone()
    x_t[should_mask] = mask_id

    return x_t, should_mask

In [None]:
# Visualize the masking process on a sample note
sample = dataset[0].unsqueeze(0).to(device)
sample_text = tokenizer.decode(sample[0].cpu().tolist())
print(f"Original note (first 20 tokens):\n{' '.join(sample_text.split()[:20])}\n")

fig, axes = plt.subplots(1, 4, figsize=(20, 3))
for i, t_val in enumerate([0.2, 0.4, 0.6, 0.9]):
    t = torch.tensor([t_val], device=device)
    x_t, mask = forward_mask(sample, t)
    masked_text = tokenizer.decode(x_t[0].cpu().tolist())
    words = masked_text.split()[:20]

    colors = ['red' if w == '[mask]' else 'black' for w in words]
    axes[i].set_xlim(0, 1)
    axes[i].set_ylim(0, len(words))
    for j, (word, color) in enumerate(zip(words, colors)):
        axes[i].text(0.05, len(words) - j - 0.5, word, fontsize=9,
                     color=color, family='monospace')
    axes[i].set_title(f"t = {t_val}")
    axes[i].axis('off')

plt.suptitle("Forward Masking Process on a Clinical Note", fontsize=14)
plt.tight_layout()
plt.show()

## 7. Training the Diffusion Model

### 7.1 Diffusion Training Loss

The training objective is the weighted cross-entropy at masked positions:

$$\mathcal{L} = -\mathbb{E}_{t} \left[ \frac{1}{t \cdot L} \sum_{i: x_t^i = \texttt{[MASK]}} \log p_\theta(x_0^i \mid x_t) \right]$$

In [None]:
def diffusion_loss(model, x, pad_id=PAD_ID, mask_id=MASK_ID):
    """
    Compute the masked diffusion training loss.

    Steps:
    1. Sample t ~ U(0.02, 1.0) for each batch element
    2. Apply forward masking
    3. Get model predictions at masked positions
    4. Compute weighted cross-entropy loss

    Args:
        model: ClinicalDiffusionLM
        x: (B, L) clean token IDs

    Returns:
        loss: scalar loss value
    """
    B, L = x.shape

    # Sample masking ratios
    t = torch.rand(B, device=x.device) * 0.98 + 0.02  # U(0.02, 1.0)

    # Forward mask
    x_t, mask = forward_mask(x, t, pad_id, mask_id)

    # Padding mask for the Transformer
    pad_mask = (x == pad_id)

    # Get predictions
    logits = model(x_t, t, pad_mask=pad_mask)

    # Cross-entropy at masked positions only
    loss_per_token = F.cross_entropy(
        logits.reshape(-1, model.vocab_size),
        x.reshape(-1),
        reduction='none'
    ).reshape(B, L)

    # Zero out loss at non-masked and padding positions
    loss_per_token = loss_per_token * mask.float()

    # Weight by 1/(t * L) -- the ELBO-derived importance weight
    n_masked = mask.float().sum(dim=1).clamp(min=1)  # avoid div by zero
    loss_per_sample = loss_per_token.sum(dim=1) / (t * L)

    return loss_per_sample.mean()

### 7.2 Training Loop

In [None]:
optimizer = torch.optim.AdamW(diff_model.parameters(), lr=3e-4, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=15, eta_min=1e-5)

diff_losses = []
val_losses = []

print("Training diffusion model...")
for epoch in range(15):
    # Train
    diff_model.train()
    epoch_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        loss = diffusion_loss(diff_model, batch)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(diff_model.parameters(), 1.0)
        optimizer.step()
        epoch_loss += loss.item()
    epoch_loss /= len(train_loader)
    diff_losses.append(epoch_loss)

    # Validate
    diff_model.eval()
    v_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            v_loss += diffusion_loss(diff_model, batch).item()
    v_loss /= len(val_loader)
    val_losses.append(v_loss)

    scheduler.step()

    print(f"Epoch {epoch+1:2d}: train_loss={epoch_loss:.4f}  val_loss={v_loss:.4f}  "
          f"lr={scheduler.get_last_lr()[0]:.2e}")

print("\nDiffusion model training complete!")

In [None]:
# Training curves
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(range(1, len(diff_losses)+1), diff_losses, 'b-o', label='Train Loss', markersize=4)
ax.plot(range(1, len(val_losses)+1), val_losses, 'r-o', label='Val Loss', markersize=4)
ax.set_xlabel("Epoch")
ax.set_ylabel("Diffusion Loss")
ax.set_title("Diffusion Model Training Progress")
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 8. Your Turn: Implement Confidence-Based Generation

This is the core generation algorithm. Starting from a fully masked sequence, the model iteratively unmasks tokens by selecting the most confident predictions at each step.

In [None]:
@torch.no_grad()
def generate_diffusion(model, tokenizer, n_steps=20, max_len=128,
                       temperature=0.8, device='cuda'):
    """
    Generate a clinical note using confidence-based iterative unmasking.

    Algorithm:
    1. Start with a sequence of all [MASK] tokens
    2. For each step s = 1, ..., n_steps:
       a. Compute the current masking ratio t = 1 - s/n_steps
       b. Forward pass: get logits at all positions
       c. Divide logits by temperature
       d. Apply softmax to get probabilities
       e. Sample tokens from the distribution
       f. Compute confidence = probability of the sampled token
       g. Determine how many tokens to unmask at this step:
          k = (number of currently masked tokens) // (n_steps - s + 1)
       h. Among currently masked positions, keep the top-k most confident
       i. Remask all other positions
    3. Return the final sequence

    Args:
        model: ClinicalDiffusionLM
        tokenizer: SimpleTokenizer
        n_steps: Number of denoising steps
        max_len: Length of generated sequence
        temperature: Sampling temperature (lower = more conservative)
        device: torch device

    Returns:
        generated_ids: List[int], the generated token IDs
        steps_history: List[List[int]], token IDs at each step (for visualization)

    Hints:
        - Use torch.multinomial(probs, 1) for sampling
        - Use torch.gather to get the probability of each sampled token
        - Use torch.topk to find the most confident positions
        - Masked positions are where x == MASK_ID
    """
    MASK_ID = tokenizer.word2idx["[MASK]"]

    # Start fully masked
    x = torch.full((1, max_len), MASK_ID, dtype=torch.long, device=device)
    steps_history = [x[0].cpu().tolist()]

    for s in range(1, n_steps + 1):
        t = torch.tensor([1.0 - s / n_steps], device=device)

        # ============ TODO ============
        # Step 1: Forward pass to get logits
        #         logits = model(x, t)
        #
        # Step 2: Apply temperature scaling
        #         logits = logits / temperature
        #
        # Step 3: Convert to probabilities with softmax
        #         probs = F.softmax(logits, dim=-1)  (over the vocab dimension)
        #
        # Step 4: Sample tokens at every position
        #         sampled = torch.multinomial(probs[0], 1).squeeze(-1)
        #
        # Step 5: Compute confidence for each sampled token
        #         confidence = torch.gather(probs[0], 1, sampled.unsqueeze(-1)).squeeze(-1)
        #
        # Step 6: Determine which positions are currently masked
        #         is_masked = (x[0] == MASK_ID)
        #
        # Step 7: Calculate how many tokens to unmask this step
        #         n_masked = is_masked.sum().item()
        #         k = max(1, n_masked // (n_steps - s + 1))
        #
        # Step 8: Among masked positions, find the top-k most confident
        #         Set confidence of non-masked positions to -1 so they are not selected
        #         Use torch.topk to get the indices of the k most confident masked positions
        #
        # Step 9: Unmask the top-k positions by assigning their sampled tokens
        #         Keep everything else unchanged
        # ==============================

        pass  # YOUR CODE HERE -- replace this with the steps above

        steps_history.append(x[0].cpu().tolist())

    return x[0].cpu().tolist(), steps_history

In [None]:
# Verification cell -- run after implementing generate_diffusion
generated_ids, history = generate_diffusion(diff_model, tokenizer, n_steps=20,
                                             max_len=MAX_LEN, device=device)
generated_text = tokenizer.decode(generated_ids)
print("Generated clinical note:")
print(generated_text[:500])
print(f"\nTotal tokens: {len(generated_ids)}")
n_masks = sum(1 for t in generated_ids if t == MASK_ID)
print(f"Remaining [MASK] tokens: {n_masks}")
assert n_masks == 0, f"Generation incomplete: {n_masks} masks remaining"
print("\nGeneration function working correctly!")

### Stop and Think

Before moving on, consider these questions:

1. Why do we divide the logits by temperature *before* softmax rather than adjusting the probabilities after softmax?
2. What happens if temperature is very low (e.g., 0.1)? Very high (e.g., 2.0)?
3. Why is confidence-based unmasking better than random unmasking? Think about what information the easy-to-predict tokens give the model when predicting harder tokens.

## 9. Evaluation: Diffusion vs. Autoregressive

Now we compare our diffusion model against the autoregressive baseline on the key metrics.

In [None]:
# Generate notes with both models and measure quality + latency

def compute_rouge_l(generated, reference):
    """Compute ROUGE-L F1 score between two strings."""
    gen_words = generated.lower().split()
    ref_words = reference.lower().split()
    if len(gen_words) == 0 or len(ref_words) == 0:
        return 0.0

    # LCS computation
    m, n = len(gen_words), len(ref_words)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if gen_words[i-1] == ref_words[j-1]:
                dp[i][j] = dp[i-1][j-1] + 1
            else:
                dp[i][j] = max(dp[i-1][j], dp[i][j-1])
    lcs = dp[m][n]

    precision = lcs / m if m > 0 else 0
    recall = lcs / n if n > 0 else 0
    if precision + recall == 0:
        return 0.0
    f1 = 2 * precision * recall / (precision + recall)
    return f1


def check_section_structure(text):
    """Check if a generated note has correct SOAP section ordering."""
    positions = []
    for header in SECTION_HEADERS:
        pos = text.lower().find(header.lower())
        positions.append(pos if pos >= 0 else float('inf'))

    # Count present sections
    present = sum(1 for p in positions if p < float('inf'))
    # Check ordering of present sections
    filtered_pos = [p for p in positions if p < float('inf')]
    correct_order = all(a < b for a, b in zip(filtered_pos, filtered_pos[1:]))

    return present, correct_order


# Evaluate both models
n_eval = 50
ref_notes = [filtered[i]["text"] for i in range(min(n_eval, len(filtered)))]

# Diffusion evaluation
diff_rouges = []
diff_latencies = []
diff_sections = []
print("Evaluating diffusion model...")
for i in tqdm(range(n_eval)):
    start = time.time()
    gen_ids, _ = generate_diffusion(diff_model, tokenizer, n_steps=20,
                                     max_len=MAX_LEN, device=device)
    elapsed = time.time() - start
    diff_latencies.append(elapsed)

    gen_text = tokenizer.decode(gen_ids)
    diff_rouges.append(compute_rouge_l(gen_text, ref_notes[i]))
    n_sec, _ = check_section_structure(gen_text)
    diff_sections.append(n_sec)

# Autoregressive evaluation
ar_rouges = []
ar_latencies = []
ar_sections = []
print("Evaluating autoregressive model...")
for i in tqdm(range(n_eval)):
    gen_text, elapsed = generate_ar(ar_model, tokenizer, max_len=MAX_LEN)
    ar_latencies.append(elapsed)
    ar_rouges.append(compute_rouge_l(gen_text, ref_notes[i]))
    n_sec, _ = check_section_structure(gen_text)
    ar_sections.append(n_sec)

# Results table
print("\n" + "="*60)
print("EVALUATION RESULTS: Diffusion vs. Autoregressive")
print("="*60)
print(f"{'Metric':<25} {'Diffusion':>15} {'Autoregressive':>15}")
print("-"*60)
print(f"{'ROUGE-L (mean)':<25} {np.mean(diff_rouges):>15.3f} {np.mean(ar_rouges):>15.3f}")
print(f"{'Latency mean (ms)':<25} {np.mean(diff_latencies)*1000:>15.1f} {np.mean(ar_latencies)*1000:>15.1f}")
print(f"{'Latency P99 (ms)':<25} {np.percentile(diff_latencies, 99)*1000:>15.1f} {np.percentile(ar_latencies, 99)*1000:>15.1f}")
print(f"{'Sections found (mean)':<25} {np.mean(diff_sections):>15.1f} {np.mean(ar_sections):>15.1f}")
print(f"{'Speedup':<25} {np.mean(ar_latencies)/np.mean(diff_latencies):>15.1f}x {1.0:>15.1f}x")
print("="*60)

In [None]:
# Visualization: latency comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Latency distribution
axes[0].hist(np.array(ar_latencies)*1000, bins=20, alpha=0.6, label='Autoregressive',
             color='#F44336', edgecolor='black')
axes[0].hist(np.array(diff_latencies)*1000, bins=20, alpha=0.6, label='Diffusion',
             color='#2196F3', edgecolor='black')
axes[0].set_xlabel("Latency (ms)")
axes[0].set_ylabel("Count")
axes[0].set_title("Generation Latency Distribution")
axes[0].legend()

# Quality vs Speed scatter
axes[1].scatter(np.array(ar_latencies)*1000, ar_rouges, alpha=0.5,
                label='Autoregressive', color='#F44336', s=40)
axes[1].scatter(np.array(diff_latencies)*1000, diff_rouges, alpha=0.5,
                label='Diffusion', color='#2196F3', s=40)
axes[1].set_xlabel("Latency (ms)")
axes[1].set_ylabel("ROUGE-L")
axes[1].set_title("Quality vs. Speed Tradeoff")
axes[1].legend()

plt.tight_layout()
plt.show()

## 10. Visualizing the Unmasking Process

Watch a clinical note crystallize from a sea of [MASK] tokens into structured text. This is the visual payoff of diffusion-based generation.

In [None]:
# Generate a note and visualize the unmasking trajectory
gen_ids, history = generate_diffusion(diff_model, tokenizer, n_steps=20,
                                       max_len=MAX_LEN, device=device)

# Show steps 1, 5, 10, 15, 20
steps_to_show = [0, 4, 9, 14, 19]  # indices into history
fig, axes = plt.subplots(len(steps_to_show), 1, figsize=(16, len(steps_to_show)*2))

for ax_idx, step_idx in enumerate(steps_to_show):
    step_ids = history[step_idx]
    words = tokenizer.decode(step_ids).split()[:40]  # Show first 40 tokens

    ax = axes[ax_idx]
    ax.set_xlim(0, len(words))
    ax.set_ylim(0, 1)

    for i, word in enumerate(words):
        if word == '[mask]':
            color = '#E0E0E0'
            text_color = '#999999'
        elif word in ['hpi:', 'ros:', 'exam:', 'assessment:', 'plan:']:
            color = '#BBDEFB'
            text_color = '#1565C0'
        else:
            color = '#C8E6C9'
            text_color = '#2E7D32'

        ax.add_patch(plt.Rectangle((i, 0.1), 0.9, 0.8, facecolor=color,
                                    edgecolor='#666666', linewidth=0.5))
        ax.text(i + 0.45, 0.5, word[:8], ha='center', va='center',
                fontsize=7, color=text_color, family='monospace')

    ax.set_ylabel(f"Step {step_idx + 1}", fontsize=11, fontweight='bold')
    ax.set_xticks([])
    ax.set_yticks([])

plt.suptitle("Clinical Note Unmasking Trajectory", fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 11. Error Analysis

### TODO: Implement Error Categorization

Analyze the failure modes of the diffusion model by categorizing errors in generated notes.

In [None]:
def categorize_errors(generated_text, reference_text=None):
    """
    Categorize errors in a generated clinical note.

    Error types:
    1. Section confusion: Content appears under the wrong SOAP section
    2. Repetition: Repeated phrases within a section (3+ consecutive
       repeated words)
    3. Missing sections: Standard SOAP sections that are absent
    4. Truncation: Incomplete sentences (ending without punctuation
       or standard clinical phrasing)

    Args:
        generated_text: The generated clinical note text
        reference_text: Optional reference note for comparison

    Returns:
        errors: Dict with keys:
            - 'section_confusion': List of (expected_section, actual_section, content) tuples
            - 'repetitions': List of repeated phrases found
            - 'missing_sections': List of section names that are absent
            - 'truncations': int, count of apparently truncated sentences
            - 'total_errors': int, total error count

    Hints:
        - Use SECTION_HEADERS to check which sections are present
        - To detect repetition, look for 3+ word sequences that appear
          more than once in the same section
        - A truncated sentence is one that ends with a common word rather
          than a period, comma, or clinical term
    """
    # ============ TODO ============
    # Step 1: Check which SOAP sections are present/missing
    # Step 2: For each section, check for repeated phrases
    # Step 3: Check for truncated sentences (heuristic: last word
    #         is not punctuation or a known clinical end-word)
    # Step 4: If reference is provided, check for section confusion
    #         (content from one section appearing in another)
    # Step 5: Return the error dict
    # ==============================

    errors = ???  # YOUR CODE HERE

    return errors

In [None]:
# Verification cell
print("Analyzing errors in 20 generated notes...")
all_errors = {"section_confusion": 0, "repetitions": 0,
              "missing_sections": 0, "truncations": 0}
for i in range(20):
    gen_ids, _ = generate_diffusion(diff_model, tokenizer, n_steps=20,
                                     max_len=MAX_LEN, device=device)
    gen_text = tokenizer.decode(gen_ids)
    errs = categorize_errors(gen_text)
    for k in all_errors:
        if isinstance(errs.get(k), list):
            all_errors[k] += len(errs[k])
        elif isinstance(errs.get(k), int):
            all_errors[k] += errs[k]

print("\nError summary (across 20 notes):")
for error_type, count in all_errors.items():
    print(f"  {error_type}: {count}")
assert isinstance(all_errors, dict), "categorize_errors should return a dict"
print("\nError analysis implementation complete!")

### Thought Questions: Error Analysis

1. Which error type is most common? Why might diffusion models be prone to this particular failure mode?
2. How might the section structure loss ($\mathcal{L}_{\text{section}}$) from the technical formulation help reduce section confusion errors?
3. If you could add one post-processing rule to fix the most common error, what would it be?

## 12. Scalability and Deployment Benchmarking

### TODO: Inference Benchmarking

Profile the diffusion model's inference performance across different step counts to map the quality-speed tradeoff.

In [None]:
def benchmark_inference(model, tokenizer, step_counts, n_samples=10,
                        max_len=128, device='cuda'):
    """
    Benchmark diffusion generation across different step counts.

    For each step count, measure:
    - Mean generation latency
    - Mean ROUGE-L quality (against random reference notes)
    - Tokens per second throughput

    Args:
        model: ClinicalDiffusionLM
        tokenizer: SimpleTokenizer
        step_counts: List[int], denoising step counts to benchmark
        n_samples: Number of samples per step count
        max_len: Sequence length
        device: torch device

    Returns:
        results: Dict with keys:
            - 'step_counts': List[int]
            - 'latencies_ms': List[float], mean latency per step count
            - 'rouge_scores': List[float], mean ROUGE-L per step count
            - 'tokens_per_sec': List[float], throughput per step count

    Hints:
        - Use time.time() for latency measurement
        - Warm up with 2 dummy generations before timing
        - Tokens per second = max_len / (latency in seconds)
    """
    # ============ TODO ============
    # Step 1: For each step count, generate n_samples notes
    # Step 2: Measure latency for each generation
    # Step 3: Compute ROUGE-L against reference notes
    # Step 4: Calculate tokens/second throughput
    # Step 5: Return results dict
    # ==============================

    results = ???  # YOUR CODE HERE

    return results

In [None]:
# Verification cell
step_counts = [1, 3, 5, 10, 15, 20]
results = benchmark_inference(diff_model, tokenizer, step_counts,
                               n_samples=5, max_len=MAX_LEN, device=device)

print("Inference benchmark results:")
print(f"{'Steps':>6} {'Latency (ms)':>14} {'ROUGE-L':>10} {'Tok/sec':>10}")
print("-" * 45)
for i, steps in enumerate(results['step_counts']):
    print(f"{steps:>6} {results['latencies_ms'][i]:>14.1f} "
          f"{results['rouge_scores'][i]:>10.3f} {results['tokens_per_sec'][i]:>10.0f}")

assert len(results['step_counts']) == len(step_counts), "Missing step counts"
assert results['latencies_ms'][0] < results['latencies_ms'][-1], \
    "Latency should increase with more steps"
print("\nBenchmark implementation correct!")

In [None]:
# Plot the quality-speed tradeoff
fig, ax1 = plt.subplots(figsize=(10, 6))
ax2 = ax1.twinx()

ax1.plot(results['step_counts'], results['rouge_scores'], 'b-o',
         linewidth=2, markersize=8, label='ROUGE-L Quality')
ax2.plot(results['step_counts'], results['latencies_ms'], 'g-s',
         linewidth=2, markersize=8, label='Latency (ms)')

ax1.set_xlabel("Denoising Steps", fontsize=12)
ax1.set_ylabel("ROUGE-L Score", color='blue', fontsize=12)
ax2.set_ylabel("Latency (ms)", color='green', fontsize=12)
ax1.tick_params(axis='y', labelcolor='blue')
ax2.tick_params(axis='y', labelcolor='green')

# Mark the sweet spot
sweet_spot_idx = len(results['step_counts']) // 2
ax1.axvline(x=results['step_counts'][sweet_spot_idx], color='red',
            linestyle='--', alpha=0.5, label='Sweet Spot')

lines1, labels1 = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax1.legend(lines1 + lines2, labels1 + labels2, loc='center right')

plt.title("Quality vs. Speed Tradeoff: Choosing the Right Step Count",
          fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

## 13. Infilling Demonstration

One of the key advantages of diffusion models over autoregressive models: native infilling. We can fix some sections of a note and let the model generate the rest.

In [None]:
def infill_note(model, tokenizer, partial_note_text, sections_to_generate,
                n_steps=20, max_len=128, temperature=0.8, device='cuda'):
    """
    Generate missing sections of a partially completed clinical note.

    Args:
        model: ClinicalDiffusionLM
        tokenizer: SimpleTokenizer
        partial_note_text: Text with some sections filled in
        sections_to_generate: List of section names to generate (e.g., ["PLAN"])
        n_steps: Denoising steps
        max_len: Sequence length
        temperature: Sampling temperature
        device: torch device

    Returns:
        completed_text: The full note with generated sections
    """
    MASK_ID = tokenizer.word2idx["[MASK]"]
    PAD_ID = tokenizer.word2idx["[PAD]"]

    # Encode the partial note
    tokens = tokenizer.encode(partial_note_text)[:max_len]
    tokens = tokens + [PAD_ID] * (max_len - len(tokens))
    x = torch.tensor([tokens], dtype=torch.long, device=device)

    # Identify which positions to mask (positions belonging to sections_to_generate)
    text_lower = partial_note_text.lower()
    for section in sections_to_generate:
        header = f"{section.lower()}:"
        start = text_lower.find(header)
        if start >= 0:
            # Find the end of this section (next header or end of text)
            end = len(partial_note_text)
            for other_header in SECTION_HEADERS:
                pos = text_lower.find(other_header.lower(), start + len(header))
                if pos >= 0 and pos < end:
                    end = pos
            # Mask the tokens in this section range
            # Approximate: mask tokens corresponding to words in this range
            section_text = partial_note_text[start:end]
            section_tokens = tokenizer.encode(section_text)
            # Find and mask these tokens in x
            word_idx = 0
            all_words = partial_note_text.lower().split()
            section_words = section_text.lower().split()
            for i, word in enumerate(all_words):
                if i < max_len and word in [w.lower() for w in section_words]:
                    x[0, i] = MASK_ID

    # Run diffusion generation on masked positions only
    for s in range(1, n_steps + 1):
        t = torch.tensor([1.0 - s / n_steps], device=device)
        logits = model(x, t) / temperature
        probs = F.softmax(logits, dim=-1)
        sampled = torch.multinomial(probs[0], 1).squeeze(-1)
        confidence = torch.gather(probs[0], 1, sampled.unsqueeze(-1)).squeeze(-1)

        is_masked = (x[0] == MASK_ID)
        n_masked = is_masked.sum().item()
        if n_masked == 0:
            break

        k = max(1, n_masked // (n_steps - s + 1))
        conf_masked = confidence.clone()
        conf_masked[~is_masked] = -1.0
        _, top_indices = torch.topk(conf_masked, min(k, n_masked))
        x[0, top_indices] = sampled[top_indices]

    return tokenizer.decode(x[0].cpu().tolist())


# Demonstration: infill the PLAN section
partial = filtered[0]["text"]
# Keep everything except the PLAN section
plan_start = partial.lower().find("plan:")
if plan_start >= 0:
    partial_no_plan = partial[:plan_start] + "PLAN: [to be generated]"
else:
    partial_no_plan = partial

print("Partial note (PLAN section removed):")
print(partial_no_plan[:300])
print("\n" + "-"*50 + "\n")

completed = infill_note(diff_model, tokenizer, partial_no_plan,
                        sections_to_generate=["PLAN"],
                        n_steps=20, device=device)
print("Completed note (PLAN generated by diffusion):")
print(completed[:500])

## 14. Ethical and Regulatory Analysis

### TODO: Ethical Impact Assessment

Write a brief ethical impact assessment for MedScribe's diffusion-based clinical documentation system.

In [None]:
def ethical_impact_assessment():
    """
    Generate a structured ethical impact assessment for deploying
    a diffusion language model in clinical documentation.

    Returns:
        assessment: Dict with keys:
            - 'bias_risks': List of identified bias risks with mitigation strategies
            - 'privacy_concerns': List of privacy/HIPAA concerns with safeguards
            - 'liability_framework': Description of who is responsible for errors
            - 'transparency_measures': How physicians can understand/verify outputs
            - 'overall_risk_level': 'low', 'medium', or 'high' with justification

    Consider:
        - The model trains on historical clinical notes, which may reflect
          demographic disparities in care quality
        - Diffusion models generate all tokens simultaneously, making it
          harder to trace which input influenced which output
        - Clinical notes are legal documents; errors have real patient consequences
        - HIPAA requires that patient data be protected at all stages

    This is an open-ended assessment. There are no wrong answers, but
    your analysis should be specific to the clinical documentation use case
    and reference concrete risks and mitigations.
    """
    # ============ TODO ============
    # Write your ethical impact assessment here.
    # Return a dict with the keys described above.
    # Each value should contain substantive analysis, not placeholder text.
    #
    # Example structure for bias_risks:
    # [
    #     {
    #         "risk": "Demographic documentation quality disparity",
    #         "description": "Notes for certain patient populations may be...",
    #         "mitigation": "Stratified evaluation across demographics..."
    #     },
    #     ...
    # ]
    # ==============================

    assessment = ???  # YOUR CODE HERE

    return assessment

In [None]:
# Verification cell
assessment = ethical_impact_assessment()
required_keys = ['bias_risks', 'privacy_concerns', 'liability_framework',
                 'transparency_measures', 'overall_risk_level']
for key in required_keys:
    assert key in assessment, f"Missing key: {key}"
    assert assessment[key] is not None and assessment[key] != "???", \
        f"Key '{key}' must contain substantive content"
print("Ethical impact assessment structure is complete!")
print(f"\nOverall risk level: {assessment['overall_risk_level']}")
print(f"Bias risks identified: {len(assessment['bias_risks'])}")
print(f"Privacy concerns identified: {len(assessment['privacy_concerns'])}")

### Thought Questions: Ethics

1. If the diffusion model generates a note with an incorrect medication dosage that leads to patient harm, who bears legal responsibility -- the physician who approved the note, the AI company, or both?
2. How would you design an audit trail for AI-generated clinical notes that satisfies both HIPAA requirements and malpractice liability concerns?
3. What specific fairness metrics would you track to ensure the model performs equally well across patient demographics (age, race, gender, primary language)?

## 15. Summary and Next Steps

Congratulations! In this notebook, you have:

1. **Built a synthetic clinical notes dataset** mimicking MIMIC-III structure
2. **Implemented and trained an autoregressive baseline** to establish quality and latency benchmarks
3. **Built a masked diffusion Transformer** with time conditioning and bidirectional attention
4. **Trained the diffusion model** using the ELBO-derived weighted cross-entropy loss
5. **Implemented confidence-based generation** -- the core iterative unmasking algorithm
6. **Compared diffusion vs. autoregressive** on quality, latency, and structural metrics
7. **Visualized the unmasking process** showing how clinical notes crystallize from masks
8. **Benchmarked the quality-speed tradeoff** across different step counts
9. **Demonstrated native infilling** -- completing partial notes without prompt engineering
10. **Conducted an ethical impact assessment** for healthcare AI deployment

For further reading, see **Section 4 of the case study document** which covers the full production system design: API endpoints, serving infrastructure, latency budgets, monitoring, A/B testing, CI/CD pipelines, and cost analysis for deploying this system at MedScribe's scale.

### Key Takeaways

- Diffusion LLMs trade sequential generation for parallel prediction, achieving order-of-magnitude latency improvements
- Confidence-based unmasking naturally prioritizes easy predictions first, creating a "structure before content" generation pattern
- Bidirectional attention enables capabilities impossible with autoregressive models: infilling, edit propagation, and section-level regeneration
- The quality-speed tradeoff curve has diminishing returns: 10-20 steps capture most of the quality, making sub-second generation feasible
- Healthcare AI requires careful ethical analysis: bias, privacy, liability, and transparency are not optional considerations