In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Clinical Note Classification with Self-Attention -- Implementation Notebook

*Vizuara Case Study: MedScribe Analytics*

In this notebook, we implement the self-attention-based clinical note classifier described in the case study. We will build a Transformer encoder from scratch, train it on synthetic clinical notes, and analyze what the attention heads learn.

## Setup

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils.data import DataLoader, TensorDataset
from collections import Counter

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

## 1. Synthetic Clinical Notes Dataset

We create a synthetic dataset that mimics the structure of real clinical notes with diagnostic codes.

In [None]:
# Medical vocabulary with semantic groupings
VOCAB = ['<PAD>', '<UNK>', '<CLS>']

# Symptom/diagnosis words grouped by category
CARDIAC_TERMS = ['chest', 'pain', 'cardiac', 'heart', 'failure', 'arrhythmia',
                 'palpitations', 'hypertension', 'blood', 'pressure', 'elevated',
                 'ecg', 'abnormal', 'murmur', 'edema', 'dyspnea']
RESPIRATORY_TERMS = ['cough', 'breath', 'shortness', 'pneumonia', 'lung',
                     'respiratory', 'wheezing', 'sputum', 'oxygen', 'saturation',
                     'bronchitis', 'pleurisy', 'crackles', 'consolidation']
RENAL_TERMS = ['kidney', 'renal', 'creatinine', 'dialysis', 'urine',
               'proteinuria', 'gfr', 'nephropathy', 'bladder', 'urinary']
DIABETES_TERMS = ['diabetes', 'glucose', 'insulin', 'hba1c', 'hyperglycemia',
                  'diabetic', 'neuropathy', 'retinopathy', 'metformin', 'sugar']
NEGATION_TERMS = ['no', 'not', 'denies', 'absent', 'negative', 'without',
                  'ruled', 'out', 'unlikely', 'excluded']
GENERAL_TERMS = ['patient', 'presents', 'with', 'history', 'of', 'the',
                 'and', 'is', 'was', 'has', 'reports', 'examination',
                 'shows', 'diagnosis', 'treatment', 'assessment', 'plan',
                 'follow', 'up', 'prescribed', 'admitted', 'discharged',
                 'stable', 'condition', 'chronic', 'acute', 'mild',
                 'moderate', 'severe', 'bilateral', 'left', 'right']

VOCAB.extend(CARDIAC_TERMS + RESPIRATORY_TERMS + RENAL_TERMS +
             DIABETES_TERMS + NEGATION_TERMS + GENERAL_TERMS)
VOCAB = list(dict.fromkeys(VOCAB))  # Remove duplicates preserving order
word2idx = {w: i for i, w in enumerate(VOCAB)}
vocab_size = len(VOCAB)

# Diagnostic code labels
CODE_NAMES = ['Cardiac', 'Respiratory', 'Renal', 'Diabetes']
num_classes = len(CODE_NAMES)

print(f"Vocabulary size: {vocab_size}")
print(f"Number of diagnostic codes: {num_classes}")

In [None]:
def generate_clinical_note(max_len=64):
    """Generate a synthetic clinical note with associated diagnostic codes."""
    labels = np.zeros(num_classes, dtype=np.float32)

    # Randomly select 1-3 diagnoses
    n_diagnoses = random.randint(1, 3)
    active_diagnoses = random.sample(range(num_classes), n_diagnoses)

    tokens = [word2idx['<CLS>']]  # Start with CLS token

    # Generate text for each diagnosis
    term_groups = [CARDIAC_TERMS, RESPIRATORY_TERMS, RENAL_TERMS, DIABETES_TERMS]

    for diag_idx in active_diagnoses:
        # Decide if this diagnosis is negated (20% chance)
        is_negated = random.random() < 0.2

        if is_negated:
            # Add negation phrase
            neg_word = random.choice(NEGATION_TERMS[:6])  # 'no', 'not', 'denies', etc.
            tokens.append(word2idx.get(neg_word, 1))
            # Don't set the label for negated diagnoses
        else:
            labels[diag_idx] = 1.0

        # Add general context
        for _ in range(random.randint(2, 4)):
            gen_word = random.choice(GENERAL_TERMS)
            tokens.append(word2idx.get(gen_word, 1))

        # Add diagnosis-specific terms
        diag_terms = term_groups[diag_idx]
        for _ in range(random.randint(3, 6)):
            term = random.choice(diag_terms)
            tokens.append(word2idx.get(term, 1))

        # Add more general context
        for _ in range(random.randint(1, 3)):
            gen_word = random.choice(GENERAL_TERMS)
            tokens.append(word2idx.get(gen_word, 1))

    # Pad or truncate to max_len
    if len(tokens) > max_len:
        tokens = tokens[:max_len]
    else:
        tokens = tokens + [0] * (max_len - len(tokens))

    return tokens, labels

