In [None]:
# Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime -> Change runtime type -> GPU")

print(f"\nPython {sys.version.split()[0]}")
print(f"PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"Random seed set to {SEED}")

%matplotlib inline

# Case Study: Training Pipeline Engineering
## Building a Domain-Specific Medical Language Model

---

**Scenario:** You are an ML engineer at **Aethon Health**, a Series A healthcare AI startup building AI-assisted radiology workflows for community hospitals. Your task is to engineer the complete training pipeline -- tokenization, data loading, and optimization -- to train a 500M parameter language model for radiology report generation.

**Current system:** A retrieval-augmented approach that fails on 35% of studies involving multi-system findings, forcing radiologists to rewrite reports from scratch.

**Budget:** $12,000 compute (4x A100 for ~72 hours). No room for failed training runs.

**Key Insight:** The Transformer architecture is well-understood. The real engineering challenge is the **training pipeline**: domain tokenization, efficient data loading via sequence packing, and stable optimization with careful learning rate scheduling.

---

### What You Will Build

1. **Domain-specific BPE tokenizer** that preserves medical terms as single tokens
2. **Sequence packer** that eliminates 40%+ wasted compute from padding
3. **Masked loss function** that correctly handles packed sequences
4. **Training loop** with warmup + cosine decay, gradient clipping, and mixed precision
5. **Stability analysis dashboard** for monitoring training health
6. **Report generation** with qualitative evaluation

### AI Teaching Assistant

Need help with this notebook? Open the **AI Teaching Assistant** -- it has already read this entire notebook and can help with concepts, code, and exercises.

**[Open AI Teaching Assistant](https://pods.vizuara.ai/courses/build-llm/practice/4/assistant)**

*Tip: Open it in a separate tab and work through this notebook side-by-side.*

## 1. Environment Setup and Data Loading

We install `tiktoken` (OpenAI's BPE tokenizer library) to compare our domain tokenizer against the standard GPT-2 tokenizer, and use PyTorch for model training.

In [None]:
!pip install -q tiktoken

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import math
import time
import re
from collections import Counter, defaultdict
import tiktoken

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

### Simulating the Radiology Report Corpus

Aethon Health has 2.3 million de-identified radiology reports. We simulate a representative corpus that captures the key properties:
- Structured format (FINDINGS / IMPRESSION sections)
- Domain-specific vocabulary (pneumoperitoneum, hepatosplenomegaly, etc.)
- Highly variable report lengths (80-1200 tokens)

In [None]:
# Simulated radiology report corpus
# In production, this would be 2.3M de-identified reports from hospital partners
medical_corpus = """
FINDINGS: The heart is mildly enlarged with cardiomegaly. There is a small left
pleural effusion. Bibasilar atelectasis is noted. No pneumothorax. The mediastinal
contour is within normal limits. No focal consolidation. The osseous structures
are unremarkable.

IMPRESSION: Mild cardiomegaly with small left pleural effusion and bibasilar
atelectasis. No acute cardiopulmonary process.

FINDINGS: CT of the abdomen and pelvis with contrast. The liver demonstrates
hepatosplenomegaly. There is cholelithiasis without evidence of cholecystitis.
A small amount of free fluid is seen in the pelvis. No lymphadenopathy.
The kidneys show mild bilateral hydronephrosis. No pneumoperitoneum.

IMPRESSION: Hepatosplenomegaly with cholelithiasis. Mild bilateral hydronephrosis.
Recommend clinical correlation.

FINDINGS: High resolution CT of the chest demonstrates diffuse ground-glass
opacities bilaterally. There is mediastinal lymphadenopathy measuring up to 1.5 cm.
Findings are consistent with interstitial lung disease. Small bilateral pleural
effusions are present. No pericardial effusion. The heart size is normal.

IMPRESSION: Diffuse ground-glass opacities consistent with interstitial lung
disease. Mediastinal lymphadenopathy. Bilateral pleural effusions.

FINDINGS: PA and lateral chest radiograph. The cardiac silhouette is within
normal limits. The lungs are clear bilaterally without focal consolidation,
pleural effusion, or pneumothorax. No acute osseous abnormality.

IMPRESSION: No acute cardiopulmonary abnormality.

FINDINGS: CT of the chest with contrast. There is a 2.3 cm pulmonary nodule in
the right upper lobe. Mediastinal lymphadenopathy is present with the largest
node measuring 1.8 cm in the subcarinal region. Small right pleural effusion.
Mild emphysematous changes are noted bilaterally. The heart demonstrates mild
cardiomegaly. No pericardial effusion. Bilateral lower lobe bronchiectasis.

IMPRESSION: Right upper lobe pulmonary nodule measuring 2.3 cm. Subcarinal
lymphadenopathy. Recommend PET-CT for further evaluation. Additional findings
include right pleural effusion, cardiomegaly, and bronchiectasis.

FINDINGS: CT abdomen and pelvis without contrast. The liver is normal in size
and attenuation. The gallbladder is surgically absent. Mild diverticulosis of
the sigmoid colon without evidence of acute diverticulitis. The kidneys are
unremarkable. The aorta demonstrates mild atherosclerotic calcification.
No free fluid or lymphadenopathy.

IMPRESSION: Sigmoid diverticulosis without diverticulitis. Post-cholecystectomy
state. Mild aortic atherosclerosis.
""" * 200  # Repeat to simulate larger corpus

print(f"Corpus size: {len(medical_corpus):,} characters")
print(f"Approximate words: {len(medical_corpus.split()):,}")

### Visualizing Report Length Distribution

Understanding the length distribution is critical for designing the data loader. Radiology reports range from short X-ray findings (~80 tokens) to complex multi-system CT reports (~1200 tokens).

In [None]:
# Simulate the real-world report length distribution
np.random.seed(42)
report_lengths = np.concatenate([
    np.random.normal(80, 20, 5000).astype(int),     # Short (chest X-ray)
    np.random.normal(350, 80, 3000).astype(int),     # Medium (CT scans)
    np.random.normal(1200, 200, 1000).astype(int),   # Long (complex multi-system)
])
report_lengths = np.clip(report_lengths, 20, 2000)

print(f"Total reports: {len(report_lengths):,}")
print(f"Mean length: {report_lengths.mean():.0f} tokens")
print(f"Median length: {np.median(report_lengths):.0f} tokens")
print(f"Min / Max: {report_lengths.min()} / {report_lengths.max()} tokens")
print(f"\nBreakdown:")
print(f"  Short (<150 tokens):  {(report_lengths < 150).sum():,} reports ({(report_lengths < 150).mean()*100:.1f}%)")
print(f"  Medium (150-600):     {((report_lengths >= 150) & (report_lengths < 600)).sum():,} reports")
print(f"  Long (>600):          {(report_lengths >= 600).sum():,} reports")

fig, ax = plt.subplots(1, 1, figsize=(10, 5))
ax.hist(report_lengths, bins=60, color='#3498db', edgecolor='black',
        linewidth=0.5, alpha=0.7)
ax.axvline(x=512, color='red', linestyle='--', linewidth=2, label='Context length (512)')
ax.axvline(x=np.median(report_lengths), color='orange', linestyle='--',
           linewidth=2, label=f'Median ({np.median(report_lengths):.0f})')
ax.set_xlabel('Report Length (tokens)', fontsize=12)
ax.set_ylabel('Count', fontsize=12)
ax.set_title('Aethon Health -- Radiology Report Length Distribution', fontsize=13, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(alpha=0.3)
plt.tight_layout()
plt.show()

---

## 2. The Tokenization Problem

Standard BPE tokenizers (like GPT-2's) are trained on general web text. They fragment medical terminology into meaningless subword pieces, wasting model capacity and inflating sequence lengths.

Let us quantify this problem.

In [None]:
# GPT-2 tokenizer baseline
enc = tiktoken.get_encoding("gpt2")

# Medical terms that appear in radiology reports
medical_terms = [
    "pneumoperitoneum", "hepatosplenomegaly", "cardiomegaly",
    "atelectasis", "pneumothorax", "consolidation",
    "lymphadenopathy", "cholelithiasis", "hydronephrosis",
    "osteophyte", "spondylolisthesis", "bronchiectasis",
    "emphysema", "diverticulitis", "cholecystitis",
    "ground-glass opacities", "mediastinal lymphadenopathy",
    "interstitial lung disease", "pericardial effusion",
    "aortic aneurysm", "pleural effusion", "pulmonary embolism",
]

print("GPT-2 Tokenization of Medical Terms:")
print("=" * 70)
total_tokens = 0
for term in medical_terms:
    token_ids = enc.encode(term)
    token_strings = [enc.decode([t]) for t in token_ids]
    total_tokens += len(token_ids)
    padding = " " * max(1, 35 - len(term))
    print(f"  {term}{padding}-> {len(token_ids)} tokens: {token_strings}")

avg_tokens = total_tokens / len(medical_terms)
print(f"\nAverage tokens per medical term: {avg_tokens:.1f}")
print(f"Compare: average English word ~1.3 tokens")
print(f"Medical terms are {avg_tokens / 1.3:.1f}x more expensive!")
print(f"\nThis means:")
print(f"  (a) Sequences are 30-40% longer than necessary -> quadratic attention cost")
print(f"  (b) Model must learn that fragmented pieces = medical concepts -> wasted capacity")

---

## 3. Building a Domain-Specific BPE Tokenizer

We train a BPE tokenizer from scratch on the medical corpus. The key insight: BPE merges are driven by **frequency**, so training on medical text naturally creates single tokens for common medical terms.

### How BPE Works (Review)

1. Start with a vocabulary of individual characters
2. Count all adjacent character pairs in the corpus
3. Merge the most frequent pair into a new token
4. Repeat until vocabulary reaches target size

### TODO 1: Implement BPE Tokenizer Training

Complete the three helper functions below. The `train_medical_bpe` function orchestrates the full BPE training loop.

In [None]:
def get_word_frequencies(text):
    """
    Split text into words and compute frequency of each word.
    Represent each word as a tuple of characters plus an end-of-word marker '_'.

    Args:
        text: string of text

    Returns:
        dict mapping tuple-of-chars -> frequency

    Example:
        "the the cat" -> {('t','h','e','_'): 2, ('c','a','t','_'): 1}

    Hints:
    - Use text.lower().split() to get words
    - Use Counter for frequency counting
    - Each word becomes tuple(list(word) + ['_'])
    """
    # YOUR CODE HERE
    pass


def get_pair_counts(vocab_freqs):
    """
    Count all adjacent token pairs across the vocabulary.

    Args:
        vocab_freqs: dict mapping word (tuple of tokens) -> frequency

    Returns:
        dict mapping (token_a, token_b) -> count

    Example:
        {('c','a','t','_'): 5} -> {('c','a'): 5, ('a','t'): 5, ('t','_'): 5}

    Hints:
    - For each word, iterate through consecutive pairs
    - Weight each pair count by the word frequency
    """
    # YOUR CODE HERE
    pass


def merge_pair(vocab_freqs, pair_to_merge):
    """
    Merge all occurrences of a character pair into a single token.

    Args:
        vocab_freqs: dict mapping word (tuple of tokens) -> frequency
        pair_to_merge: tuple of (token_a, token_b) to merge

    Returns:
        new vocab_freqs with the pair merged

    Example:
        vocab_freqs = {('c','a','t','_'): 5}
        pair_to_merge = ('c', 'a')
        -> {('ca','t','_'): 5}

    Hints:
    - Scan left to right through each word
    - When you see the pair, combine into one token and skip ahead
    - Otherwise keep the token as-is
    """
    # YOUR CODE HERE
    pass

In [None]:
# Verification: test your BPE helper functions
test_freqs = get_word_frequencies("the the cat")
assert ('t', 'h', 'e', '_') in test_freqs, "Word 'the' should be in vocab"
assert test_freqs[('t', 'h', 'e', '_')] == 2, "'the' should have frequency 2"

test_pairs = get_pair_counts(test_freqs)
assert ('t', 'h') in test_pairs, "Pair ('t','h') should exist"
assert test_pairs[('t', 'h')] == 2, "Pair ('t','h') should have count 2"

test_merged = merge_pair(test_freqs, ('t', 'h'))
assert ('th', 'e', '_') in test_merged, "After merge, 'the' should become ('th','e','_')"

print("All BPE helper function tests passed.")

In [None]:
def train_medical_bpe(text, num_merges=500):
    """Train BPE on medical corpus and return merge rules + vocabulary."""
    vocab_freqs = get_word_frequencies(text)
    merge_rules = []

    for step in range(num_merges):
        pair_counts = get_pair_counts(vocab_freqs)
        if not pair_counts:
            break
        best_pair = max(pair_counts, key=pair_counts.get)
        vocab_freqs = merge_pair(vocab_freqs, best_pair)
        merged_token = best_pair[0] + best_pair[1]
        merge_rules.append((best_pair, merged_token))

        if (step + 1) % 100 == 0:
            print(f"  Merge {step+1}: '{best_pair[0]}' + '{best_pair[1]}' -> '{merged_token}'")

    # Build final vocabulary
    vocab = set()
    for word in vocab_freqs:
        for token in word:
            vocab.add(token)

    return merge_rules, vocab, vocab_freqs


print("Training domain-specific BPE tokenizer on medical corpus...")
print("(In production, this runs on 2.3M reports with vocab_size=32,768)")
print()
merge_rules, vocab, final_vocab = train_medical_bpe(medical_corpus, num_merges=500)
print(f"\nVocabulary size: {len(vocab)}")
print(f"Merge rules learned: {len(merge_rules)}")

### Comparing Domain BPE vs. GPT-2 BPE

Now we encode medical terms with both tokenizers and measure the improvement.

In [None]:
def encode_bpe(text, merge_rules):
    """Encode text using our trained BPE merge rules."""
    words = text.lower().split()
    all_tokens = []
    for word in words:
        tokens = list(word) + ['_']
        for pair, merged in merge_rules:
            new_tokens = []
            i = 0
            while i < len(tokens):
                if (i < len(tokens) - 1 and tokens[i] == pair[0]
                        and tokens[i + 1] == pair[1]):
                    new_tokens.append(merged)
                    i += 2
                else:
                    new_tokens.append(tokens[i])
                    i += 1
            tokens = new_tokens
        all_tokens.extend(tokens)
    return all_tokens


# Compare tokenization efficiency
print("Domain BPE vs GPT-2 BPE on Medical Terms:")
print("=" * 70)
domain_total = 0
gpt2_total = 0

for term in medical_terms:
    domain_tokens = encode_bpe(term, merge_rules)
    gpt2_tokens = enc.encode(term)
    domain_total += len(domain_tokens)
    gpt2_total += len(gpt2_tokens)
    improvement = len(gpt2_tokens) / max(len(domain_tokens), 1)
    padding = " " * max(1, 35 - len(term))
    print(f"  {term}{padding}Domain: {len(domain_tokens):2d}  GPT-2: {len(gpt2_tokens):2d}  ({improvement:.1f}x)")

print(f"\nAverage tokens per term:")
print(f"  Domain BPE: {domain_total / len(medical_terms):.1f}")
print(f"  GPT-2 BPE:  {gpt2_total / len(medical_terms):.1f}")
print(f"  Improvement: {gpt2_total / max(domain_total, 1):.1f}x fewer tokens")

### TODO 2: Measure Domain Vocabulary Coverage

Aethon's medical safety requirement: >99.5% of a curated list of 4,200 essential radiology terms must be representable as single tokens. Compute the coverage metric:

$$\text{Coverage}(V) = \frac{|\{w \in D : w \text{ is a single token in } V\}|}{|D|}$$

In [None]:
def compute_domain_coverage(medical_terms, merge_rules):
    """
    Compute the domain vocabulary coverage metric.

    A term is "covered" if it encodes as a single token (plus the
    end-of-word marker '_'). That is, encode_bpe(term) should return
    at most 2 tokens (the term itself + '_').

    Args:
        medical_terms: list of medical term strings
        merge_rules: BPE merge rules from training

    Returns:
        tuple of (coverage_ratio, covered_terms, uncovered_terms)
        - coverage_ratio: float between 0 and 1
        - covered_terms: list of terms that are single tokens
        - uncovered_terms: list of terms that are fragmented

    Hints:
    - Use encode_bpe() to tokenize each term
    - A term is covered if len(encode_bpe(term, merge_rules)) <= 2
    - Report both the ratio and the specific terms in each category
    """
    # YOUR CODE HERE
    pass


coverage, covered, uncovered = compute_domain_coverage(medical_terms, merge_rules)
print(f"Domain Vocabulary Coverage: {coverage:.1%}")
print(f"\nCovered terms ({len(covered)}): {covered[:10]}")
print(f"Uncovered terms ({len(uncovered)}): {uncovered[:10]}")
print(f"\nTarget: >99.5% -- {'MET' if coverage > 0.995 else 'NOT MET (need more training data)'}")

**Thought Questions:**
1. Why might our small demo corpus not reach 99.5% coverage? What would change with 2.3M real reports?
2. How does vocabulary size affect the tradeoff between sequence length and embedding memory?
3. What happens to rare terms that appear fewer than 100 times in the corpus?

---

## 4. Sequence Packing for Variable-Length Reports

With a fixed context length of 512 tokens:
- Short chest X-ray reports (~80 tokens) waste 84% of each sequence on padding
- Long multi-system reports (~1200 tokens) must be split

**Sequence packing** concatenates multiple short reports into a single sequence with separator tokens, achieving >92% token utilization vs. ~58% with naive padding.

### TODO 3: Implement the Sequence Packer

Complete the `pack` method. The packer should:
1. Concatenate short reports into fixed-length sequences with separators
2. Split long reports across multiple sequences
3. Generate attention masks (1 = real token, 0 = padding/separator)
4. Achieve >90% packing efficiency

In [None]:
class SequencePacker:
    """
    Pack variable-length reports into fixed-size sequences.
    Uses a separator token between reports and generates loss masks.
    """

    def __init__(self, context_length=512, separator_token_id=0):
        self.context_length = context_length
        self.separator_id = separator_token_id

    def pack(self, reports):
        """
        Pack a list of token sequences into fixed-length packed sequences.

        Args:
            reports: list of lists of token IDs (variable length)

        Returns:
            list of (tokens, mask) tuples where:
            - tokens is a list of length context_length
            - mask is a list of length context_length
              (1 for real tokens, 0 for padding and separator tokens)

        Algorithm:
        1. Maintain a 'current' buffer for the sequence being built
        2. For each report:
           a. If buffer is non-empty, add a separator token (mask=0)
           b. If report fits in remaining space, append it (mask=1)
           c. If report doesn't fit, pad+save current buffer, start new one
           d. If report exceeds context_length, split into chunks
        3. Pad and save the final buffer

        Hints:
        - Use self.separator_id for separator tokens
        - Pad with 0s at the end of each sequence
        - The mask should be 0 for both padding AND separator tokens
          (we don't want to compute loss on either)
        """
        # YOUR CODE HERE
        pass

In [None]:
# Verification: test the packer
packer = SequencePacker(context_length=512)

# Generate synthetic reports of various lengths
np.random.seed(42)
test_lengths = [60, 80, 100, 300, 500, 800, 1200, 60, 80, 100]
synthetic_reports = [
    list(np.random.randint(1, 1000, size=length))
    for length in test_lengths
]

packed = packer.pack(synthetic_reports)

# Check correctness
assert all(len(tokens) == 512 for tokens, _ in packed), "All sequences must be 512 tokens"
assert all(len(mask) == 512 for _, mask in packed), "All masks must be 512 elements"

total_real = sum(sum(mask) for _, mask in packed)
total_positions = len(packed) * 512
efficiency = total_real / total_positions * 100

print(f"Input: {len(synthetic_reports)} reports, lengths: {test_lengths}")
print(f"Output: {len(packed)} packed sequences")
print(f"Packing efficiency: {efficiency:.1f}%")
print(f"Without packing: {len(synthetic_reports)} sequences of 512 = {len(synthetic_reports)*512:,} positions")
print(f"With packing: {len(packed)} sequences of 512 = {total_positions:,} positions")
print("Verification passed.")

In [None]:
# Scale test: pack 100 reports with realistic length distribution
np.random.seed(42)
large_reports = [
    list(np.random.randint(1, 1000, size=length))
    for length in np.random.choice([60, 80, 100, 300, 500, 800, 1200], size=100)
]

packed_large = packer.pack(large_reports)
total_real_large = sum(sum(mask) for _, mask in packed_large)
total_pos_large = len(packed_large) * 512
eff_large = total_real_large / total_pos_large * 100

# Compare: naive padding
naive_positions = len(large_reports) * 512
naive_real = sum(min(len(r), 512) for r in large_reports)
naive_eff = naive_real / naive_positions * 100

print(f"Scale test: {len(large_reports)} reports")
print(f"\nNaive padding:")
print(f"  Sequences: {len(large_reports)}, Efficiency: {naive_eff:.1f}%")
print(f"\nSequence packing:")
print(f"  Sequences: {len(packed_large)}, Efficiency: {eff_large:.1f}%")
print(f"\nReduction: {(1 - len(packed_large)/len(large_reports))*100:.1f}% fewer sequences")

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

labels = ['Naive Padding', 'Sequence Packing']
efficiencies = [naive_eff, eff_large]
colors_bar = ['#e74c3c', '#2ecc71']
bars = axes[0].bar(labels, efficiencies, color=colors_bar, edgecolor='black', linewidth=0.5)
axes[0].set_ylabel('Token Utilization (%)', fontsize=12)
axes[0].set_title('Packing Efficiency Comparison', fontsize=13, fontweight='bold')
axes[0].set_ylim(0, 100)
axes[0].grid(axis='y', alpha=0.3)
for bar, eff in zip(bars, efficiencies):
    axes[0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
                f'{eff:.1f}%', ha='center', fontsize=12, fontweight='bold')

# Show sequence counts
seq_counts = [len(large_reports), len(packed_large)]
bars2 = axes[1].bar(labels, seq_counts, color=colors_bar, edgecolor='black', linewidth=0.5)
axes[1].set_ylabel('Number of Sequences', fontsize=12)
axes[1].set_title('Sequence Count Comparison', fontsize=13, fontweight='bold')
axes[1].grid(axis='y', alpha=0.3)
for bar, cnt in zip(bars2, seq_counts):
    axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                str(cnt), ha='center', fontsize=12, fontweight='bold')

plt.suptitle('Aethon Health -- Data Loading Optimization', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

**Thought Questions:**
1. Why do we set the mask to 0 for separator tokens? What would happen if we computed loss on them?
2. The packer uses a greedy first-fit strategy. Could a bin-packing algorithm (e.g., first-fit-decreasing) achieve better efficiency? At what cost?
3. What happens to attention across report boundaries in a packed sequence? Do we need to modify the attention mask?

---

## 5. Model Architecture

We build a simplified GPT-style decoder-only Transformer for demonstration. In production, Aethon would use a 500M parameter model (24 layers, 16 heads, d_model=1024). Our demo uses a smaller model to train quickly on Colab.

### TODO 4: Build the Transformer Language Model

Complete the `__init__` and `forward` methods. The architecture should include:
- Token embedding + positional embedding
- Pre-norm Transformer encoder layers with causal masking
- Final layer norm + output projection (weight-tied with embedding)
- Causal (autoregressive) attention mask

In [None]:
class SimpleTransformerLM(nn.Module):
    """
    GPT-style decoder-only Transformer language model.

    Architecture:
    1. Token embedding (vocab_size -> d_model)
    2. Positional embedding (context_length -> d_model)
    3. Dropout
    4. N Transformer layers (pre-norm, causal self-attention)
    5. Final LayerNorm
    6. Output projection (d_model -> vocab_size, weight-tied)

    Args:
        vocab_size: number of tokens in vocabulary
        d_model: hidden dimension (default 256)
        n_heads: number of attention heads (default 4)
        n_layers: number of transformer layers (default 4)
        context_length: maximum sequence length (default 128)
        dropout: dropout probability (default 0.1)

    Forward:
        Input: x of shape (batch, seq_len) -- token IDs
        Output: logits of shape (batch, seq_len, vocab_size)

    Hints:
    - Use nn.Embedding for both token and positional embeddings
    - Use nn.TransformerEncoderLayer with norm_first=True and batch_first=True
    - Use nn.TransformerEncoder to stack layers
    - Create a causal mask using torch.triu (upper triangular = True)
    - Weight tying: set self.head.weight = self.token_emb.weight
    - In forward: create position indices with torch.arange
    """

    def __init__(self, vocab_size, d_model=256, n_heads=4, n_layers=4,
                 context_length=128, dropout=0.1):
        super().__init__()
        # YOUR CODE HERE
        pass

    def forward(self, x):
        # YOUR CODE HERE
        pass

In [None]:
# Verification: test model architecture
VOCAB_SIZE = 5000
CONTEXT_LENGTH = 128

model = SimpleTransformerLM(
    VOCAB_SIZE, d_model=256, n_heads=4,
    n_layers=4, context_length=CONTEXT_LENGTH
).to(device)

num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params:,}")

# Test forward pass
test_input = torch.randint(0, VOCAB_SIZE, (2, CONTEXT_LENGTH)).to(device)
test_output = model(test_input)
assert test_output.shape == (2, CONTEXT_LENGTH, VOCAB_SIZE), \
    f"Expected (2, {CONTEXT_LENGTH}, {VOCAB_SIZE}), got {test_output.shape}"
print(f"Forward pass: input {test_input.shape} -> output {test_output.shape}")
print("Architecture verification passed.")

---

## 6. Training Data Preparation

We create a simulated tokenized dataset with Zipf-distributed token frequencies (which mimics real text) and a dataset class that supports masked loss for packed sequences.

In [None]:
# Simulate medical tokenized data with Zipf distribution
# (few common words like 'the', 'is', 'no'; many rare medical terms)
torch.manual_seed(42)
np.random.seed(42)

num_tokens = 200_000
token_probs = 1.0 / np.arange(1, VOCAB_SIZE + 1)
token_probs /= token_probs.sum()
all_tokens = np.random.choice(VOCAB_SIZE, size=num_tokens, p=token_probs)

# Create masks: simulate separator tokens every ~100 tokens
masks = np.ones(num_tokens, dtype=np.float32)
for i in range(0, num_tokens, 100):
    masks[i] = 0.0  # separator tokens -- don't compute loss


class MaskedTextDataset(Dataset):
    """Dataset for packed sequences with loss masking."""

    def __init__(self, tokens, masks, context_length):
        self.tokens = torch.tensor(tokens, dtype=torch.long)
        self.masks = torch.tensor(masks, dtype=torch.float)
        self.context_length = context_length

    def __len__(self):
        return len(self.tokens) - self.context_length

    def __getitem__(self, idx):
        x = self.tokens[idx:idx + self.context_length]        # input
        y = self.tokens[idx + 1:idx + self.context_length + 1]  # target (shifted by 1)
        m = self.masks[idx + 1:idx + self.context_length + 1]   # mask for target
        return x, y, m


# Train/val split
BATCH_SIZE = 16
split = int(0.9 * num_tokens)
train_ds = MaskedTextDataset(all_tokens[:split], masks[:split], CONTEXT_LENGTH)
val_ds = MaskedTextDataset(all_tokens[split:], masks[split:], CONTEXT_LENGTH)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

print(f"Training samples: {len(train_ds):,}")
print(f"Validation samples: {len(val_ds):,}")
print(f"Batches per epoch: {len(train_loader):,}")

---

## 7. Masked Cross-Entropy Loss

In packed sequences, we must not compute loss on padding tokens or separator tokens. The masked loss function applies a binary mask to the per-token cross-entropy:

$$\mathcal{L}_{\text{masked}} = -\frac{1}{\sum_t m_t} \sum_{t=1}^{T} m_t \cdot \log P_\theta(x_t \mid x_{<t})$$

### TODO 5: Implement Masked Cross-Entropy Loss

In [None]:
def masked_cross_entropy(logits, targets, mask):
    """
    Cross-entropy loss with mask for packed sequences.

    Args:
        logits: model output, shape (B, T, V) where V is vocab size
        targets: target token IDs, shape (B, T)
        mask: binary mask, shape (B, T). 1 = compute loss, 0 = ignore.

    Returns:
        scalar loss: average cross-entropy over masked positions only

    Hints:
    - Use F.cross_entropy with reduction='none' to get per-token losses
    - Reshape logits to (B*T, V) and targets to (B*T) for cross_entropy
    - Reshape back to (B, T) then multiply by mask
    - Divide by mask.sum() (use .clamp(min=1) to avoid division by zero)
    """
    # YOUR CODE HERE
    pass


# Verification
test_logits = torch.randn(2, 10, VOCAB_SIZE)
test_targets = torch.randint(0, VOCAB_SIZE, (2, 10))
test_mask = torch.tensor([[1,1,1,0,0,0,0,0,0,0],
                          [1,1,1,1,1,0,0,0,0,0]], dtype=torch.float)

loss = masked_cross_entropy(test_logits, test_targets, test_mask)
assert loss.ndim == 0, "Loss should be a scalar"
assert loss.item() > 0, "Loss should be positive"
print(f"Masked loss: {loss.item():.4f}")

# Verify masking works: loss with all-zero mask should be 0
zero_mask = torch.zeros(2, 10)
zero_loss = masked_cross_entropy(test_logits, test_targets, zero_mask)
assert zero_loss.item() == 0.0, "Loss with zero mask should be 0"
print("Masked cross-entropy verification passed.")

---

## 8. Learning Rate Schedule: Warmup + Cosine Decay

Training from scratch requires careful learning rate scheduling:
- **Linear warmup** (2000 steps): prevents large early updates from destabilizing randomly initialized weights
- **Cosine decay**: smoothly reduces the learning rate, allowing fine-grained convergence

### TODO 6: Implement the Learning Rate Schedule

In [None]:
def get_lr(step, warmup_steps, total_steps, lr_max=3e-4, lr_min=1e-5):
    """
    Compute learning rate with linear warmup + cosine decay.

    Args:
        step: current training step
        warmup_steps: number of warmup steps
        total_steps: total number of training steps
        lr_max: peak learning rate (default 3e-4)
        lr_min: minimum learning rate (default 1e-5)

    Returns:
        float: learning rate for this step

    Schedule:
    - Steps [0, warmup_steps): linear ramp from 0 to lr_max
    - Steps [warmup_steps, total_steps]: cosine decay from lr_max to lr_min

    Formulas:
    - Warmup: lr = lr_max * step / warmup_steps
    - Cosine: lr = lr_min + 0.5 * (lr_max - lr_min) * (1 + cos(pi * progress))
      where progress = (step - warmup_steps) / (total_steps - warmup_steps)
    """
    # YOUR CODE HERE
    pass


# Verification + visualization
total_steps = len(train_loader) * 10  # 10 epochs
warmup_steps = int(0.1 * total_steps)

steps = range(total_steps)
lrs = [get_lr(s, warmup_steps, total_steps) for s in steps]

assert abs(lrs[0] - 0.0) < 1e-8, "LR at step 0 should be ~0"
assert abs(lrs[warmup_steps] - 3e-4) < 1e-6, "LR at end of warmup should be lr_max"
assert lrs[-1] >= 1e-5, "LR at end should be >= lr_min"

plt.figure(figsize=(10, 4))
plt.plot(steps, lrs, linewidth=2, color='#2ecc71')
plt.axvline(x=warmup_steps, color='gray', linestyle='--', label=f'End warmup ({warmup_steps})')
plt.xlabel('Step', fontsize=12)
plt.ylabel('Learning Rate', fontsize=12)
plt.title('Warmup + Cosine Decay Schedule', fontsize=13, fontweight='bold')
plt.legend(fontsize=11)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()
print("LR schedule verification passed.")

**Thought Questions:**
1. Why does the warmup prevent training instability? What would happen with a large LR from step 0?
2. Why cosine decay instead of a step schedule or linear decay?
3. How would you choose lr_max for a 500M parameter model vs. our small demo model?

---

## 9. Complete Training Loop

The training loop combines all pipeline components:
- **AdamW optimizer** with beta2=0.95 (better for language models than default 0.999)
- **Gradient clipping** at max_norm=1.0 to prevent gradient spikes
- **Masked loss** for packed sequences
- **Learning rate scheduling** with warmup + cosine decay
- **Metric tracking** for stability monitoring

In [None]:
# Training configuration
EPOCHS = 10
LR_MAX = 3e-4
LR_MIN = 1e-5
WEIGHT_DECAY = 0.1
MAX_GRAD_NORM = 1.0
BETAS = (0.9, 0.95)

optimizer = torch.optim.AdamW(
    model.parameters(), lr=LR_MAX,
    betas=BETAS, weight_decay=WEIGHT_DECAY
)

total_steps = len(train_loader) * EPOCHS
warmup_steps = int(0.1 * total_steps)

# Tracking metrics
train_losses = []
val_losses = []
learning_rates = []
grad_norms = []

print(f"Training for {EPOCHS} epochs ({total_steps:,} steps)")
print(f"Warmup: {warmup_steps} steps")
print(f"Optimizer: AdamW (lr={LR_MAX}, betas={BETAS}, wd={WEIGHT_DECAY})")
print(f"Gradient clipping: max_norm={MAX_GRAD_NORM}")
print("=" * 70)

step = 0
for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0

    for batch_x, batch_y, batch_m in train_loader:
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)
        batch_m = batch_m.to(device)

        # Forward pass
        logits = model(batch_x)
        loss = masked_cross_entropy(logits, batch_y, batch_m)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()

        # Gradient clipping
        total_norm = torch.nn.utils.clip_grad_norm_(
            model.parameters(), max_norm=MAX_GRAD_NORM
        )
        grad_norms.append(total_norm.item())

        # Learning rate update
        lr = get_lr(step, warmup_steps, total_steps, LR_MAX, LR_MIN)
        for pg in optimizer.param_groups:
            pg['lr'] = lr
        learning_rates.append(lr)

        optimizer.step()

        train_losses.append(loss.item())
        epoch_loss += loss.item()
        step += 1

    # Validation
    model.eval()
    val_loss_sum = 0
    val_count = 0
    with torch.no_grad():
        for vx, vy, vm in val_loader:
            vx, vy, vm = vx.to(device), vy.to(device), vm.to(device)
            vlogits = model(vx)
            vloss = masked_cross_entropy(vlogits, vy, vm)
            val_loss_sum += vloss.item()
            val_count += 1

    avg_val = val_loss_sum / max(val_count, 1)
    val_losses.append(avg_val)
    avg_train = epoch_loss / len(train_loader)
    ppl = math.exp(min(avg_val, 20))  # cap to avoid overflow
    mean_grad = np.mean(grad_norms[-len(train_loader):])

    print(f"Epoch {epoch+1:>2d}/{EPOCHS} | "
          f"Train: {avg_train:.4f} | Val: {avg_val:.4f} | "
          f"PPL: {ppl:.1f} | LR: {lr:.2e} | "
          f"Grad Norm: {mean_grad:.2f}")

---

## 10. Training Stability Analysis

With a $12,000 budget and no room for restarts, training stability is critical. We monitor four key signals:
1. **Training loss curve** -- should decrease smoothly
2. **Validation loss + perplexity** -- should track training loss without diverging
3. **Learning rate schedule** -- confirms warmup + decay are working
4. **Gradient norms** -- spikes indicate instability

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# 1. Training loss (raw + smoothed)
ax = axes[0, 0]
ax.plot(train_losses, linewidth=0.5, alpha=0.3, color='#3498db', label='Raw')
window = max(1, len(train_losses) // 20)
smoothed = np.convolve(train_losses, np.ones(window)/window, mode='valid')
ax.plot(range(window-1, len(train_losses)), smoothed,
        linewidth=2, color='#2c3e50', label='Smoothed')
ax.set_xlabel('Step')
ax.set_ylabel('Loss')
ax.set_title('Training Loss', fontweight='bold')
ax.legend()
ax.grid(alpha=0.3)

# 2. Validation loss + perplexity
ax = axes[0, 1]
ax2 = ax.twinx()
epochs_x = range(1, len(val_losses) + 1)
ax.plot(epochs_x, val_losses, 'o-', color='#e74c3c', linewidth=2, label='Val Loss')
ppls = [math.exp(min(v, 20)) for v in val_losses]
ax2.plot(epochs_x, ppls, 's--', color='#9b59b6', linewidth=2, label='Perplexity')
ax.set_xlabel('Epoch')
ax.set_ylabel('Validation Loss', color='#e74c3c')
ax2.set_ylabel('Perplexity', color='#9b59b6')
ax.set_title('Validation Metrics', fontweight='bold')
ax.grid(alpha=0.3)
lines1, labels1 = ax.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax.legend(lines1 + lines2, labels1 + labels2, loc='upper right')

# 3. Learning rate schedule
ax = axes[1, 0]
ax.plot(learning_rates, linewidth=2, color='#2ecc71')
ax.set_xlabel('Step')
ax.set_ylabel('Learning Rate')
ax.set_title('LR Schedule (Warmup + Cosine)', fontweight='bold')
ax.grid(alpha=0.3)

# 4. Gradient norms
ax = axes[1, 1]
ax.plot(grad_norms, linewidth=0.5, alpha=0.5, color='#e67e22')
ax.axhline(y=MAX_GRAD_NORM, color='red', linestyle='--', linewidth=1,
           label=f'Clip threshold ({MAX_GRAD_NORM})')
mean_gn = np.mean(grad_norms)
ax.axhline(y=mean_gn, color='black', linestyle=':', linewidth=1,
           label=f'Mean ({mean_gn:.2f})')
ax.set_xlabel('Step')
ax.set_ylabel('Gradient Norm')
ax.set_title('Gradient Norms (Stability Monitor)', fontweight='bold')
ax.legend()
ax.grid(alpha=0.3)

plt.suptitle('Aethon Health -- Training Pipeline Monitoring Dashboard',
             fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

# Stability report
spike_threshold = 5 * np.mean(grad_norms)
num_spikes = sum(1 for g in grad_norms if g > spike_threshold)
print(f"\nStability Report:")
print(f"  Final validation loss: {val_losses[-1]:.4f}")
print(f"  Final perplexity: {math.exp(min(val_losses[-1], 20)):.1f}")
print(f"  Mean gradient norm: {np.mean(grad_norms):.3f}")
print(f"  Max gradient norm: {np.max(grad_norms):.3f}")
print(f"  Gradient spikes (>5x mean): {num_spikes}")
print(f"  Training stable: {'Yes' if num_spikes < 10 else 'WARNING - investigate!'}")

### TODO 7: Gradient Clipping Ablation

Train a second model WITHOUT gradient clipping and compare the gradient norm distributions. This demonstrates why clipping is essential for training stability.

In [None]:
def train_with_ablation(clip_grad=True, max_norm=1.0, epochs=3):
    """
    Train a model with or without gradient clipping.

    Args:
        clip_grad: bool, whether to clip gradients
        max_norm: gradient clipping threshold (only used if clip_grad=True)
        epochs: number of training epochs

    Returns:
        dict with 'losses' (list) and 'grad_norms' (list)

    Hints:
    - Create a fresh model and optimizer
    - Run the training loop but track grad norms BEFORE clipping
      (compute the norm manually with torch.nn.utils.clip_grad_norm_
      and record it, but only actually clip if clip_grad=True)
    - Return the tracked metrics
    """
    # YOUR CODE HERE
    pass


print("Running ablation: gradient clipping vs. no clipping...")
# results_clipped = train_with_ablation(clip_grad=True, epochs=3)
# results_unclipped = train_with_ablation(clip_grad=False, epochs=3)
print("(Uncomment the above lines to run the ablation -- takes a few minutes)")

**Thought Questions:**
1. What is the relationship between gradient clipping threshold and effective learning rate?
2. If you observe frequent gradient spikes, what does this suggest about the data or model?
3. How would you set the clipping threshold for a 500M parameter model? Would 1.0 still be appropriate?

---

## 11. Evaluation: Perplexity and Report Quality

### TODO 8: Compute Validation Perplexity

In [None]:
@torch.no_grad()
def compute_perplexity(model, dataloader, device):
    """
    Compute validation perplexity using masked loss.

    Perplexity = exp(average_cross_entropy_loss)

    Args:
        model: trained language model
        dataloader: validation DataLoader yielding (x, y, mask)
        device: torch device

    Returns:
        tuple of (perplexity, average_loss)

    Hints:
    - Set model to eval mode
    - Accumulate total loss and total masked token count
    - Use masked_cross_entropy for each batch but weight by mask.sum()
    - Final perplexity = exp(total_loss / total_tokens)
    - Use math.exp and cap loss at 20 to avoid overflow
    """
    # YOUR CODE HERE
    pass


ppl, avg_loss = compute_perplexity(model, val_loader, device)
print(f"Validation Perplexity: {ppl:.2f}")
print(f"Validation Loss: {avg_loss:.4f}")
print(f"\nTarget: PPL < 12.0 -- {'MET' if ppl < 12.0 else 'NOT MET'}")
print(f"(Note: on synthetic data, absolute PPL is not meaningful.")
print(f" What matters is the training pipeline works correctly.)")

---

## 12. Report Generation (Qualitative)

We implement autoregressive text generation using the trained model. The generation function uses **top-p (nucleus) sampling** for diverse yet coherent text.

In [None]:
def generate(model, start_tokens, max_length=100, temperature=0.8, top_p=0.9):
    """
    Generate text autoregressively with top-p sampling.

    Args:
        model: trained language model
        start_tokens: list of starting token IDs
        max_length: maximum tokens to generate
        temperature: sampling temperature (lower = more deterministic)
        top_p: nucleus sampling threshold

    Returns:
        list of generated token IDs (including start tokens)
    """
    model.eval()
    tokens = list(start_tokens)

    with torch.no_grad():
        for _ in range(max_length):
            # Use the last context_length tokens
            ctx = tokens[-model.context_length:] if hasattr(model, 'context_length') else tokens[-128:]
            x = torch.tensor([ctx], dtype=torch.long).to(device)
            logits = model(x)[:, -1, :]  # last position

            # Temperature scaling
            logits = logits / temperature

            # Top-p (nucleus) sampling
            probs = F.softmax(logits, dim=-1)
            sorted_probs, sorted_indices = torch.sort(probs, descending=True)
            cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

            # Remove tokens with cumulative probability above threshold
            sorted_mask = cumulative_probs - sorted_probs > top_p
            sorted_probs[sorted_mask] = 0.0
            sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)

            # Sample
            idx = torch.multinomial(sorted_probs, num_samples=1)
            next_token = sorted_indices.gather(-1, idx).item()
            tokens.append(next_token)

    return tokens


# Generate from different starting points
print("Generating sequences from trained model:")
print("(Token IDs -- in production these would decode to medical text)")
print("=" * 70)

for i, start in enumerate([[1, 5, 10], [42, 100, 200], [3, 7, 15]]):
    generated = generate(model, start, max_length=30, temperature=0.8)
    print(f"\nSequence {i+1}: {generated[:20]}...")
    print(f"  Length: {len(generated)} tokens")

### TODO 9: Inference Latency Profiling

Aethon's requirement: reports must be generated in under 3 seconds on a single A100 GPU. Profile the inference latency of our model.

In [None]:
def profile_latency(model, device, sequence_lengths=[32, 64, 128, 256],
                    num_runs=50):
    """
    Profile inference latency across different sequence lengths.

    For each sequence length:
    1. Generate a random input of that length
    2. Run num_runs forward passes
    3. Record latency for each run
    4. Report mean, p50, p95, p99

    Args:
        model: trained language model
        device: torch device
        sequence_lengths: list of sequence lengths to test
        num_runs: number of forward passes per length

    Returns:
        dict mapping seq_length -> dict of latency stats

    Hints:
    - Set model to eval mode and use torch.no_grad()
    - Use torch.cuda.synchronize() before timing on GPU
    - Use time.perf_counter() for high-resolution timing
    - Do a warmup run first (discard the first measurement)
    - Report times in milliseconds
    """
    # YOUR CODE HERE
    pass


# latency_results = profile_latency(model, device)
# (Uncomment to run -- provides meaningful results on GPU)
print("Latency profiling: uncomment the above line to run.")
print("On a T4 GPU, expect ~5-15ms per forward pass for our small model.")
print("The production 500M model on A100 targets <3s for 512-token generation.")

---

## 13. Results Summary and Business Impact

Let us compile all metrics and evaluate against Aethon Health's requirements.

In [None]:
print("=" * 65)
print("  AETHON HEALTH -- TRAINING PIPELINE ENGINEERING RESULTS")
print("=" * 65)
print()
print(f"  Model: {num_params:,} parameters (demo)")
print(f"  Production target: ~500M parameters")
print(f"  Training epochs: {EPOCHS}")
print(f"  Final validation loss: {val_losses[-1]:.4f}")
print(f"  Final perplexity: {math.exp(min(val_losses[-1], 20)):.1f}")
print()
print("  Pipeline Components:")
print("  [x] Domain-specific BPE tokenizer (2x token reduction)")
print("  [x] Sequence packing (>90% efficiency vs. ~58% naive)")
print("  [x] Masked cross-entropy loss (packed sequence support)")
print("  [x] AdamW optimizer (beta2=0.95, weight_decay=0.1)")
print("  [x] Warmup + cosine decay learning rate schedule")
print("  [x] Gradient clipping (max_norm=1.0)")
print("  [x] Training stability monitoring dashboard")
print()
print("  Business Impact (if deployed with 500M model):")
print("  - Report time: 12 min -> 4 min for complex cases")
print("  - Addressable market: 800 -> 2,200 hospitals")
print("  - Estimated ARR uplift: $8.4M in 18 months")
print()
print("  Key Insight:")
print("  The Transformer architecture is well-understood. The real")
print("  engineering challenge -- and the competitive advantage -- is")
print("  the training pipeline. Domain tokenization reduced sequence")
print("  length by ~2x, packing eliminated 40%+ wasted compute, and")
print("  careful optimization kept training stable within budget.")
print("=" * 65)

---

## 14. Production Extension: Scaling and Deployment

### Deployment Architecture

```
RadAssist v2 Pipeline
                                                               
  CV Module    ->  Finding     ->  Medical LM         
  (Images)         Extractor       (500M params)      
                                   Context: 512       
                                   Latency: <3s       
                                       |              
                                  Post-Processing     
                                  - Format check      
                                  - Term validation    
                                  - Confidence score   
                                       |              
                                  Radiologist Review   
                                  (Edit + Approve)     
```

### Inference Optimization Roadmap

| Technique | Speedup | Memory Reduction | Complexity |
|-----------|---------|-----------------|------------|
| KV Cache | 2-5x | +50% memory | Low |
| INT8 Quantization | 1.5-2x | 4x reduction | Medium |
| Speculative Decoding | 2-3x | +10% memory | High |
| Flash Attention | 2-4x | 5-20x reduction | Medium |

### Monitoring in Production

| Metric | Threshold | Action |
|--------|-----------|--------|
| Inference latency (p99) | > 3.5s | Scale GPU instances |
| Report rejection rate | > 15% | Retrain or fine-tune |
| Unknown token rate | > 0.5% | Retrain tokenizer |
| Validation perplexity | > 15.0 | Early stop, investigate |

### Ethical Considerations

- All training data is de-identified per HIPAA Safe Harbor guidelines
- The model generates **draft** reports only -- a licensed radiologist must review and approve every report
- Confidence scores flag uncertain reports for immediate human review
- Regular bias audits ensure equitable performance across demographics

---

## Summary

In this case study, you engineered a complete LLM training pipeline:

1. **Domain tokenization**: Trained BPE on medical text, achieving ~2x token reduction over GPT-2
2. **Sequence packing**: Implemented variable-length report packing, improving token utilization from ~58% to >90%
3. **Masked loss**: Built a cross-entropy loss that correctly ignores padding and separator tokens
4. **Learning rate scheduling**: Implemented warmup + cosine decay for stable training from scratch
5. **Training loop**: Combined AdamW, gradient clipping, and all components into a production-grade loop
6. **Stability analysis**: Built a monitoring dashboard to detect gradient spikes and divergence
7. **Evaluation**: Computed perplexity and profiled inference latency

The key takeaway: **the training pipeline is the competitive advantage**, not the architecture. Every component -- from tokenization to optimization -- directly impacts whether a $12,000 training run succeeds or fails.