In [2]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader,Dataset
from transformers import ViTModel, BertModel, BertTokenizer
from sklearn.model_selection import train_test_split
import json
from collections import Counter
import cv2
import gc

In [None]:
# Define your model architecture
class CoTransformer(nn.Module):
    def __init__(self, dim):
        super(CoTransformer, self).__init__()
        self.attention = nn.MultiheadAttention(dim, num_heads=8)
        self.linear = nn.Linear(dim, dim)
        self.layer_norm = nn.LayerNorm(dim)

    def forward(self, query, key, value):
        attn_output, _ = self.attention(query, key, value)
        co_transformed_repr = self.layer_norm(self.linear(attn_output) + query)
        return co_transformed_repr

In [None]:
vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
bert_model = BertModel.from_pretrained('bert-base-uncased')

In [None]:
class MyModel(nn.Module):
    def __init__(self, num_classes, dim):
        super(MyModel, self).__init__()
        self.query_linear_image = nn.Linear(dim, dim)
        self.query_linear_text = nn.Linear(dim, dim)
        self.co_transformer_vit_to_bert = CoTransformer(dim=dim)
        self.co_transformer_bert_to_vit = CoTransformer(dim=dim)
        
        self.layer1 = nn.Sequential(
            nn.Linear(in_features=dim*2, out_features=512),
            nn.ReLU(),
            nn.Dropout(0.3),
        )

        self.layer2 = nn.Sequential(
            nn.Linear(in_features=512, out_features=256),
            nn.ReLU(),
            nn.Dropout(0.3),
        )
        
        self.output_layer = nn.Linear(in_features=256, out_features=num_classes)

    def forward(self, img_embedding, ques_embedding):
        # Project image and text embeddings to obtain query vectors
        img_query = self.query_linear_image(img_embedding)
        ques_query = self.query_linear_text(ques_embedding)
        
#         print(img_query.shape)
#         print(ques_query.shape)

        # Forward pass through co-attention transformer layers
        joint_repr_vit_to_bert = self.co_transformer_vit_to_bert(ques_query, img_embedding, img_embedding)
        joint_repr_bert_to_vit = self.co_transformer_bert_to_vit(img_query, ques_embedding, ques_embedding)

        # Combine representations
        combined_repr = torch.cat((joint_repr_vit_to_bert, joint_repr_bert_to_vit), dim=-1)
        
        x = self.layer1(combined_repr)
        x = self.layer2(x)
        output = self.output_layer(x)
        return output

In [None]:
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image, question, answer = self.data[index]
        return image, question, answer

In [None]:
def getImageName(image_path, image_id):

    path = image_path+"COCO_train2014_"
    output = "0" * (12 - len(str(image_id))) + str(image_id)
    path = path+output+".jpg"
    return path

In [None]:
def filterMajoritySingleWord(answer_list):

    single_word_answers = [entry["answer"] for entry in answer_list if len(entry["answer"].split()) == 1]
    if (len(single_word_answers) == 0):
        single_word_answers = [entry["answer"] for entry in answer_list]

    answer_counts = Counter(single_word_answers)

    majority_answer_count = max(answer_counts.values())
    majority_answers = [answer for answer, count in answer_counts.items() if count == majority_answer_count]

    return majority_answers[0]

def load_image(image_path):
    # Load image with OpenCV
    image = cv2.imread(image_path)
    # Convert BGR to RGB
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    # Resize image
    image = cv2.resize(image, (224, 224))
#     print(image.shape)
    image = np.transpose(image, (2, 0, 1))  # Change shape from HxWxC to CxHxW
    return image
#     return image

In [None]:
df = pd.read_csv("/kaggle/input/vr-vqa-final-public/Final.csv")
image_path = "/kaggle/input/vr-vqa-final-public/Subset_train2014/Subset_train2014/"

dummy_data = [tuple(row) for row in df.values]

unique_ids = set([image for image,_,_ in dummy_data])

images_dict = {image_id : load_image(getImageName(image_path, image_id)) for image_id in unique_ids}

print(images_dict[36].shape)

In [None]:

# print(dummy_data[:5])

answer_vocabulary = set([answer for _, _, answer in dummy_data])
# print(answer_vocabulary)

# # Create a mapping between answers and class label numbers
answer_to_label = {answer: label for label, answer in enumerate(answer_vocabulary)}

dummy_data_new = [(images_dict[image_id], question, answer_to_label.get(answer, -1)) for (image_id, question, answer) in dummy_data]

# for i, (image_id, question, answer) in enumerate(dummy_data):
#     label_number = answer_to_label.get(answer, -1)  # Use -1 as default label for unknown answers
#     image = load_image(getImageName(image_path, image_id))
#     dummy_data[i] = (image, question, label_number)

dummy_data = dummy_data_new

# print(dummy_data[:5])

In [None]:
train_data, val_data = train_test_split(dummy_data, test_size=0.2)

# Define data loaders
train_loader = DataLoader(MyDataset(train_data), batch_size=32, shuffle=True)
val_loader = DataLoader(MyDataset(val_data), batch_size=32, shuffle=True)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

num_classes = len(answer_vocabulary)

In [None]:
model = MyModel(num_classes=num_classes, dim=768)  # Assuming dim=768 for ViT and BERT

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Load the state dictionary
model_path = '/kaggle/input/cotrm-32-30/pytorch/1/1/cotrm-32-30-novitbert.pth'
model.load_state_dict(torch.load(model_path))

model.eval()


In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score

# Train the model
num_epochs = 10
num_batches_per_epoch = 30

for epoch in range(num_epochs):    
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_samples = 0
    
    val_predicted = []
    val_answers = []

    with torch.no_grad():
        for batch_idx, (images, questions, answers) in enumerate(val_loader):
            
            if batch_idx >= num_batches_per_epoch:
                break
            # Process images
            img_output = vit_model(pixel_values=images).last_hidden_state[:, 0, :]

            # Tokenize questions
            question_tokens = tokenizer(questions, return_tensors='pt', padding=True, truncation=True)
            ques_ids = question_tokens['input_ids']
        
            ques_output = bert_model(ques_ids)[1]

            # Forward pass
            output = model(img_output, ques_output)
            
            # Compute loss
            loss = criterion(output, answers)

            # Update metrics
            val_loss += loss.item()
            _, predicted = torch.max(output, 1)
            val_correct += (predicted == answers).sum().item()
            val_samples += images.size(0)
            
            val_predicted.extend(predicted.cpu().numpy())
            val_answers.extend(answers.cpu().numpy())
    
    # Calculate precision, recall, and F1 score for validation
    val_precision = precision_score(val_answers, val_predicted, average='weighted')
    val_recall = recall_score(val_answers, val_predicted, average='weighted')
    val_f1 = f1_score(val_answers, val_predicted, average='weighted')

    val_avg_loss = val_loss / len(val_loader)
    val_accuracy = (val_correct / val_samples) * 100

    # Print epoch-level metrics for both training and validation data
    print(f"Validation Loss: {val_avg_loss}, Validation Accuracy: {val_accuracy}%, "
          f"Validation Precision: {val_precision}, Validation Recall: {val_recall}, Validation F1: {val_f1}")