# Generate dataset
n_train, n_val, n_test = 2000, 300, 300
max_len = 64

all_data = [generate_clinical_note(max_len) for _ in range(n_train + n_val + n_test)]
all_tokens = torch.tensor([d[0] for d in all_data], dtype=torch.long)
all_labels = torch.tensor(np.array([d[1] for d in all_data]), dtype=torch.float32)

train_X, val_X, test_X = all_tokens[:n_train], all_tokens[n_train:n_train+n_val], all_tokens[n_train+n_val:]
train_y, val_y, test_y = all_labels[:n_train], all_labels[n_train:n_train+n_val], all_labels[n_train+n_val:]

train_loader = DataLoader(TensorDataset(train_X, train_y), batch_size=32, shuffle=True)
val_loader = DataLoader(TensorDataset(val_X, val_y), batch_size=32)
test_loader = DataLoader(TensorDataset(test_X, test_y), batch_size=32)

print(f"Train: {len(train_X)}, Val: {len(val_X)}, Test: {len(test_X)}")
print(f"Sequence length: {max_len}")
print(f"Label distribution (train):")
for i, name in enumerate(CODE_NAMES):
    count = train_y[:, i].sum().item()
    print(f"  {name}: {count:.0f} ({count/len(train_y)*100:.1f}%)")

## 2. Model Architecture

### 2.1 Sinusoidal Positional Encoding

In [None]:
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

### 2.2 Multi-Head Self-Attention

In [None]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        B, T, C = x.size()
        Q = self.W_q(x).view(B, T, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(B, T, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(B, T, self.num_heads, self.d_k).transpose(1, 2)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn_weights = self.dropout(F.softmax(scores, dim=-1))
        context = torch.matmul(attn_weights, V)

        context = context.transpose(1, 2).contiguous().view(B, T, self.d_model)
        return self.W_o(context), attn_weights

### 2.3 Transformer Block

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadSelfAttention(d_model, num_heads, dropout)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn_out, attn_weights = self.attention(x, mask)
        x = self.norm1(x + self.dropout1(attn_out))
        ffn_out = self.ffn(x)
        x = self.norm2(x + self.dropout2(ffn_out))
        return x, attn_weights

### 2.4 Clinical Note Classifier

In [None]:
class ClinicalNoteClassifier(nn.Module):
    """
    Transformer-based multi-label classifier for clinical notes.
    """
    def __init__(self, vocab_size, d_model=128, num_heads=4, d_ff=512,
                 num_layers=2, num_classes=4, max_len=128, dropout=0.15):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_len)
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.classifier = nn.Linear(d_model, num_classes)
        self.dropout = nn.Dropout(dropout)
        self.d_model = d_model

    def forward(self, x):
        # Create padding mask
        pad_mask = (x != 0).unsqueeze(1).unsqueeze(2)  # (B, 1, 1, T)

        # Embedding + PE
        x = self.embedding(x) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        x = self.dropout(x)

        # Transformer blocks
        all_attn = []
        for block in self.blocks:
            x, attn = block(x, pad_mask)
            all_attn.append(attn)

        # Mean pooling (excluding padding)
        mask_float = (pad_mask.squeeze(1).squeeze(1)).unsqueeze(-1).float()
        x_pooled = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1).clamp(min=1)

        logits = self.classifier(x_pooled)
        return logits, all_attn

