# Visual Question Answering (VQA) Fine-Tuning with LoRA

This notebook implements fine-tuning of the **ViLT** model for Visual Question Answering (VQA) using **Low-Rank Adaptation (LoRA)** on a dataset split into training and test sets. The code includes data loading, model setup with LoRA, training, evaluation, and inference, with results saved for analysis.

---

## Setup and Dependencies

The following steps outline the initial setup:

- **Import libraries**  
  Includes `torch`, `transformers`, `peft`, `PIL`, `pandas`, `sklearn`, and others for model fine-tuning and data processing.

- **Define file paths**  
  Specifies locations for images, VQA dataset, metadata, model output, and results.

- **Set image size**  
  Uses `384x384` as required by ViLT.


In [None]:
#with train and test split fine tuning with lora for vilt model on entire dataset.

import os
import json
import csv
import torch
import numpy as np
from PIL import Image
import pandas as pd
from tqdm import tqdm
from transformers import ViltProcessor, ViltForQuestionAnswering
from peft import LoraConfig, get_peft_model
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split

# Define paths
IMAGE_BASE_DIR = r"/kaggle/input/vr-project/images/small"
VQA_DATA_FILE = "/kaggle/input/vqa-training-complete/vqa_training_data_complete.json"
INPUT_IMAGES_FILE = "/kaggle/input/vr-project/images/metadata/images.csv"
OUTPUT_MODEL_DIR = "/kaggle/working/vilt-lora-finetuned"
OUTPUT_RESULTS_FILE = "/kaggle/working/vqa_finetune_results.csv"

# Fixed image size for ViLT
IMAGE_SIZE = (384, 384)  # ViLT expects 384x384 images

## Custom VQA Dataset

The `VQADataset` class handles the VQA dataset:

- **Initialization**  
  Takes VQA data, image metadata, processor, and answer-to-index mapping.

- **Data loading**  
  Loads and preprocesses images and questions.

- **Output**  
  Returns encoded inputs (image, text, labels) for model training.

The `custom_collate_fn` ensures proper batching of tensor inputs.


In [None]:
# Custom VQA Dataset
class VQADataset(Dataset):
    def __init__(self, vqa_data, image_map, processor, answer_to_idx):
        self.vqa_data = vqa_data
        self.image_map = image_map
        self.processor = processor
        self.answer_to_idx = answer_to_idx
        self.transform = transforms.Compose([
            transforms.Resize(IMAGE_SIZE),
            transforms.ToTensor()  # Outputs (C, H, W) with values in [0, 1]
        ])

    def __len__(self):
        return sum(len(item["questions"]) for item in self.vqa_data)

    def __getitem__(self, idx):
        # Find the correct item and question
        count = 0
        for item in self.vqa_data:
            for q_item in item["questions"]:
                if count == idx:
                    image_id = item["image_id"]
                    question = q_item["question"]
                    answer = q_item["answer"]
                    break
                count += 1
            else:
                continue
            break

        # Load image
        image_path = os.path.join(IMAGE_BASE_DIR, self.image_map[image_id]["path"])
        if not os.path.exists(image_path):
            raise FileNotFoundError(f"Image not found at {image_path}")
        image = Image.open(image_path).convert('RGB')

        # Apply resizing transform
        image_tensor = self.transform(image)  # Shape: (C, H, W)

        # Process inputs with ViLT processor
        encoding = self.processor(
            images=image_tensor,
            text=question,
            padding="max_length",
            max_length=40,
            truncation=True,
            return_tensors="pt"
        )

        # Prepare labels - using one-hot encoding for VQA classification
        answer_idx = self.answer_to_idx.get(answer, -1)
        
        # Create a one-hot tensor with same size as model output logits
        # The size should match the model's output size (3129 in this case)
        num_answers = len(self.answer_to_idx)
        if answer_idx != -1:
            # Create a one-hot encoded label
            one_hot = torch.zeros(num_answers)
            one_hot[answer_idx] = 1.0
            encoding["labels"] = one_hot
        else:
            # Handle unknown answers
            encoding["labels"] = torch.zeros(num_answers)

        # Remove batch dimension
        encoding = {k: v.squeeze(0) for k, v in encoding.items()}
        return encoding

