# Visual Question Answering with ViLT


## Part 1: Imports and Setup

In [1]:
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import json
import os
from collections import Counter
from tqdm import tqdm

# PyTorch imports
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torchvision import transforms
device = torch.device("cpu")

# Transformers imports
from transformers import ViltProcessor, ViltForQuestionAnswering
from transformers import get_linear_schedule_with_warmup

print(f"PyTorch version: {torch.__version__}")
print(f"MPS available: {torch.backends.mps.is_available()}")
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

PyTorch version: 2.9.1
MPS available: True
Using device: mps


## Part 2: Load MS-COCO Data

In [2]:
# Load questions

with open('v2_OpenEnded_mscoco_val2014_questions.json', 'r') as f:
    questions_data = json.load(f)

# Load annotations
with open('v2_mscoco_val2014_annotations.json', 'r') as f:
    annotations_data = json.load(f)

print(f"Loaded {len(questions_data['questions'])} questions")
print(f"Loaded {len(annotations_data['annotations'])} annotations")

Loaded 214354 questions
Loaded 214354 annotations


## Part 3: Prepare Dataset for ViLT

In [3]:
# Prepare image-question-answer triplets
N = 30000  
vilt_data = []

for i in tqdm(range(min(N, len(questions_data['questions'])))):
    question_obj = questions_data['questions'][i]
    image_id = question_obj['image_id']
    question_id = question_obj['question_id']
    img_filename = f"COCO_val2014_{image_id:012d}.jpg"
    img_path = f"val2014/{img_filename}"
    
    # Check if image exists
    if os.path.exists(img_path):
        # Get annotation
        ann = next((a for a in annotations_data['annotations'] 
                   if a['question_id'] == question_id), None)
        if ann:
            # Get most common answer
            answers = [ans['answer'] for ans in ann['answers']]
            most_common = max(set(answers), key=answers.count)
            
            vilt_data.append({
                'image_path': img_path,
                'question': question_obj['question'],
                'answer': most_common,
                'all_answers': answers
            })

print(f"\n Collected {len(vilt_data)} valid samples")

# Split into train/val (80/20)
split_idx = int(0.8 * len(vilt_data))
train_data = vilt_data[:split_idx]
val_data = vilt_data[split_idx:]

print(f"  Training: {len(train_data)} samples")
print(f"  Validation: {len(val_data)} samples")

100%|███████████████████████████████████| 30000/30000 [00:12<00:00, 2321.37it/s]


 Collected 30000 valid samples
  Training: 24000 samples
  Validation: 6000 samples





## Part 4: Build Answer Vocabulary

In [4]:
# Build answer vocabulary from training data
train_answers = [d['answer'] for d in train_data]
answer_freq = Counter(train_answers)

# Use top 3000 answers for better coverage
top_k = 3000
top_answers = [ans for ans, _ in answer_freq.most_common(top_k)]
answer_to_id = {ans: idx for idx, ans in enumerate(top_answers)}
id_to_answer = {idx: ans for ans, idx in answer_to_id.items()}

print(f"Answer vocabulary: {len(answer_to_id)} classes")
coverage = sum(1 for d in train_data if d['answer'] in answer_to_id) / len(train_data)
print(f"  Coverage: {coverage*100:.1f}%")

# Show distribution
print(f"\nTop 10 most common answers:")
for ans, count in answer_freq.most_common(10):
    print(f"  {ans}: {count}")

Answer vocabulary: 3000 classes
  Coverage: 98.5%

Top 10 most common answers:
  no: 4574
  yes: 4431
  1: 671
  2: 646
  white: 496
  blue: 346
  3: 321
  red: 301
  black: 281
  0: 265


## Part 5: PyTorch Dataset for ViLT

In [5]:
class VQAViLTDataset(Dataset):
    """Dataset for ViLT model"""
    def __init__(self, data, processor, answer_to_id):
        self.data = data
        self.processor = processor
        self.answer_to_id = answer_to_id
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        try:
            # Load image and RESIZE to fixed size
            image = Image.open(item['image_path']).convert('RGB')
            image = image.resize((384, 384))  # <-- ADD THIS LINE
        except:
            # Fallback to blank image
            image = Image.new('RGB', (384, 384), color='white')  # <-- Also update this
        
        question = item['question']
        answer = item['answer']
        
        # Process with ViLT processor
        encoding = self.processor(
            image, 
            question,
            padding="max_length",
            truncation=True,
            max_length=40,
            return_tensors="pt"
        )
        
        # Remove batch dimension
        encoding = {k: v.squeeze(0) for k, v in encoding.items()}
        
        # Add label
        # Add label as one-hot vector (ViLT expects this format)
        num_labels = len(self.answer_to_id)
        labels = torch.zeros(num_labels)
        if answer in self.answer_to_id:
            labels[self.answer_to_id[answer]] = 1.0
        
        encoding['labels'] = labels
        
        return encoding
        
        

# Load ViLT processor

processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

# Create datasets
train_dataset = VQAViLTDataset(train_data, processor, answer_to_id)
val_dataset = VQAViLTDataset(val_data, processor, answer_to_id)

print(f"Train dataset: {len(train_dataset)} samples")
print(f"Val dataset: {len(val_dataset)} samples")

Train dataset: 24000 samples
Val dataset: 6000 samples


## Part 6: DataLoaders

In [6]:
# Collate function to handle batching
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([item['pixel_values'] for item in batch]),
        'input_ids': torch.stack([item['input_ids'] for item in batch]),
        'attention_mask': torch.stack([item['attention_mask'] for item in batch]),
        'labels': torch.stack([item['labels'] for item in batch])  # Now [batch, 3000]
    }

