<a href="https://colab.research.google.com/github/Yiwen91/MED-VQA/blob/main/MED_VQA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -U \
  transformers==4.44.2 \
  datasets==2.20.0 \
  accelerate==0.32.1 \
  peft==0.8.2 \
  evaluate==0.4.1 \
  fsspec==2024.5.0




In [None]:
from transformers import BlipProcessor, BlipForConditionalGeneration
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained(
    "Salesforce/blip-image-captioning-base",
    torch_dtype=torch.float16 if device == "cuda" else torch.float32
)
model.to(device)

print("BLIP loaded on", device)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


BLIP loaded on cpu


In [None]:
import torch
import random
import io
import string
import numpy as np
from tqdm import tqdm
from collections import defaultdict
from sklearn.metrics import accuracy_score

# Hugging Face
from datasets import load_dataset
from transformers import (
    BlipProcessor,
    BlipForConditionalGeneration
)

# PyTorch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Hugging Face Transformers
from transformers import BertTokenizer, Blip2Processor

# TorchVision & PyTorch for CNN-LSTM preprocessing
from torchvision import transforms

# Evaluation
import evaluate

# Reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)


In [None]:
import hashlib
import io
from collections import defaultdict

In [None]:
# Load VQA-RAD dataset
ds = load_dataset("flaviagiammarino/vqa-rad")
print("Dataset Structure:", ds)
print("Sample Train Item:", ds["train"][0])

# Hash images to group multiple questions per image (for multi-turn setup)
def hash_image(img):
    buf = io.BytesIO()
    img.save(buf, format="PNG")
    return hashlib.md5(buf.getvalue()).hexdigest()

image_question_map = defaultdict(list)
for split in ["train", "test"]:
    for item in ds[split]:
        img_hash = hash_image(item["image"])
        image_question_map[img_hash].append(item)

print(f"Unique Images: {len(image_question_map)}")

# Body Part Detection (Anatomical Region Grouping, as in Preliminary Results)
body_parts = {
    "brain": ["brain", "cerebrum", "cerebellum", "ventricle", "cortex"],
    "lung": ["lung", "lungs", "pulmonary", "pleura", "chest"],
    "heart": ["heart", "cardiac", "ventricle", "atrium", "pericardium"],
    "abdomen": ["abdomen", "liver", "kidney", "pancreas", "stomach", "spleen", "intestine", "gallbladder"],
    "pelvis": ["pelvis", "bladder", "prostate", "uterus", "ovary", "pelvic"],
    "spine": ["spine", "vertebra", "cervical", "thoracic", "lumbar", "sacrum"],
    "eye": ["eye", "ocular", "retina", "cornea", "optic"],
    "other": []
}

def detect_body_part(question):
    q = question.lower().translate(str.maketrans("", "", string.punctuation))
    for part, keywords in body_parts.items():
        if any(k in q for k in keywords):
            return part
    return "other"

# Add body part annotations to all samples
for img_hash, qas in image_question_map.items():
    for qa in qas:
        qa["body_part"] = detect_body_part(qa["question"])