def custom_collate_fn(batch):
    """Custom collate function to handle ViLT inputs, only stacking tensors."""
    keys = batch[0].keys()
    collated = {}

    for key in keys:
        items = [item[key] for item in batch]
        if all(isinstance(item, torch.Tensor) for item in items):
            try:
                collated[key] = torch.stack(items)
            except RuntimeError as e:
                print(f"Error stacking key '{key}': {e}")
                print(f"Shapes: {[item.shape for item in items]}")
                raise
        else:
            print(f"Skipping key '{key}' as it contains non-tensor items: {type(items[0])}")
            continue  # Skip non-tensor keys

    return collated

## Data Loading and Preprocessing

These functions prepare the dataset:

- **`load_image_metadata`**  
  Loads image metadata from a CSV into a dictionary.

- **`load_vqa_data`**  
  Loads the VQA dataset from a JSON file.

- **`create_answer_to_idx`**  
  Creates a mapping from answers to indices based on frequency.


In [None]:
def load_image_metadata():
    """Load image metadata from CSV into a dictionary."""
    image_map = {}
    with open(INPUT_IMAGES_FILE, "r", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            image_map[row["image_id"]] = {
                "height": int(row["height"]),
                "width": int(row["width"]),
                "path": row["path"]
            }
    return image_map

def load_vqa_data():
    """Load the VQA dataset from JSON."""
    with open(VQA_DATA_FILE, "r", encoding="utf-8") as f:
        vqa_data = json.load(f)
    return vqa_data

def create_answer_to_idx(vqa_data):
    """Create a mapping from answers to indices based on frequency."""
    answer_freq = {}
    for item in vqa_data:
        for q_item in item["questions"]:
            answer = q_item["answer"]
            answer_freq[answer] = answer_freq.get(answer, 0) + 1

    # Print diagnostic info
    print(f"Total unique answers in dataset: {len(answer_freq)}")
    
    # Use model's vocabulary size - we'll get this later
    # For now, keep most frequent answers
    sorted_answers = sorted(answer_freq.items(), key=lambda x: x[1], reverse=True)
    
    # Print top 5 answers for diagnostic purposes
    print("Top 5 answers by frequency:")
    for i, (answer, count) in enumerate(sorted_answers[:5]):
        print(f"  {i}. '{answer}': {count} occurrences")
        
    return {answer: idx for idx, (answer, _) in enumerate(sorted_answers)}

## Model Setup with LoRA

The `setup_lora_model` function initializes the ViLT model with LoRA:

- Loads the pre-trained ViLT model and processor.
- Applies LoRA configuration with specified parameters.
- Ensures trainable parameters are within limits.
- Moves the model to GPU if available.


In [None]:
def setup_lora_model():
    """Initialize ViLT model with LoRA configuration."""
    processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
    # Disable image rescaling at processor initialization
    processor.image_processor.do_rescale = False
    model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
    
    # Store the original classifier size
    original_num_labels = model.config.num_labels
    print(f"Original model has {original_num_labels} answer classes")

    # Define LoRA configuration
    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        target_modules=["query", "value"],
        lora_dropout=0.1,
        bias="none"
    )

    # Apply LoRA
    model = get_peft_model(model, lora_config)

    # Count trainable parameters
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Number of trainable parameters: {trainable_params}")
    if trainable_params > 7_000_000:
        raise ValueError("Trainable parameters exceed 7M limit")

    # Move model to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    return processor, model, device

## Training and Evaluation

These functions handle model training and evaluation:

- **`train_model`**  
  Fine-tunes the LoRA-adapted model using AdamW optimizer and CrossEntropyLoss.

- **`evaluate_model`**  
  Evaluates the model on the test set, computing loss and accuracy.


In [None]:
def train_model(model, processor, train_dataloader, device, num_epochs=3):
    """Train the LoRA-finetuned ViLT model."""
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
    criterion = torch.nn.CrossEntropyLoss()  # Added explicit loss function

    model.train()

    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        total_loss = 0
        batch_count = 0
        for batch in tqdm(train_dataloader, desc="Training"):
            try:
                # Verify batch contains only tensors
                for key, value in batch.items():
                    if not isinstance(value, torch.Tensor):
                        raise ValueError(f"Batch key '{key}' is not a tensor: {type(value)}")

                # Move batch to device
                batch = {k: v.to(device) for k, v in batch.items()}
                
                # Extract labels before passing to model
                labels = batch.pop("labels")
                
                # Forward pass - compute outputs without labels
                outputs = model(**batch)
                logits = outputs.logits  # Shape [batch_size, num_answers]
                
                # Compute loss manually using cross entropy
                loss = criterion(logits, labels)  # CrossEntropyLoss expects logits and class indices

                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                batch_count += 1
            except Exception as e:
                print(f"Error in batch: {str(e)}")
                print(f"Batch keys: {list(batch.keys())}")
                continue

        if batch_count > 0:
            avg_loss = total_loss / batch_count
            print(f"Average loss for epoch {epoch + 1}: {avg_loss:.4f}")
        else:
            print(f"No valid batches processed in epoch {epoch + 1}")

