In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Radiology Report Generation with Multimodal Instruction Tuning -- Implementation Notebook

*Meridian Diagnostic Intelligence Case Study*

In [None]:
# Setup and installations
!pip install -q torch torchvision matplotlib numpy scikit-learn Pillow

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, classification_report, confusion_matrix
import random
import time

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Dataset Creation

We create a synthetic radiology dataset that mimics the structure of MIMIC-CXR.

In [None]:
class SyntheticRadiologyDataset:
    """Synthetic dataset mimicking chest X-ray report generation.

    Each sample has:
    - A synthetic grayscale image with abnormality patterns
    - Multi-label annotations (14 CheXpert-style labels)
    - A structured report text
    """

    FINDINGS = [
        "cardiomegaly", "edema", "consolidation", "atelectasis",
        "pleural_effusion", "pneumothorax", "pneumonia", "nodule",
        "mass", "fracture", "emphysema", "fibrosis",
        "thickening", "hernia"
    ]

    TEMPLATES = {
        "cardiomegaly": "The cardiac silhouette is enlarged, suggesting cardiomegaly.",
        "edema": "There is pulmonary edema with bilateral perihilar haziness.",
        "consolidation": "Consolidation is present in the {location}.",
        "atelectasis": "Linear atelectasis is noted in the {location}.",
        "pleural_effusion": "There is a {side} pleural effusion.",
        "pneumothorax": "A {side} pneumothorax is identified.",
        "pneumonia": "Findings consistent with pneumonia in the {location}.",
        "nodule": "A {size} nodule is noted in the {location}.",
        "mass": "A mass is present in the {location}.",
        "fracture": "A {side} rib fracture is identified.",
        "emphysema": "Hyperinflation consistent with emphysema.",
        "fibrosis": "Interstitial changes suggesting fibrosis.",
        "thickening": "Pleural thickening noted on the {side}.",
        "hernia": "A hiatal hernia is present.",
    }

    LOCATIONS = ["right lower lobe", "left lower lobe", "right upper lobe",
                 "left upper lobe", "right middle lobe", "bilateral bases"]
    SIDES = ["right", "left", "bilateral"]
    SIZES = ["5mm", "8mm", "12mm", "15mm"]

    VOCAB = {"<pad>": 0, "<start>": 1, "<end>": 2}
    _next_id = 3

    def __init__(self, n_samples=500, image_size=128):
        self.image_size = image_size
        self.data = []

        # Build vocabulary from templates
        for template in self.TEMPLATES.values():
            for word in template.lower().replace(",", "").replace(".", "").split():
                if word not in self.VOCAB:
                    self.VOCAB[word] = self._next_id
                    self._next_id += 1
        for loc in self.LOCATIONS + self.SIDES + self.SIZES:
            for word in loc.lower().split():
                if word not in self.VOCAB:
                    self.VOCAB[word] = self._next_id
                    self._next_id += 1
        # Add extra tokens for normal finding
        for word in ["no", "acute", "cardiopulmonary", "abnormality", "identified", "normal", "study"]:
            if word not in self.VOCAB:
                self.VOCAB[word] = self._next_id
                self._next_id += 1

        self.INV_VOCAB = {v: k for k, v in self.VOCAB.items()}

        for _ in range(n_samples):
            # Random number of findings (0-3)
            n_findings = np.random.choice([0, 1, 2, 3], p=[0.3, 0.35, 0.25, 0.1])
            finding_indices = sorted(random.sample(range(14), n_findings))
            labels = torch.zeros(14)
            labels[finding_indices] = 1.0

            # Generate image
            img = self._create_image(finding_indices)

            # Generate report
            report = self._generate_report(finding_indices)
            report_tokens = self._tokenize(report)

            self.data.append({
                "image": img,
                "labels": labels,
                "report_text": report,
                "report_tokens": report_tokens,
                "findings": [self.FINDINGS[i] for i in finding_indices],
            })

    def _create_image(self, finding_indices):
        img = torch.ones(1, self.image_size, self.image_size) * 0.7
        # Add noise pattern for each finding
        y, x = torch.meshgrid(
            torch.arange(self.image_size),
            torch.arange(self.image_size), indexing='ij'
        )
        for idx in finding_indices:
            cx = random.randint(20, self.image_size - 20)
            cy = random.randint(20, self.image_size - 20)
            r = random.randint(8, 20)
            mask = ((x - cx)**2 + (y - cy)**2) < r**2
            intensity = 0.3 + random.random() * 0.4
            img[0][mask] = intensity
        img += torch.randn_like(img) * 0.05
        return img.clamp(0, 1)

    def _generate_report(self, finding_indices):
        if not finding_indices:
            return "No acute cardiopulmonary abnormality identified. Normal study."
        sentences = []
        for idx in finding_indices:
            template = self.TEMPLATES[self.FINDINGS[idx]]
            template = template.replace("{location}", random.choice(self.LOCATIONS))
            template = template.replace("{side}", random.choice(self.SIDES))
            template = template.replace("{size}", random.choice(self.SIZES))
            sentences.append(template)
        return " ".join(sentences)

    def _tokenize(self, text):
        words = text.lower().replace(",", "").replace(".", "").split()
        return [self.VOCAB.get(w, 0) for w in words]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Create datasets
