In [1]:
import torch
import torch.nn as nn
import os
import numpy as np
import pandas as pd
import timm
import matplotlib.pyplot as plt

from PIL import Image
from torch.utils.data import Dataset, DataLoader
from transformers import ViTModel, ViTImageProcessor
from transformers import AutoTokenizer, RobertaModel
import torch.nn.functional as F


  from .autonotebook import tqdm as notebook_tqdm
2024-01-07 01:52:37.789590: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-07 01:52:37.789645: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-07 01:52:37.790894: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-01-07 01:52:37.799151: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler fl

In [2]:
# Load train data
train_data = []
train_set_path = 'data/vaq2.0.TrainImages.txt'

with open(train_set_path, "r") as f:
    lines = f.readlines()
    for line in lines:
        temp = line.split('\t')
        qa = temp[1].split('?')

        if len(qa) == 3:
            answer = qa[2].strip()
        else:
            answer = qa[1].strip()

        data_sample = {
            'image_path': temp[0][:-2],
            'question': qa[0] + '?',
            'answer': answer
        }
        train_data.append(data_sample)

# Load val data
val_data = []
val_set_path = 'data/vaq2.0.DevImages.txt'

with open(val_set_path, "r") as f:
    lines = f.readlines()
    for line in lines:
        temp = line.split('\t')
        qa = temp[1].split('?')

        if len(qa) == 3:
            answer = qa[2].strip()
        else:
            answer = qa[1].strip()

        data_sample = {
            'image_path': temp[0][:-2],
            'question': qa[0] + '?',
            'answer': answer
        }
        val_data.append(data_sample)

# Load test data
test_data = []
test_set_path = 'data/vaq2.0.TestImages.txt'

with open(test_set_path, "r") as f:
    lines = f.readlines()
    for line in lines:
        temp = line.split('\t')
        qa = temp[1].split('?')

        if len(qa) == 3:
            answer = qa[2].strip()
        else:
            answer = qa[1].strip()

        data_sample = {
            'image_path': temp[0][:-2],
            'question': qa[0] + '?',
            'answer': answer
        }
        test_data.append(data_sample)


In [3]:
classes = set([sample['answer'] for sample in train_data])

classes_to_idx = {cls_name: idx for idx, cls_name in enumerate(classes)}

idx_to_classes = {idx: cls_name for idx, cls_name in enumerate(classes)}


In [4]:
class VQADataset(Dataset):
    def __init__(
        self,
        data,
        classes_to_idx,
        img_feature_extractor,
        text_tokenizer,
        device,
        root_dir='/space/hotel/bachn/VQA/data/val2014-resised'
    ):
        self.data = data
        self.root_dir = root_dir
        self.classes_to_idx = classes_to_idx
        self.img_feature_extractor = img_feature_extractor
        self.text_tokenizer = text_tokenizer
        self.device = device

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img_path = os.path.join(self.root_dir, self.data[index]['image_path'])
        img = Image.open(img_path).convert('RGB')

        if self.img_feature_extractor:
            img = self.img_feature_extractor(images=img, return_tensors="pt")
            img = {k: v.to(self.device).squeeze(0) for k, v in img.items()}

        question = self.data[index]['question']
        if self.text_tokenizer:
            question = self.text_tokenizer(
                question,
                padding="max_length",
                max_length=20,
                truncation=True,
                return_tensors="pt"
            )
            question = {k: v.to(self.device).squeeze(0) for k, v in question.items()}

        label = self.data[index]['answer']
        label = torch.tensor(
            classes_to_idx[label],
            dtype=torch.long
        ).to(self.device)

        sample = {
            'image': img,
            'question': question,
            'label': label
        }

        return sample


In [5]:
img_feature_extractor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
text_tokenizer = AutoTokenizer.from_pretrained("roberta-base")
device = 'cuda' if torch.cuda.is_available() else 'cpu'

train_dataset = VQADataset(
    train_data,
    classes_to_idx=classes_to_idx,
    img_feature_extractor=img_feature_extractor,
    text_tokenizer=text_tokenizer,
    device=device
)

