# Todo

- [ ] training loop
- [ ] metrics computation
- [ ] tensorboard/wandb

In [1]:
import datasets
from transformers import AutoImageProcessor, AutoModel, AutoTokenizer
from huggingface_hub import login

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import StepLR

import math
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.metrics import f1_score, accuracy_score

from tqdm import tqdm
import matplotlib.pyplot as plt
import os

from dataclasses import dataclass

2024-05-14 23:46:03.821093: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-14 23:46:03.821165: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-14 23:46:03.822616: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
login("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")

Token has not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
dataset = datasets.load_dataset("Gapes21/vqa2", split = "train")

In [5]:
labelEncoder = LabelEncoder()
labelEncoder.fit(dataset['answer'])

In [6]:
BERT = "FacebookAI/roberta-base"
VIT = 'facebook/dinov2-base'

In [7]:
processor = AutoImageProcessor.from_pretrained(VIT)
tokenizer = AutoTokenizer.from_pretrained(BERT)

In [8]:
class SastaLoader:
    def __init__(self, dataset, batch_size, collator_fn, train_max = 100000, mode = "train"):
        self.dataset = dataset.shuffle()
        self.collator_fn = collator_fn
        self.len = len(self.dataset)
        self.batch_size = batch_size
        if mode == "train":
            self.index = 0
        else :
            self.index = train_max
        self.train_max = train_max
        self.mode = mode

    def hasNext(self):
        if self.mode == "train":
            return self.index + self.batch_size <= self.train_max
        else :
            return self.index + self.batch.size <= self.len
    
    def reset(self):
        if self.mode == "train":
            self.index = 0
        else:
            self.index = self.train_max
        
    def __iter__(self):
        return self

    def __next__(self):
        if self.mode == "train":
            if self.index >= self.train_max:
                raise StopIteration
        else :
            if self.index >= self.len:
                raise StopIteration
                
        batch = self.dataset[self.index: self.index + self.batch_size]
        batch = self.collator_fn(batch)
        self.index += self.batch_size
        return batch
    
    def __len__(self):
        if self.mode == "train":
            return self.train_max
        return self.len - self.train_max
    
    def train(self):
        self.mode = "train"
        
    def validate(self):
        self.mode = "validation"

In [9]:
def sasta_collator(batch):
    # process images
    images = processor(images = batch['image'], return_tensors="pt")['pixel_values']

    # preprocess questions
    questions = tokenizer(
            text=batch['question'],
            padding='longest',
            max_length=24,
            truncation=True,
            return_tensors='pt',
            return_attention_mask=True,
        )

    # process labels
    labels = torch.Tensor(labelEncoder.transform(batch['answer']))

    return (images, questions, labels)


In [10]:
class VQAModel(nn.Module):
    def __init__(
        self,
        num_labels,
        intermediate_dim,
        pretrained_text_name,
        pretrained_image_name
    ):
        super(VQAModel, self).__init__()
        
        self.num_labels = num_labels
        self.intermediate_dim = intermediate_dim
        self.pretrained_text_name = pretrained_text_name
        self.pretrained_image_name = pretrained_image_name
        
        # Text and image encoders
        
        self.text_encoder = AutoModel.from_pretrained(self.pretrained_text_name)
        self.image_encoder = AutoModel.from_pretrained(self.pretrained_image_name)

        assert(self.text_encoder.config.hidden_size == self.image_encoder.config.hidden_size)

        self.embedd_dim = self.text_encoder.config.hidden_size

        # Co-attentions
        self.textq = nn.MultiheadAttention(self.embedd_dim, 1, 0.1, batch_first=True)
        self.imgq = nn.MultiheadAttention(self.embedd_dim, 1, 0.1, batch_first=True)
        
        # Classifier
        self.classifier = nn.Linear(self.embedd_dim, self.num_labels)

    def forward(
        self,
        input_ids,
        pixel_values,
        attention_mask
    ):
        # Encode text with masking
        encoded_text = self.text_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        
        # Encode images
        encoded_image = self.image_encoder(
            pixel_values=pixel_values,
        )
        
        text = encoded_text.last_hidden_state
        img = encoded_image.last_hidden_state

        textcls = self.textq(text, img, img)[0][:, 0, :]
        imgcls = self.imgq(img, text, text)[0][:, 0, :]

        cls = textcls+imgcls
        
        # Make predictions
        logits = self.classifier(cls)
        
        return logits

