In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader,Dataset
from transformers import ViTModel, BertModel, BertTokenizer
from sklearn.model_selection import train_test_split

import json
import pandas as pd
from collections import Counter
import cv2
import gc

In [2]:
class CoTransformer(nn.Module):
    def __init__(self, dim):
        super(CoTransformer, self).__init__()
        self.attention = nn.MultiheadAttention(dim, num_heads=8)
        self.linear = nn.Linear(dim, dim)
        self.layer_norm = nn.LayerNorm(dim)

    def forward(self, query, key, value):
        attn_output, _ = self.attention(query, key, value)
        co_transformed_repr = self.layer_norm(self.linear(attn_output) + query)
        return co_transformed_repr

In [3]:
vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
bert_model = BertModel.from_pretrained('bert-base-uncased')

In [4]:
class MyModel(nn.Module):
    def __init__(self, num_classes, dim):
        super(MyModel, self).__init__()
        self.co_transformer_vit_to_bert = CoTransformer(dim=dim)
        self.co_transformer_bert_to_vit = CoTransformer(dim=dim)

        self.layer1 = nn.Sequential(
            nn.Linear(in_features=dim*2, out_features=512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.3),
        )

        self.layer2 = nn.Sequential(
            nn.Linear(in_features=512, out_features=256),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.3),
        )
        
        self.output_layer = nn.Linear(in_features=256, out_features=num_classes)


    def forward(self, img_query, img_key, img_value, ques_query, ques_key, ques_value):
        joint_repr_vit_to_bert = self.co_transformer_vit_to_bert(ques_query, img_key, img_value)
        joint_repr_bert_to_vit = self.co_transformer_bert_to_vit(img_query, ques_key, ques_value)
        combined_repr = torch.cat((joint_repr_vit_to_bert, joint_repr_bert_to_vit), dim=-1)
        x = self.layer1(combined_repr)
        x = self.layer2(x)
        output = self.output_layer(x)
        return output

In [5]:
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image, question, answer = self.data[index]
        return image, question, answer

In [23]:
def getImageName(image_path, image_id):

    path = image_path+"COCO_train2014_"
    output = "0" * (12 - len(str(image_id))) + str(image_id)
    path = path+output+".jpg"
    return path


In [7]:
def filterMajoritySingleWord(answer_list):

    single_word_answers = [entry["answer"] for entry in answer_list if len(entry["answer"].split()) == 1]
    if (len(single_word_answers) == 0):
        single_word_answers = [entry["answer"] for entry in answer_list]

    answer_counts = Counter(single_word_answers)

    majority_answer_count = max(answer_counts.values())
    majority_answers = [answer for answer, count in answer_counts.items() if count == majority_answer_count]

    return majority_answers[0]

def load_image(image_path):
    image = cv2.imread(image_path)
    return cv2.resize(image, (224, 224))

In [8]:
# image_path = "./Subset_train2014"

# with open('output_image_questions.json') as questions_json:
#     images_questions = json.load(questions_json)

# with open('output_questions_answers.json') as questions_json:
#     questions_answers = json.load(questions_json)

# # print(list(images_questions.keys())[:5])
# # print(list(questions_answers.values())[:5])

# # print(filterMajoritySingleWord(questions_answers['36000']))

# l = [(image_id, question['question'], str(question['question_id'])) for image_id, questions in images_questions.items() for question in questions]

# dummy_data = [(tup[0], tup[1], filterMajoritySingleWord(questions_answers[tup[2]])) for tup in l]
# dummy_data = [data for data in dummy_data if data[2] != ""]

# df = pd.DataFrame(dummy_data, columns=['image', 'question_id', 'answer'])
# df.to_csv("Final.csv", index=False)

# del l
# gc.collect()

In [25]:
df = pd.read_csv('Final.csv')
image_path = "./Subset_train2014/"

dummy_data = [tuple(row) for row in df.values]
print(dummy_data[:5])

answer_vocabulary = set([answer for _, _, answer in dummy_data])
# print(answer_vocabulary)

# # Create a mapping between answers and class label numbers
answer_to_label = {answer: label for label, answer in enumerate(answer_vocabulary)}

dummy_data_new = [(load_image(getImageName(image_path, image_id)), question, answer_to_label.get(answer, -1)) for (image_id, question, answer) in dummy_data]

# for i, (image_id, question, answer) in enumerate(dummy_data):
#     label_number = answer_to_label.get(answer, -1)  # Use -1 as default label for unknown answers
#     image = load_image(getImageName(image_path, image_id))
#     dummy_data[i] = (image, question, label_number)

dummy_data = dummy_data_new

print(dummy_data[:5])

[(36, 'Is she wearing a bathing suit?', 'yes'), (36, 'What color is the umbrella?', 'pink'), (36, 'Why is the girl holding an umbrella?', 'sun'), (64, 'Who made the cock?', 'rolex'), (64, 'Are there numbers on the clock face?', 'no')]


: 

In [10]:
# Split dataset into training and validation sets
train_data, val_data = train_test_split(dummy_data, test_size=0.2)

# Define data loaders
train_loader = DataLoader(MyDataset(train_data), batch_size=32, shuffle=True)
val_loader = DataLoader(MyDataset(val_data), batch_size=32, shuffle=False)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [11]:
num_classes = len(answer_vocabulary)

In [12]:
model = MyModel(num_classes=num_classes, dim=768)  # Assuming dim=768 for ViT and BERT

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [13]:
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    for images, questions, answers in train_loader:
        optimizer.zero_grad()

        # Process images
        img_output = vit_model(pixel_values=images)
        img_query, img_key, img_value = img_output.last_hidden_state.split(1, dim=-1)

        # Tokenize questions
        question_tokens = tokenizer(questions, return_tensors='pt', padding=True, truncation=True)
        ques_output = bert_model(**question_tokens)
        ques_query, ques_key, ques_value = ques_output.last_hidden_state.split(1, dim=-1)

        # Forward pass
        output = model(img_query.squeeze(0), img_key.squeeze(0), img_value.squeeze(0), ques_query.squeeze(0), ques_key.squeeze(0), ques_value.squeeze(0))

        # Compute loss
        loss = criterion(output, answers)

        # Backward pass
        loss.backward()
        optimizer.step()

    # Validate the model
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, questions, answers in val_loader:
            # Process images
            img_output = vit_model(pixel_values=images)
            img_query, img_key, img_value = img_output.last_hidden_state.split(1, dim=-1)

            # Tokenize questions
            question_tokens = tokenizer(questions, return_tensors='pt', padding=True, truncation=True)
            ques_output = bert_model(**question_tokens)
            ques_query, ques_key, ques_value = ques_output.last_hidden_state.split(1, dim=-1)

            # Forward pass
            output = model(img_query.squeeze(0), img_key.squeeze(0), img_value.squeeze(0), ques_query.squeeze(0), ques_key.squeeze(0), ques_value.squeeze(0))

            # Compute loss
            val_loss += criterion(output, answers).item()

            # Compute accuracy
            _, predicted = torch.max(output, 1)
            total += answers.size(0)
            correct += (predicted == answers).sum().item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {val_loss/len(val_loader)}, Accuracy: {(correct/total)*100}%")