# Todo

- [ ] training loop
- [ ] metrics computation
- [ ] tensorboard/wandb

In [1]:
import datasets
from transformers import AutoImageProcessor, AutoModel, AutoTokenizer
from huggingface_hub import login

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import StepLR

import math
import numpy as np
from sklearn.preprocessing import LabelEncoder

from tqdm import tqdm
import matplotlib.pyplot as plt

from dataclasses import dataclass

2024-05-12 17:07:45.983321: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-12 17:07:45.983443: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-12 17:07:46.119669: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
dataset = datasets.load_dataset("Gapes21/vqa2", split = "train")

Downloading readme:   0%|          | 0.00/361 [00:00<?, ?B/s]

Downloading data: 100%|██████████| 487M/487M [00:02<00:00, 241MB/s]  
Downloading data: 100%|██████████| 490M/490M [00:01<00:00, 258MB/s]  
Downloading data: 100%|██████████| 487M/487M [00:01<00:00, 263MB/s]  
Downloading data: 100%|██████████| 490M/490M [00:02<00:00, 234MB/s]  
Downloading data: 100%|██████████| 486M/486M [00:02<00:00, 213MB/s]  
Downloading data: 100%|██████████| 490M/490M [00:02<00:00, 233MB/s]  
Downloading data: 100%|██████████| 485M/485M [00:02<00:00, 228MB/s]  
Downloading data: 100%|██████████| 485M/485M [00:02<00:00, 237MB/s]  
Downloading data: 100%|██████████| 487M/487M [00:02<00:00, 216MB/s]  
Downloading data: 100%|██████████| 490M/490M [00:02<00:00, 237MB/s]  
Downloading data: 100%|██████████| 490M/490M [00:02<00:00, 235MB/s]  


Generating train split:   0%|          | 0/109485 [00:00<?, ? examples/s]

In [5]:
labelEncoder = LabelEncoder()
labelEncoder.fit(dataset['answer'])

In [6]:
BERT = "FacebookAI/roberta-base"
VIT = 'facebook/dinov2-base'

In [7]:
processor = AutoImageProcessor.from_pretrained(VIT)
tokenizer = AutoTokenizer.from_pretrained(BERT)

preprocessor_config.json:   0%|          | 0.00/436 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [8]:
class SastaLoader:
    def __init__(self, dataset, batch_size, collator_fn):
        self.dataset = dataset.shuffle()
        self.collator_fn = collator_fn
        self.len = len(self.dataset)
        self.batch_size = batch_size
        self.index = 0

    def hasNext(self):
        return self.index + self.batch_size <= self.len
    
    def reset(self):
        self.index = 0
        
    def __iter__(self):
        return self

    def __next__(self):
        if self.index >= self.len:
            raise StopIteration
        batch = self.dataset[self.index: self.index + self.batch_size]
        batch = self.collator_fn(batch)
        self.index += self.batch_size
        return batch
    
    def __len__(self):
        return self.len

In [9]:
def sasta_collator(batch):
    # process images
    images = processor(images = batch['image'], return_tensors="pt")['pixel_values']

    # preprocess questions
    questions = tokenizer(
            text=batch['question'],
            padding='longest',
            max_length=24,
            truncation=True,
            return_tensors='pt',
            return_attention_mask=True,
        )

    # process labels
    labels = torch.Tensor(labelEncoder.transform(batch['answer']))

    return (images, questions, labels)


In [10]:
class VQAModel(nn.Module):
    def __init__(
        self,
        num_labels,
        intermediate_dim,
        pretrained_text_name,
        pretrained_image_name
    ):
        super(VQAModel, self).__init__()
        
        self.num_labels = num_labels
        self.intermediate_dim = intermediate_dim
        self.pretrained_text_name = pretrained_text_name
        self.pretrained_image_name = pretrained_image_name
        
        # Text and image encoders
        
        self.text_encoder = AutoModel.from_pretrained(self.pretrained_text_name)
        self.image_encoder = AutoModel.from_pretrained(self.pretrained_image_name)

        assert(self.text_encoder.config.hidden_size == self.image_encoder.config.hidden_size)

        self.embedd_dim = self.text_encoder.config.hidden_size

        # Cross attentions
        self.textq = nn.MultiheadAttention(self.embedd_dim, 1, 0.1, batch_first=True)
        self.imgq = nn.MultiheadAttention(self.embedd_dim, 1, 0.1, batch_first=True)
        
        # Classifier
        self.classifier = nn.Linear(self.embedd_dim, self.num_labels)

    def forward(
        self,
        input_ids,
        pixel_values,
        attention_mask
    ):
        # Encode text with masking
        encoded_text = self.text_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        
        # Encode images
        encoded_image = self.image_encoder(
            pixel_values=pixel_values,
        )
        
        text = encoded_text.last_hidden_state
        img = encoded_image.last_hidden_state

        textcls = self.textq(text, img, img)[0][:, 0, :]
        imgcls = self.imgq(img, text, text)[0][:, 0, :]

        cls = textcls+imgcls
        
        # Make predictions
        logits = self.classifier(cls)
        
        return logits

## Training

#### Model, optimizer and loss

In [11]:
def save_model(model, name):
    torch.save(model.state_dict(), name)

def initVQA():
    model = VQAModel(len(labelEncoder.classes_), 512, BERT, VIT).to(device)
    return model

def load_model(name, backup = initVQA):
    model = backup()
    try : 
        model.load_state_dict(torch.load(f"/kaggle/working/{name}"))
        print("Loaded model successfully.")
    except:
        print("Couldn't find model. Initializing from scratch.")
    return model

In [12]:
model = load_model("vqa_dr.pth")
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


config.json:   0%|          | 0.00/548 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Loaded model successfully.


#### Hyperparams

In [14]:
collator_fn = sasta_collator
loader = SastaLoader(dataset, 16, sasta_collator)
num_epochs = 2

In [17]:
def train(model, optimizer, criterion, loader, num_epochs, device):
    scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
    loss_plot, accuracy_plot = [], []
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        correct = 0
        total_samples = 0
        with tqdm(total=len(loader), desc="Processing batches", dynamic_ncols=True) as pbar:
            for batchidx, batch in enumerate(loader):
                ids = batch[1]['input_ids'].to(device)
                pxlvalues = batch[0].to(device)
                masks = batch[1]['attention_mask'].to(device)
                labels = batch[2].to(device)

                optimizer.zero_grad()
                outputs = model(ids, pxlvalues, masks)
                loss = criterion(outputs, labels.long())
                loss.backward()
                optimizer.step()

                total_loss += loss.item() * loader.batch_size
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total_samples += labels.size(0)
                pbar.update(loader.batch_size)
                if batchidx % 16000 <= 1:
                    save_model(model, 'vqa_dr.pth')
                
        epoch_loss = total_loss / total_samples
        accuracy = correct / total_samples
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {accuracy:.4f}")
        accuracy_plot.append(accuracy * 100)
        loss_plot.append(epoch_loss)
        save_model(model, "vqa_dr.pth")
        scheduler.step()
        loader.reset()
    plt.plot(accuracy_plot)
    plt.plot(loss_plot)

In [None]:
train(model, optimizer, criterion, loader, num_epochs, device)

Processing batches: 100%|█████████▉| 109472/109485 [1:09:54<00:00, 26.10it/s]


Epoch 1/2, Loss: 5.5683, Accuracy: 0.1905


Processing batches:  54%|█████▍    | 58896/109485 [37:41<32:31, 25.92it/s]  

In [None]:
save_model(model, "vqr_dr2.pth")