val_dataset = VQADataset(
    val_data,
    classes_to_idx=classes_to_idx,
    img_feature_extractor=img_feature_extractor,
    text_tokenizer=text_tokenizer,
    device=device
)

test_dataset = VQADataset(
    test_data,
    classes_to_idx=classes_to_idx,
    img_feature_extractor=img_feature_extractor,
    text_tokenizer=text_tokenizer,
    device=device
)
train_batch_size = 128
test_batch_size = 32

train_loader = DataLoader(
    train_dataset,
    batch_size=train_batch_size,
    shuffle=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=test_batch_size,
    shuffle=False
)

test_loader = DataLoader(
    test_dataset,
    batch_size=test_batch_size,
    shuffle=False
)

In [6]:
# for k in train_loader:
#     print(k)

In [7]:
# Contrastive Loss Function (e.g., NT-Xent)
class NTXentLoss(nn.Module):
    def __init__(self, temperature, device):
        super(NTXentLoss, self).__init__()
        self.temperature = temperature
        self.device = device
        self.criterion = nn.CrossEntropyLoss().to(device)

    def forward(self, z_i, z_j):
        N, Z = z_i.shape  # batch size and feature dimension
        z = torch.cat((z_i, z_j), dim=0)
        sim = torch.mm(z, z.T) / self.temperature  # cosine similarity
        sim_i_j = torch.diag(sim, N)
        sim_j_i = torch.diag(sim, -N)
        positives = torch.cat((sim_i_j, sim_j_i), dim=0).view(2 * N, 1)
        negatives = sim[~torch.eye(2 * N, dtype=bool, device=self.device)].view(2 * N, -1)

        labels = torch.zeros(2 * N).to(self.device).long()
        logits = torch.cat((positives, negatives), dim=1)
        loss = self.criterion(logits, labels)
        return loss


In [8]:
class TextEncoder(nn.Module):
    def __init__(self):
        super(TextEncoder, self).__init__()
        self.model = RobertaModel.from_pretrained("roberta-base")

    def forward(self, inputs):
        outputs = self.model(**inputs)
        return outputs.pooler_output
    
    
class VisualEncoder(nn.Module):
    def __init__(self):
        super(VisualEncoder, self).__init__()
        self.model = ViTModel.from_pretrained("google/vit-base-patch16-224")

    def forward(self, inputs):
        outputs = self.model(**inputs)
        return outputs.pooler_output
    
    
class Classifier(nn.Module):
    def __init__(
        self,
        input_size=768*2,
        hidden_size=512,
        n_layers=1,
        dropout_prob=0.2,
        n_classes=2
    ):
        super(Classifier, self).__init__()
        self.lstm = nn.LSTM(
            input_size,
            hidden_size,
            num_layers=n_layers,
            batch_first=True,
            bidirectional=True
        )
        self.dropout = nn.Dropout(dropout_prob)
        self.fc1 = nn.Linear(hidden_size*2, n_classes)

    def forward(self, x):
        x, _ = self.lstm(x)
        x = self.dropout(x)
        x = self.fc1(x)
        return x

In [9]:
class VQAModel(nn.Module):
    def __init__(
        self,
        visual_encoder,
        text_encoder,
        classifier,
        temperature=0.07  # Temperature parameter for contrastive loss
    ):
        super(VQAModel, self).__init__()
        self.visual_encoder = visual_encoder
        self.text_encoder = text_encoder
        self.classifier = classifier
        self.temperature = temperature

    def forward(self, image, question, return_embedding=False):
        text_out = self.text_encoder(question)
        image_out = self.visual_encoder(image)
        if return_embedding:
            return text_out, image_out
        combined = torch.cat((text_out, image_out), dim=1)
        output = self.classifier(combined)
        return output

    def freeze(self, visual=True, textual=True, clas=False):
        if visual:
            for n, p in self.visual_encoder.named_parameters():
                p.requires_grad = False
        if textual:
            for n, p in self.text_encoder.named_parameters():
                p.requires_grad = False
        if clas:
            for n, p in self.classifier.named_parameters():
                p.requires_grad = False


In [10]:
n_classes = len(classes)
hidden_size = 1024
n_layers = 1
dropout_prob = 0.2

