# Vision Transformer (ViT) Fine-tuning for VizWiz

This notebook implements ViT fine-tuning for the VizWiz dataset with focus on:
1. Image quality classification (answerable vs unanswerable)
2. Question type classification
3. Degradation type detection (blur, darkness, poor framing)

Since ViT cannot directly do VQA, we'll train it as a multi-task classifier to:
- Predict if an image is answerable
- Classify question types (OCR-like, color, object, etc.)
- Detect image quality issues

# Quick Start
```
jupyter lab --no-browser
http://127.0.0.1:8888/lab 123
```

## 1. Setup and Imports

In [27]:
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import numpy as np
from PIL import Image
from transformers import ViTForImageClassification, ViTFeatureExtractor
from transformers import TrainingArguments, Trainer
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score, classification_report
import cv2
from collections import Counter
import warnings
warnings.filterwarnings('ignore')
import wandb
wandb.init(project="vit-vizwiz", name="experiment-1")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")



Using device: cuda
GPU: NVIDIA GeForce RTX 5090


## 2. Data Loading and Analysis

In [16]:
# Load annotations
def load_vizwiz_annotations(split='train'):
    """
    Load VizWiz annotations.
    Note: Your current files seem to have limited info.
    You may need to download the full dataset from:
    https://vizwiz.org/tasks-and-datasets/vqa/
    """
    with open(f'data/annotations/{split}.json', 'r', encoding='utf-8') as f:
        data = json.load(f)
    return data

# Load all splits
train_data = load_vizwiz_annotations('train')
val_data = load_vizwiz_annotations('val')

print("Train data keys:", train_data.keys())
print(f"Train images: {len(train_data['images'])}")
print(f"Train annotations: {len(train_data['annotations'])}")
print(f"\nVal images: {len(val_data['images'])}")
print(f"Val annotations: {len(val_data['annotations'])}")

Train data keys: dict_keys(['info', 'images', 'annotations'])
Train images: 23431
Train annotations: 117155

Val images: 7750
Val annotations: 38750


## 3. Question Type Classifier (Heuristic-based)

Based on your paper's Table 2, we'll classify questions into buckets.

In [17]:
def classify_question_type(question):
    """
    Classify question into types based on keywords:
    - OCR_LIKE: text reading questions
    - COLOR: color-related questions
    - COUNT: counting questions
    - DIRECTION: directional questions
    - TIME: time-related questions
    - OTHER: everything else
    """
    question_lower = question.lower()
    
    # OCR-related keywords
    ocr_keywords = ['read', 'say', 'text', 'label', 'written', 'writing', 
                    'words', 'screen', 'display', 'says', 'does this say']
    if any(keyword in question_lower for keyword in ocr_keywords):
        return 'OCR_LIKE'
    
    # Color keywords
    color_keywords = ['color', 'colour', 'what color']
    if any(keyword in question_lower for keyword in color_keywords):
        return 'COLOR'
    
    # Count keywords
    count_keywords = ['how many', 'count', 'number of']
    if any(keyword in question_lower for keyword in count_keywords):
        return 'COUNT'
    
    # Direction keywords
    direction_keywords = ['left', 'right', 'top', 'bottom', 'front', 'back',
                         'above', 'below', 'which side']
    if any(keyword in question_lower for keyword in direction_keywords):
        return 'DIRECTION'
    
    # Time keywords
    time_keywords = ['time', 'clock', 'hour', 'minute']
    if any(keyword in question_lower for keyword in time_keywords):
        return 'TIME'
    
    return 'OTHER'

# Test the classifier
test_questions = [
    "What does this label say?",
    "What color is this shirt?",
    "How many bottles are there?",
    "What is on the left?",
    "What time is it?",
    "What is this?"
]

for q in test_questions:
    print(f"{q:40s} -> {classify_question_type(q)}")

What does this label say?                -> OCR_LIKE
What color is this shirt?                -> COLOR
How many bottles are there?              -> COUNT
What is on the left?                     -> DIRECTION
What time is it?                         -> TIME
What is this?                            -> OTHER


## 4. Image Quality Detection

Detect blur, darkness, and poor contrast as mentioned in your paper.

In [18]:
def detect_blur(image_path, threshold=100):
    """Detect if image is blurry using Laplacian variance."""
    img = cv2.imread(str(image_path))
    if img is None:
        return False
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    variance = cv2.Laplacian(gray, cv2.CV_64F).var()
    return variance < threshold

def detect_darkness(image_path, threshold=50):
    """Detect if image is too dark."""
    img = cv2.imread(str(image_path))
    if img is None:
        return False
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    mean_brightness = np.mean(gray)
    return mean_brightness < threshold

def detect_low_contrast(image_path, threshold=30):
    """Detect if image has low contrast."""
    img = cv2.imread(str(image_path))
    if img is None:
        return False
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    contrast = gray.std()
    return contrast < threshold

def get_image_quality_features(image_path):
    """Get all quality features for an image."""
    return {
        'is_blurry': detect_blur(image_path),
        'is_dark': detect_darkness(image_path),
        'is_low_contrast': detect_low_contrast(image_path)
    }