def evaluate_model(model, test_dataloader, device):
    """Evaluate the model on test data."""
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    criterion = torch.nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for batch in tqdm(test_dataloader, desc="Evaluating"):
            try:
                # Move batch to device
                batch = {k: v.to(device) for k, v in batch.items()}
                
                # Extract labels before passing to model
                labels = batch.pop("labels")
                
                # Forward pass
                outputs = model(**batch)
                logits = outputs.logits
                
                # Compute loss
                loss = criterion(logits, labels)
                test_loss += loss.item()
                
                # Get predictions
                _, predicted = torch.max(logits, 1)
                _, target = torch.max(labels, 1)
                
                total += labels.size(0)
                correct += (predicted == target).sum().item()
                
            except Exception as e:
                print(f"Error in evaluation batch: {str(e)}")
                continue
    
    # Calculate accuracy
    accuracy = correct / total if total > 0 else 0
    avg_loss = test_loss / len(test_dataloader) if len(test_dataloader) > 0 else 0
    
    print(f"Test Loss: {avg_loss:.4f}")
    print(f"Test Accuracy: {accuracy:.4f} ({correct}/{total})")
    return accuracy

## Inference

The `run_inference` function performs inference on the test set:

- Generates predictions for each question-image pair.
- Compares predictions to ground truth.
- Saves results to a CSV file.


In [None]:
def run_inference(model, processor, test_data, image_map, answer_to_idx, device):
    """Run inference on the test dataset and save results."""
    model.eval()
    results = []

    with torch.no_grad():
        for item in tqdm(test_data, desc="Processing test VQA items"):
            image_id = item["image_id"]
            if image_id not in image_map:
                print(f"Warning: Image ID {image_id} not found in metadata. Skipping.")
                continue

            image_path = os.path.join(IMAGE_BASE_DIR, image_map[image_id]["path"])
            if not os.path.exists(image_path):
                print(f"Warning: Image file not found at {image_path}. Skipping.")
                continue

            try:
                image = Image.open(image_path).convert('RGB')
                transform = transforms.Compose([
                    transforms.Resize(IMAGE_SIZE),
                    transforms.ToTensor()
                ])
                
                for q_item in item["questions"]:
                    question = q_item["question"]
                    ground_truth = q_item["answer"]
                    ground_truth_idx = answer_to_idx.get(ground_truth, -1)

                    # Preprocess image
                    image_tensor = transform(image)

                    # Prepare inputs
                    inputs = processor(
                        images=image_tensor,
                        text=question,
                        return_tensors="pt",
                        padding="max_length",
                        max_length=40,
                        truncation=True
                    ).to(device)

                    # Generate answer
                    outputs = model(**inputs)
                    predicted_answer_idx = outputs.logits.argmax(-1).item()
                    
                    # Convert index back to answer text
                    idx_to_answer = {idx: answer for answer, idx in answer_to_idx.items()}
                    predicted_answer = idx_to_answer.get(predicted_answer_idx, f"unknown_{predicted_answer_idx}")

                    # Store result
                    results.append({
                        "image_id": image_id,
                        "question": question,
                        "ground_truth": ground_truth,
                        "predicted_answer": predicted_answer,
                        "correct": predicted_answer == ground_truth
                    })

            except Exception as e:
                print(f"Error processing image {image_id}: {str(e)}")
                continue

    # Calculate accuracy
    accuracy = sum(1 for r in results if r["correct"]) / len(results) if results else 0
    print(f"Inference accuracy: {accuracy:.4f} ({sum(1 for r in results if r['correct'])}/{len(results)})")
    
    # Save results
    results_df = pd.DataFrame(results)
    results_df.to_csv(OUTPUT_RESULTS_FILE, index=False)
    print(f"Results saved to {OUTPUT_RESULTS_FILE}")
    return results_df

## Main Execution

The `main` function orchestrates the entire process:

- Loads and splits the dataset into training and test sets.
- Sets up the model with LoRA.
- Creates data loaders for training and testing.
- Trains and evaluates the model.
- Saves the model and runs inference.


In [None]:
def main():
    # Load data
    image_map = load_image_metadata()
    vqa_data = load_vqa_data()
    
    # Split data into train and test sets
    train_data, test_data = train_test_split(vqa_data, test_size=0.2, random_state=42)
    print(f"Training data: {len(train_data)} items")
    print(f"Test data: {len(test_data)} items")
    
    # Setup model first to get the proper answer vocabulary size
    processor, model, device = setup_lora_model()
    
    # Now get the answer-to-index mapping from the full dataset
    # (We need the full vocabulary, even if we're only training on a subset)
    answer_to_idx = create_answer_to_idx(vqa_data)
    
    # Get the vocabulary size from the model
    num_labels = model.config.num_labels
    print(f"Model expects {num_labels} possible answers")
    
    # Ensure our answer mapping matches the model's expected vocabulary size
    if len(answer_to_idx) != num_labels:
        print(f"Warning: Answer mapping size ({len(answer_to_idx)}) doesn't match model's vocabulary ({num_labels})")
        print("Using model's id2label mapping instead")
        answer_to_idx = {v: int(k) for k, v in model.config.id2label.items()}
        print(f"Model's answer vocabulary size: {len(answer_to_idx)}")
    
    # Create training dataset and dataloader
    train_dataset = VQADataset(train_data, image_map, processor, answer_to_idx)
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=8,
        shuffle=True,
        num_workers=0,  # Keep at 0 for debugging
        collate_fn=custom_collate_fn
    )
    
    # Create test dataset and dataloader for evaluation
    test_dataset = VQADataset(test_data, image_map, processor, answer_to_idx)
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=8,
        shuffle=False,
        num_workers=0,
        collate_fn=custom_collate_fn
    )

    # Train model
    print("Training model on training data...")
    train_model(model, processor, train_dataloader, device)

    # Evaluate model on test data
    print("Evaluating model on test data...")
    evaluate_model(model, test_dataloader, device)

    # Save model
    model.save_pretrained(OUTPUT_MODEL_DIR)
    processor.save_pretrained(OUTPUT_MODEL_DIR)
    print(f"Model saved to {OUTPUT_MODEL_DIR}")

    # Run inference on test data
    print("Running inference on test data...")
    results_df = run_inference(model, processor, test_data, image_map, answer_to_idx, device)

if __name__ == "__main__":
    main()

Training data: 7094 items
Test data: 1774 items
Original model has 3129 answer classes
Number of trainable parameters: 589824
Total unique answers in dataset: 2093
Top 5 answers by frequency:
  0. 'Yes': 3060 occurrences
  1. 'Black': 1391 occurrences
  2. 'Two': 1234 occurrences
  3. 'Blue': 968 occurrences
  4. 'Hard': 739 occurrences
Model expects 3129 possible answers
Using model's id2label mapping instead
Model's answer vocabulary size: 3129
Training model on training data...
Epoch 1/3


Training: 100%|██████████| 2433/2433 [13:38<00:00,  2.97it/s]


Average loss for epoch 1: 0.0899
Epoch 2/3


Training: 100%|██████████| 2433/2433 [12:49<00:00,  3.16it/s]


Average loss for epoch 2: 0.0711
Epoch 3/3


Training: 100%|██████████| 2433/2433 [12:33<00:00,  3.23it/s]


Average loss for epoch 3: 0.0611
Evaluating model on test data...


Evaluating: 100%|██████████| 606/606 [02:35<00:00,  3.90it/s]


Test Loss: 0.0658
Test Accuracy: 0.0076 (37/4848)
Model saved to /kaggle/working/vilt-lora-finetuned
Running inference on test data...


Processing test VQA items: 100%|██████████| 1774/1774 [03:34<00:00,  8.26it/s]


Inference accuracy: 0.0076 (37/4848)
Results saved to /kaggle/working/vqa_finetune_results.csv


# Visual Question Answering (VQA) Evaluation Part

This Part evaluates the performance of a Visual Question Answering (VQA) model by computing various metrics on a results dataset. It calculates Exact Match, Token Match, Wu-Palmer (WUP) Score, F1 Score, and BERTScore, and provides detailed analysis by question type and for yes/no questions.

---

## Setup and Dependencies

The following steps outline the initial setup:

- **Import libraries**  
  Includes `pandas`, `nltk`, `transformers`, `torch`, `sklearn`, and others for data processing and metric computation.

- **Download NLTK resources**  
  Ensures WordNet, Punkt, and POS tagger are available for text processing.

- **Define output path**  
  Specifies the location for saving metrics.

- **Suppress warnings**  
  Ignores non-critical warnings to keep output clean.


In [None]:
import pandas as pd
import json
import nltk
import re
from collections import Counter
from transformers import BertTokenizer, BertModel
import torch
from sklearn.metrics import f1_score, precision_score, recall_score
from nltk.corpus import wordnet as wn
import warnings

# Suppress warnings
warnings.filterwarnings('ignore')

nltk.download('wordnet', quiet=True)
nltk.download('punkt', quiet=True)
nltk.download('averaged_perceptron_tagger', quiet=True)
nltk.download('averaged_perceptron_tagger_eng', quiet=True)

OUTPUT_METRICS_FILE = "/kaggle/working/vqa_metrics.json"

## Text Normalization and Scoring Functions

These functions handle text processing and metric calculations:

- **`normalize_answer`**  
  Converts answers to a standardized format by mapping digits to words, removing articles, punctuation, and extra whitespace.

- **`calculate_bertscore`**  
  Computes BERTScore using BERT embeddings and cosine similarity.

- **`exact_match`**  
  Checks if the predicted answer exactly matches the ground truth after normalization.

- **`token_match`**  
  Compares token sets between predicted and ground truth answers.

- **`get_wordnet_pos`**  
  Maps NLTK POS tags to WordNet POS for WUP similarity.

- **`calculate_wup_score`**  
  Calculates Wu-Palmer similarity between predicted and ground truth answers.

- **`calculate_f1_score`**  
  Computes F1 score based on token overlap.


In [None]:
def normalize_answer(s):
    if pd.isna(s) or not isinstance(s, str):
        s = ""
    number_map = {
        '0': 'zero', '1': 'one', '2': 'two', '3': 'three', '4': 'four',
        '5': 'five', '6': 'six', '7': 'seven', '8': 'eight', '9': 'nine', '10': 'ten'
    }
    s = str(s).lower()
    for digit, word in number_map.items():
        s = re.sub(r'\b' + digit + r'\b', word, s)
    s = re.sub(r'\b(a|an|the)\b', ' ', s)
    s = re.sub(r'[^\w\s]', '', s)
    s = re.sub(r'\s+', ' ', s).strip()
    return s

def calculate_bertscore(pred, ref, tokenizer, model, device):
    pred = normalize_answer(pred)
    ref = normalize_answer(ref)
    if not pred or not ref:
        return 0.0
    pred_tokens = tokenizer(pred, return_tensors='pt', padding=True, truncation=True).to(device)
    ref_tokens = tokenizer(ref, return_tensors='pt', padding=True, truncation=True).to(device)
    with torch.no_grad():
        pred_outputs = model(**pred_tokens)
        ref_outputs = model(**ref_tokens)
    pred_embedding = pred_outputs.last_hidden_state[:, 0, :]
    ref_embedding = ref_outputs.last_hidden_state[:, 0, :]
    pred_embedding = pred_embedding / pred_embedding.norm(dim=1, keepdim=True)
    ref_embedding = ref_embedding / ref_embedding.norm(dim=1, keepdim=True)
    similarity = torch.matmul(pred_embedding, ref_embedding.transpose(0, 1)).item()
    return similarity

def exact_match(pred, ref):
    return normalize_answer(pred) == normalize_answer(ref)

def token_match(pred, ref):
    pred_tokens = normalize_answer(pred).split()
    ref_tokens = normalize_answer(ref).split()
    return Counter(pred_tokens) == Counter(ref_tokens)

def get_wordnet_pos(word):
    tag = nltk.pos_tag([word])[0][1][0].upper()
    tag_dict = {"J": wn.ADJ, "N": wn.NOUN, "V": wn.VERB, "R": wn.ADV}
    return tag_dict.get(tag, wn.NOUN)