# Create dataloaders
batch_size = 16  

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_fn 
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_fn  
)

print(f" Train batches: {len(train_loader)}")
print(f" Val batches: {len(val_loader)}")

 Train batches: 1500
 Val batches: 375


## Part 7: Load ViLT Model and Setup Training

In [7]:
# Load pre-trained ViLT model

model = ViltForQuestionAnswering.from_pretrained(
    "dandelin/vilt-b32-finetuned-vqa",
    num_labels=len(answer_to_id),
    ignore_mismatched_sizes=True
)

model = model.to(device)
print(f"Model loaded on {device}")

# Training configuration
epochs = 5
learning_rate = 5e-5

# Optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

num_training_steps = len(train_loader) * epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * num_training_steps),
    num_training_steps=num_training_steps
)

print(f"\nTraining Configuration:")
print(f"  Epochs: {epochs}")
print(f"  Learning rate: {learning_rate}")
print(f"  Training steps: {num_training_steps}")
print(f"  Warmup steps: {int(0.1 * num_training_steps)}")

Some weights of ViltForQuestionAnswering were not initialized from the model checkpoint at dandelin/vilt-b32-finetuned-vqa and are newly initialized because the shapes did not match:
- classifier.3.weight: found shape torch.Size([3129, 1536]) in the checkpoint and torch.Size([3000, 1536]) in the model instantiated
- classifier.3.bias: found shape torch.Size([3129]) in the checkpoint and torch.Size([3000]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded on mps

Training Configuration:
  Epochs: 5
  Learning rate: 5e-05
  Training steps: 7500
  Warmup steps: 750


## Part 8: Training Loop

In [None]:
def train_epoch(model, loader, optimizer, scheduler, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    progress_bar = tqdm(loader, desc="Training")
    
    for batch in progress_bar:
        # Move to device
        batch = {k: v.to(device) for k, v in batch.items()}
        
        # Forward pass
        outputs = model(**batch)
        loss = outputs.loss
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()
        
        # Calculate accuracy
        # Calculate accuracy (for one-hot labels)
        predictions = outputs.logits.argmax(dim=-1)
        targets = batch['labels'].argmax(dim=-1)
        correct += (predictions == targets).sum().item()
        total += predictions.size(0)
        
        # Update progress
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{100*correct/total:.2f}%'
        })
    
    return total_loss / len(loader), 100 * correct / total


def validate(model, loader, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    progress_bar = tqdm(loader, desc="Validation")
    
    with torch.no_grad():
        for batch in progress_bar:
            batch = {k: v.to(device) for k, v in batch.items()}
            
            outputs = model(**batch)
            loss = outputs.loss
            
            # Calculate accuracy (for one-hot labels)
            predictions = outputs.logits.argmax(dim=-1)
            targets = batch['labels'].argmax(dim=-1)
            correct += (predictions == targets).sum().item()
            total += predictions.size(0)
            
            progress_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100*correct/total:.2f}%'
            })
    
    return total_loss / len(loader), 100 * correct / total


# Training loop

best_val_acc = 0
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

for epoch in range(epochs):
    print(f"\nEpoch {epoch+1}/{epochs}")
    
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, device)
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    
    # Validate
    val_loss, val_acc = validate(model, val_loader, device)
    print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
        }, 'best_vilt_vqa_model.pth')
        print(f"Saved best model (Val Acc: {val_acc:.2f}%)")

print(f"Best Validation Accuracy: {best_val_acc:.2f}%")


Epoch 1/5


Training:   0%|  | 2/1500 [00:53<11:22:00, 27.32s/it, loss=2098.5903, acc=0.00%]

## Part 9: Visualize Results

In [None]:
# Plot training history
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Loss
ax1.plot(history['train_loss'], label='Train Loss', marker='o')
ax1.plot(history['val_loss'], label='Val Loss', marker='s')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True)

# Accuracy
ax2.plot(history['train_acc'], label='Train Acc', marker='o')
ax2.plot(history['val_acc'], label='Val Acc', marker='s')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training and Validation Accuracy')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.show()

print(f"\nFinal Results:")
print(f"  Best Validation Accuracy: {best_val_acc:.2f}%")
print(f"  Final Train Accuracy: {history['train_acc'][-1]:.2f}%")

## Part 10: Test Predictions

In [None]:
# Load best model
checkpoint = torch.load('best_vilt_vqa_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print("Sample Predictions:\n")

# Test on validation samples
for i in range(min(15, len(val_data))):
    item = val_data[i]
    
    try:
        image = Image.open(item['image_path']).convert('RGB')
        encoding = processor(image, item['question'], return_tensors="pt")
        encoding = {k: v.to(device) for k, v in encoding.items()}
        
        with torch.no_grad():
            outputs = model(**encoding)
            pred_idx = outputs.logits.argmax(dim=-1).item()
        
        pred_answer = id_to_answer.get(pred_idx, "unknown")
        is_correct = pred_answer.lower() in [a.lower() for a in item['all_answers']]
        
        print(f"{i+1}. Q: {item['question']}")
        print(f"   Predicted: {pred_answer}")
        print(f"   Ground Truth: {item['answer']}")
        print(f"   {'✓ CORRECT' if is_correct else '✗ WRONG'}\n")
    except Exception as e:
        print(f"{i+1}. Error: {e}\n")
        continue

print(f"\nModel saved as: best_vilt_vqa_model.pth")
print(f"Best Validation Accuracy: {best_val_acc:.2f}%")