In [None]:
!pip install peft

In [None]:
import datasets
from transformers import AutoImageProcessor, AutoModel, AutoTokenizer
from huggingface_hub import login
from peft import LoraConfig, get_peft_model

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import StepLR

import math
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.metrics import f1_score, accuracy_score

from tqdm import tqdm
import matplotlib.pyplot as plt
import os

from dataclasses import dataclass
import re

In [None]:
login("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")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
dataset = datasets.load_dataset("Gapes21/aug_vqa", split = "train")

In [None]:
len(dataset)

In [None]:
labelEncoder = LabelEncoder()
labelEncoder.fit(dataset['answer'])

In [None]:
BERT = "FacebookAI/roberta-base"
VIT = 'facebook/dinov2-base'

In [None]:
processor = AutoImageProcessor.from_pretrained(VIT)
tokenizer = AutoTokenizer.from_pretrained(BERT)

In [None]:
class SastaLoader:
    def __init__(self, dataset, batch_size, collator_fn, train_max = 100000, mode = "train"):
        self.dataset = dataset.shuffle()
        self.collator_fn = collator_fn
        self.len = len(self.dataset)
        self.batch_size = batch_size
        if mode == "train":
            self.index = 0
        else :
            self.index = train_max
        self.train_max = train_max
        self.mode = mode

    def hasNext(self):
        if self.mode == "train":
            return self.index + self.batch_size <= self.train_max
        else :
            return self.index + self.batch.size <= self.len
    
    def reset(self):
        if self.mode == "train":
            self.index = 0
        else:
            self.index = self.train_max
        
    def __iter__(self):
        return self

    def __next__(self):
        if self.mode == "train":
            if self.index >= self.train_max:
                raise StopIteration
        else :
            if self.index >= self.len:
                raise StopIteration
                
        batch = self.dataset[self.index: self.index + self.batch_size]
        batch = self.collator_fn(batch)
        self.index += self.batch_size
        return batch
    
    def __len__(self):
        if self.mode == "train":
            return self.train_max
        return self.len - self.train_max
    
    def train(self):
        self.mode = "train"
        
    def validate(self):
        self.mode = "validation"

In [None]:
def sasta_collator(batch):
    # process images
    images = processor(images = batch['image'], return_tensors="pt")['pixel_values']

    # preprocess questions
    questions = tokenizer(
            text=batch['question'],
            padding='longest',
            max_length=24,
            truncation=True,
            return_tensors='pt',
            return_attention_mask=True,
        )

    # process labels
    labels = torch.Tensor(labelEncoder.transform(batch['answer']))
    return (images, questions, labels)


In [None]:
class DinoBertDeep(nn.Module):
    def __init__(
        self,
        num_labels,
        intermediate_dim,
        pretrained_text_name,
        pretrained_image_name,
        classifier_dim = 9024,
    ):
        super(DinoBertDeep, self).__init__()
        
        self.num_labels = num_labels
        self.intermediate_dim = intermediate_dim
        self.pretrained_text_name = pretrained_text_name
        self.pretrained_image_name = pretrained_image_name
        self.classifier_dim = classifier_dim
        
        # Text and image encoders
        
        self.text_encoder = AutoModel.from_pretrained(self.pretrained_text_name)
        self.image_encoder = AutoModel.from_pretrained(self.pretrained_image_name)

        assert(self.text_encoder.config.hidden_size == self.image_encoder.config.hidden_size)

        self.embedd_dim = self.text_encoder.config.hidden_size

        # Co-attentions and encoders
        self.textq = nn.MultiheadAttention(self.embedd_dim, 8, 0.2, batch_first=True)
        self.imgq = nn.MultiheadAttention(self.embedd_dim, 8, 0.2, batch_first=True)
        
        # Classifier
        self.initdim = 2 * self.embedd_dim
        self.classifier = nn.Sequential(
            nn.Linear(self.initdim, self.num_labels),
            nn.GELU(), 
            nn.Linear(self.num_labels, self.num_labels)
        )

    def forward(
        self,
        input_ids,
        pixel_values,
        attention_mask
    ):
        # Encode text with masking
        encoded_text = self.text_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        
        # Encode images
        encoded_image = self.image_encoder(
            pixel_values=pixel_values,
        )
        
        text = encoded_text.last_hidden_state
        img = encoded_image.last_hidden_state
        
        textatt, _t = self.textq(text, img, img)
        imgatt, _i = self.imgq(img, text, text)
        
#         print(f"textatt : {textatt.shape}, imgatt : {imgatt.shape}")
        conatt = torch.cat((textatt[:, 0, :], imgatt[:, 0, :]), dim = 1)
        conatt = conatt.view(conatt.shape[0], -1)
        
#         print(f"text : {text.shape}, img : {img.shape}")
        #Introducing a skip connection
        conatt += torch.cat((text[:, 0, :], img[:, 0, :]), dim = 1).view(conatt.shape[0], -1)
        
        # Make predictions
        logits = self.classifier(conatt)
        return logits

In [None]:
def save_model(model, name):
    torch.save(model.state_dict(), name)