def calculate_wup_score(pred, ref):
    pred_tokens = normalize_answer(pred).split()
    ref_tokens = normalize_answer(ref).split()
    if not pred_tokens or not ref_tokens:
        return 0.0
    max_similarities = []
    for p_token in pred_tokens:
        token_max_sim = 0.0
        p_synsets = wn.synsets(p_token, pos=get_wordnet_pos(p_token))
        if not p_synsets:
            p_synsets = wn.synsets(p_token)
        if not p_synsets:
            continue
        for r_token in ref_tokens:
            r_synsets = wn.synsets(r_token, pos=get_wordnet_pos(r_token))
            if not r_synsets:
                r_synsets = wn.synsets(r_token)
            if not r_synsets:
                continue
            token_sims = []
            for p_syn in p_synsets:
                for r_syn in r_synsets:
                    try:
                        sim = wn.wup_similarity(p_syn, r_syn)
                        if sim is not None:
                            token_sims.append(sim)
                    except:
                        continue
            if token_sims:
                token_max_sim = max(token_max_sim, max(token_sims))
        if token_max_sim > 0:
            max_similarities.append(token_max_sim)
    return sum(max_similarities) / len(max_similarities) if max_similarities else 0.0

def calculate_f1_score(pred, ref):
    pred_tokens = set(normalize_answer(pred).split())
    ref_tokens = set(normalize_answer(ref).split())
    if not pred_tokens and not ref_tokens:
        return 1.0
    if not pred_tokens or not ref_tokens:
        return 0.0
    common_tokens = pred_tokens.intersection(ref_tokens)
    precision = len(common_tokens) / len(pred_tokens) if pred_tokens else 0.0
    recall = len(common_tokens) / len(ref_tokens) if ref_tokens else 0.0
    if precision + recall == 0:
        return 0.0
    f1 = 2 * precision * recall / (precision + recall)
    return f1

## BERTScore Model Setup

The `setup_bertscore_model` function initializes the BERT model for BERTScore calculation:

- Loads the BERT tokenizer and model (`bert-base-uncased`).
- Moves the model to GPU if available.


def setup_bertscore_model():
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = BertModel.from_pretrained('bert-base-uncased')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    return tokenizer, model, device

In [None]:
def setup_bertscore_model():
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = BertModel.from_pretrained('bert-base-uncased')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    return tokenizer, model, device

## Evaluation Function

The `evaluate_results` function computes and analyzes metrics:

- Calculates Exact Match, Token Match, WUP Score, F1 Score, and BERTScore for each prediction.
- Categorizes questions into types (counting, color, yes/no, other).
- Computes metrics by question type.
- Analyzes yes/no questions with binary classification metrics (accuracy, precision, recall, F1).
- Saves metrics to a JSON file and prints a summary.


