In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Training Pipeline Case Study: Domain-Specific Medical Language Model

> **Scenario:** You are an ML engineer at Aethon Health, building a domain-specific language model for radiology report generation. Your task is to engineer the training pipeline -- tokenization, data loading, optimization -- to train a 500M parameter model within a \$12,000 compute budget.

## Setup

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import math
import time
import re
from collections import Counter, defaultdict
import subprocess

subprocess.check_call(['pip', 'install', '-q', 'tiktoken'])
import tiktoken

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 1. The Problem: Medical Vocabulary and Standard Tokenizers

Standard tokenizers fragment medical terminology into meaningless pieces. Let us quantify this problem.

In [None]:
# GPT-2 tokenizer
enc = tiktoken.get_encoding("gpt2")

# Medical terms that appear in radiology reports
medical_terms = [
    "pneumoperitoneum", "hepatosplenomegaly", "cardiomegaly",
    "atelectasis", "pneumothorax", "consolidation",
    "lymphadenopathy", "cholelithiasis", "hydronephrosis",
    "osteophyte", "spondylolisthesis", "bronchiectasis",
    "emphysema", "pleural effusion", "pulmonary embolism",
    "ground-glass opacities", "mediastinal lymphadenopathy",
    "interstitial lung disease", "pericardial effusion",
    "aortic aneurysm", "diverticulitis", "cholecystitis",
]

print("GPT-2 Tokenization of Medical Terms:")
print("=" * 65)
total_tokens = 0
for term in medical_terms:
    token_ids = enc.encode(term)
    token_strings = [enc.decode([t]) for t in token_ids]
    total_tokens += len(token_ids)
    padding = " " * (30 - len(term))
    print(f"  {term}{padding} -> {len(token_ids)} tokens: {token_strings}")

avg_tokens = total_tokens / len(medical_terms)
print(f"\nAverage tokens per medical term: {avg_tokens:.1f}")
print(f"Compare: average English word = ~1.3 tokens")
print(f"Medical terms are {avg_tokens / 1.3:.1f}x more expensive!")

## 2. Building a Domain-Specific BPE Tokenizer

In [None]:
# Simulate a medical corpus
medical_corpus = """
FINDINGS: The heart is mildly enlarged with cardiomegaly. There is a small left
pleural effusion. Bibasilar atelectasis is noted. No pneumothorax. The mediastinal
contour is within normal limits. No focal consolidation. The osseous structures
are unremarkable.

IMPRESSION: Mild cardiomegaly with small left pleural effusion and bibasilar
atelectasis. No acute cardiopulmonary process.

FINDINGS: CT of the abdomen and pelvis with contrast. The liver demonstrates
hepatosplenomegaly. There is cholelithiasis without evidence of cholecystitis.
A small amount of free fluid is seen in the pelvis. No lymphadenopathy.
The kidneys show mild bilateral hydronephrosis. No pneumoperitoneum.

IMPRESSION: Hepatosplenomegaly with cholelithiasis. Mild bilateral hydronephrosis.
Recommend clinical correlation.

FINDINGS: High resolution CT of the chest demonstrates diffuse ground-glass
opacities bilaterally. There is mediastinal lymphadenopathy measuring up to 1.5 cm.
Findings are consistent with interstitial lung disease. Small bilateral pleural
effusions are present. No pericardial effusion. The heart size is normal.

IMPRESSION: Diffuse ground-glass opacities consistent with interstitial lung
disease. Mediastinal lymphadenopathy. Bilateral pleural effusions.
""" * 200  # Repeat for training

def get_word_frequencies(text):
    words = text.lower().split()
    word_freqs = Counter(words)
    vocab_freqs = {}
    for word, freq in word_freqs.items():
        char_tuple = tuple(list(word) + ['_'])
        vocab_freqs[char_tuple] = freq
    return vocab_freqs

def get_pair_counts(vocab_freqs):
    pair_counts = defaultdict(int)
    for word, freq in vocab_freqs.items():
        for i in range(len(word) - 1):
            pair_counts[(word[i], word[i + 1])] += freq
    return pair_counts

