# Todo

- [ ] training loop
- [ ] metrics computation
- [ ] tensorboard/wandb

In [1]:
import datasets
from transformers import AutoImageProcessor, AutoModel, AutoTokenizer

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm
import matplotlib.pyplot as plt
from dataclasses import dataclass


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
dataset = datasets.load_dataset("Gapes21/vqa2", split = "train")

In [4]:
labelEncoder = LabelEncoder()
labelEncoder.fit(dataset['answer'])

In [5]:
BERT = "FacebookAI/roberta-base"
VIT = 'facebook/dinov2-base'

In [6]:
processor = AutoImageProcessor.from_pretrained(VIT)
tokenizer = AutoTokenizer.from_pretrained(BERT)

In [7]:
class SastaLoader:
    def __init__(self, dataset, batch_size, collator_fn):
        self.dataset = dataset.shuffle()
        self.collator_fn = collator_fn
        self.len = len(self.dataset)
        self.batch_size = batch_size
        self.index = 0

    def hasNext(self):
        return self.index + self.batch_size <= self.len
    
    def reset(self):
        self.index = 0

    def next(self):
        batch = self.dataset[self.index: self.index+self.batch_size]
        batch = self.collator_fn(batch)
        self.index += self.batch_size
        return batch 


In [8]:
def sasta_collator(batch):

    # process images
    images = processor(images = batch['image'], return_tensors="pt")['pixel_values']

    # preprocess questions
    questions = tokenizer(
            text=batch['question'],
            padding='longest',
            max_length=24,
            truncation=True,
            return_tensors='pt',
            return_attention_mask=True,
        )

    # process labels
    labels = torch.Tensor(labelEncoder.transform(batch['answer']))

    return (images, questions, labels)


In [14]:
class VQAModel(nn.Module):
    def __init__(
        self,
        num_labels,
        intermediate_dim,
        pretrained_text_name,
        pretrained_image_name
    ):
        super(VQAModel, self).__init__()
        
        self.num_labels = num_labels
        self.intermediate_dim = intermediate_dim
        self.pretrained_text_name = pretrained_text_name
        self.pretrained_image_name = pretrained_image_name
        
        # Text and image encoders
        
        self.text_encoder = AutoModel.from_pretrained(self.pretrained_text_name)
        self.image_encoder = AutoModel.from_pretrained(self.pretrained_image_name)

        assert(self.text_encoder.config.hidden_size == self.image_encoder.config.hidden_size)

        self.embedd_dim = self.text_encoder.config.hidden_size

        # Cross attentions
        self.textq = nn.MultiheadAttention(self.embedd_dim, 1, 0.1, batch_first=True)
        self.imgq = nn.MultiheadAttention(self.embedd_dim, 1, 0.1, batch_first=True)
        
        # Classifier
        self.classifier = nn.Linear(self.embedd_dim, self.num_labels)
        


    def forward(
        self,
        input_ids,
        pixel_values,
        attention_mask,
        labels
    ):
        # Encode text with masking
        encoded_text = self.text_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        
        # Encode images
        encoded_image = self.image_encoder(
            pixel_values=pixel_values,
        )
        
        text = encoded_text.last_hidden_state
        img = encoded_image.last_hidden_state

        textcls = self.textq(text, img, img)[0][:, 0, :]
        imgcls = self.imgq(img, text, text)[0][:, 0, :]

        cls = textcls+imgcls
        
        # Make predictions
        logits = self.classifier(cls)
        
        return logits

In [15]:
loader = SastaLoader(dataset, 4, sasta_collator)

In [16]:
model = VQAModel(len(labelEncoder.classes_), 512, BERT, VIT).to(device)

Some weights of RobertaModel were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [17]:
batch = loader.next()

In [18]:
tz = model(batch[1]['input_ids'].to(device), batch[0].to(device), batch[1]['attention_mask'].to(device), batch[2].to(device))