text_encoder = TextEncoder().to(device)
visual_encoder = VisualEncoder().to(device)
classifier = Classifier(
    hidden_size=hidden_size,
    n_layers=n_layers,
    dropout_prob=dropout_prob,
    n_classes=n_classes
).to(device)

model = VQAModel(
    visual_encoder=visual_encoder,
    text_encoder=text_encoder,
    classifier=classifier
).to(device)
model.freeze()


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.weight', 'vit.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
# Training function with contrastive learning
def fit_with_contrastive_learning(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs, contrastive_loss_fn):
    for epoch in range(epochs):
        model.train()
        total_loss, total_contrastive_loss, total_vqa_loss = 0, 0, 0

        for batch in train_loader:
            optimizer.zero_grad()

            images = batch['image']
            questions = batch['question']
            labels = batch['label']

            # Forward pass for VQA task
            outputs = model(images, questions)
            vqa_loss = criterion(outputs, labels)

            # Forward pass for contrastive learning
            text_embedding, image_embedding = model(images, questions, return_embedding=True)
            contrastive_loss = contrastive_loss_fn(text_embedding, image_embedding)

            # Combine losses and backpropagate
            loss = vqa_loss + contrastive_loss
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_contrastive_loss += contrastive_loss.item()
            total_vqa_loss += vqa_loss.item()

        scheduler.step()
        avg_loss = total_loss / len(train_loader)
        avg_contrastive_loss = total_contrastive_loss / len(train_loader)
        avg_vqa_loss = total_vqa_loss / len(train_loader)

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}, VQA Loss: {avg_vqa_loss:.4f}, Contrastive Loss: {avg_contrastive_loss:.4f}")

        # Evaluate on validation set
        val_loss, val_acc = evaluate(model, val_loader, criterion)
        print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}")

# Evaluation function
def evaluate(model, dataloader, criterion):
    model.eval()
    total_loss, total_correct, total_samples = 0, 0, 0

    with torch.no_grad():
        for batch in dataloader:
            images = batch['image']
            questions = batch['question']
            labels = batch['label']

            outputs = model(images, questions)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_correct += (predicted == labels).sum().item()
            total_samples += labels.size(0)

    avg_loss = total_loss / len(dataloader)
    accuracy = total_correct / total_samples
    return avg_loss, accuracy

# Instantiate the model, loss functions, optimizer, and scheduler
contrastive_loss_fn = NTXentLoss(temperature=0.5, device=device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

# Train the model
fit_with_contrastive_learning(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs=50, contrastive_loss_fn=contrastive_loss_fn)

# Evaluate the model
val_loss, val_acc = evaluate(model, val_loader, criterion)
print(f'Validation Loss: {val_loss}, Validation Accuracy: {val_acc}')

Epoch [1/50], Loss: 149.0297, VQA Loss: 1.2482, Contrastive Loss: 147.7815
Validation Loss: 0.6972, Validation Accuracy: 0.4621
Epoch [2/50], Loss: 147.7884, VQA Loss: 0.7416, Contrastive Loss: 147.0468
Validation Loss: 0.6983, Validation Accuracy: 0.5359
Epoch [3/50], Loss: 147.8898, VQA Loss: 0.7478, Contrastive Loss: 147.1420
Validation Loss: 0.7987, Validation Accuracy: 0.5359
Epoch [4/50], Loss: 147.9100, VQA Loss: 0.7639, Contrastive Loss: 147.1460
Validation Loss: 0.7976, Validation Accuracy: 0.4641
Epoch [5/50], Loss: 147.7115, VQA Loss: 0.7486, Contrastive Loss: 146.9629
Validation Loss: 0.8079, Validation Accuracy: 0.4641
Epoch [6/50], Loss: 148.1384, VQA Loss: 0.7380, Contrastive Loss: 147.4004
Validation Loss: 0.6898, Validation Accuracy: 0.5364
Epoch [7/50], Loss: 148.0665, VQA Loss: 0.7341, Contrastive Loss: 147.3323
Validation Loss: 0.7016, Validation Accuracy: 0.5369