train_data = SyntheticRadiologyDataset(n_samples=600)
val_data = SyntheticRadiologyDataset(n_samples=100)
test_data = SyntheticRadiologyDataset(n_samples=200)

print(f"Vocabulary size: {len(train_data.VOCAB)}")
print(f"Training:   {len(train_data)} studies")
print(f"Validation: {len(val_data)} studies")
print(f"Test:       {len(test_data)} studies")

## 2. Exploratory Data Analysis

In [None]:
# Finding distribution
finding_counts = torch.stack([d["labels"] for d in train_data.data]).sum(dim=0)

fig, axes = plt.subplots(1, 2, figsize=(16, 5))

# Finding frequency
bars = axes[0].barh(range(14), finding_counts.numpy(), color='steelblue')
axes[0].set_yticks(range(14))
axes[0].set_yticklabels(SyntheticRadiologyDataset.FINDINGS, fontsize=9)
axes[0].set_xlabel("Count")
axes[0].set_title("Finding Frequency in Training Set")
axes[0].invert_yaxis()

# Number of findings per study
n_findings = [len(d["findings"]) for d in train_data.data]
axes[1].hist(n_findings, bins=range(5), color='coral', edgecolor='black', align='left')
axes[1].set_xlabel("Number of Findings")
axes[1].set_ylabel("Count")
axes[1].set_title("Findings per Study")
axes[1].set_xticks(range(4))

plt.tight_layout()
plt.show()

In [None]:
# Sample studies
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

for i in range(8):
    sample = train_data[i]
    axes[i].imshow(sample["image"][0].numpy(), cmap='gray')
    findings = ", ".join(sample["findings"]) if sample["findings"] else "Normal"
    axes[i].set_title(f"Findings: {findings}", fontsize=8, wrap=True)
    axes[i].axis('off')

