In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Grounded Radiology Report Generation — Implementation Notebook

**MedSight AI Case Study** | Cross-Attention for Medical Image Captioning

This notebook implements a grounded radiology report generation system using cross-attention. We build a simplified version of the MedSight AI pipeline that takes chest X-ray-like images and generates text reports with visualizable attention grounding.

In [None]:
# Setup and imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import math
from torch.utils.data import Dataset, DataLoader

%matplotlib inline
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
torch.manual_seed(42)
np.random.seed(42)

## 1. Synthetic Medical Image Dataset

We create a synthetic dataset that mimics chest X-ray + report pairs. Each "image" has patches with specific patterns representing findings (nodule, effusion, cardiomegaly), and each "report" describes those findings.

In [None]:
class SyntheticRadiologyDataset(Dataset):
    """
    Synthetic dataset for grounded report generation.
    Each sample has:
    - image_patches: (n_patches, d_vision) - simulated ViT output
    - report_ids: (max_len,) - tokenized report
    - finding_mask: (max_len, n_patches) - which patches each token refers to
    """

    FINDINGS = {
        0: ("normal", "No acute findings. Heart size is normal. Lungs are clear."),
        1: ("nodule", "There is a small nodule in the right upper lobe."),
        2: ("effusion", "There is a moderate left-sided pleural effusion."),
        3: ("cardiomegaly", "The heart is enlarged, consistent with cardiomegaly."),
        4: ("pneumothorax", "Small right-sided pneumothorax is present."),
    }

    # Simplified vocabulary
    VOCAB = ["<pad>", "<bos>", "<eos>", "no", "acute", "findings", "heart",
             "size", "is", "normal", "lungs", "are", "clear", "there", "a",
             "small", "nodule", "in", "the", "right", "upper", "lobe",
             "moderate", "left", "sided", "pleural", "effusion", "enlarged",
             "consistent", "with", "cardiomegaly", "pneumothorax", "present"]

    def __init__(self, n_samples=1000, n_patches=16, d_vision=64, max_len=20):
        self.n_samples = n_samples
        self.n_patches = n_patches
        self.d_vision = d_vision
        self.max_len = max_len
        self.vocab_size = len(self.VOCAB)

        self.word2idx = {w: i for i, w in enumerate(self.VOCAB)}

        # Pre-generate all samples
        self.data = []
        for _ in range(n_samples):
            finding_idx = np.random.randint(0, 5)
            self.data.append(self._generate_sample(finding_idx))

    def _tokenize(self, text):
        tokens = [self.word2idx.get(w.lower().rstrip(".,"), 0)
                  for w in text.split()]
        tokens = [self.word2idx["<bos>"]] + tokens + [self.word2idx["<eos>"]]
        # Pad or truncate
        if len(tokens) < self.max_len:
            tokens += [0] * (self.max_len - len(tokens))
        return tokens[:self.max_len]

    def _generate_sample(self, finding_idx):
        # Generate image patches
        patches = torch.randn(self.n_patches, self.d_vision) * 0.3

        # Finding location (which patches are abnormal)
        finding_name, report_text = self.FINDINGS[finding_idx]

        # Assign finding to specific patches
        if finding_idx == 0:  # normal
            finding_patches = []
        elif finding_idx == 1:  # nodule - right upper (patches 0-3)
            finding_patches = [0, 1]
            patches[0] += torch.randn(self.d_vision) * 2 + 3.0
            patches[1] += torch.randn(self.d_vision) * 1.5 + 2.0
        elif finding_idx == 2:  # effusion - left lower (patches 12-15)
            finding_patches = [12, 13, 14]
            for p in finding_patches:
                patches[p] += torch.randn(self.d_vision) * 2 + 2.5
        elif finding_idx == 3:  # cardiomegaly - center (patches 5-10)
            finding_patches = [5, 6, 9, 10]
            for p in finding_patches:
                patches[p] += torch.randn(self.d_vision) * 1.5 + 2.0
        elif finding_idx == 4:  # pneumothorax - right (patches 3, 7)
            finding_patches = [3, 7]
            for p in finding_patches:
                patches[p] += torch.randn(self.d_vision) * 2.5 + 3.5

        # Tokenize report
        token_ids = self._tokenize(report_text)

        # Create grounding mask (which patches each token should attend to)
        finding_mask = torch.zeros(self.max_len, self.n_patches)
        if finding_patches:
            # Finding-related tokens should attend to finding patches
            for t_idx in range(len(token_ids)):
                if token_ids[t_idx] > 2:  # skip special tokens
                    for p in finding_patches:
                        finding_mask[t_idx, p] = 1.0

        return {
            "patches": patches,
            "token_ids": torch.tensor(token_ids, dtype=torch.long),
            "finding_mask": finding_mask,
            "finding_idx": finding_idx,
            "finding_name": finding_name,
        }

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        return self.data[idx]