def initVQA():
    model = DinoBertDeep(len(labelEncoder.classes_), 512, BERT, VIT).to(device)
    return model

def load_model(name, backup = initVQA, frommem = True):
    model = backup()
    if frommem == False:
        print("Initializing from scratch.")
        return model
    try : 
        model.load_state_dict(torch.load(f"{name}"))
        print("Loaded model successfully.")
    except:
        print("Couldn't find model. Initializing from scratch.")
    return model

def load_peft_model(basemodel, peftmodel, config):
    model = load_model(basemodel, frommem = True)
    lora_model = get_peft_model(model, config)
    lora_model.load_state_dict(torch.load(peftmodel))
    return lora_model

def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
    )

### Adding LoRA to the mix

In [None]:
model = load_model("/kaggle/input/dinobert-adv/vqa_dinobert_adv.pth", frommem = True)
pattern = r'\((\w+)\): Linear'
linear_layers = re.findall(pattern, str(model.modules))
target_modules = list(set(linear_layers))

In [None]:
config = LoraConfig(
    r = 16,
    lora_alpha = 8,
    target_modules = ["fc1", "fc2", "query", "key", "value", "dense"],
#     [
#         "dense", "word_embeddings", "query", "key", "value",
#         "position_embeddings", "out_proj"
#     ],
    lora_dropout = 0.05,
    bias = "none"
)
loraFromMem = True
lora_model = None
if not loraFromMem : 
    model = load_model("/kaggle/input/dinobert-adv/vqa_dinobert_adv.pth", frommem = True)
    lora_model = get_peft_model(model, config)
else:
    lora_model = load_peft_model("/kaggle/input/dinobert-adv/vqa_dinobert_adv.pth", "/kaggle/input/dinobert-adv/vqa_dinobert_lora_onpretrained.pth", config)
print_trainable_parameters(lora_model)

## Training

#### Hyperparams

In [None]:
collator_fn = sasta_collator
loader = SastaLoader(dataset, 32, sasta_collator, train_max = 210000)
num_epochs = 1
optimizer = optim.Adam(lora_model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
scheduler = StepLR(optimizer, step_size = 1, gamma=0.9)

In [None]:
def train(model, optimizer, criterion, scheduler, loader, num_epochs, device):
    loss_plot, accuracy_plot = [], []
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        correct = 0
        total_samples = 0
        with tqdm(total=len(loader), desc="Processing batches", dynamic_ncols=True) as pbar:
            for batchidx, batch in enumerate(loader):
                ids = batch[1]['input_ids'].to(device)
                pxlvalues = batch[0].to(device)
                masks = batch[1]['attention_mask'].to(device)
                labels = batch[2].to(device)

                optimizer.zero_grad()
                outputs = model(ids, pxlvalues, masks)
                loss = criterion(outputs, labels.long())
                loss.backward()
                optimizer.step()

                total_loss += loss.item() * loader.batch_size
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total_samples += labels.size(0)
                pbar.update(loader.batch_size)
                if batchidx % 16000 <= 1:
                    save_model(model, 'vqa_dr.pth')
                
        epoch_loss = total_loss / total_samples
        accuracy = correct / total_samples
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {accuracy:.4f}")
        accuracy_plot.append(accuracy * 100)
        loss_plot.append(epoch_loss)
        save_model(model, "vqa_dr.pth")
        scheduler.step()
        loader.reset()
        
    plt.plot(accuracy_plot)
    plt.plot(loss_plot)

In [None]:
train(lora_model, optimizer, criterion, scheduler, loader, num_epochs, device)

In [None]:
validation_loader = SastaLoader(dataset, 8, sasta_collator, mode = "validation", train_max = 210000)

In [None]:
def evaluate_model(model, loader, device):
    y_true, y_pred = [], []
    model.eval()
    loader.reset()
    with tqdm(total=len(loader), desc="Processing batches", dynamic_ncols=True) as pbar:
        for batchidx, batch in enumerate(loader):
            ids = batch[1]['input_ids'].to(device)
            pxlvalues = batch[0].to(device)
            masks = batch[1]['attention_mask'].to(device)
            labels = batch[2].to("cpu")
            outputs = model(ids, pxlvalues, masks)
            _, predicted = torch.max(outputs, 1)
            predicted = predicted.to("cpu")
            y_true.extend(labels)
            y_pred.extend(predicted)
            pbar.update(loader.batch_size)
    f1 = f1_score(y_true, y_pred, average = "weighted")
    accuracy = accuracy_score(y_true, y_pred)
    print(f"F1-score: {f1 : 0.2f}")
    print(f"Accuracy: {accuracy * 100 : 0.2f}%")
    return y_pred, y_true

In [None]:
y_pred, y_true = evaluate_model(lora_model, validation_loader, device)

In [None]:
print(classification_report(y_true, y_pred))

In [None]:
save_model(lora_model, "dinobert_lora.pth")

In [None]:
label_dict = dict()
for label in y_pred:
    if label.item() in label_dict:
        label_dict[label.item()] += 1
    else:
        label_dict[label.item()] = 1
    
for label in label_dict.keys():
    print(f"{labelEncoder.inverse_transform([label])} : {label_dict[label]}")