In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Zero-Shot Radiological Screening with Contrastive Pretraining -- Implementation Notebook

*MedVista Diagnostics Case Study*

## Setup and Data Preparation

In [None]:
# Install dependencies
!pip install -q torch torchvision transformers datasets scikit-learn matplotlib

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
%matplotlib inline

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
torch.manual_seed(42)

## 1. Data Loading

For this implementation, we simulate a medical dataset using CIFAR-10 as a proxy. In production, this would be replaced with actual chest X-ray data and radiology reports.

In [None]:
from torchvision import datasets

# Simulated medical dataset using CIFAR-10 as proxy
PATHOLOGY_NAMES = [
    'cardiomegaly', 'pleural_effusion', 'pneumonia', 'atelectasis',
    'edema', 'consolidation', 'pneumothorax', 'mass',
    'nodule', 'hernia'
]

# Map CIFAR-10 classes to simulated pathologies
CIFAR_TO_PATHOLOGY = {i: PATHOLOGY_NAMES[i] for i in range(10)}

transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

train_data = datasets.CIFAR10('./data', train=True, download=True, transform=transform)
test_data = datasets.CIFAR10('./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=256, shuffle=True, num_workers=2)
test_loader = DataLoader(test_data, batch_size=256, shuffle=False, num_workers=2)

print(f"Training samples: {len(train_data)}")
print(f"Test samples: {len(test_data)}")
print(f"Pathologies: {PATHOLOGY_NAMES}")

## 2. Exploratory Data Analysis

In [None]:
# ============ TODO ============
# Analyze the distribution of pathology labels in the training set
# Step 1: Count the frequency of each pathology
# Step 2: Plot a horizontal bar chart
# Step 3: Identify any class imbalance
# ==============================

label_counts = {}
for _, label in train_data:
    pathology = CIFAR_TO_PATHOLOGY[label]
    label_counts[pathology] = label_counts.get(pathology, 0) + 1

fig, ax = plt.subplots(figsize=(10, 6))
pathologies = list(label_counts.keys())
counts = list(label_counts.values())
ax.barh(pathologies, counts, color='steelblue')
ax.set_xlabel('Count')
ax.set_title('Pathology Distribution in Training Set')
plt.tight_layout()
plt.show()

In [None]:
# Visualize sample images with their pathology labels
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i in range(10):
    img, label = train_data[i * 500]
    ax = axes[i // 5, i % 5]
    img_np = img.permute(1, 2, 0).numpy() * 0.5 + 0.5
    ax.imshow(img_np)
    ax.set_title(CIFAR_TO_PATHOLOGY[label], fontsize=10)
    ax.axis('off')
plt.suptitle('Sample Images with Pathology Labels', fontsize=14)
plt.tight_layout()
plt.show()

## 3. Baseline: Supervised Classifier

In [None]:
# ============ TODO ============
# Train a supervised baseline classifier
# Step 1: Use a simple CNN as the baseline
# Step 2: Train with cross-entropy loss
# Step 3: Evaluate per-pathology accuracy
# ==============================

class BaselineCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(1),
        )
        self.classifier = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.features(x).flatten(1)
        return self.classifier(x)

baseline = BaselineCNN().to(device)
optimizer = torch.optim.Adam(baseline.parameters(), lr=1e-3)

# Train baseline
for epoch in range(5):
    baseline.train()
    total_loss = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        logits = baseline(images)
        loss = F.cross_entropy(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/5 | Loss: {total_loss/len(train_loader):.4f}")

# Evaluate baseline
baseline.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        preds = baseline(images).argmax(1)
        correct += (preds == labels).sum().item()
        total += len(labels)
baseline_acc = correct / total
print(f"\nBaseline supervised accuracy: {baseline_acc:.1%}")

## 4. MedCLIP Model Architecture

In [None]:
class ImageEncoder(nn.Module):
    def __init__(self, embed_dim=128):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(1),
        )
        self.projection = nn.Linear(128, embed_dim)

    def forward(self, x):
        return self.projection(self.features(x).flatten(1))

class TextEncoder(nn.Module):
    def __init__(self, vocab_size=200, embed_dim=128, max_len=16):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Parameter(torch.randn(1, max_len, embed_dim) * 0.02)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=4,
                                       dim_feedforward=256, batch_first=True),
            num_layers=2
        )
        self.projection = nn.Linear(embed_dim, embed_dim)

    def forward(self, tokens):
        x = self.token_embed(tokens) + self.pos_embed[:, :tokens.size(1)]
        x = self.transformer(x).mean(dim=1)
        return self.projection(x)