plt.suptitle("Sample Chest X-rays with Findings", fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 3. Baseline Model

In [None]:
# Template-based baseline
def template_baseline(labels):
    """Generate report from labels using templates."""
    findings = []
    for i, label in enumerate(labels):
        if label > 0.5:
            findings.append(SyntheticRadiologyDataset.FINDINGS[i])
    if not findings:
        return "No acute cardiopulmonary abnormality identified."
    return "; ".join([f"{f} detected" for f in findings]) + "."

# Evaluate baseline (using ground truth labels as perfect detection)
from collections import Counter
baseline_bleu_proxy = []
for sample in test_data:
    pred_report = template_baseline(sample["labels"])
    gt_report = sample["report_text"]
    pred_words = set(pred_report.lower().split())
    gt_words = set(gt_report.lower().split())
    overlap = len(pred_words & gt_words) / max(len(gt_words), 1)
    baseline_bleu_proxy.append(overlap)

print(f"Baseline word overlap score: {np.mean(baseline_bleu_proxy):.3f}")

## 4. Model Architecture

In [None]:
class RadiologyMultimodalModel(nn.Module):
    """LLaVA-style model for radiology report generation."""

    def __init__(self, image_size=128, patch_size=16, vision_dim=256,
                 llm_dim=128, vocab_size=100, num_labels=14,
                 num_heads=4, num_layers=2):
        super().__init__()
        self.num_patches = (image_size // patch_size) ** 2

        # Vision encoder (frozen)
        self.vision_encoder = nn.Sequential(
            nn.Conv2d(1, vision_dim, kernel_size=patch_size, stride=patch_size),
            nn.Flatten(2),
        )

        # Projection layer (always trainable)
        self.projector = nn.Sequential(
            nn.Linear(vision_dim, llm_dim),
            nn.GELU(),
            nn.Linear(llm_dim, llm_dim),
        )

        # Classification head (for finding detection)
        self.cls_head = nn.Linear(vision_dim, num_labels)

        # Language model
        self.text_embed = nn.Embedding(vocab_size, llm_dim)
        self.pos_embed = nn.Embedding(512, llm_dim)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model=llm_dim, nhead=num_heads,
            dim_feedforward=llm_dim * 4, batch_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.output_head = nn.Linear(llm_dim, vocab_size)

    def set_stage(self, stage):
        for p in self.vision_encoder.parameters():
            p.requires_grad = False
        for p in self.projector.parameters():
            p.requires_grad = True
        for p in self.cls_head.parameters():
            p.requires_grad = True

        llm_trainable = (stage == 2)
        for module in [self.text_embed, self.pos_embed, self.decoder, self.output_head]:
            for p in module.parameters():
                p.requires_grad = llm_trainable

        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        total = sum(p.numel() for p in self.parameters())
        print(f"Stage {stage}: {trainable:,}/{total:,} params trainable ({trainable/total:.1%})")

    def forward(self, images, token_ids):
        B = images.shape[0]

        # Vision
        with torch.no_grad():
            vis = self.vision_encoder(images)
            vis = vis.transpose(1, 2)

        # Classification from pooled vision features
        vis_pooled = vis.mean(dim=1)
        label_logits = self.cls_head(vis_pooled)

        # Projection
        vis_tokens = self.projector(vis)

        # Text
        text_tokens = self.text_embed(token_ids)

        # Combine
        combined = torch.cat([vis_tokens, text_tokens], dim=1)
        seq_len = combined.shape[1]
        positions = torch.arange(seq_len, device=combined.device).unsqueeze(0)
        combined = combined + self.pos_embed(positions)

        mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(combined.device)
        memory = torch.zeros(B, 1, combined.shape[-1], device=combined.device)
        output = self.decoder(combined, memory, tgt_mask=mask)

        logits = self.output_head(output)
        return logits, label_logits


model = RadiologyMultimodalModel(vocab_size=len(train_data.VOCAB)).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")

## 5. Two-Stage Training

In [None]:
def train_model(model, train_data, val_data, stage1_epochs=20, stage2_epochs=30):
    """Two-stage training pipeline."""
    history = {"stage1_loss": [], "stage2_loss": [], "stage2_detect_loss": []}

    # Stage 1: Feature Alignment
    model.set_stage(1)
    opt = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

    for epoch in range(stage1_epochs):
        model.train()
        epoch_loss = 0
        indices = list(range(len(train_data)))
        random.shuffle(indices)

        for i in range(0, len(indices), 32):
            batch = [train_data[j] for j in indices[i:i+32]]
            images = torch.stack([b["image"] for b in batch]).to(device)
            labels = torch.stack([b["labels"] for b in batch]).to(device)

            # Simple captioning: predict first few report tokens
            max_len = min(max(len(b["report_tokens"]) for b in batch), 15)
            input_ids = torch.zeros(len(batch), max_len, dtype=torch.long, device=device)
            target_ids = torch.full((len(batch), max_len), -100, dtype=torch.long, device=device)

            for j, b in enumerate(batch):
                tokens = b["report_tokens"][:max_len-1]
                input_ids[j, 0] = 1  # <start>
                for k, t in enumerate(tokens):
                    input_ids[j, k+1] = t
                    target_ids[j, k] = t
                target_ids[j, len(tokens)] = 2  # <end>

            logits, label_logits = model(images, input_ids)
            text_logits = logits[:, model.num_patches:, :]
            gen_loss = F.cross_entropy(text_logits.reshape(-1, text_logits.shape[-1]),
                                       target_ids.reshape(-1), ignore_index=-100)
            detect_loss = F.binary_cross_entropy_with_logits(label_logits, labels)
            loss = gen_loss + 0.5 * detect_loss

            opt.zero_grad()
            loss.backward()
            opt.step()
            epoch_loss += loss.item()

        history["stage1_loss"].append(epoch_loss / (len(indices) // 32))
        if (epoch + 1) % 5 == 0:
            print(f"  Stage 1 Epoch {epoch+1}: loss={history['stage1_loss'][-1]:.4f}")

    # Stage 2: Instruction Tuning
    model.set_stage(2)
    opt = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-4)

    for epoch in range(stage2_epochs):
        model.train()
        epoch_loss = 0
        epoch_detect = 0
        indices = list(range(len(train_data)))
        random.shuffle(indices)

        for i in range(0, len(indices), 32):
            batch = [train_data[j] for j in indices[i:i+32]]
            images = torch.stack([b["image"] for b in batch]).to(device)
            labels = torch.stack([b["labels"] for b in batch]).to(device)

            max_len = min(max(len(b["report_tokens"]) for b in batch) + 1, 20)
            input_ids = torch.zeros(len(batch), max_len, dtype=torch.long, device=device)
            target_ids = torch.full((len(batch), max_len), -100, dtype=torch.long, device=device)

            for j, b in enumerate(batch):
                tokens = b["report_tokens"][:max_len-1]
                input_ids[j, 0] = 1
                for k, t in enumerate(tokens):
                    input_ids[j, k+1] = t
                    target_ids[j, k] = t
                target_ids[j, len(tokens)] = 2

            logits, label_logits = model(images, input_ids)
            text_logits = logits[:, model.num_patches:, :]
            gen_loss = F.cross_entropy(text_logits.reshape(-1, text_logits.shape[-1]),
                                       target_ids.reshape(-1), ignore_index=-100)
            detect_loss = F.binary_cross_entropy_with_logits(label_logits, labels)
            loss = gen_loss + 0.5 * detect_loss

            opt.zero_grad()
            loss.backward()
            opt.step()
            epoch_loss += gen_loss.item()
            epoch_detect += detect_loss.item()

        n = len(indices) // 32
        history["stage2_loss"].append(epoch_loss / n)
        history["stage2_detect_loss"].append(epoch_detect / n)
        if (epoch + 1) % 10 == 0:
            print(f"  Stage 2 Epoch {epoch+1}: gen={history['stage2_loss'][-1]:.4f}, "
                  f"detect={history['stage2_detect_loss'][-1]:.4f}")

    return history

history = train_model(model, train_data, val_data)

## 6. Evaluation

In [None]:
# Training curves
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

axes[0].plot(history["stage1_loss"], color='steelblue', linewidth=2)
axes[0].set_title("Stage 1: Alignment Loss")
axes[0].set_xlabel("Epoch"); axes[0].set_ylabel("Loss")
axes[0].grid(True, alpha=0.3)

axes[1].plot(history["stage2_loss"], color='coral', linewidth=2)
axes[1].set_title("Stage 2: Report Generation Loss")
axes[1].set_xlabel("Epoch"); axes[1].set_ylabel("Loss")
axes[1].grid(True, alpha=0.3)

axes[2].plot(history["stage2_detect_loss"], color='seagreen', linewidth=2)
axes[2].set_title("Stage 2: Finding Detection Loss")
axes[2].set_xlabel("Epoch"); axes[2].set_ylabel("Loss")
axes[2].grid(True, alpha=0.3)

plt.suptitle("Two-Stage Training Curves", fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# Evaluate finding detection
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for sample in test_data:
        img = sample["image"].unsqueeze(0).to(device)
        dummy_ids = torch.ones(1, 1, dtype=torch.long, device=device)
        _, label_logits = model(img, dummy_ids)
        preds = (torch.sigmoid(label_logits) > 0.5).float().cpu()
        all_preds.append(preds[0])
        all_labels.append(sample["labels"])

preds_tensor = torch.stack(all_preds).numpy()
labels_tensor = torch.stack(all_labels).numpy()

# Per-finding F1
print("Per-Finding F1 Scores:")
print("-" * 40)
for i, finding in enumerate(SyntheticRadiologyDataset.FINDINGS):
    f1 = f1_score(labels_tensor[:, i], preds_tensor[:, i], zero_division=0)
    print(f"  {finding:20s}: {f1:.3f}")

macro_f1 = f1_score(labels_tensor, preds_tensor, average='macro', zero_division=0)
micro_f1 = f1_score(labels_tensor, preds_tensor, average='micro', zero_division=0)
print(f"\nMacro F1: {macro_f1:.3f}")
print(f"Micro F1: {micro_f1:.3f}")

## 7. Error Analysis

In [None]:
# Confusion analysis for critical findings
critical_findings = ["pneumothorax", "fracture", "mass", "pleural_effusion"]
critical_indices = [SyntheticRadiologyDataset.FINDINGS.index(f) for f in critical_findings]

fig, axes = plt.subplots(1, len(critical_findings), figsize=(20, 4))

for i, (idx, finding) in enumerate(zip(critical_indices, critical_findings)):
    cm = confusion_matrix(labels_tensor[:, idx], preds_tensor[:, idx])
    im = axes[i].imshow(cm, cmap='Blues')
    axes[i].set_title(finding, fontsize=11, fontweight='bold')
    axes[i].set_xlabel("Predicted")
    axes[i].set_ylabel("True")
    axes[i].set_xticks([0, 1])
    axes[i].set_xticklabels(["Neg", "Pos"])
    axes[i].set_yticks([0, 1])
    axes[i].set_yticklabels(["Neg", "Pos"])
    for r in range(2):
        for c in range(2):
            axes[i].text(c, r, str(cm[r, c]), ha='center', va='center',
                        fontsize=14, fontweight='bold',
                        color='white' if cm[r, c] > cm.max()/2 else 'black')

plt.suptitle("Confusion Matrices for Critical Findings", fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# Critical finding sensitivity
print("\nCritical Finding Sensitivity (must be > 98% for deployment):")
for idx, finding in zip(critical_indices, critical_findings):
    tp = ((preds_tensor[:, idx] == 1) & (labels_tensor[:, idx] == 1)).sum()
    fn = ((preds_tensor[:, idx] == 0) & (labels_tensor[:, idx] == 1)).sum()
    sensitivity = tp / max(tp + fn, 1)
    status = "PASS" if sensitivity >= 0.98 else "FAIL"
    print(f"  {finding:20s}: {sensitivity:.1%} [{status}]")

## 8. Deployment Readiness

In [None]:
# Inference latency benchmark
model.eval()
latencies = []

with torch.no_grad():
    for _ in range(50):
        img = torch.randn(1, 1, 128, 128).to(device)
        ids = torch.ones(1, 1, dtype=torch.long, device=device)
        start = time.time()
        _ = model(img, ids)
        latencies.append((time.time() - start) * 1000)

print(f"Inference Latency:")
print(f"  P50: {np.percentile(latencies, 50):.1f} ms")
print(f"  P95: {np.percentile(latencies, 95):.1f} ms")
print(f"  P99: {np.percentile(latencies, 99):.1f} ms")
print(f"  Target: < 30,000 ms (30 seconds)")

# Model size
param_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
print(f"\nModel Size: {param_bytes / 1024 / 1024:.1f} MB")

## 9. Summary

In [None]:
print("=" * 60)
print("CASE STUDY SUMMARY: Radiology Report Generation")
print("=" * 60)
print(f"\nCompany: Meridian Diagnostic Intelligence")
print(f"Industry: Healthcare -- Diagnostic Radiology")
print(f"\nApproach: Multimodal Instruction Tuning (LLaVA-style)")
print(f"  - Two-stage training: alignment + instruction tuning")
print(f"  - MLP projection bridging vision and language")
print(f"  - Multi-label detection as auxiliary objective")
print(f"\nResults:")
print(f"  - Macro F1: {macro_f1:.3f}")
print(f"  - Micro F1: {micro_f1:.3f}")
print(f"  - Inference latency (P50): {np.percentile(latencies, 50):.1f} ms")
print(f"\nKey Takeaway: Multimodal instruction tuning provides a flexible,")
print(f"end-to-end approach to radiology report generation that can handle")
print(f"both structured detection and free-text generation in a single model.")