# Create dataset
dataset = SyntheticRadiologyDataset(n_samples=2000)
train_set = torch.utils.data.Subset(dataset, range(1600))
val_set = torch.utils.data.Subset(dataset, range(1600, 2000))

train_loader = DataLoader(train_set, batch_size=32, shuffle=True,
                          collate_fn=lambda batch: {
                              k: torch.stack([b[k] for b in batch]) if isinstance(batch[0][k], torch.Tensor)
                              else [b[k] for b in batch]
                              for k in batch[0]
                          })

print(f"Dataset: {len(dataset)} samples")
print(f"  Train: {len(train_set)}, Val: {len(val_set)}")
print(f"  Patches per image: {dataset.n_patches}")
print(f"  Vocab size: {dataset.vocab_size}")
print(f"  Max report length: {dataset.max_len}")

## 2. Exploratory Data Analysis

In [None]:
# Analyze finding distribution
finding_counts = {}
for sample in dataset.data:
    name = sample["finding_name"]
    finding_counts[name] = finding_counts.get(name, 0) + 1

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Finding distribution
names = list(finding_counts.keys())
counts = list(finding_counts.values())
axes[0].bar(names, counts, color=['#2ecc71', '#e74c3c', '#3498db', '#f39c12', '#9b59b6'])
axes[0].set_title('Finding Distribution', fontsize=13, fontweight='bold')
axes[0].set_ylabel('Count')
axes[0].tick_params(axis='x', rotation=30)

# Patch activation patterns
finding_examples = {0: None, 1: None, 2: None, 3: None, 4: None}
for s in dataset.data:
    idx = s["finding_idx"]
    if finding_examples[idx] is None:
        finding_examples[idx] = s["patches"].norm(dim=-1).numpy()

patch_grid = np.stack([finding_examples[i].reshape(4, 4) for i in range(5)])
for i, name in enumerate(dataset.FINDINGS.values()):
    if i == 0:
        axes[1].plot(finding_examples[i], label=name[0], alpha=0.7)
    else:
        axes[1].plot(finding_examples[i], label=name[0], alpha=0.7)

axes[1].set_title('Patch Activation by Finding Type', fontsize=13, fontweight='bold')
axes[1].set_xlabel('Patch Index')
axes[1].set_ylabel('Patch Norm')
axes[1].legend(fontsize=9)

plt.tight_layout()
plt.show()

## 3. Model Architecture