## Training

#### Model, optimizer and loss

In [11]:
def save_model(model, name):
    torch.save(model.state_dict(), name)

def initVQA():
    model = VQAModel(len(labelEncoder.classes_), 512, BERT, VIT).to(device)
    return model

def load_model(name, backup = initVQA):
    model = backup()
    try : 
        model.load_state_dict(torch.load(f"/kaggle/working/{name}"))
        print("Loaded model successfully.")
    except:
        print("Couldn't find model. Initializing from scratch.")
    return model

In [25]:
model = load_model("vqa_dr.pth")
optimizer = optim.Adam(model.parameters(), lr=0.05)
criterion = nn.CrossEntropyLoss()

Some weights of RobertaModel were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loaded model successfully.


#### Hyperparams

In [26]:
collator_fn = sasta_collator
loader = SastaLoader(dataset, 16, sasta_collator)
num_epochs = 1

In [20]:
def train(model, optimizer, criterion, loader, num_epochs, device):
    scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
    loss_plot, accuracy_plot = [], []
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        correct = 0
        total_samples = 0
        with tqdm(total=len(loader), desc="Processing batches", dynamic_ncols=True) as pbar:
            for batchidx, batch in enumerate(loader):
                ids = batch[1]['input_ids'].to(device)
                pxlvalues = batch[0].to(device)
                masks = batch[1]['attention_mask'].to(device)
                labels = batch[2].to(device)

                optimizer.zero_grad()
                outputs = model(ids, pxlvalues, masks)
                loss = criterion(outputs, labels.long())
                loss.backward()
                optimizer.step()

                total_loss += loss.item() * loader.batch_size
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total_samples += labels.size(0)
                pbar.update(loader.batch_size)
                if batchidx % 16000 <= 1:
                    save_model(model, 'vqa_dr.pth')
                
        epoch_loss = total_loss / total_samples
        accuracy = correct / total_samples
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {accuracy:.4f}")
        accuracy_plot.append(accuracy * 100)
        loss_plot.append(epoch_loss)
        save_model(model, "vqa_dr.pth")
        scheduler.step()
        loader.reset()
    plt.plot(accuracy_plot)
    plt.plot(loss_plot)

In [None]:
train(model, optimizer, criterion, loader, num_epochs, device)

Processing batches:  45%|████▍     | 44976/100000 [28:53<35:10, 26.07it/s] 

In [22]:
validation_loader = SastaLoader(dataset, 16, sasta_collator, mode = "validation")

In [27]:
def evaluate_model(model, loader, device):
    y_true, y_pred = [], []
    model.eval()
    loader.reset()
    with tqdm(total=len(loader), desc="Processing batches", dynamic_ncols=True) as pbar:
        for batchidx, batch in enumerate(loader):
            ids = batch[1]['input_ids'].to(device)
            pxlvalues = batch[0].to(device)
            masks = batch[1]['attention_mask'].to(device)
            labels = batch[2].to("cpu")
            outputs = model(ids, pxlvalues, masks)
            _, predicted = torch.max(outputs, 1)
            predicted = predicted.to("cpu")
            y_true.extend(labels)
            y_pred.extend(predicted)
            pbar.update(loader.batch_size)
    print(len(y_true), len(y_pred))
    f1 = f1_score(y_true, y_pred, average = "weighted")
    accuracy = accuracy_score(y_true, y_pred)
    print(f"F1-score: {f1 : 0.2f}")
    print(f"Accuracy: {accuracy * 100 : 0.2f}%")
    return y_pred

In [28]:
y_pred = evaluate_model(model, validation_loader, device)

Processing batches: 9488it [03:01, 52.33it/s]                          


9485 9485
F1-score:  0.11
Accuracy:  19.33%


In [35]:
label_dict = dict()
for label in y_pred:
    if label.item() in label_dict:
        label_dict[label.item()] += 1
    else:
        label_dict[label.item()] = 1
    
for label in label_dict.keys():
    print(f"{labelEncoder.inverse_transform([label])} : {label_dict[label]}")

['no'] : 4912
['yes'] : 4571
['surfing'] : 2