## 5. Dataset Preparation


Dataset from: https://vizwiz.org/tasks-and-datasets/vqa/

create a dataset for answerability classification as a starting point

In [19]:
class VizWizAnswerabilityDataset(Dataset):
    """
    Dataset for predicting if an image/question pair is answerable.
    binary classification task for ViT 
    """
    def __init__(self, annotations, images_dir, feature_extractor, max_samples=None):
        self.annotations = annotations['annotations']
        self.images = {img['id']: img for img in annotations['images']}
        self.images_dir = Path(images_dir)
        self.feature_extractor = feature_extractor
        
        if max_samples:
            self.annotations = self.annotations[:max_samples]
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        ann = self.annotations[idx]
        image_info = self.images[ann['image_id']]
        
        image_path = self.images_dir / image_info['file_name']
        try:
            image = Image.open(image_path).convert('RGB')
        except:
            image = Image.new('RGB', (224, 224), color='gray')
        
        inputs = self.feature_extractor(images=image, return_tensors="pt")
        pixel_values = inputs['pixel_values'].squeeze(0)
        
        # Label: 1 if answerable (not rejected), 0 if unanswerable
        label = 0 if ann.get('is_rejected', False) else 1
        
        return {
            'pixel_values': pixel_values,
            'labels': torch.tensor(label, dtype=torch.long)
        }

# eature extractor
model_name = "google/vit-base-patch16-224"
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)

print("Feature extractor loaded")

Feature extractor loaded


## 6. Model Setup

We'll fine-tune ViT for binary classification (answerable vs unanswerable).

In [20]:
model = ViTForImageClassification.from_pretrained(
    model_name,
    num_labels=2,  # Binary classification: answerable vs unanswerable
    ignore_mismatched_sizes=True
)

model.to(device)
print(f"Model loaded on {device}")
print(f"Number of parameters: {sum(p.numel() for p in model.parameters()):,}")

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded on cuda
Number of parameters: 85,800,194


## 7. Create Datasets

**IMPORTANT**: Update the `images_dir` path to where your VizWiz images are stored.

In [21]:
TRAIN_IMAGES_DIR = "data/train"  
VAL_IMAGES_DIR = "data/val"      

# create datasets 
train_dataset = VizWizAnswerabilityDataset(
    train_data,
    TRAIN_IMAGES_DIR,
    feature_extractor,
    max_samples=100  # trial, adjust as needed
)

val_dataset = VizWizAnswerabilityDataset(
    val_data,
    VAL_IMAGES_DIR,
    feature_extractor,
    max_samples=50  # trial, adjust as needed
)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Val dataset size: {len(val_dataset)}")

sample = train_dataset[0]
print(f"\nSample pixel_values shape: {sample['pixel_values'].shape}")
print(f"Sample label: {sample['labels']}")

Train dataset size: 100
Val dataset size: 50

Sample pixel_values shape: torch.Size([3, 224, 224])
Sample label: 1


## 8. Training Configuration

In [None]:
def compute_metrics(eval_pred):
    """Compute accuracy metrics."""
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    accuracy = accuracy_score(labels, predictions)
    return {'accuracy': accuracy}

# training args
training_args = TrainingArguments(
    output_dir="./models/vit_vizwiz_answerability",
    num_train_epochs=5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=50,
    evaluation_strategy="steps",
    eval_steps=200,
    save_strategy="steps",
    save_steps=200,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    greater_is_better=True,
    learning_rate=5e-5,
    fp16=torch.cuda.is_available(), 
    report_to="wandb",
    report_to="tensorboard"
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)

print("Trainer initialized")

Trainer initialized


## 9. Start Training

In [None]:
print("Starting training...")
trainer.train()

print("\nTraining completed!")
wandb.finish()

Starting training...


Step,Training Loss,Validation Loss



Training completed!


## 10. Evaluation

In [24]:
eval_results = trainer.evaluate()
print("\nEvaluation Results:")
for key, value in eval_results.items():
    print(f"  {key}: {value:.4f}")


Evaluation Results:
  eval_loss: nan
  eval_accuracy: 0.0000
  eval_runtime: 0.8668
  eval_samples_per_second: 57.6860
  eval_steps_per_second: 2.3070
  epoch: 5.0000


## 11. Save Model

In [25]:
model.save_pretrained("./models/vit_vizwiz_finetuned")
feature_extractor.save_pretrained("./models/vit_vizwiz_finetuned")

print("Model saved successfully!")

Model saved successfully!


## 12. Detailed Analysis (Per Question Type)

This analyzes performance by question type as mentioned in your paper.

In [None]:
predictions = trainer.predict(val_dataset)
pred_labels = np.argmax(predictions.predictions, axis=1)
true_labels = predictions.label_ids

# overview
print("\n" + "="*60)
print("CLASSIFICATION REPORT")
print("="*60)
print(classification_report(
    true_labels, 
    pred_labels, 
    target_names=['Unanswerable', 'Answerable']
))