def merge_pair(vocab_freqs, pair_to_merge):
    new_vocab = {}
    for word, freq in vocab_freqs.items():
        new_word = []
        i = 0
        while i < len(word):
            if (i < len(word) - 1 and word[i] == pair_to_merge[0]
                    and word[i + 1] == pair_to_merge[1]):
                new_word.append(pair_to_merge[0] + pair_to_merge[1])
                i += 2
            else:
                new_word.append(word[i])
                i += 1
        new_vocab[tuple(new_word)] = freq
    return new_vocab

def train_medical_bpe(text, num_merges=500):
    """Train BPE on medical corpus."""
    vocab_freqs = get_word_frequencies(text)
    merge_rules = []

    for step in range(num_merges):
        pair_counts = get_pair_counts(vocab_freqs)
        if not pair_counts:
            break
        best_pair = max(pair_counts, key=pair_counts.get)
        vocab_freqs = merge_pair(vocab_freqs, best_pair)
        merged_token = best_pair[0] + best_pair[1]
        merge_rules.append((best_pair, merged_token))

    # Build vocabulary
    vocab = set()
    for word in vocab_freqs:
        for token in word:
            vocab.add(token)

    return merge_rules, vocab, vocab_freqs

print("Training domain-specific BPE tokenizer...")
merge_rules, vocab, final_vocab = train_medical_bpe(medical_corpus, num_merges=500)
print(f"Vocabulary size: {len(vocab)}")
print(f"Merge rules learned: {len(merge_rules)}")

In [None]:
# Measure domain vocabulary coverage
def encode_bpe(text, merge_rules):
    words = text.lower().split()
    all_tokens = []
    for word in words:
        tokens = list(word) + ['_']
        for pair, merged in merge_rules:
            new_tokens = []
            i = 0
            while i < len(tokens):
                if (i < len(tokens) - 1 and tokens[i] == pair[0]
                        and tokens[i + 1] == pair[1]):
                    new_tokens.append(merged)
                    i += 2
                else:
                    new_tokens.append(tokens[i])
                    i += 1
            tokens = new_tokens
        all_tokens.extend(tokens)
    return all_tokens

# Compare tokenization efficiency
print("Domain BPE vs GPT-2 BPE:")
print("=" * 65)
domain_total = 0
gpt2_total = 0

for term in medical_terms:
    domain_tokens = encode_bpe(term, merge_rules)
    gpt2_tokens = enc.encode(term)
    domain_total += len(domain_tokens)
    gpt2_total += len(gpt2_tokens)
    improvement = len(gpt2_tokens) / len(domain_tokens)
    padding = " " * (30 - len(term))
    print(f"  {term}{padding} Domain: {len(domain_tokens):2d}  GPT-2: {len(gpt2_tokens):2d}  ({improvement:.1f}x)")

print(f"\nAverage tokens per term:")
print(f"  Domain BPE: {domain_total / len(medical_terms):.1f}")
print(f"  GPT-2 BPE:  {gpt2_total / len(medical_terms):.1f}")
print(f"  Improvement: {gpt2_total / domain_total:.1f}x fewer tokens")

## 3. Sequence Packing for Variable-Length Reports

In [None]:
# Simulate report length distribution
np.random.seed(42)
report_lengths = np.concatenate([
    np.random.normal(80, 20, 5000).astype(int),     # Short (X-ray)
    np.random.normal(350, 80, 3000).astype(int),     # Medium (CT)
    np.random.normal(1200, 200, 1000).astype(int),   # Long (complex)
])
report_lengths = np.clip(report_lengths, 20, 2000)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Length distribution
axes[0].hist(report_lengths, bins=50, color='#3498db', edgecolor='black',
             linewidth=0.5, alpha=0.7)