In [None]:
class GroundedReportVLM(nn.Module):
    """
    Simplified VLM for grounded radiology report generation.
    """
    def __init__(self, d_vision, d_model, num_heads, num_layers,
                 vocab_size, max_len):
        super().__init__()

        # Token alignment
        self.image_proj = nn.Linear(d_vision, d_model)

        # Text embeddings
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Embedding(max_len, d_model)

        # Transformer decoder blocks
        self.blocks = nn.ModuleList()
        for _ in range(num_layers):
            self.blocks.append(nn.ModuleDict({
                'self_attn': nn.MultiheadAttention(d_model, num_heads, batch_first=True),
                'cross_attn_q': nn.Linear(d_model, d_model),
                'cross_attn_k': nn.Linear(d_model, d_model),
                'cross_attn_v': nn.Linear(d_model, d_model),
                'cross_attn_out': nn.Linear(d_model, d_model),
                'ffn': nn.Sequential(
                    nn.Linear(d_model, d_model * 4),
                    nn.GELU(),
                    nn.Linear(d_model * 4, d_model),
                ),
                'norm1': nn.LayerNorm(d_model),
                'norm2': nn.LayerNorm(d_model),
                'norm3': nn.LayerNorm(d_model),
            }))

        self.output_head = nn.Linear(d_model, vocab_size)
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.d_model = d_model

    def cross_attention(self, block, text, image):
        """Manual cross-attention to extract weights."""
        B, T, D = text.shape
        _, N, _ = image.shape

        Q = block['cross_attn_q'](text)   # (B, T, D)
        K = block['cross_attn_k'](image)  # (B, N, D)
        V = block['cross_attn_v'](image)  # (B, N, D)

        # Reshape for multi-head
        Q = Q.view(B, T, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(B, N, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(B, N, self.num_heads, self.d_k).transpose(1, 2)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn_weights = F.softmax(scores, dim=-1)  # (B, H, T, N)

        context = torch.matmul(attn_weights, V)
        context = context.transpose(1, 2).contiguous().view(B, T, self.d_model)
        output = block['cross_attn_out'](context)

        return output, attn_weights

    def forward(self, image_patches, token_ids):
        """
        Args:
            image_patches: (B, N, d_vision)
            token_ids: (B, T)
        Returns:
            logits: (B, T, vocab_size)
            all_attn: list of (B, H, T, N) per layer
        """
        B, T = token_ids.shape

        # Embed and project
        image_tokens = self.image_proj(image_patches)  # (B, N, d_model)
        positions = torch.arange(T, device=token_ids.device).unsqueeze(0)
        text_tokens = self.token_embedding(token_ids) + self.pos_embedding(positions)

        all_attn = []
        for block in self.blocks:
            # Self-attention
            normed = block['norm1'](text_tokens)
            sa_out, _ = block['self_attn'](normed, normed, normed)
            text_tokens = text_tokens + sa_out

            # Cross-attention
            normed = block['norm2'](text_tokens)
            ca_out, attn = self.cross_attention(block, normed, image_tokens)
            text_tokens = text_tokens + ca_out
            all_attn.append(attn)

            # FFN
            normed = block['norm3'](text_tokens)
            text_tokens = text_tokens + block['ffn'](normed)

        logits = self.output_head(text_tokens)
        return logits, all_attn

# Initialize model
model = GroundedReportVLM(
    d_vision=64, d_model=64, num_heads=4, num_layers=2,
    vocab_size=dataset.vocab_size, max_len=dataset.max_len
).to(device)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## 4. Training with Grounding Loss

In [None]:
# TODO: Implement the combined loss function
def compute_loss(logits, token_ids, attn_weights, finding_masks,
                 lambda_ground=0.1, lambda_entropy=0.01):
    """
    Compute combined training loss.

    Args:
        logits: (B, T, vocab_size)
        token_ids: (B, T) ground truth tokens
        attn_weights: list of (B, H, T, N) per layer
        finding_masks: (B, T, N) ground truth grounding
        lambda_ground: weight for grounding loss
        lambda_entropy: weight for entropy regularization

    Returns:
        total_loss, ce_loss, ground_loss, entropy_loss

    TODO:
    1. Cross-entropy loss (shifted by 1 for next-token prediction)
    2. Grounding loss (BCE between avg attention and masks)
    3. Entropy regularization on attention weights
    """
    # ============ YOUR CODE HERE ============
    # Cross-entropy (teacher forcing)
    shift_logits = logits[:, :-1].contiguous()
    shift_targets = token_ids[:, 1:].contiguous()
    ce_loss = F.cross_entropy(shift_logits.view(-1, logits.size(-1)),
                               shift_targets.view(-1), ignore_index=0)

    # Grounding loss
    last_attn = attn_weights[-1]  # Use last layer
    avg_attn = last_attn.mean(dim=1)  # Average across heads (B, T, N)

    mask_exists = finding_masks.sum(dim=-1) > 0  # (B, T)
    if mask_exists.any():
        attn_for_ground = avg_attn[mask_exists]
        masks_for_ground = finding_masks[mask_exists]
        ground_loss = F.binary_cross_entropy(
            attn_for_ground.clamp(1e-6, 1-1e-6),
            masks_for_ground / (masks_for_ground.sum(dim=-1, keepdim=True) + 1e-6)
        )
    else:
        ground_loss = torch.tensor(0.0, device=logits.device)

    # Entropy regularization
    entropy = -(avg_attn * (avg_attn + 1e-8).log()).sum(dim=-1).mean()

    total = ce_loss + lambda_ground * ground_loss + lambda_entropy * entropy
    # ========================================

    return total, ce_loss, ground_loss, entropy

In [None]:
# Training loop
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

train_losses = []
val_losses = []

for epoch in range(15):
    model.train()
    epoch_loss = 0
    n_batches = 0

    for batch in train_loader:
        patches = batch["patches"].to(device)
        token_ids = batch["token_ids"].to(device)
        masks = batch["finding_mask"].to(device)

        logits, attn = model(patches, token_ids)
        total_loss, ce, gnd, ent = compute_loss(logits, token_ids, attn, masks)

        optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        epoch_loss += total_loss.item()
        n_batches += 1

    avg_loss = epoch_loss / n_batches
    train_losses.append(avg_loss)

    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1}: loss={avg_loss:.4f}, CE={ce.item():.4f}, "
              f"Ground={gnd.item():.4f}, Entropy={ent.item():.4f}")

