Dependencies

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from transformers import CLIPProcessor, CLIPModel
from torchvision import transforms
from torchvision.models import efficientnet_b3, EfficientNet_B3_Weights
from tqdm import tqdm
from PIL import Image
import json
import random

# Check GPU Availibility

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load VizWiz Dataset

In [None]:
annotations_file = "VizWiz-2023/annotations/train.json"
image_dir = "VizWiz-2023/images/train"

with open(annotations_file, "r", encoding="utf-8") as f:
    annotations = json.load(f)

print(f"Total annotations loaded: {len(annotations)}")

# Analyze Dataset

In [None]:
unanswerable_count = sum(1 for ann in annotations if ann["answer"] == "unanswerable")
valid_count = len(annotations) - unanswerable_count
print(f"Valid answers: {valid_count}, 'Unanswerable' answers: {unanswerable_count}")

# Debug Dataset

In [None]:
print("Sample Debug Data:")
for i in range(3):
    print(f"Sample {i+1}: {annotations[i]}")

# Build Answer Vocab

In [None]:
answer_vocab = {ann["answer"] for ann in annotations}
answer_vocab = {ans: idx for idx, ans in enumerate(sorted(answer_vocab))}
print(f"Answer vocabulary size: {len(answer_vocab)}")


# Preprocessing

In [None]:
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
])

# Balance Unanswerable Answers

In [None]:
valid = [ann for ann in annotations if ann["answer"] != "unanswerable"]
unanswerable = [ann for ann in annotations if ann["answer"] == "unanswerable"]
annotations = valid + random.sample(unanswerable, min(len(valid) * 3 // 4, len(unanswerable)))

# Dataset Class

In [None]:
class VizWizDataset(Dataset):
    def __init__(self, annotations, image_dir, answer_vocab):
        self.annotations = annotations
        self.image_dir = image_dir
        self.answer_vocab = answer_vocab
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        annotation = self.annotations[idx]
        image_path = os.path.join(self.image_dir, annotation["image_id"])
        image = Image.open(image_path).convert("RGB")
        image = image_transform(image)
        
        question = annotation["question"]
        tokenized_question = processor.tokenizer(
            question, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
        
        answer = self.answer_vocab.get(annotation["answer"], -1)
        
        return {
            "image": image,
            "input_ids": tokenized_question["input_ids"].squeeze(0),
            "attention_mask": tokenized_question["attention_mask"].squeeze(0),
            "answer": answer,
        }

# Create Train/Val Split

In [None]:
dataset = VizWizDataset(annotations, image_dir, answer_vocab)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])


# Define Model

In [None]:
class VQAModel(nn.Module):
    def __init__(self, num_classes):
        super(VQAModel, self).__init__()
        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.image_encoder = efficientnet_b3(weights=EfficientNet_B3_Weights.IMAGENET1K_V1)
        self.image_encoder.classifier = nn.Linear(self.image_encoder.classifier[1].in_features, 512)
        self.classifier = nn.Sequential(
            nn.Linear(512 + self.clip_model.config.projection_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes),
        )
    
    def forward(self, images, input_ids, attention_mask):
        text_features = self.clip_model.get_text_features(input_ids=input_ids, attention_mask=attention_mask)
        image_features = self.image_encoder(images)
        combined_features = torch.cat((text_features, image_features), dim=1)
        logits = self.classifier(combined_features)
        return logits

# Initialize Model

In [None]:
num_classes = len(answer_vocab)
model = VQAModel(num_classes=num_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=4)
criterion = nn.CrossEntropyLoss()

# Training Loop

In [None]:
epochs = 15
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{epochs}"):
        images = batch["image"].to(device)
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        answers = batch["answer"].to(device)

        optimizer.zero_grad()
        outputs = model(images, input_ids, attention_mask)
        loss = criterion(outputs, answers)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    scheduler.step(running_loss)
    print(f"Epoch {epoch + 1}/{epochs}, Loss: {running_loss / len(train_dataloader)}")
    torch.save(model.state_dict(), "vizwiz_checkpoint.pth")

# Evaluation

In [None]:
total_loss, correct, total = 0.0, 0, 0
model.eval()
with torch.no_grad():
    for batch in tqdm(val_dataloader, desc="Evaluating"):
        images = batch["image"].to(device)
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        answers = batch["answer"].to(device)

        outputs = model(images, input_ids, attention_mask)
        loss = criterion(outputs, answers)
        total_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == answers).sum().item()
        total += answers.size(0)

accuracy = correct / total * 100
print(f"Validation Loss: {total_loss / len(val_dataloader):.4f}, Accuracy: {accuracy:.2f}%")
