# 🧠 Visual Question Answering with ViLT on MEDPIX-ClinQA

This notebook demonstrates a Visual Question Answering (VQA) pipeline using the [MEDPIX-ClinQA](https://huggingface.co/datasets/adishourya/MEDPIX-ClinQA) dataset, where the model learns to answer clinically relevant questions based on medical images.

We fine-tune the [`dandelin/vilt-b32-mlm`](https://huggingface.co/dandelin/vilt-b32-mlm) vision-language model on this task, treating VQA as a **multi-label classification** problem. Each answer is mapped to an integer label, and label weights are used to reflect answer frequencies where applicable.

**🔍 Key Information:**
- **Model:** `dandelin/vilt-b32-mlm` (ViLT - Vision-and-Language Transformer)
- **Task:** Visual Question Answering (VQA)
- **Dataset:** `adishourya/MEDPIX-ClinQA` from Hugging Face 🤗
- **Preprocessing:** Image normalization, question tokenization, label ID mapping
- **Objective:** Predict medically accurate answers from paired image-question inputs

This project serves as a practical application of ViLT for multimodal clinical reasoning and can be adapted for other VQA datasets or tasks.


In [None]:
# Initialize huggingface env
from huggingface_hub import notebook_login

notebook_login()

### Import dataset and preprocess

In [None]:
# load dataset
from datasets import load_dataset

ds = load_dataset("adishourya/MEDPIX-ClinQA")

In [None]:
# Take a subset of 5000 examples
ds = ds["train"].select(range(5000))

In [None]:
ds

In [None]:
ds[0]

In [None]:
# clear unnecessary columns
dataset = ds.remove_columns(['mode', 'case_id',])
dataset[0]

In [None]:
# Visualize an image
from IPython.display import display

image = dataset[0]['image_id']
display(image)

In [None]:
# Step 1: Collect all unique answers
all_answers = set(example["answer"] for example in dataset)

# Step 2: Create label mappings
label2id = {label: idx for idx, label in enumerate(sorted(all_answers))}
id2label = {idx: label for label, idx in label2id.items()}

# # if want short answer
# def summarize_answer(answer):
#     return answer.split('.')[0].strip()  # take first sentence or phrase

# short_labels = set(summarize_answer(example["answer"]) for example in dataset)
# label2id = {label: idx for idx, label in enumerate(sorted(short_labels))}
# id2label = {idx: label for label, idx in label2id.items()}



In [None]:
# id2label

In [None]:
# replacing all the answer to ids for training purpose
def replace_answers_with_ids(example):
    example["label"] = label2id[example["answer"]]
    return example

dataset = dataset.map(replace_answers_with_ids)


In [None]:
dataset.features

### Import Vilt processor


In [None]:
# importing Vilt processor
from transformers import ViltProcessor

model_name = "dandelin/vilt-b32-mlm"
processor = ViltProcessor.from_pretrained(model_name)

To preprocess the data we need to encode the images and questions using the ViltProcessor. The processor will use the BertTokenizerFast to tokenize the text and create input_ids, attention_mask and token_type_ids for the text data. As for images, the processor will leverage ViltImageProcessor to resize and normalize the image, and create pixel_values and pixel_mask.

In [None]:
import torch


def preprocess_data(batch):
    # Convert each image to RGB
    images = [img.convert("RGB") for img in batch['image_id']]
    texts = batch['question']

    # Tokenize with the processor
    encoding = processor(images, texts, padding="max_length", truncation=True, return_tensors="pt")

    # Remove batch dimension manually
    for k, v in encoding.items():
        encoding[k] = v

    # Create soft labels
    targets = []
    for label in batch['label']:
        target = torch.zeros(len(id2label))
        target[label] = 1.0
        targets.append(target)

    encoding["labels"] = targets
    return encoding


In [None]:
# mapping the preprocess with the dataset
processed_dataset = dataset.map(
    preprocess_data,
    batched=True,
    batch_size=5,
    remove_columns=['image_id', 'question', 'answer']
)

processed_dataset

#### Adding a datacollator

In [None]:
from transformers import DefaultDataCollator

data_collator = DefaultDataCollator()

## Import Model

In [None]:
from transformers import ViltForQuestionAnswering

model = ViltForQuestionAnswering.from_pretrained(model_name, num_labels=len(id2label), id2label=id2label, label2id=label2id)

#### Importing Training Arguments

In [None]:
from transformers import TrainingArguments

repo_id = "Tamal/vilt_finetuned_5000"

training_args = TrainingArguments(
    output_dir=repo_id,
    per_device_train_batch_size=5,
    num_train_epochs=20,
    save_steps=200,
    logging_steps=50,
    learning_rate=5e-5,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=True,
)

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=processed_dataset,
    processing_class=processor,
)

In [None]:
trainer.train()

# Inference

In [None]:
processor = ViltProcessor.from_pretrained("Tamal/vilt_finetuned_5000")

image = Image.open(example['image_id'])
question = example['question']

# prepare inputs
inputs = processor(image, question, return_tensors="pt")

model = ViltForQuestionAnswering.from_pretrained("Tamal/vilt_finetuned_5000")

# forward pass
with torch.no_grad():
    outputs = model(**inputs)

logits = outputs.logits
idx = logits.argmax(-1).item()
print("Predicted answer:", model.config.id2label[idx])

## Zero shot answering question
Earlier models treated Visual Question Answering (VQA) as a classification problem, while newer models like BLIP-2 approach it as a generative task using vision-language pretraining. BLIP-2 allows combining any vision encoder with an LLM, achieving state-of-the-art results on tasks like VQA.

In [None]:
from transformers import AutoProcessor, Blip2ForConditionalGeneration
import torch
from accelerate.test_utils.testing import get_backend

processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
device, _, _ = get_backend() # automatically detects the underlying device type (CUDA, CPU, XPU, MPS, etc.)
model.to(device)

In [None]:
from PIL import Image

image = example['image_id'] 
question = example['question']


In [None]:
prompt = f"Question: {question} Answer:"

In [None]:
inputs = processor(image, text=prompt, return_tensors="pt").to(device, torch.float16)

generated_ids = model.generate(**inputs, max_new_tokens=10)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)