# ü©∫ Explainable Vision-Language Model for Radiology (LLaVA-style, Colab-ready)

**End-to-end, runnable Google Colab notebook**  
GPU: **NVIDIA T4**  
Frameworks: **PyTorch + Hugging Face Transformers**  
Dataset: **ROCOv2-radiology (Hugging Face)**

---

## 1Ô∏è‚É£ Environment Setup


In [None]:
!pip install -q torch torchvision torchaudio
!pip install -q transformers accelerate datasets sentencepiece
!pip install -q scikit-learn pillow matplotlib opencv-python
!pip install -q pytorch-grad-cam

import torch
import torchvision
import numpy as np
import random
import os

print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

## 2Ô∏è‚É£ Dataset Loading

In [None]:
from datasets import load_dataset

dataset = load_dataset("roco", "radiology")

print(dataset)
print(dataset["train"][0].keys())

## 3Ô∏è‚É£ Model Initialization

In [None]:
from transformers import AutoProcessor, AutoModelForVision2Seq

model_id = "llava-hf/llava-1.5-7b-hf"

processor = AutoProcessor.from_pretrained(model_id)

model = AutoModelForVision2Seq.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="auto"
)

model.gradient_checkpointing_enable()

print("Vision tower:", model.vision_tower.__class__)

## 4Ô∏è‚É£ Data Preprocessing & Prompting

In [None]:
from torchvision import transforms

image_transform = transforms.Compose([
    transforms.ToTensor()
])

MEDICAL_CONCEPTS = [
    "pneumonia",
    "edema",
    "fracture",
    "tumor",
    "hemorrhage",
    "infection",
    "lesion",
    "effusion",
    "nodule",
    "cardiomegaly"
]

NUM_LABELS = len(MEDICAL_CONCEPTS)

def extract_concepts(caption):
    """Weak multi-label extraction via keyword matching"""
    caption = caption.lower()
    labels = [1 if concept in caption else 0 for concept in MEDICAL_CONCEPTS]
    return torch.tensor(labels, dtype=torch.float32)

def build_prompt():
    return (
        "You are an expert radiologist. "
        "Given the medical image, generate a concise and accurate radiology report."
    )

def preprocess(example):
    image = example["image"]
    caption = example["caption"]

    prompt = build_prompt()
    labels = extract_concepts(caption)

    inputs = processor(
        images=image,
        text=prompt,
        return_tensors="pt"
    )

    inputs = {k: v.squeeze(0) for k, v in inputs.items()}
    inputs["labels_text"] = caption
    inputs["labels_concepts"] = labels
    return inputs

train_ds = dataset["train"].shuffle(seed=seed).select(range(2000)).map(preprocess)
val_ds   = dataset["validation"].shuffle(seed=seed).select(range(500)).map(preprocess)
test_ds  = dataset["test"].shuffle(seed=seed).select(range(500)).map(preprocess)


## 5Ô∏è‚É£ Training Loop (Report + Concept Loss)

In [None]:
from torch.utils.data import DataLoader
from torch import nn
from transformers import get_scheduler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ConceptClassifier(nn.Module):
    def __init__(self, hidden_dim, num_labels):
        super().__init__()
        self.fc = nn.Linear(hidden_dim, num_labels)

    def forward(self, x):
        return self.fc(x)

vision_hidden = model.vision_tower.config.hidden_size
classifier = ConceptClassifier(vision_hidden, NUM_LABELS).to(device)

optimizer = torch.optim.AdamW(
    list(model.parameters()) + list(classifier.parameters()),
    lr=2e-5
)

bce_loss = nn.BCEWithLogitsLoss()

def collate_fn(batch):
    keys = batch[0].keys()
    out = {}
    for k in keys:
        if torch.is_tensor(batch[0][k]):
            out[k] = torch.stack([b[k] for b in batch])
    return out

train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_ds, batch_size=1, shuffle=False, collate_fn=collate_fn)

num_epochs = 2
num_training_steps = num_epochs * len(train_loader)

lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=100,
    num_training_steps=num_training_steps
)

model.train()

for epoch in range(num_epochs):
    total_loss = 0
    for step, batch in enumerate(train_loader):
        optimizer.zero_grad()

        pixel_values = batch["pixel_values"].to(device)
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels_concepts = batch["labels_concepts"].to(device)

        outputs = model(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=input_ids
        )

        vision_outputs = model.vision_tower(pixel_values)
        vision_feat = vision_outputs.last_hidden_state.mean(dim=1)
        logits = classifier(vision_feat)

        loss_cls = bce_loss(logits, labels_concepts)
        loss = outputs.loss + loss_cls

        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        total_loss += loss.item()

        if step % 200 == 0:
            print(f"Epoch {epoch} | Step {step} | Loss {loss.item():.4f}")

    print(f"Epoch {epoch} Average Loss: {total_loss/len(train_loader):.4f}")


## 6Ô∏è‚É£ AUROC Evaluation

In [None]:
from sklearn.metrics import roc_auc_score

model.eval()
classifier.eval()

all_labels = []
all_probs = []

with torch.no_grad():
    for batch in val_loader:
        pixel_values = batch["pixel_values"].to(device)
        labels_concepts = batch["labels_concepts"].cpu().numpy()

        vision_outputs = model.vision_tower(pixel_values)
        vision_feat = vision_outputs.last_hidden_state.mean(dim=1)
        logits = classifier(vision_feat)
        probs = torch.sigmoid(logits).cpu().numpy()

        all_labels.append(labels_concepts)
        all_probs.append(probs)

all_labels = np.vstack(all_labels)
all_probs = np.vstack(all_probs)

macro_auc = roc_auc_score(all_labels, all_probs, average="macro")
micro_auc = roc_auc_score(all_labels, all_probs, average="micro")

print("Macro AUROC:", macro_auc)
print("Micro AUROC:", micro_auc)


## 7Ô∏è‚É£ Grad-CAM Implementation

In [None]:
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image

def reshape_transform(tensor):
    B, N, C = tensor.shape
    H = W = int((N - 1) ** 0.5)
    return tensor[:, 1:, :].permute(0, 2, 1).reshape(B, C, H, W)

target_layer = model.vision_tower.vision_model.encoder.layers[-1]

cam = GradCAM(
    model=model.vision_tower,
    target_layers=[target_layer],
    reshape_transform=reshape_transform
)


## 8Ô∏è‚É£ Visualization Examples

In [None]:
import matplotlib.pyplot as plt

example = test_ds[0]
image = example["pixel_values"].unsqueeze(0).to(device)

grayscale_cam = cam(input_tensor=image)[0]

orig_img = example["pixel_values"].permute(1, 2, 0).cpu().numpy()
orig_img = (orig_img - orig_img.min()) / (orig_img.max() - orig_img.min())

cam_image = show_cam_on_image(orig_img, grayscale_cam, use_rgb=True)

plt.figure(figsize=(6,6))
plt.imshow(cam_image)
plt.axis("off")
plt.title("Grad-CAM (Vision Encoder)")
plt.show()


## 9Ô∏è‚É£ Inference Demo

In [None]:
model.eval()

example = test_ds[1]

inputs = processor(
    images=example["pixel_values"],
    text=build_prompt(),
    return_tensors="pt"
).to(device)

with torch.no_grad():
    generated_ids = model.generate(
        **inputs,
        max_new_tokens=128
    )

report = processor.decode(generated_ids[0], skip_special_tokens=True)
print("Generated Radiology Report:\n")
print(report)


## üîü Notes on Limitations & Extensions

- Weak supervision for concepts via keyword matching (can be replaced with expert labels)
- ROCO captions are not full clinical reports
- Larger batch sizes require stronger GPUs
- Extend with:
  ‚Ä¢ CheXpert / MIMIC-CXR labels
  ‚Ä¢ LoRA fine-tuning
  ‚Ä¢ Attention rollout explainability
  ‚Ä¢ Clinical prompt engineering
