In [None]:
# Install required libraries

import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset
from transformers import AutoTokenizer, AutoModel, AutoImageProcessor, ViTModel
from PIL import Image
from sklearn.preprocessing import LabelEncoder
from datasets import load_dataset
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import matplotlib.pyplot as plt
from torchvision.transforms import Compose, Resize, ToTensor, Normalize


from IPython.display import display
import random

In [None]:
# Load the VQAv2 dataset from Hugging Face Hub
# %%timeC
dataset = load_dataset("HuggingFaceM4/VQAv2")

In [None]:
# Access train, validation, and test sets
train_dataset = dataset['train']
test_dataset = dataset['test']
val_dataset = dataset['validation']

print(train_dataset[0]['image'])

image=train_dataset[0]['image']
display(image)
answer=train_dataset[0]['answers']
print(answer)
# If you need to ensure the image is in RGB mode
image = image.convert("RGB")

def load_image(image):
    return image.convert("RGB")

def display_image(image):
    display(image)

In [None]:
# Set up environment
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
# Tokenizer and Image Processor
bert_tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")

In [None]:
class VQADataset(Dataset):
    def __init__(self, dataset, tokenizer, image_processor, label_encoder):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.label_encoder = label_encoder

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        question = item['question']
        answers = item['answers']
        image = item['image'].convert("RGB")

        text_inputs = self.tokenizer(question, padding='max_length', truncation=True, return_tensors="pt")
        image_inputs = self.image_processor(images=[image], return_tensors="pt")

        text_inputs = {k: v.squeeze(0) for k, v in text_inputs.items()}
        image_inputs = {k: v.squeeze(0) for k, v in image_inputs.items()}

        # Convert answers to soft targets
        answer_texts = [answer['answer'] for answer in answers]
        encoded_labels = self.label_encoder.transform(answer_texts)
        one_hot_labels = np.zeros((len(self.label_encoder.classes_)))
        for encoded_label in encoded_labels:
            one_hot_labels[encoded_label] += 1
        one_hot_labels /= len(encoded_labels)  # Average the one-hot vectors

        label = torch.tensor(one_hot_labels, dtype=torch.float)  # Convert to tensor

        return {'text_inputs': text_inputs, 'image_inputs': image_inputs, 'labels': label}


In [None]:
class VQAModel(nn.Module):
    def __init__(self, text_model_name="bert-base-cased", image_model_name="google/vit-base-patch16-224", num_answers=1000):
        super(VQAModel, self).__init__()
        self.text_model = AutoModel.from_pretrained(text_model_name)
        self.image_model = ViTModel.from_pretrained(image_model_name)
        self.text_fc = nn.Linear(self.text_model.config.hidden_size, 512)
        self.image_fc = nn.Linear(self.image_model.config.hidden_size, 512)
        self.classifier = nn.Linear(1024, num_answers)
        
    def forward(self, text_inputs, image_inputs):
        text_outputs = self.text_model(**text_inputs).last_hidden_state[:, 0, :]  # CLS token
        image_outputs = self.image_model(**image_inputs).last_hidden_state[:, 0, :]  # CLS token
        text_features = self.text_fc(text_outputs)
        image_features = self.image_fc(image_outputs)
        combined_features = torch.cat((text_features, image_features), dim=1)
        logits = self.classifier(combined_features)
        return logits

In [None]:
label_encoder = LabelEncoder()
all_answer_texts = [answer['answer'] for example in train_dataset for answer in example['answers']]
label_encoder.fit(all_answer_texts)
random.seed(42)
subset_indices = random.sample(range(len(dataset['train'])), len(dataset['train']) // 4)
subset_train_dataset = Subset(VQADataset(dataset['train'], bert_tokenizer, image_processor, label_encoder), subset_indices)