In [None]:
def evaluate_results(results_df):
    bert_tokenizer, bert_model, bert_device = setup_bertscore_model()
    
    # Compute metrics for each row
    results_df["exact_match"] = results_df.apply(lambda row: exact_match(row["predicted_answer"], row["ground_truth"]), axis=1)
    results_df["token_match"] = results_df.apply(lambda row: token_match(row["predicted_answer"], row["ground_truth"]), axis=1)
    results_df["wup_score"] = results_df.apply(lambda row: calculate_wup_score(row["predicted_answer"], row["ground_truth"]), axis=1)
    results_df["f1_score"] = results_df.apply(lambda row: calculate_f1_score(row["predicted_answer"], row["ground_truth"]), axis=1)
    results_df["bertscore"] = results_df.apply(lambda row: calculate_bertscore(row["predicted_answer"], row["ground_truth"], bert_tokenizer, bert_model, bert_device), axis=1)
    
    metrics = {
        "overall": {
            "exact_match": results_df["exact_match"].mean(),
            "token_match": results_df["token_match"].mean(),
            "wup_score": results_df["wup_score"].mean(),
            "f1_score": results_df["f1_score"].mean(),
            "bertscore": results_df["bertscore"].mean()
        }
    }
    results_df["question_type"] = "other"
    results_df.loc[results_df["question"].str.contains("how many|number|count", case=False), "question_type"] = "counting"
    results_df.loc[results_df["question"].str.contains("color|colour", case=False), "question_type"] = "color"
    results_df.loc[results_df["question"].str.startswith(("Is ", "Are ", "Does ", "Do ", "Can ", "Could ", "Has ", "Have ")), "question_type"] = "yes/no"
    
    question_types = results_df["question_type"].unique()
    metrics["by_question_type"] = {}
    for qtype in question_types:
        subset = results_df[results_df["question_type"] == qtype]
        metrics["by_question_type"][qtype] = {
            "count": len(subset),
            "exact_match": subset["exact_match"].mean(),
            "token_match": subset["token_match"].mean(),
            "wup_score": subset["wup_score"].mean(),
            "f1_score": subset["f1_score"].mean(),
            "bertscore": subset["bertscore"].mean()
        }
    
    yes_no_df = results_df[results_df["question_type"] == "yes/no"]
    if len(yes_no_df) > 0:
        yes_no_df["gt_binary"] = yes_no_df["ground_truth"].str.lower().apply(
            lambda x: 1 if x in ["yes", "yeah", "true"] else 0)
        yes_no_df["pred_binary"] = yes_no_df["predicted_answer"].str.lower().apply(
            lambda x: 1 if x in ["yes", "yeah", "true"] else 0)
        metrics["yes_no_analysis"] = {
            "accuracy": (yes_no_df["gt_binary"] == yes_no_df["pred_binary"]).mean(),
            "precision": precision_score(yes_no_df["gt_binary"], yes_no_df["pred_binary"], zero_division=0),
            "recall": recall_score(yes_no_df["gt_binary"], yes_no_df["pred_binary"], zero_division=0),
            "f1": f1_score(yes_no_df["gt_binary"], yes_no_df["pred_binary"], zero_division=0)
        }
    
    with open(OUTPUT_METRICS_FILE, 'w') as f:
        json.dump(metrics, f, indent=2)
    
    print("\n===== VQA Evaluation Results =====")
    print(f"Total questions evaluated: {len(results_df)}")
    print(f"Exact match accuracy: {metrics['overall']['exact_match']:.4f}")
    print(f"Token match accuracy: {metrics['overall']['token_match']:.4f}")
    print(f"Average WUP score: {metrics['overall']['wup_score']:.4f}")
    print(f"Average F1 score: {metrics['overall']['f1_score']:.4f}")
    print(f"Average BERTScore: {metrics['overall']['bertscore']:.4f}")
    
    print("\n===== Results by Question Type =====")
    for qtype, qmetrics in metrics["by_question_type"].items():
        print(f"\n{qtype.upper()} Questions ({qmetrics['count']} questions):")
        print(f"  Exact match: {qmetrics['exact_match']:.4f}")
        print(f"  Token match: {qmetrics['token_match']:.4f}")
        print(f"  WUP score: {qmetrics['wup_score']:.4f}")
        print(f"  F1 score: {qmetrics['f1_score']:.4f}")
        print(f"  BERTScore: {qmetrics['bertscore']:.4f}")
    
    if "yes_no_analysis" in metrics:
        print("\n===== Yes/No Question Analysis =====")
        print(f"  Accuracy: {metrics['yes_no_analysis']['accuracy']:.4f}")
        print(f"  Precision: {metrics['yes_no_analysis']['precision']:.4f}")
        print(f"  Recall: {metrics['yes_no_analysis']['recall']:.4f}")
        print(f"  F1 score: {metrics['yes_no_analysis']['f1']:.4f}")
    
    return metrics


## Main Execution

The main block loads the results and runs the evaluation:

- Reads the VQA results from a CSV file.
- Calls the `evaluate_results` function to compute and save metrics.


In [None]:

if __name__ == "__main__":
    results_df = pd.read_csv("/kaggle/working/vqa_finetune_results.csv")
    evaluate_results(results_df)


===== VQA Evaluation Results =====
Total questions evaluated: 4848
Exact match accuracy: 0.4099
Token match accuracy: 0.4099
Average WUP score: 0.6925
Average F1 score: 0.4137
Average BERTScore: 0.9590

===== Results by Question Type =====

COLOR Questions (1577 questions):
  Exact match: 0.5568
  Token match: 0.5568
  WUP score: 0.7692
  F1 score: 0.5637
  BERTScore: 0.9762

YES/NO Questions (1195 questions):
  Exact match: 0.6126
  Token match: 0.6126
  WUP score: 0.7268
  F1 score: 0.6126
  BERTScore: 0.9761

OTHER Questions (1441 questions):
  Exact match: 0.1443
  Token match: 0.1443
  WUP score: 0.4962
  F1 score: 0.1498
  BERTScore: 0.9249

COUNTING Questions (635 questions):
  Exact match: 0.2661
  Token match: 0.2661
  WUP score: 0.8833
  F1 score: 0.2661
  BERTScore: 0.9612

===== Yes/No Question Analysis =====
  Accuracy: 0.7941
  Precision: 0.8127
  Recall: 0.7754
  F1 score: 0.7936