axes[0].axvline(x=512, color='red', linestyle='--', linewidth=2, label='Context length (512)')
axes[0].set_xlabel('Report Length (tokens)', fontsize=12)
axes[0].set_ylabel('Count', fontsize=12)
axes[0].set_title('Radiology Report Length Distribution', fontsize=13, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(alpha=0.3)

# Padding waste without packing
context_length = 512
naive_tokens = len(report_lengths) * context_length
real_tokens = np.minimum(report_lengths, context_length).sum()
waste_pct = (1 - real_tokens / naive_tokens) * 100

# With packing
packed_tokens = sum(report_lengths)
packed_sequences = math.ceil(packed_tokens / context_length)
packed_total = packed_sequences * context_length
packed_efficiency = packed_tokens / packed_total * 100

labels = ['Naive Padding', 'Sequence Packing']
efficiencies = [100 - waste_pct, packed_efficiency]
colors_bar = ['#e74c3c', '#2ecc71']
bars = axes[1].bar(labels, efficiencies, color=colors_bar, edgecolor='black', linewidth=0.5)
axes[1].set_ylabel('Token Utilization (%)', fontsize=12)
axes[1].set_title('Packing Efficiency', fontsize=13, fontweight='bold')
axes[1].set_ylim(0, 100)
axes[1].grid(axis='y', alpha=0.3)

for bar, eff in zip(bars, efficiencies):
    axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
                f'{eff:.1f}%', ha='center', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.savefig('packing_efficiency.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"Naive padding efficiency: {100 - waste_pct:.1f}%")
print(f"Sequence packing efficiency: {packed_efficiency:.1f}%")
print(f"Packing saves {(packed_total - naive_tokens) / naive_tokens * 100:.1f}% compute!")

In [None]:
# TODO: Implement the sequence packer

class SequencePacker:
    """
    Pack variable-length reports into fixed-size sequences.
    Uses a separator token between reports and generates attention masks.
    """

    def __init__(self, context_length=512, separator_token_id=0):
        self.context_length = context_length
        self.separator_id = separator_token_id

    def pack(self, reports):
        """
        Pack a list of token sequences into fixed-length packed sequences.

        Args:
            reports: List of lists of token IDs

        Returns:
            packed_sequences: List of (tokens, mask) tuples
            Each tokens is length context_length
            Each mask is length context_length (1 = real token, 0 = padding)
        """
        packed = []
        current_tokens = []
        current_mask = []

        for report in reports:
            # Add separator if not at start of sequence
            if current_tokens:
                current_tokens.append(self.separator_id)
                current_mask.append(0)  # Don't compute loss on separator

            # Check if report fits in current sequence
            if len(current_tokens) + len(report) > self.context_length:
                # Pad and save current sequence
                if current_tokens:
                    pad_len = self.context_length - len(current_tokens)
                    current_tokens.extend([0] * pad_len)
                    current_mask.extend([0] * pad_len)
                    packed.append((current_tokens, current_mask))
                current_tokens = []
                current_mask = []

            # Handle reports longer than context_length
            if len(report) > self.context_length:
                # Split into chunks
                for i in range(0, len(report), self.context_length):
                    chunk = report[i:i + self.context_length]
                    pad_len = self.context_length - len(chunk)
                    packed.append((
                        chunk + [0] * pad_len,
                        [1] * len(chunk) + [0] * pad_len
                    ))
                continue

            current_tokens.extend(report)
            current_mask.extend([1] * len(report))

        # Save last sequence
        if current_tokens:
            pad_len = self.context_length - len(current_tokens)
            current_tokens.extend([0] * pad_len)
            current_mask.extend([0] * pad_len)
            packed.append((current_tokens, current_mask))

        return packed

# Test the packer
packer = SequencePacker(context_length=512)

# Generate synthetic reports of various lengths
np.random.seed(42)
synthetic_reports = [
    list(np.random.randint(1, 1000, size=length))
    for length in np.random.choice([60, 80, 100, 300, 500, 800, 1200], size=100)
]

packed = packer.pack(synthetic_reports)
total_real = sum(sum(mask) for _, mask in packed)
total_positions = len(packed) * 512

print(f"Reports: {len(synthetic_reports)}")
print(f"Packed sequences: {len(packed)}")
print(f"Packing efficiency: {total_real / total_positions * 100:.1f}%")
print(f"Without packing: {len(synthetic_reports)} sequences of 512 each")
print(f"Saved: {(len(synthetic_reports) - len(packed)) / len(synthetic_reports) * 100:.1f}% fewer sequences")

## 4. Training Loop with Medical-Specific Optimizations

In [None]:
# Build a simplified training pipeline for demonstration

class SimpleTransformerLM(nn.Module):
    """Small Transformer LM for case study demonstration."""

    def __init__(self, vocab_size, d_model=256, n_heads=4, n_layers=4,
                 context_length=128, dropout=0.1):
        super().__init__()
        self.context_length = context_length
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(context_length, d_model)
        self.dropout = nn.Dropout(dropout)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=d_model * 4,
            dropout=dropout, batch_first=True, norm_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        self.head.weight = self.token_emb.weight

        self.register_buffer('causal_mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1).bool())

    def forward(self, x):
        B, T = x.shape
        pos = torch.arange(T, device=x.device).unsqueeze(0)
        x = self.dropout(self.token_emb(x) + self.pos_emb(pos))
        x = self.transformer(x, mask=self.causal_mask[:T, :T], is_causal=True)
        return self.head(self.ln_f(x))

# Create model and training data
VOCAB_SIZE = 5000  # Simplified for demo
CONTEXT_LENGTH = 128
BATCH_SIZE = 16

# Simulate medical tokenized data
torch.manual_seed(42)
num_tokens = 200000
# Zipf distribution mimics real text (few common words, many rare ones)
token_probs = 1.0 / np.arange(1, VOCAB_SIZE + 1)
token_probs /= token_probs.sum()
all_tokens = np.random.choice(VOCAB_SIZE, size=num_tokens, p=token_probs)

# Dataset
class MaskedTextDataset(Dataset):
    def __init__(self, tokens, masks, context_length):
        self.tokens = torch.tensor(tokens, dtype=torch.long)
        self.masks = torch.tensor(masks, dtype=torch.float)
        self.context_length = context_length

    def __len__(self):
        return len(self.tokens) - self.context_length

    def __getitem__(self, idx):
        x = self.tokens[idx:idx + self.context_length]
        y = self.tokens[idx + 1:idx + self.context_length + 1]
        m = self.masks[idx + 1:idx + self.context_length + 1]
        return x, y, m

masks = np.ones(num_tokens, dtype=np.float32)
# Simulate separator tokens every ~100 tokens
for i in range(0, num_tokens, 100):
    masks[i] = 0.0

split = int(0.9 * num_tokens)
train_ds = MaskedTextDataset(all_tokens[:split], masks[:split], CONTEXT_LENGTH)
val_ds = MaskedTextDataset(all_tokens[split:], masks[split:], CONTEXT_LENGTH)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

model = SimpleTransformerLM(VOCAB_SIZE, d_model=256, n_heads=4,
                             n_layers=4, context_length=CONTEXT_LENGTH).to(device)

num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params:,}")
print(f"Train samples: {len(train_ds):,}")
print(f"Val samples: {len(val_ds):,}")

In [None]:
# Training with all stability mechanisms
def get_lr(step, warmup_steps, total_steps, lr_max=3e-4, lr_min=1e-5):
    if step < warmup_steps:
        return lr_max * step / max(1, warmup_steps)
    progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
    return lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * min(progress, 1.0)))

