# Training Coattention Model

In [1]:
from transformers import BertModel, BertTokenizer, ViTImageProcessor, ViTModel
import torch
from torchinfo import summary
from torch import nn
from torch.nn import Transformer, TransformerDecoder, TransformerDecoderLayer, TransformerEncoder, TransformerEncoderLayer

2024-05-18 08:22:21.398514: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-18 08:22:21.398567: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-18 08:22:21.399973: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')

## Initialize Necessary Modules

In [3]:
device

device(type='cuda')

In [4]:
class TextTokenizer(torch.nn.Module):
    def __init__(
        self,
        text_tokenizer=BertTokenizer,
        max_length=25  # Add a max_length parameter
    ):
        super().__init__()
        self.text_tokenizer = text_tokenizer.from_pretrained('bert-base-uncased')
        self.max_length = max_length  # Store the max_length

    def forward(self, input_question, padding='max_length', truncation=True):
        tokens = self.text_tokenizer(input_question, return_tensors='pt', 
                                     padding=padding, truncation=truncation, 
                                     max_length=self.max_length).to(device)  # Use max_length

        return tokens

class ImageProcessor(torch.nn.Module):
    def __init__(
        self,
        image_model_processor=ViTImageProcessor
    ):

        super().__init__()
        self.image_model_processor = image_model_processor.from_pretrained('google/vit-base-patch16-224-in21k')

    def forward(self, image):
        image = self.image_model_processor(image, return_tensors='pt').to(device)

        return image

class TextEmbedding(torch.nn.Module):
    def __init__(
        self,
        text_model=BertModel,
    ):
        super().__init__()
        self.text_model = text_model.from_pretrained('bert-base-uncased').to(device)


    def forward(self, tokens):
        text_output = self.text_model(input_ids=tokens.input_ids, attention_mask=tokens.attention_mask)
        text_output = text_output.last_hidden_state     # CLS token from the last layer

        return text_output


class ImageEmbedding(torch.nn.Module):
    def __init__(
            self, 
            image_model=ViTModel
        ):
        
        super().__init__()
        self.image_model = image_model.from_pretrained('google/vit-base-patch16-224-in21k').to(device)


    def forward(self, image):
        image_output = self.image_model(pixel_values=image.pixel_values).last_hidden_state

        return image_output

In [5]:
from notebook.services.config import ConfigManager
cm = ConfigManager()
cm.update('notebook', {
    'NotebookApp': {
        'iopub_msg_rate_limit': 10000,
        'rate_limit_window': 10.0
    }
})

{'NotebookApp': {'iopub_msg_rate_limit': 10000, 'rate_limit_window': 10.0}}

In [6]:
from datasets import load_dataset

dataset = load_dataset("HuggingFaceM4/VQAv2")

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
Repo card metadata block was not found. Setting CardData to empty.


In [7]:
train_dataset = dataset['train']
# test_dataset = dataset['test']
# val_dataset = dataset['validation']

## Load Dataset

In [8]:
import pandas as pd

train_df = pd.read_csv('/kaggle/input/vqadataset/vqa_train_dataset.csv')
val_df = pd.read_csv('/kaggle/input/vqadataset/vqa_val_dataset.csv')

train_df = train_df[~train_df['answers'].isna()]
val_df = val_df[~val_df['answers'].isna()]

In [9]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor

class VQADataset(Dataset):
    def __init__(self, dataframe, image_dataset):
        self.dataframe = dataframe
        self.image_dataset = image_dataset
        self.text_tokenizer = TextTokenizer()
        self.image_processor = ImageProcessor()

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        
        ind = int(row['index'])
        image = self.image_dataset[ind]['image']
        question = row['question']
        answer = row['answers']
        
        # sanity check        
        assert self.image_dataset[ind]['question'] == question, "Mismatching training and Image data"

        
        image = image.convert('RGB')

        tokens = self.text_tokenizer(question, padding='max_length', truncation=True)
        tokens.input_ids = tokens.input_ids.squeeze()
        tokens.attention_mask = tokens.attention_mask.squeeze()
        image = self.image_processor(image)
        return {
            'image': image,
            'questions': tokens,
            'answer': answer
        }

batch_size = 64
    
train_data = VQADataset(train_df, train_dataset)
val_data = VQADataset(val_df, train_dataset)

train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_data, batch_size=batch_size)

In [10]:
import pickle

with open('/kaggle/input/vqadataset/answers_dictionaries.pkl', 'rb') as f:
    data = pickle.load(f)
    id_to_answer = data['id_to_answer']
    answer_to_id = data['answer_to_id']

# print("Dictionaries have been loaded from answers_dictionaries.pkl")
# print("ID to Answer Dictionary:", id_to_answer)
# print("Answer dede bhai: ", answer_to_id)

## Model Creation

