In [None]:
import json
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from transformers import ViltProcessor, ViltForQuestionAnswering, ViltConfig
from peft import get_peft_model, LoraConfig, TaskType
from PIL import Image
from tqdm import tqdm

# ============================
# 1. Load JSON and Preprocess
# ============================

with open("/kaggle/input/json-train/qna_train.json") as f:
    raw_data = json.load(f)

# Get unique answers and define label mappings
unique_answers = sorted({d["answer"] for d in raw_data})
label2id = {label: i for i, label in enumerate(unique_answers)}
id2label = {i: label for label, i in label2id.items()}

processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

# ============================
# 2. Custom Dataset
# ============================

from torchvision import transforms

class CustomVQADataset(Dataset):
    def __init__(self, data, processor, label2id):
        self.data = data
        self.processor = processor
        self.label2id = label2id
        self.resize = transforms.Resize((384, 384))  # ViLT expects 384x384

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image_path = "/kaggle/input/train-data" + item["image_path"]
        image = Image.open(image_path).convert("RGB")
        image = self.resize(image)  # Resize image to fixed size
        question = item["question"]
        answer = item["answer"]

        encoding = self.processor(
            images=image,
            text=question,
            return_tensors="pt",
            padding="max_length",
            truncation=True
        )
        encoding = {k: v.squeeze(0) for k, v in encoding.items()}

        label_id = self.label2id[answer]
        encoding["labels"] = torch.tensor(label_id).long()  # ✅ use index not one-hot

        return encoding

from transformers import ViltForQuestionAnswering
import torch.nn as nn

class ViltForSingleLabelClassification(ViltForQuestionAnswering):
    def forward(self, *args, **kwargs):
        labels = kwargs.pop("labels", None)
        output = super().forward(*args, **kwargs)
        if labels is not None:
            logits = output.logits
            loss = nn.CrossEntropyLoss()(logits, labels)
            return type(output)(loss=loss, **{k: v for k, v in output.items() if k != "loss"})
        return output



# ============================
# 3. Model + LoRA
# ============================

config = ViltConfig.from_pretrained(
    "dandelin/vilt-b32-finetuned-vqa",
    num_labels=len(label2id),
    id2label=id2label,
    label2id=label2id
)

base_model = base_model = ViltForSingleLabelClassification.from_pretrained(
    "dandelin/vilt-b32-finetuned-vqa",
    config=config,
    ignore_mismatched_sizes=True
)

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["query", "value"],
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.SEQ_CLS
)

model = get_peft_model(base_model, lora_config)

# ============================
# 4. Training
# ============================

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

dataset = CustomVQADataset(raw_data, processor, label2id)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

model.train()
for epoch in range(15):
    total_loss = 0
    for batch in tqdm(dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss.item()
    print(f"Epoch {epoch+1} Loss: {total_loss:.4f}")


In [None]:
model.save_pretrained("vilt-finetuned-vqa")
processor.save_pretrained("vilt-finetuned-vqa")
model.config.save_pretrained("vilt-finetuned-vqa")
