# Med-VQA Final Project (VQA-RAD)

This notebook trains and evaluates **two methods** on VQA-RAD:

1) **Baseline**: ResNet-50 (image) + LSTM (question) + concat classifier  
2) **VLM**: **ViLT** + classification head

All outputs are saved to `runs/<timestamp>/`:
- `baseline_best.pt`, `vilt_best.pt`
- `baseline_predictions.csv`, `vilt_predictions.csv`
- `comparison.csv` (final table for report)


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# (Optional) Install dependencies
# If you already installed them in your env, you can skip this cell.
!pip -q install -U datasets transformers accelerate evaluate scikit-learn torchvision


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m512.3/512.3 kB[0m [31m22.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.0/12.0 MB[0m [31m123.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.9/8.9 MB[0m [31m118.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.0/8.0 MB[0m [31m90.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m899.7/899.7 MB[0m [31m613.0 kB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m594.3/594.3 MB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
!pip -q install -U --force-reinstall "pillow==11.3.0"


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/6.6 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/6.6 MB[0m [31m47.7 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━[0m [32m4.4/6.6 MB[0m [31m58.7 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m6.6/6.6 MB[0m [31m61.6 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m6.6/6.6 MB[0m [31m61.6 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.6/6.6 MB[0m [31m37.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
!pip install -U "pillow==11.3.0"




In [None]:
import os, json, time, random
from collections import Counter
from typing import Dict

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
import torchvision.transforms as T

from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer, ViltProcessor, ViltModel, get_cosine_schedule_with_warmup
from sklearn.metrics import f1_score

def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


ValueError: pyarrow.lib.IpcReadOptions size changed, may indicate binary incompatibility. Expected 112 from C header, got 104 from PyObject

In [None]:
# ======================
# Config
# ======================
RUN_NAME = time.strftime("%Y%m%d_%H%M%S")
OUT_DIR = f"runs/{RUN_NAME}"
os.makedirs(OUT_DIR, exist_ok=True)
print("OUT_DIR:", OUT_DIR)

TOP_N = 200
MAX_Q_LEN = 32
BATCH_SIZE = 32
NUM_WORKERS = 2

BASE_EPOCHS = 8
BASE_LR = 2e-4

VILT_EPOCHS = 6
VILT_LR = 2e-5
WARMUP_RATIO = 0.1

LSTM_HIDDEN = 512
TXT_EMB_DIM = 300
DROPOUT = 0.2


OUT_DIR: runs/20260110_103856


In [None]:
# ======================
# Load VQA-RAD (auto fallback)
# ======================
def try_load_vqarad() -> DatasetDict:
    candidates = [
        "HongyiPeng/VQA-RAD",
        "flaviagiammarino/vqa-rad",
    ]
    last_err = None
    for name in candidates:
        try:
            ds = load_dataset(name)
            print(f"Loaded dataset: {name}")
            return ds
        except Exception as e:
            print(f"Failed: {name} -> {repr(e)}")
            last_err = e
    raise RuntimeError(f"Could not load VQA-RAD. Last error: {last_err}")

ds_raw = try_load_vqarad()
print(ds_raw)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/474 [00:00<?, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/42.2M [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/12.8M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1793 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/451 [00:00<?, ? examples/s]

Loaded dataset: HongyiPeng/VQA-RAD
DatasetDict({
    train: Dataset({
        features: ['image', 'question', 'answer'],
        num_rows: 1793
    })
    test: Dataset({
        features: ['image', 'question', 'answer'],
        num_rows: 451
    })
})


In [None]:
# ======================
# Standardize fields + split train/val/test
# (No stratify_by_column to avoid ClassLabel constraint)
# ======================
def normalize_answer(a: str) -> str:
    a = str(a).strip().lower()
    if a in ["y", "yes", "true"]: return "yes"
    if a in ["n", "no", "false"]: return "no"
    return a

def answer_type(a: str) -> str:
    a = normalize_answer(a)
    return "closed" if a in ["yes", "no"] else "open"

def standardize_split(split):
    keys = split.column_names

    def pick_col(cands):
        for c in cands:
            if c in keys:
                return c
        return None

    img_col = pick_col(["image", "img", "Image", "ImagePath", "path"])
    q_col   = pick_col(["question", "Question", "query", "text"])
    a_col   = pick_col(["answer", "Answer", "label", "ans"])

    if img_col is None or q_col is None or a_col is None:
        raise ValueError(f"Cannot infer columns from keys={keys}")

    def _map(ex):
        return {
            "image": ex[img_col],
            "question": str(ex[q_col]),
            "answer": normalize_answer(ex[a_col]),
            "answer_type": answer_type(ex[a_col]),
        }

    return split.map(_map, remove_columns=keys)

std = {}
for k in ds_raw.keys():
    std[k] = standardize_split(ds_raw[k])
ds = DatasetDict(std)

if "validation" not in ds:
    split = ds["train"].train_test_split(test_size=0.15, seed=42)
    ds = DatasetDict({
        "train": split["train"],
        "validation": split["test"],
        "test": ds["test"] if "test" in ds else split["test"]
    })

print(ds)
print("Sizes:", {k: len(ds[k]) for k in ds.keys()})


Map:   0%|          | 0/1793 [00:00<?, ? examples/s]

Map:   0%|          | 0/451 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['image', 'question', 'answer', 'answer_type'],
        num_rows: 1524
    })
    validation: Dataset({
        features: ['image', 'question', 'answer', 'answer_type'],
        num_rows: 269
    })
    test: Dataset({
        features: ['image', 'question', 'answer', 'answer_type'],
        num_rows: 451
    })
})
Sizes: {'train': 1524, 'validation': 269, 'test': 451}


In [None]:
# ======================
# Build answer space: Top-N answers + "other"
# ======================
train_answers = list(ds["train"]["answer"])
counter = Counter(train_answers)
top_answers = [a for a,_ in counter.most_common(TOP_N)]
answer2id = {a:i for i,a in enumerate(top_answers)}
OTHER_ID = len(answer2id)
id2answer = {i:a for a,i in answer2id.items()}
id2answer[OTHER_ID] = "other"
num_classes = OTHER_ID + 1

def encode_label(ex):
    ex["label"] = answer2id.get(ex["answer"], OTHER_ID)
    return ex

ds = ds.map(encode_label)

with open(os.path.join(OUT_DIR, "answer2id.json"), "w", encoding="utf-8") as f:
    json.dump(answer2id, f, ensure_ascii=False, indent=2)
with open(os.path.join(OUT_DIR, "id2answer.json"), "w", encoding="utf-8") as f:
    json.dump({str(k):v for k,v in id2answer.items()}, f, ensure_ascii=False, indent=2)

print("num_classes:", num_classes, "OTHER_ID:", OTHER_ID)
print("Top answers:", top_answers[:10])


Map:   0%|          | 0/1524 [00:00<?, ? examples/s]

Map:   0%|          | 0/269 [00:00<?, ? examples/s]

Map:   0%|          | 0/451 [00:00<?, ? examples/s]

num_classes: 201 OTHER_ID: 200
Top answers: ['no', 'yes', 'axial', 'right', 'left', 'pa', 'ct', 'fat', 'diffuse', 'left kidney']


In [None]:
# ======================
# Metrics: overall / closed / open / yes-no F1
# ======================
def compute_metrics(df: pd.DataFrame) -> Dict[str, float]:
    overall_acc = float((df["answer"] == df["pred_answer"]).mean())

    closed_df = df[df["answer_type"] == "closed"]
    open_df   = df[df["answer_type"] == "open"]
    closed_acc = float((closed_df["answer"] == closed_df["pred_answer"]).mean()) if len(closed_df) else float("nan")
    open_acc   = float((open_df["answer"] == open_df["pred_answer"]).mean()) if len(open_df) else float("nan")

    yn_df = df[df["answer"].isin(["yes","no"])]
    if len(yn_df):
        y_true = (yn_df["answer"] == "yes").astype(int).values
        y_pred = (yn_df["pred_answer"] == "yes").astype(int).values
        yn_f1 = float(f1_score(y_true, y_pred))
    else:
        yn_f1 = float("nan")

    return {
        "overall_acc": overall_acc,
        "closed_acc": closed_acc,
        "open_acc": open_acc,
        "yesno_f1": yn_f1,
        "n": int(len(df)),
        "n_closed": int(len(closed_df)),
        "n_open": int(len(open_df)),
    }


In [None]:
# ======================
# Baseline Dataset & DataLoaders (ResNet+LSTM)
# ======================
tokenizer_base = AutoTokenizer.from_pretrained("bert-base-uncased")

img_tf = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224,224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

def to_pil(img):
    if isinstance(img, Image.Image):
        return img.convert("RGB")
    if isinstance(img, dict):
        if "bytes" in img and img["bytes"] is not None:
            import io
            return Image.open(io.BytesIO(img["bytes"])).convert("RGB")
        if "path" in img and img["path"] is not None:
            return Image.open(img["path"]).convert("RGB")
    if isinstance(img, str):
        return Image.open(img).convert("RGB")
    try:
        return img.convert("RGB")
    except Exception:
        return Image.fromarray(np.array(img)).convert("RGB")

class VQABaselineTorch(Dataset):
    def __init__(self, hf_split):
        self.data = hf_split
    def __len__(self): return len(self.data)
    def __getitem__(self, idx):
        ex = self.data[idx]
        return {
            "image": to_pil(ex["image"]),
            "question": ex["question"],
            "answer": ex["answer"],
            "answer_type": ex["answer_type"],
            "label": int(ex["label"]),
        }

def collate_baseline(batch):
    images = torch.stack([img_tf(b["image"]) for b in batch], dim=0)
    toks = tokenizer_base(
        [b["question"] for b in batch],
        padding=True,
        truncation=True,
        max_length=MAX_Q_LEN,
        return_tensors="pt"
    )
    labels = torch.tensor([b["label"] for b in batch], dtype=torch.long)
    meta = {
        "answer": [b["answer"] for b in batch],
        "answer_type": [b["answer_type"] for b in batch],
        "question": [b["question"] for b in batch],
    }
    return images, toks["input_ids"], toks["attention_mask"], labels, meta

train_loader = DataLoader(VQABaselineTorch(ds["train"]), batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, collate_fn=collate_baseline)
val_loader   = DataLoader(VQABaselineTorch(ds["validation"]), batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, collate_fn=collate_baseline)
test_loader  = DataLoader(VQABaselineTorch(ds["test"]), batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, collate_fn=collate_baseline)

print("Batches:", len(train_loader), len(val_loader), len(test_loader))


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Batches: 48 9 15


In [None]:
# ======================
# Baseline Model: ResNet50 + LSTM + concat
# ======================
class ResNetEncoder(nn.Module):
    def __init__(self, out_dim=512):
        super().__init__()
        weights = torchvision.models.ResNet50_Weights.DEFAULT
        backbone = torchvision.models.resnet50(weights=weights)
        self.backbone = nn.Sequential(*list(backbone.children())[:-1])
        self.proj = nn.Linear(2048, out_dim)
    def forward(self, x):
        feat = self.backbone(x).flatten(1)
        return self.proj(feat)

class BaselineCNNLSTM(nn.Module):
    def __init__(self, vocab_size, num_classes, txt_emb_dim=300, lstm_hidden=512, img_dim=512, dropout=0.2, pad_id=0):
        super().__init__()
        self.img_enc = ResNetEncoder(out_dim=img_dim)
        self.emb = nn.Embedding(vocab_size, txt_emb_dim, padding_idx=pad_id)
        self.lstm = nn.LSTM(txt_emb_dim, lstm_hidden, batch_first=True)
        self.fc = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(img_dim + lstm_hidden, 512),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512, num_classes),
        )
    def forward(self, images, input_ids, attn_mask):
        img_feat = self.img_enc(images)
        x = self.emb(input_ids)
        lengths = attn_mask.sum(dim=1).cpu()
        packed = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        _, (h, _) = self.lstm(packed)
        txt_feat = h[-1]
        return self.fc(torch.cat([img_feat, txt_feat], dim=1))

baseline = BaselineCNNLSTM(
    vocab_size=tokenizer_base.vocab_size,
    num_classes=num_classes,
    txt_emb_dim=TXT_EMB_DIM,
    lstm_hidden=LSTM_HIDDEN,
    img_dim=LSTM_HIDDEN,
    dropout=DROPOUT,
    pad_id=tokenizer_base.pad_token_id,
).to(device)

baseline


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


100%|██████████| 97.8M/97.8M [00:01<00:00, 67.1MB/s]


BaselineCNNLSTM(
  (img_enc): ResNetEncoder(
    (backbone): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=T

In [None]:
# ======================
# Baseline Train/Eval
# ======================
@torch.no_grad()
def predict_baseline(model, loader) -> pd.DataFrame:
    model.eval()
    rows = []
    for images, input_ids, attn_mask, labels, meta in loader:
        images = images.to(device)
        input_ids = input_ids.to(device)
        attn_mask = attn_mask.to(device)
        logits = model(images, input_ids, attn_mask)
        pred = logits.argmax(dim=1).cpu().tolist()
        for i, pl in enumerate(pred):
            rows.append({
                "question": meta["question"][i],
                "answer": meta["answer"][i],
                "answer_type": meta["answer_type"][i],
                "label": int(labels[i].item()),
                "pred_label": int(pl),
                "pred_answer": id2answer.get(pl, "other"),
            })
    return pd.DataFrame(rows)

def train_baseline(model):
    opt = torch.optim.AdamW(model.parameters(), lr=BASE_LR)
    total_steps = BASE_EPOCHS * len(train_loader)
    sched = get_cosine_schedule_with_warmup(opt, int(0.1 * total_steps), total_steps)

    best_val = -1.0
    history = []

    for epoch in range(1, BASE_EPOCHS + 1):
        model.train()
        total_loss = 0.0
        for images, input_ids, attn_mask, labels, _ in train_loader:
            images = images.to(device)
            input_ids = input_ids.to(device)
            attn_mask = attn_mask.to(device)
            labels = labels.to(device)

            opt.zero_grad(set_to_none=True)
            logits = model(images, input_ids, attn_mask)
            loss = F.cross_entropy(logits, labels)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            sched.step()
            total_loss += loss.item()

        train_loss = total_loss / max(1, len(train_loader))
        val_df = predict_baseline(model, val_loader)
        val_m = compute_metrics(val_df)

        history.append({"epoch": epoch, "train_loss": float(train_loss), **{f"val_{k}": v for k,v in val_m.items()}})

        print(f"[Baseline] epoch {epoch}/{BASE_EPOCHS} loss={train_loss:.4f} "
              f"val_acc={val_m['overall_acc']:.4f} closed={val_m['closed_acc']:.4f} "
              f"open={val_m['open_acc']:.4f} f1={val_m['yesno_f1']:.4f}")

        if val_m["overall_acc"] > best_val:
            best_val = val_m["overall_acc"]
            torch.save(model.state_dict(), os.path.join(OUT_DIR, "baseline_best.pt"))

    with open(os.path.join(OUT_DIR, "baseline_history.json"), "w") as f:
        json.dump(history, f, indent=2)

train_baseline(baseline)


[Baseline] epoch 1/8 loss=4.4396 val_acc=0.2565 closed=0.5074 open=0.0000 f1=0.2947
[Baseline] epoch 2/8 loss=3.0911 val_acc=0.2602 closed=0.5147 open=0.0000 f1=0.3125
[Baseline] epoch 3/8 loss=2.8145 val_acc=0.3086 closed=0.6103 open=0.0000 f1=0.6905
[Baseline] epoch 4/8 loss=2.3443 val_acc=0.3123 closed=0.6176 open=0.0000 f1=0.6795
[Baseline] epoch 5/8 loss=2.1475 val_acc=0.3309 closed=0.6471 open=0.0075 f1=0.6887
[Baseline] epoch 6/8 loss=1.9863 val_acc=0.3346 closed=0.6618 open=0.0000 f1=0.7143
[Baseline] epoch 7/8 loss=1.8751 val_acc=0.3346 closed=0.6544 open=0.0075 f1=0.6939
[Baseline] epoch 8/8 loss=1.8273 val_acc=0.3457 closed=0.6765 open=0.0075 f1=0.7075


In [None]:
# ======================
# Baseline Test + Export
# ======================
baseline.load_state_dict(torch.load(os.path.join(OUT_DIR, "baseline_best.pt"), map_location=device))
test_df_base = predict_baseline(baseline, test_loader)
metrics_base = compute_metrics(test_df_base)
print("Baseline test metrics:", metrics_base)

test_df_base.to_csv(os.path.join(OUT_DIR, "baseline_predictions.csv"), index=False)
test_df_base[test_df_base["answer"] != test_df_base["pred_answer"]].to_csv(os.path.join(OUT_DIR, "baseline_errors.csv"), index=False)
with open(os.path.join(OUT_DIR, "baseline_test_metrics.json"), "w") as f:
    json.dump(metrics_base, f, indent=2)


Baseline test metrics: {'overall_acc': 0.36807095343680707, 'closed_acc': 0.6454183266932271, 'open_acc': 0.02, 'yesno_f1': 0.6451612903225806, 'n': 451, 'n_closed': 251, 'n_open': 200}


In [None]:
# ======================
# ViLT DataLoaders
# ======================
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")

class VQAViltTorch(Dataset):
    def __init__(self, hf_split):
        self.data = hf_split
    def __len__(self): return len(self.data)
    def __getitem__(self, idx):
        ex = self.data[idx]
        return {
            "image": to_pil(ex["image"]),
            "question": ex["question"],
            "answer": ex["answer"],
            "answer_type": ex["answer_type"],
            "label": int(ex["label"]),
        }

def collate_vilt(batch):
    enc = processor(
        images=[b["image"] for b in batch],
        text=[b["question"] for b in batch],
        padding=True,
        truncation=True,
        return_tensors="pt"
    )
    labels = torch.tensor([b["label"] for b in batch], dtype=torch.long)
    meta = {
        "answer": [b["answer"] for b in batch],
        "answer_type": [b["answer_type"] for b in batch],
        "question": [b["question"] for b in batch],
    }
    return enc, labels, meta

train_loader_v = DataLoader(VQAViltTorch(ds["train"]), batch_size=BATCH_SIZE, shuffle=True,
                            num_workers=NUM_WORKERS, collate_fn=collate_vilt)
val_loader_v   = DataLoader(VQAViltTorch(ds["validation"]), batch_size=BATCH_SIZE, shuffle=False,
                            num_workers=NUM_WORKERS, collate_fn=collate_vilt)
test_loader_v  = DataLoader(VQAViltTorch(ds["test"]), batch_size=BATCH_SIZE, shuffle=False,
                            num_workers=NUM_WORKERS, collate_fn=collate_vilt)

print("Batches:", len(train_loader_v), len(val_loader_v), len(test_loader_v))


preprocessor_config.json:   0%|          | 0.00/251 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/320 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Batches: 48 9 15


In [None]:
# ======================
# ViLT Model + head
# ======================
class ViltClassifier(nn.Module):
    def __init__(self, num_classes: int, dropout=0.2):
        super().__init__()
        self.vilt = ViltModel.from_pretrained("dandelin/vilt-b32-mlm")
        hidden = self.vilt.config.hidden_size
        self.head = nn.Sequential(nn.Dropout(dropout), nn.Linear(hidden, num_classes))
    def forward(self, enc):
        out = self.vilt(**enc)
        return self.head(out.pooler_output)

vilt = ViltClassifier(num_classes=num_classes, dropout=DROPOUT).to(device)
vilt


config.json:   0%|          | 0.00/653 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/543M [00:00<?, ?B/s]

ViltClassifier(
  (vilt): ViltModel(
    (embeddings): ViltEmbeddings(
      (text_embeddings): TextEmbeddings(
        (word_embeddings): Embedding(30522, 768)
        (position_embeddings): Embedding(40, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (patch_embeddings): ViltPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
      )
      (token_type_embeddings): Embedding(2, 768)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViltEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViltLayer(
          (attention): ViltAttention(
            (attention): ViltSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_fea

In [None]:
# ======================
# ViLT Train/Eval
# ======================
@torch.no_grad()
def predict_vilt(model, loader) -> pd.DataFrame:
    model.eval()
    rows = []
    for enc, labels, meta in loader:
        enc = {k:v.to(device) for k,v in enc.items()}
        logits = model(enc)
        pred = logits.argmax(dim=1).cpu().tolist()
        for i, pl in enumerate(pred):
            rows.append({
                "question": meta["question"][i],
                "answer": meta["answer"][i],
                "answer_type": meta["answer_type"][i],
                "label": int(labels[i].item()),
                "pred_label": int(pl),
                "pred_answer": id2answer.get(pl, "other"),
            })
    return pd.DataFrame(rows)

def train_vilt(model):
    opt = torch.optim.AdamW(model.parameters(), lr=VILT_LR)
    total_steps = VILT_EPOCHS * len(train_loader_v)
    warmup_steps = int(WARMUP_RATIO * total_steps)
    sched = get_cosine_schedule_with_warmup(opt, warmup_steps, total_steps)

    best_val = -1.0
    history = []

    for epoch in range(1, VILT_EPOCHS + 1):
        model.train()
        total_loss = 0.0
        for enc, labels, _ in train_loader_v:
            enc = {k:v.to(device) for k,v in enc.items()}
            labels = labels.to(device)

            opt.zero_grad(set_to_none=True)
            logits = model(enc)
            loss = F.cross_entropy(logits, labels)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            sched.step()
            total_loss += loss.item()

        train_loss = total_loss / max(1, len(train_loader_v))
        val_df = predict_vilt(model, val_loader_v)
        val_m = compute_metrics(val_df)

        history.append({"epoch": epoch, "train_loss": float(train_loss), **{f"val_{k}": v for k,v in val_m.items()}})

        print(f"[ViLT] epoch {epoch}/{VILT_EPOCHS} loss={train_loss:.4f} "
              f"val_acc={val_m['overall_acc']:.4f} closed={val_m['closed_acc']:.4f} "
              f"open={val_m['open_acc']:.4f} f1={val_m['yesno_f1']:.4f}")

        if val_m["overall_acc"] > best_val:
            best_val = val_m["overall_acc"]
            torch.save(model.state_dict(), os.path.join(OUT_DIR, "vilt_best.pt"))

    with open(os.path.join(OUT_DIR, "vilt_history.json"), "w") as f:
        json.dump(history, f, indent=2)

train_vilt(vilt)


[ViLT] epoch 1/6 loss=5.0494 val_acc=0.2788 closed=0.5515 open=0.0000 f1=0.7109
[ViLT] epoch 2/6 loss=3.4391 val_acc=0.3086 closed=0.6103 open=0.0000 f1=0.6667
[ViLT] epoch 3/6 loss=2.6989 val_acc=0.3086 closed=0.6103 open=0.0000 f1=0.6581
[ViLT] epoch 4/6 loss=2.4913 val_acc=0.3160 closed=0.6250 open=0.0000 f1=0.6623
[ViLT] epoch 5/6 loss=2.4019 val_acc=0.3309 closed=0.6544 open=0.0000 f1=0.6759
[ViLT] epoch 6/6 loss=2.3567 val_acc=0.3234 closed=0.6397 open=0.0000 f1=0.6525


In [None]:
# ======================
# ViLT Test + Export
# ======================
vilt.load_state_dict(torch.load(os.path.join(OUT_DIR, "vilt_best.pt"), map_location=device))
test_df_vilt = predict_vilt(vilt, test_loader_v)
metrics_vilt = compute_metrics(test_df_vilt)
print("ViLT test metrics:", metrics_vilt)

test_df_vilt.to_csv(os.path.join(OUT_DIR, "vilt_predictions.csv"), index=False)
test_df_vilt[test_df_vilt["answer"] != test_df_vilt["pred_answer"]].to_csv(os.path.join(OUT_DIR, "vilt_errors.csv"), index=False)
with open(os.path.join(OUT_DIR, "vilt_test_metrics.json"), "w") as f:
    json.dump(metrics_vilt, f, indent=2)


ViLT test metrics: {'overall_acc': 0.2926829268292683, 'closed_acc': 0.5258964143426295, 'open_acc': 0.0, 'yesno_f1': 0.4824561403508772, 'n': 451, 'n_closed': 251, 'n_open': 200}


In [None]:
# ======================
# Final comparison table (for report)
# ======================
comp = pd.DataFrame([
    {"model": "Baseline (ResNet+LSTM)", **metrics_base},
    {"model": "ViLT (Transformer VLM)", **metrics_vilt},
])
comp.to_csv(os.path.join(OUT_DIR, "comparison.csv"), index=False)
print(comp)
print("All outputs saved to:", OUT_DIR)


                    model  overall_acc  closed_acc  open_acc  yesno_f1    n  \
0  Baseline (ResNet+LSTM)     0.368071    0.645418      0.02  0.645161  451   
1  ViLT (Transformer VLM)     0.292683    0.525896      0.00  0.482456  451   

   n_closed  n_open  
0       251     200  
1       251     200  
All outputs saved to: runs/20260110_103856


In [None]:
import shutil

shutil.make_archive(
    "/content/vqa_results_20260110_103856",
    "zip",
    "/content/runs/20260110_103856"
)


'/content/vqa_results_20260110_103856.zip'

In [1]:
import sys, torch
import transformers

print("Python:", sys.version)
print("PyTorch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA version:", torch.version.cuda)
print("GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
print("Transformers:", transformers.__version__)


Python: 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]
PyTorch: 2.9.0+cu126
CUDA available: True
CUDA version: 12.6
GPU: Tesla T4
Transformers: 4.57.3