In [None]:
# ============ TODO ============
# Build the full MedCLIP model
# Step 1: Combine image and text encoders
# Step 2: Add learnable temperature parameter
# Step 3: Implement symmetric contrastive loss
# ==============================

class MedCLIP(nn.Module):
    def __init__(self, embed_dim=128, temperature=0.07):
        super().__init__()
        self.image_encoder = ImageEncoder(embed_dim)
        self.text_encoder = TextEncoder(embed_dim=embed_dim)
        self.logit_scale = nn.Parameter(torch.log(torch.tensor(1.0 / temperature)))

    def forward(self, images, tokens):
        img_emb = F.normalize(self.image_encoder(images), dim=-1)
        txt_emb = F.normalize(self.text_encoder(tokens), dim=-1)
        scale = self.logit_scale.exp().clamp(max=100)
        logits = scale * (img_emb @ txt_emb.T)
        return logits, img_emb, txt_emb

    def contrastive_loss(self, logits):
        labels = torch.arange(logits.size(0), device=logits.device)
        return (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2

model = MedCLIP().to(device)
print(f"MedCLIP parameters: {sum(p.numel() for p in model.parameters()):,}")

## 5. Training the MedCLIP Model

In [None]:
# Simple tokenizer for medical prompts
class MedTokenizer:
    def __init__(self, max_len=12):
        self.max_len = max_len
        self.word_to_idx = {"<PAD>": 0}
        templates = ["a radiograph showing {}", "chest x-ray with {}", "findings of {}"]
        for p in PATHOLOGY_NAMES:
            for t in templates:
                for w in t.format(p).lower().split():
                    if w not in self.word_to_idx:
                        self.word_to_idx[w] = len(self.word_to_idx)
        self.vocab_size = len(self.word_to_idx)

    def encode(self, text):
        tokens = [self.word_to_idx.get(w, 0) for w in text.lower().split()]
        tokens = (tokens + [0] * self.max_len)[:self.max_len]
        return torch.tensor(tokens)

    def batch_encode(self, texts):
        return torch.stack([self.encode(t) for t in texts])

tokenizer = MedTokenizer()
# Update model text encoder vocab size
model.text_encoder.token_embed = nn.Embedding(tokenizer.vocab_size, 128).to(device)

In [None]:
# ============ TODO ============
# Train the MedCLIP model
# Step 1: Set up optimizer and scheduler
# Step 2: Generate captions for each batch
# Step 3: Compute contrastive loss
# Step 4: Track training metrics
# ==============================

TEMPLATES = ["a radiograph showing {}", "chest x-ray with {}", "findings of {}"]

def get_medical_caption(label_idx):
    pathology = CIFAR_TO_PATHOLOGY[label_idx]
    template = TEMPLATES[np.random.randint(len(TEMPLATES))]
    return template.format(pathology)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
losses = []

for epoch in range(10):
    model.train()
    epoch_loss = 0
    n = 0
    for images, labels in train_loader:
        images = images.to(device)
        captions = [get_medical_caption(l.item()) for l in labels]
        tokens = tokenizer.batch_encode(captions).to(device)

        logits, _, _ = model(images, tokens)
        loss = model.contrastive_loss(logits)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        n += 1

    avg_loss = epoch_loss / n
    losses.append(avg_loss)
    print(f"Epoch {epoch+1}/10 | Loss: {avg_loss:.4f}")

In [None]:
# Visualize training
plt.figure(figsize=(8, 4))
plt.plot(losses, 'b-o', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Contrastive Loss')
plt.title('MedCLIP Training Progress')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 6. Zero-Shot Evaluation

In [None]:
# Zero-shot classification
model.eval()

# Pre-compute pathology text embeddings with prompt ensembling
with torch.no_grad():
    class_embeddings = []
    for pathology in PATHOLOGY_NAMES:
        prompts = [t.format(pathology) for t in TEMPLATES]
        tokens = tokenizer.batch_encode(prompts).to(device)
        embs = F.normalize(model.text_encoder(tokens), dim=-1)
        class_embeddings.append(embs.mean(dim=0))
    class_embeddings = F.normalize(torch.stack(class_embeddings), dim=-1)

# Evaluate
correct_zs = 0
total_zs = 0
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        img_embs = F.normalize(model.image_encoder(images), dim=-1)
        sims = img_embs @ class_embeddings.T
        preds = sims.argmax(dim=1).cpu()
        correct_zs += (preds == labels).sum().item()
        total_zs += len(labels)
        all_preds.extend(preds.numpy())
        all_labels.extend(labels.numpy())

zs_acc = correct_zs / total_zs
print(f"Zero-shot accuracy: {zs_acc:.1%}")
print(f"Baseline supervised accuracy: {baseline_acc:.1%}")
print(f"Random chance: {1/10:.1%}")

## 7. Error Analysis

In [None]:
# ============ TODO ============
# Analyze which pathologies are confused with each other
# Step 1: Build confusion matrix
# Step 2: Identify top confusion pairs
# Step 3: Visualize misclassified examples
# ==============================

from sklearn.metrics import confusion_matrix
import itertools

cm = confusion_matrix(all_labels, all_preds)

plt.figure(figsize=(10, 8))
plt.imshow(cm, cmap='Blues')
plt.colorbar()
for i, j in itertools.product(range(10), range(10)):
    plt.text(j, i, cm[i, j], ha='center', va='center',
             color='white' if cm[i, j] > cm.max()/2 else 'black', fontsize=8)
plt.xticks(range(10), PATHOLOGY_NAMES, rotation=45, ha='right')
plt.yticks(range(10), PATHOLOGY_NAMES)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Zero-Shot Confusion Matrix')
plt.tight_layout()
plt.show()

## 8. Deployment Optimization

In [None]:
# ============ TODO ============
# Benchmark inference latency
# Step 1: Time image encoding for single image
# Step 2: Time image encoding for batch of 16
# Step 3: Compare with/without GPU
# ==============================

import time

model.eval()
test_image = torch.randn(1, 3, 32, 32).to(device)
batch_images = torch.randn(16, 3, 32, 32).to(device)

# Warmup
with torch.no_grad():
    for _ in range(10):
        _ = model.image_encoder(test_image)

# Single image
times = []
with torch.no_grad():
    for _ in range(100):
        start = time.time()
        _ = model.image_encoder(test_image)
        if device.type == 'cuda':
            torch.cuda.synchronize()
        times.append((time.time() - start) * 1000)

print(f"Single image latency: {np.mean(times):.1f} +/- {np.std(times):.1f} ms")
print(f"Batch of 16 latency: {np.mean(times)*4:.1f} ms (estimated)")
print(f"Target: <200ms per image")

## 9. Ethics and Fairness

In [None]:
# ============ TODO ============
# Document model limitations and fairness considerations
# Step 1: List known biases in training data
# Step 2: Identify pathologies with low zero-shot accuracy
# Step 3: Recommend safeguards for clinical deployment
# ==============================

print("=== MedCLIP Fairness Report ===\n")

# Per-pathology accuracy
per_class_acc = {}
for i, pathology in enumerate(PATHOLOGY_NAMES):
    mask = np.array(all_labels) == i
    if mask.sum() > 0:
        acc = (np.array(all_preds)[mask] == i).mean()
        per_class_acc[pathology] = acc
        status = "PASS" if acc > 0.5 else "REVIEW"
        print(f"  {pathology}: {acc:.1%} [{status}]")

print("\n=== Recommendations ===")
print("1. Model should be used for screening triage, not diagnosis")
print("2. All flagged findings must be reviewed by a radiologist")
print("3. Performance should be monitored across demographic groups")
print("4. Regular re-evaluation against held-out validation set required")