# Plot training curve
plt.figure(figsize=(8, 4))
plt.plot(train_losses, 'b-', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss', fontsize=13, fontweight='bold')
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()

## 5. Evaluation and Attention Visualization

In [None]:
# Evaluate on validation set and visualize grounding
model.eval()

# Pick one example per finding type
fig, axes = plt.subplots(5, 3, figsize=(15, 20))

for finding_idx in range(5):
    # Find a sample with this finding
    sample = None
    for s in val_set:
        if s["finding_idx"] == finding_idx:
            sample = s
            break

    if sample is None:
        continue

    patches = sample["patches"].unsqueeze(0).to(device)
    token_ids = sample["token_ids"].unsqueeze(0).to(device)

    with torch.no_grad():
        logits, attn = model(patches, token_ids)

    # Get attention from last layer, averaged across heads
    avg_attn = attn[-1][0].mean(dim=0).cpu().numpy()  # (T, N)

    # Plot 1: Patch norms (image representation)
    patch_norms = sample["patches"].norm(dim=-1).numpy().reshape(4, 4)
    axes[finding_idx, 0].imshow(patch_norms, cmap='gray')
    axes[finding_idx, 0].set_title(f'{sample["finding_name"].upper()}', fontsize=11, fontweight='bold')
    axes[finding_idx, 0].set_ylabel('Image Patches')

    # Plot 2: Attention heatmap
    im = axes[finding_idx, 1].imshow(avg_attn[:10], cmap='YlOrRd', aspect='auto')
    axes[finding_idx, 1].set_title('Cross-Attention Weights', fontsize=11)
    axes[finding_idx, 1].set_xlabel('Image Patch')
    axes[finding_idx, 1].set_ylabel('Text Token')

    # Plot 3: Average attention per patch (grounding map)
    mean_attn_per_patch = avg_attn[:10].mean(axis=0).reshape(4, 4)
    axes[finding_idx, 2].imshow(mean_attn_per_patch, cmap='hot')
    axes[finding_idx, 2].set_title('Grounding Map', fontsize=11)

plt.suptitle('Grounded Report Generation: Attention Visualization by Finding Type',
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# Final summary statistics
print("=" * 60)
print("MedSight AI — Grounded Report Generator — Summary")
print("=" * 60)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Training epochs: 15")
print(f"Final training loss: {train_losses[-1]:.4f}")
print(f"Findings covered: {list(dataset.FINDINGS.keys())}")
print(f"\nThe cross-attention mechanism provides built-in grounding:")
print("- Each text token's attention weights show which image patches it queries")
print("- The grounding loss supervises these weights to match clinical ground truth")
print("- This satisfies the FDA explainability requirement for clinical AI")