Dataset Structure: DatasetDict({
    train: Dataset({
        features: ['image', 'question', 'answer'],
        num_rows: 1793
    })
    test: Dataset({
        features: ['image', 'question', 'answer'],
        num_rows: 451
    })
})
Sample Train Item: {'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=566x555 at 0x7BC50177BC50>, 'question': 'are regions of the brain infarcted?', 'answer': 'yes'}
Unique Images: 314


In [None]:
# ===============================
# Image Hashing & Multi-turn Construction
# ===============================

import io
import hashlib
import string
from collections import defaultdict

# Hash images to group multiple questions per image
def hash_image(img):
    buf = io.BytesIO()
    img.save(buf, format="PNG")
    return hashlib.md5(buf.getvalue()).hexdigest()

# -------------------------------
# Group questions by image hash
# -------------------------------
image_question_map = defaultdict(list)

for split in ["train", "test"]:
    for item in ds[split]:
        img_hash = hash_image(item["image"])
        item = dict(item)          # avoid modifying HF dataset object
        item["split"] = split      # IMPORTANT: store split explicitly
        image_question_map[img_hash].append(item)

print(f"Unique Images: {len(image_question_map)}")

# -------------------------------
# Body Part Detection
# -------------------------------
body_parts = {
    "brain": ["brain", "cerebrum", "cerebellum", "ventricle", "cortex"],
    "lung": ["lung", "lungs", "pulmonary", "pleura", "chest"],
    "heart": ["heart", "cardiac", "ventricle", "atrium", "pericardium"],
    "abdomen": ["abdomen", "liver", "kidney", "pancreas", "stomach", "spleen", "intestine", "gallbladder"],
    "pelvis": ["pelvis", "bladder", "prostate", "uterus", "ovary", "pelvic"],
    "spine": ["spine", "vertebra", "cervical", "thoracic", "lumbar", "sacrum"],
    "eye": ["eye", "ocular", "retina", "cornea", "optic"],
    "other": []
}

def detect_body_part(question):
    q = question.lower().translate(str.maketrans("", "", string.punctuation))
    for part, keywords in body_parts.items():
        if any(k in q for k in keywords):
            return part
    return "other"

# Add body part annotation
for img_hash, qas in image_question_map.items():
    for qa in qas:
        qa["body_part"] = detect_body_part(qa["question"])

# -------------------------------
# Split into single / multi Q
# -------------------------------
single_q = []
multi_q = []

for qas in image_question_map.values():
    if len(qas) == 1:
        single_q.append(qas[0])
    else:
        multi_q.append(qas)

print(f"Single-Question Images: {len(single_q)}")
print(f"Multi-Question Images: {len(multi_q)}")

# -------------------------------
# Keep valid multi-turn samples
# (same body part + same split)
# -------------------------------
valid_multi = []

for qas in multi_q:
    body_parts_set = {qa["body_part"] for qa in qas}
    splits_set = {qa["split"] for qa in qas}
    if len(body_parts_set) == 1 and len(splits_set) == 1:
        valid_multi.append(qas)

print(f"Valid Multi-Question Cases: {len(valid_multi)}")

# -------------------------------
# Create final samples
# -------------------------------
def create_single_turn_samples():
    samples = []
    for qa in single_q:
        samples.append({
            "image": qa["image"],
            "question": qa["question"],
            "answer": qa["answer"],
            "body_part": qa["body_part"],
            "split": qa["split"]
        })
    return samples

def create_multi_turn_samples():
    samples = []
    for qas in valid_multi:
        ordered_qas = sorted(qas, key=lambda x: len(x["question"]))
        samples.append({
            "image": ordered_qas[0]["image"],
            "questions": [q["question"] for q in ordered_qas],
            "answers": [q["answer"] for q in ordered_qas],
            "body_part": ordered_qas[0]["body_part"],
            "split": ordered_qas[0]["split"]
        })
    return samples

single_turn_samples = create_single_turn_samples()
multi_turn_samples = create_multi_turn_samples()

print(f"Final Single-Turn Samples: {len(single_turn_samples)}")
print(f"Final Multi-Turn Samples: {len(multi_turn_samples)}")

# -------------------------------
# Example multi-turn case
# -------------------------------
print("\nExample Multi-Turn Case:")
example = multi_turn_samples[0]
for q, a in zip(example["questions"], example["answers"]):
    print(f"Q: {q}")
    print(f"A: {a}\n")


Unique Images: 314
Single-Question Images: 0
Multi-Question Images: 314
Valid Multi-Question Cases: 37
Final Single-Turn Samples: 0
Final Multi-Turn Samples: 37

Example Multi-Turn Case:
Q: where is the mass?
A: left temporal horn

Q: how would you describe the mass?
A: isointense

Q: is there a fracture of the skull?
A: no

Q: what is the location of the mass?
A: left temporal horn

Q: what are the characteristics of the mass?
A: isointense

Q: are there other abnormalities besides the mass in the temporal horn?
A: yes

Q: besides the mass in the temporal horn, are there other enhancements in the image?
A: yes



In [None]:
# ===============================
# Split into single- and multi-question groups
# ===============================

single_q = []
multi_q = []

for img_hash, qas in image_question_map.items():
    if len(qas) == 1:
        single_q.append(qas[0])
    else:
        multi_q.append(qas)

print(f"Single-Question Images: {len(single_q)}")
print(f"Multi-Question Images: {len(multi_q)}")

# -------------------------------
# Filter valid multi-question cases
# (same body part + same split)
# -------------------------------
valid_multi = []

for qas in multi_q:
    body_parts = {qa["body_part"] for qa in qas}
    splits = {qa["split"] for qa in qas}   # IMPORTANT FIX

    if len(body_parts) == 1 and len(splits) == 1:
        valid_multi.append(qas)

print(f"Valid Multi-Question Cases: {len(valid_multi)}")

# -------------------------------
# Create structured samples
# -------------------------------
def create_single_turn_samples():
    samples = []
    for qa in single_q:
        samples.append({
            "image": qa["image"],
            "question": qa["question"],
            "answer": qa["answer"],
            "body_part": qa["body_part"],
            "split": qa["split"]   # FIXED
        })
    return samples

def create_multi_turn_samples():
    samples = []
    for qas in valid_multi:
        ordered_qas = sorted(qas, key=lambda x: len(x["question"]))
        samples.append({
            "image": ordered_qas[0]["image"],
            "questions": [q["question"] for q in ordered_qas],
            "answers": [q["answer"] for q in ordered_qas],
            "body_part": ordered_qas[0]["body_part"],
            "split": ordered_qas[0]["split"]   # FIXED
        })
    return samples

single_turn_samples = create_single_turn_samples()
multi_turn_samples = create_multi_turn_samples()

print(f"Final Single-Turn Samples: {len(single_turn_samples)}")
print(f"Final Multi-Turn Samples: {len(multi_turn_samples)}")

# -------------------------------
# Example multi-turn case
# -------------------------------
print("\nExample Multi-Turn Case:")
example = multi_turn_samples[0]
for q, a in zip(example["questions"], example["answers"]):
    print(f"Q: {q}")
    print(f"A: {a}\n")


Single-Question Images: 0
Multi-Question Images: 314
Valid Multi-Question Cases: 37
Final Single-Turn Samples: 0
Final Multi-Turn Samples: 37

Example Multi-Turn Case:
Q: where is the mass?
A: left temporal horn

Q: how would you describe the mass?
A: isointense

Q: is there a fracture of the skull?
A: no

Q: what is the location of the mass?
A: left temporal horn

Q: what are the characteristics of the mass?
A: isointense

Q: are there other abnormalities besides the mass in the temporal horn?
A: yes

Q: besides the mass in the temporal horn, are there other enhancements in the image?
A: yes



In [None]:
class SingleTurnDataset(Dataset):
    def __init__(self, samples, processor=None, tokenizer=None, max_seq_len=32):
        self.samples = samples
        self.processor = processor  # For BLIP-2
        self.tokenizer = tokenizer  # For CNN-LSTM (text)
        self.max_seq_len = max_seq_len

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        image = sample["image"]
        question = sample["question"]
        answer = sample["answer"]

        if self.processor:  # BLIP-2
            inputs = self.processor(
                images=image,
                text=question,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=self.max_seq_len
            )
            inputs = {k: v.squeeze(0) for k, v in inputs.items()}
            inputs["labels"] = self.processor.tokenizer(
                answer,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=16
            )["input_ids"].squeeze(0)
            return inputs

        elif self.tokenizer:  # CNN-LSTM
            q_enc = self.tokenizer(
                question,
                padding="max_length",
                truncation=True,
                max_length=self.max_seq_len,
                return_tensors="pt"
            )
            img_tensor = self.processor(image).unsqueeze(0)  # image transform
            return {
                "image": img_tensor.squeeze(0),
                "q_ids": q_enc["input_ids"].squeeze(0),
                "q_mask": q_enc["attention_mask"].squeeze(0),
                "answer": answer,
                "body_part": sample["body_part"]
            }


class MultiTurnDataset(Dataset):
    def __init__(self, samples, processor, max_seq_len=32, max_turns=3):
        self.samples = samples
        self.processor = processor
        self.max_seq_len = max_seq_len
        self.max_turns = max_turns

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        image = sample["image"]
        questions = sample["questions"][:self.max_turns]
        answers = sample["answers"][:self.max_turns]

        # Build conversation history
        conversation = ""
        if len(questions) > 1:
            for q, a in zip(questions[:-1], answers[:-1]):
                conversation += f"Q: {q} A: {a} "
        conversation += f"Q: {questions[-1]} A:"

        # Encode for BLIP-2
        inputs = self.processor(
            images=image,
            text=conversation,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=min(self.max_seq_len * self.max_turns, 512)
        )
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        # Encode target answer
        inputs["labels"] = self.processor.tokenizer(
            answers[-1],
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=16
        )["input_ids"].squeeze(0)
        return inputs


In [None]:
# CNN-LSTM preprocessors
cnn_image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# BLIP-1 preprocessor
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")

# Split train/test data
train_single = [s for s in single_turn_samples if s["split"] == "train"]
test_single = [s for s in single_turn_samples if s["split"] == "test"]
train_multi = [s for s in multi_turn_samples if s["split"] == "train"]
test_multi = [s for s in multi_turn_samples if s["split"] == "test"]

print(f"Single-Turn: Train={len(train_single)}, Test={len(test_single)}")
print(f"Multi-Turn: Train={len(train_multi)}, Test={len(test_multi)}")

# Filter closed-ended questions for CNN-LSTM
def filter_closed_ended(samples):
    closed_answers = ["yes", "no", "present", "absent", "normal", "abnormal"]
    return [s for s in samples if s["answer"].strip().lower() in closed_answers]

# Create datasets
train_cnn = SingleTurnDataset(
    filter_closed_ended(train_single),
    processor=cnn_image_transform,
    tokenizer=bert_tokenizer,
    max_seq_len=32
)
test_cnn = SingleTurnDataset(
    filter_closed_ended(test_single),
    processor=cnn_image_transform,
    tokenizer=bert_tokenizer,
    max_seq_len=32
)

train_blip_single = SingleTurnDataset(
    train_single,
    processor=blip_processor,
    max_seq_len=32
)
test_blip_single = SingleTurnDataset(
    test_single,
    processor=blip_processor,
    max_seq_len=32
)

train_blip_multi = MultiTurnDataset(
    train_multi,
    processor=blip_processor,
    max_seq_len=32,
    max_turns=3
)
test_blip_multi = MultiTurnDataset(
    test_multi,
    processor=blip_processor,
    max_seq_len=32,
    max_turns=3
)


Single-Turn: Train=0, Test=0
Multi-Turn: Train=36, Test=1


In [None]:
# CNN-LSTM preprocessors
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from transformers import BertTokenizer

# Image transform for CNN
cnn_image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# BERT tokenizer for questions
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Split train/test data
train_single = [s for s in single_turn_samples if s["split"] == "train"]
test_single = [s for s in single_turn_samples if s["split"] == "test"]

print(f"Single-Turn: Train={len(train_single)}, Test={len(test_single)}")

# Filter closed-ended questions for CNN-LSTM
def filter_closed_ended(samples):
    closed_answers = ["yes", "no", "present", "absent", "normal", "abnormal"]
    filtered = [s for s in samples if s["answer"].strip().lower() in closed_answers]
    return filtered

train_cnn_samples = filter_closed_ended(train_single)
test_cnn_samples = filter_closed_ended(test_single)

print(f"Filtered Train CNN samples: {len(train_cnn_samples)}")
print(f"Filtered Test CNN samples: {len(test_cnn_samples)}")

if len(train_cnn_samples) == 0 or len(test_cnn_samples) == 0:
    print("Warning: No closed-ended answers found in your dataset. CNN-LSTM training will be skipped.")
else:
    # Create SingleTurnDataset
    train_cnn = SingleTurnDataset(
        train_cnn_samples,
        processor=cnn_image_transform,
        tokenizer=bert_tokenizer,
        max_seq_len=32
    )
    test_cnn = SingleTurnDataset(
        test_cnn_samples,
        processor=cnn_image_transform,
        tokenizer=bert_tokenizer,
        max_seq_len=32
    )

    # Prepare answer vocabulary
    answer_vocab = list({s["answer"].strip().lower() for s in train_cnn_samples})
    num_classes = len(answer_vocab)
    ans_to_idx = {a: i for i, a in enumerate(answer_vocab)}
    idx_to_ans = {i: a for i, a in enumerate(answer_vocab)}

    print(f"Closed-Ended Answer Vocab: {answer_vocab}")
    print(f"Number of Classes: {num_classes}")

    # Custom DataLoader for CNN-LSTM
    class CNNDataLoader(DataLoader):
        def __iter__(self):
            for sample in self.dataset:
                sample["ans_idx"] = torch.tensor(ans_to_idx[sample["answer"].strip().lower()])
                yield sample

    # Create DataLoaders
    train_cnn_loader = CNNDataLoader(train_cnn, batch_size=8, shuffle=True)
    test_cnn_loader = CNNDataLoader(test_cnn, batch_size=8, shuffle=False)



Single-Turn: Train=0, Test=0
Filtered Train CNN samples: 0
Filtered Test CNN samples: 0


In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import BlipProcessor, BlipForConditionalGeneration
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# --- 1. Load BLIP-1 processor and model ---
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained(
    "Salesforce/blip-image-captioning-base",
    torch_dtype=torch.float16 if device.type == "cuda" else torch.float32
)
model.to(device)

# --- 2. Dataset classes ---
class MultiTurnDataset(torch.utils.data.Dataset):
    def __init__(self, samples, processor, max_seq_len=32, max_turns=3):
        self.samples = samples
        self.processor = processor
        self.max_seq_len = max_seq_len
        self.max_turns = max_turns

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        s = self.samples[idx]
        img = s["image"].resize((128, 128))
        questions = s["questions"][:self.max_turns]
        answers = s["answers"][:self.max_turns]

        conversation = ""
        for q, a in zip(questions[:-1], answers[:-1]):
            conversation += f"Q: {q} A: {a} "
        conversation += f"Q: {questions[-1]} A:"

        # Process input text + image
        inputs = self.processor(
            images=s["image"],
            text=conversation,
            return_tensors="pt",
            padding=False,
            truncation=True,
            max_length=self.max_seq_len * self.max_turns
        )
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}

        # Process labels
        labels = self.processor.tokenizer(
            answers[-1],
            return_tensors="pt",
            padding=False,
            truncation=True,
            max_length=self.max_seq_len
        )["input_ids"].squeeze(0)
        labels[labels == self.processor.tokenizer.pad_token_id] = -100

        inputs["labels"] = labels
        return inputs