# Instantiate
model = ClinicalNoteClassifier(
    vocab_size=vocab_size, d_model=128, num_heads=4,
    d_ff=512, num_layers=2, num_classes=num_classes,
    max_len=max_len, dropout=0.15
).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params:,}")
print(f"Architecture: {2} layers, {4} heads, d_model={128}")

## 3. Training

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = nn.BCEWithLogitsLoss()

train_losses = []
val_f1_scores = []

num_epochs = 25

for epoch in range(num_epochs):
    # Training
    model.train()
    epoch_loss = 0
    n_batches = 0
    for batch_X, batch_y in train_loader:
        batch_X, batch_y = batch_X.to(device), batch_y.to(device)
        logits, _ = model(batch_X)
        loss = criterion(logits, batch_y)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        epoch_loss += loss.item()
        n_batches += 1

    train_losses.append(epoch_loss / n_batches)

    # Validation
    model.eval()
    all_preds = []
    all_true = []
    with torch.no_grad():
        for batch_X, batch_y in val_loader:
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            logits, _ = model(batch_X)
            preds = (torch.sigmoid(logits) > 0.5).float()
            all_preds.append(preds.cpu())
            all_true.append(batch_y.cpu())

    all_preds = torch.cat(all_preds).numpy()
    all_true = torch.cat(all_true).numpy()
    micro_f1 = f1_score(all_true, all_preds, average='micro', zero_division=0)
    val_f1_scores.append(micro_f1)

    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1:3d}: loss={train_losses[-1]:.4f}  val_F1={micro_f1:.4f}")

In [None]:
# Training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(train_losses, color='#e74c3c', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('BCE Loss', fontsize=12)
ax1.set_title('Training Loss', fontsize=14, fontweight='bold')
ax1.grid(True, alpha=0.3)

ax2.plot(val_f1_scores, color='#2ecc71', linewidth=2)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Micro F1 Score', fontsize=12)
ax2.set_title('Validation F1 Score', fontsize=14, fontweight='bold')
ax2.set_ylim(0, 1.05)
ax2.grid(True, alpha=0.3)

plt.suptitle('Clinical Note Classifier Training', fontsize=15, fontweight='bold')
plt.tight_layout()
plt.show()

## 4. Evaluation

In [None]:
# Test set evaluation
model.eval()
all_preds = []
all_true = []
all_probs = []

with torch.no_grad():
    for batch_X, batch_y in test_loader:
        batch_X, batch_y = batch_X.to(device), batch_y.to(device)
        logits, _ = model(batch_X)
        probs = torch.sigmoid(logits)
        preds = (probs > 0.5).float()
        all_preds.append(preds.cpu())
        all_true.append(batch_y.cpu())
        all_probs.append(probs.cpu())

all_preds = torch.cat(all_preds).numpy()
all_true = torch.cat(all_true).numpy()
all_probs = torch.cat(all_probs).numpy()

# Per-class metrics
print("Per-class F1 Scores:")
print("-" * 40)
for i, name in enumerate(CODE_NAMES):
    f1 = f1_score(all_true[:, i], all_preds[:, i], zero_division=0)
    print(f"  {name:15s}: {f1:.4f}")

micro_f1 = f1_score(all_true, all_preds, average='micro', zero_division=0)
macro_f1 = f1_score(all_true, all_preds, average='macro', zero_division=0)
print(f"\n  Micro F1: {micro_f1:.4f}")
print(f"  Macro F1: {macro_f1:.4f}")

In [None]:
# Per-class F1 bar chart
fig, ax = plt.subplots(figsize=(8, 5))
per_class_f1 = [f1_score(all_true[:, i], all_preds[:, i], zero_division=0) for i in range(num_classes)]
colors = ['#e74c3c', '#3498db', '#2ecc71', '#f39c12']
bars = ax.bar(CODE_NAMES, per_class_f1, color=colors, edgecolor='black', linewidth=1)
ax.set_ylabel('F1 Score', fontsize=12)
ax.set_title('Per-Class F1 Scores on Test Set', fontsize=14, fontweight='bold')
ax.set_ylim(0, 1.05)
for bar, f1 in zip(bars, per_class_f1):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
            f'{f1:.3f}', ha='center', fontsize=11, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()