def masked_cross_entropy(logits, targets, mask):
    """Cross-entropy loss with mask for packed sequences."""
    B, T, V = logits.shape
    loss_per_token = F.cross_entropy(
        logits.view(-1, V), targets.view(-1), reduction='none'
    ).view(B, T)

    # Apply mask: only compute loss on real tokens
    masked_loss = (loss_per_token * mask).sum() / mask.sum().clamp(min=1)
    return masked_loss

# Training loop
EPOCHS = 10
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4,
                                betas=(0.9, 0.95), weight_decay=0.1)
total_steps = len(train_loader) * EPOCHS
warmup_steps = int(0.1 * total_steps)

train_losses = []
val_losses = []
learning_rates = []
grad_norms = []

print(f"Training for {EPOCHS} epochs ({total_steps} steps)")
print("=" * 70)

step = 0
for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0

    for batch_x, batch_y, batch_m in train_loader:
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)
        batch_m = batch_m.to(device)

        logits = model(batch_x)
        loss = masked_cross_entropy(logits, batch_y, batch_m)

        optimizer.zero_grad()
        loss.backward()

        total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        grad_norms.append(total_norm.item())

        lr = get_lr(step, warmup_steps, total_steps)
        for pg in optimizer.param_groups:
            pg['lr'] = lr
        learning_rates.append(lr)

        optimizer.step()
        train_losses.append(loss.item())
        epoch_loss += loss.item()
        step += 1

    # Validation
    model.eval()
    val_loss_sum = 0
    val_count = 0
    with torch.no_grad():
        for vx, vy, vm in val_loader:
            vx, vy, vm = vx.to(device), vy.to(device), vm.to(device)
            vlogits = model(vx)
            vloss = masked_cross_entropy(vlogits, vy, vm)
            val_loss_sum += vloss.item()
            val_count += 1

    avg_val = val_loss_sum / max(val_count, 1)
    val_losses.append(avg_val)
    avg_train = epoch_loss / len(train_loader)
    ppl = math.exp(avg_val)

    print(f"Epoch {epoch+1:>2d}/{EPOCHS} | Train: {avg_train:.4f} | Val: {avg_val:.4f} | "
          f"PPL: {ppl:.1f} | LR: {lr:.2e} | Grad Norm: {np.mean(grad_norms[-len(train_loader):]):.2f}")