In [11]:
class VQAModel(nn.Module):
    def __init__(
        self,
        dim_model = 768,      # image and text embeddings concatenated
        nhead = 12,                    # No. of Attention heads
        num_layers = 1,               # No. of encoder layers
        num_classes = 8000
    ):
        super().__init__()
        self.text_embedder = TextEmbedding()
        self.image_embedder = ImageEmbedding()

        self.mlp = nn.Sequential(
            nn.Linear(2*dim_model, dim_model),
            nn.ReLU(),
            nn.Dropout(0.4),
        ).to(device)

        self.classifier = nn.Linear(dim_model, num_classes).to(device)
        self.softmax = nn.Softmax(dim=1)


    def forward(self, questions, images):
        question_embedding = self.text_embedder(questions)
        image_embedding = self.image_embedder(images)
        question_embedding = question_embedding[:, 0, :]
        image_embedding = image_embedding[:, 0, :]

        # Concatenate embeddings
        embeddings = torch.cat((question_embedding, image_embedding), dim=1)  # (batch_size, 2*dim_model)

        output = self.mlp(embeddings)  # (batch_size, dim_model)
 
        logits = self.classifier(output)  # (batch_size, num_classes)

        output = self.softmax(logits)  # (batch_size, num_classes)
        


        return output


In [13]:
model = VQAModel(num_classes=len(answer_to_id))

In [14]:
# print([(n, type(m)) for n, m in model().named_modules()])
# !pip install peft

In [15]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)

In [16]:
from torch.utils.tensorboard import SummaryWriter

# Create a SummaryWriter object
writer = SummaryWriter('runs/experiment_mlp')
num_epochs = 3
model.train()

VQAModel(
  (text_embedder): TextEmbedding(
    (text_model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=Tr

In [17]:
checkpoint_ref = 100

## Evaluation Metrics

In [18]:
from torchmetrics.classification import Precision, Recall, Accuracy, F1Score, AUROC

precision_metric = Precision(task="multiclass", num_classes=len(answer_to_id)).to(device)
recall_metric = Recall(task="multiclass", num_classes=len(answer_to_id)).to(device)
accuracy_metric = Accuracy(task="multiclass", num_classes=len(answer_to_id)).to(device)
f1_metric = F1Score(task="multiclass", num_classes=len(answer_to_id)).to(device)
# auroc_metric = AUROC(task="multiclass", num_classes=len(answer_to_id)).to(device)

def evaluate(preds, true):
    p = precision_metric(preds, true)
    r = recall_metric(preds, true)
    a = accuracy_metric(preds, true)
    f = f1_metric(preds, true)
#     am = auroc_metric(preds, true)
    
    return {
        "precision": p,
        "recall": r,
        "accuracy": a,
        "f1": f,
#         "auroc": auroc_metric
    }

## Training Loop

In [19]:
# !mkdir ./checkpoints

In [20]:
import torch.nn.functional as F

batch_no = 0
avg_accuracy = 0

for epoch in range(num_epochs):
    batch_no = 0
    avg_accuracy = 0
    for batch in train_dataloader:
        # Get the inputs and targets from the batch
        images = batch['image']
        questions = batch['questions']
        answers = batch['answer']

        questions.input_ids = questions.input_ids.squeeze()
        questions.attention_mask = questions.attention_mask.squeeze()
        images.pixel_values = images.pixel_values.squeeze()
        
        # Forward pass
        outputs = model(questions, images)

        
        answers_ids = [answer_to_id[key] for key in answers]
        answers_ids_tensor = torch.tensor(answers_ids)
        answers = F.one_hot(answers_ids_tensor, num_classes=len(answer_to_id)).to(outputs.dtype)
        answers = answers.to(device)
    
        
        # Sanity check
#         print(answers.shape, outputs.shape)
        assert answers.shape == outputs.shape, "Target and Predicted shapes don't match"
    
        
        # Compute the loss
        loss = criterion(outputs, answers)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        
        _, preds = torch.max(outputs, 1)
        _, answer_inds = torch.max(answers, 1)
        
        assert preds.shape == answer_inds.shape, f"Preds_shape: {preds.shape}, Answers_shape: {answer_inds.shape}"
        eval_met = evaluate(preds, answer_inds)
        
        iter_val = epoch * len(train_dataloader) + batch_no
        
        
        writer.add_scalar('Training Loss', loss.item(), iter_val)
        writer.add_pr_curve('PR Curve', answers, outputs, iter_val)
        writer.add_scalar('Accuracy', eval_met['accuracy'], iter_val)
        writer.add_scalar('Precision', eval_met['precision'], iter_val)
        writer.add_scalar('Recall', eval_met['recall'], iter_val)
        writer.add_scalar('F1', eval_met['f1'], iter_val)
         
        
        if batch_no % checkpoint_ref == 0:
            torch.save(model.state_dict(), f"./checkpoints/mlp.pth")
            
        batch_no += 1
        print(f"Batch -> {batch_no} done -> Accu: {eval_met['accuracy']:.2f} -> Prec: {eval_met['precision']:.2f}\r", end="")
        avg_accuracy += eval_met['accuracy']
        

    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}, Avg. Epoch Acc: {avg_accuracy / len(train_dataloader)}')
# writer.close()

Epoch 1/3, Loss: 8.815025329589844, Avg. Epoch Acc: 0.19511090219020844
Epoch 2/3, Loss: 8.757333755493164, Avg. Epoch Acc: 0.19554166495800018
Epoch 3/3, Loss: 8.699640274047852, Avg. Epoch Acc: 0.1955474317073822


In [21]:
# model.transformer_encoder_text_1.state_dict()['layers.0.self_attn.in_proj_weight']