## 5. Attention Visualization

This is the most important section for clinical interpretability. Let us see what the attention heads learn.

In [None]:
# Select a test sample and visualize attention
idx2word = {v: k for k, v in word2idx.items()}

def visualize_attention(model, tokens, sample_idx=0):
    """Visualize attention patterns for a single clinical note."""
    model.eval()
    with torch.no_grad():
        logits, all_attn = model(tokens.unsqueeze(0).to(device))

    # Get the words (non-padding)
    token_list = tokens.tolist()
    words = [idx2word.get(t, '<UNK>') for t in token_list if t != 0]
    n_words = len(words)

    # Get predictions
    probs = torch.sigmoid(logits[0]).cpu().numpy()
    pred_codes = [CODE_NAMES[i] for i in range(num_classes) if probs[i] > 0.5]

    # Plot attention from the last layer
    layer_attn = all_attn[-1][0].cpu().numpy()  # (num_heads, seq, seq)
    num_heads = layer_attn.shape[0]

    fig, axes = plt.subplots(1, num_heads, figsize=(5*num_heads, 5))
    for h in range(num_heads):
        ax = axes[h]
        # Only show non-padding tokens
        attn_sub = layer_attn[h, :n_words, :n_words]
        im = ax.imshow(attn_sub, cmap='Blues', vmin=0, vmax=attn_sub.max())
        ax.set_xticks(range(n_words))
        ax.set_xticklabels(words, rotation=90, fontsize=7)
        ax.set_yticks(range(n_words))
        ax.set_yticklabels(words, fontsize=7)
        ax.set_title(f'Head {h+1}', fontsize=11, fontweight='bold')

    plt.suptitle(f'Attention Patterns\nPredicted: {", ".join(pred_codes) if pred_codes else "None"}',
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

# Visualize a few samples
for i in range(3):
    true_codes = [CODE_NAMES[j] for j in range(num_classes) if test_y[i, j] > 0.5]
    print(f"\n--- Sample {i+1} (True: {', '.join(true_codes)}) ---")
    token_words = [idx2word.get(t, '?') for t in test_X[i].tolist() if t != 0]
    print(f"Text: {' '.join(token_words[:30])}...")
    visualize_attention(model, test_X[i], i)

## 6. Ablation Study: Impact of Scaling

In [None]:
# TODO: Compare model with and without sqrt(d_k) scaling
# Train two models:
# 1. Standard attention with scaling
# 2. Attention WITHOUT scaling (remove / math.sqrt(self.d_k))
# Compare convergence speed and final accuracy

print("Ablation Study: Effect of sqrt(d_k) Scaling")
print("=" * 50)
print("The standard model already includes scaling.")
print("To test without scaling, modify the attention scores line:")
print("  scores = torch.matmul(Q, K.transpose(-2, -1))  # No scaling")
print("Then retrain and compare the validation F1 curves.")

## 7. Results Summary

In [None]:
print("=" * 60)
print("CASE STUDY RESULTS: Clinical Note Classification")
print("=" * 60)
print(f"\nModel: Transformer Encoder ({2} layers, {4} heads)")
print(f"Parameters: {total_params:,}")
print(f"Training samples: {n_train}")
print(f"Test Micro F1: {micro_f1:.4f}")
print(f"Test Macro F1: {macro_f1:.4f}")
print(f"\nPer-class performance:")
for i, name in enumerate(CODE_NAMES):
    f1 = f1_score(all_true[:, i], all_preds[:, i], zero_division=0)
    print(f"  {name:15s}: F1 = {f1:.4f}")

print(f"\nKey takeaway: Self-attention enables direct modeling of")
print(f"long-range dependencies and negation patterns in clinical text,")
print(f"outperforming sequential models on multi-label classification.")
print("=" * 60)