## 5. Training Stability Analysis

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Training loss
ax = axes[0, 0]
ax.plot(train_losses, linewidth=0.5, alpha=0.3, color='#3498db')
window = max(1, len(train_losses) // 20)
smoothed = np.convolve(train_losses, np.ones(window)/window, mode='valid')
ax.plot(range(window-1, len(train_losses)), smoothed, linewidth=2, color='#2c3e50')
ax.set_ylabel('Loss')
ax.set_title('Training Loss', fontweight='bold')
ax.grid(alpha=0.3)

# Validation loss + perplexity
ax = axes[0, 1]
ax2 = ax.twinx()
epochs_x = range(1, len(val_losses) + 1)
ax.plot(epochs_x, val_losses, 'o-', color='#e74c3c', linewidth=2, label='Val Loss')
ppls = [math.exp(v) for v in val_losses]
ax2.plot(epochs_x, ppls, 's--', color='#9b59b6', linewidth=2, label='Perplexity')
ax.set_ylabel('Validation Loss', color='#e74c3c')
ax2.set_ylabel('Perplexity', color='#9b59b6')
ax.set_title('Validation Metrics', fontweight='bold')
ax.grid(alpha=0.3)

# Learning rate
axes[1, 0].plot(learning_rates, linewidth=2, color='#2ecc71')
axes[1, 0].set_ylabel('Learning Rate')
axes[1, 0].set_title('LR Schedule', fontweight='bold')
axes[1, 0].grid(alpha=0.3)

# Gradient norms
axes[1, 1].plot(grad_norms, linewidth=0.5, alpha=0.5, color='#e67e22')
axes[1, 1].axhline(y=1.0, color='black', linestyle='--', linewidth=1)
axes[1, 1].set_ylabel('Gradient Norm')
axes[1, 1].set_title('Gradient Norms', fontweight='bold')
axes[1, 1].grid(alpha=0.3)

for ax in axes.flat:
    ax.set_xlabel('Step' if ax in axes[1] else '')

plt.suptitle('Aethon Health -- Training Pipeline Monitoring Dashboard', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('training_dashboard.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nFinal validation loss: {val_losses[-1]:.4f}")
print(f"Final perplexity: {math.exp(val_losses[-1]):.1f}")
print(f"Gradient spikes (>5x mean): {sum(1 for g in grad_norms if g > 5 * np.mean(grad_norms))}")
print(f"Training stable: {'Yes' if max(grad_norms) < 10 * np.mean(grad_norms) else 'No - investigate!'}")

## 6. Results and Business Impact

In [None]:
# Summary
print("=" * 60)
print("  AETHON HEALTH -- TRAINING PIPELINE RESULTS")
print("=" * 60)
print()
print(f"  Model: {num_params:,} parameters")
print(f"  Training epochs: {EPOCHS}")
print(f"  Final val loss: {val_losses[-1]:.4f}")
print(f"  Final perplexity: {math.exp(val_losses[-1]):.1f}")
print()
print("  Pipeline Components:")
print("  [x] Domain-specific BPE tokenizer")
print("  [x] Sequence packing (>90% efficiency)")
print("  [x] Masked cross-entropy loss")
print("  [x] AdamW optimizer (beta2=0.95)")
print("  [x] Warmup + cosine decay")
print("  [x] Gradient clipping (max_norm=1.0)")
print()
print("  Key Insight: The training pipeline -- not the")
print("  architecture -- is where the engineering challenge lies.")
print("  Domain tokenization alone reduced sequence length by ~2x,")
print("  and packing eliminated 40%+ wasted compute.")
print("=" * 60)