# --- 3. Manual collate_fn for variable-length batching ---
def collate_fn(batch):
    # Stack images
    pixel_values = torch.stack([b["pixel_values"] for b in batch])

    # Pad input_ids, attention_mask, labels manually
    input_ids = [b["input_ids"] for b in batch]
    attention_mask = [b["attention_mask"] for b in batch]
    labels = [b["labels"] for b in batch]

    # Find max lengths
    max_len_input = max([x.size(0) for x in input_ids])
    max_len_labels = max([x.size(0) for x in labels])

    # Pad inputs
    input_ids_padded = torch.stack([
        torch.cat([x, x.new_zeros(max_len_input - x.size(0))]) for x in input_ids
    ])
    attention_mask_padded = torch.stack([
        torch.cat([x, x.new_zeros(max_len_input - x.size(0))]) for x in attention_mask
    ])
    # Pad labels with -100
    labels_padded = torch.stack([
        torch.cat([x, x.new_full((max_len_labels - x.size(0),), -100)]) for x in labels
    ])

    return {
        "pixel_values": pixel_values,
        "input_ids": input_ids_padded,
        "attention_mask": attention_mask_padded,
        "labels": labels_padded
    }

# --- 4. Prepare datasets and loaders ---
multi_turn_train_samples = [s for s in multi_turn_samples if s["split"]=="train"]
multi_turn_test_samples  = [s for s in multi_turn_samples if s["split"]=="test"]

train_multi_ds = MultiTurnDataset(multi_turn_train_samples, processor)
test_multi_ds  = MultiTurnDataset(multi_turn_test_samples, processor)

train_multi_loader = DataLoader(train_multi_ds, batch_size=4, shuffle=True, collate_fn=collate_fn)
test_multi_loader  = DataLoader(test_multi_ds, batch_size=4, shuffle=False, collate_fn=collate_fn)

# --- 5. Training loop ---
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
epochs = 3

for epoch in range(epochs):
    model.train()
    train_loss = 0.0

    for batch in tqdm(train_multi_loader, desc=f"Epoch {epoch+1} Multi-Turn Train"):
        batch = {k:v.to(device) for k,v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_multi_loader)
    print(f"Epoch {epoch+1} average training loss: {avg_train_loss:.4f}")

# --- 6. Save model ---
torch.save(model.state_dict(), "blip_medvqa.pth")
print("\nBLIP model saved as blip_medvqa.pth")


Using device: cpu


Epoch 1 Multi-Turn Train:   0%|          | 0/9 [00:26<?, ?it/s]


ValueError: Expected input batch_size (156) to match target batch_size (8).