# Environment setup

## Install neccessary tools, libraries, etc.

In [None]:
!gdown 1--CmXocqkrcPR-bbSDztnxBM9dBC8uUJ # https://drive.google.com/file/d/1--CmXocqkrcPR-bbSDztnxBM9dBC8uUJ/view

In [None]:
!unzip /content/IconDomainVQAData.zip

In [None]:
!pip install -q peft transformers bitsandbytes datasets

## Import important libraries

In [None]:
import os, json
from PIL import Image

from tqdm import tqdm

import torch
from torch.utils.data import Dataset, DataLoader
from torch import optim

from peft import LoraConfig, get_peft_model

from transformers import BlipProcessor, BlipForQuestionAnswering

from datasets import load_dataset, DatasetDict

## Config

In [None]:
lora_config = LoraConfig(
                    r=16,
                    lora_alpha=32,
                    lora_dropout=0.05,
                    bias="none",
                    target_modules=["query", "key"]
                    )

In [None]:
data_path = '/content/IconDomainVQAData'
model_id = "Salesforce/blip-vqa-base"

# Experiment Setup

## Dataset

In [None]:
import pandas as pd
file_path = '/content/IconDomainVQAData/train.jsonl'
jsonObj = pd.read_json(path_or_buf=file_path, lines=True)
jsonObj

In [None]:
class VQADataset(Dataset):

    def __init__(self, dataset, processor, data_path):
        self.dataset = dataset
        self.processor = processor
        self.data_path = data_path

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):

        question = self.dataset[idx]['question']
        answer = self.dataset[idx]['answer']
        image_id = self.dataset[idx]['pid']
        image_path = os.path.join(self.data_path, f"train_fill_in_blank/train_fill_in_blank/{image_id}/image.png")
        image = Image.open(image_path).convert("RGB")
        text = question

        encoding = self.processor(image, text, padding="max_length", truncation=True, return_tensors="pt")
        labels = self.processor.tokenizer.encode(answer, max_length=8, pad_to_max_length=True, return_tensors="pt")
        encoding['labels'] = labels
        for k, v in encoding.items():
            encoding[k] = v.squeeze()

        return encoding

In [None]:
ds = load_dataset("json", data_files=f"{data_path}/train.jsonl")
ds

## Model

In [None]:
model_id = "Salesforce/blip-vqa-base"
processor = BlipProcessor.from_pretrained(model_id)
model = BlipForQuestionAnswering.from_pretrained(model_id)

In [None]:
model = get_peft_model(model, lora_config )
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.print_trainable_parameters()

In [None]:
train_dataset = VQADataset(dataset=ds["train"],
                           processor=processor,
                           data_path=data_path)
val_dataset = VQADataset(dataset=ds["train"],
                           processor=processor,
                           data_path=data_path)

In [None]:
batch_size=8
train_dataloader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              pin_memory=True)
val_dataloader = DataLoader(val_dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            pin_memory=True)

## Training


In [None]:
optimizer = optim.AdamW(model.parameters(),
                        lr=4e-5)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer,
                                             gamma=0.9,
                                             last_epoch=-1,
                                             verbose=True)
n_epochs = 3
min_val_loss = float("inf")
scaler = torch.cuda.amp.GradScaler()

In [None]:
for epoch in range(n_epochs):

    # Training setting
    train_loss = []
    model.train()
    for idx, batch in zip(tqdm(range(len(train_dataloader)),
                               desc=f"Training Batch {epoch+1}"), train_dataloader):
        input_ids = batch.pop("input_ids").to(device)
        pixel_values = batch.pop("pixel_values").to(device)
        attention_masked = batch.pop('attention_mask').to(device)
        labels = batch.pop('labels').to(device)

        with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
            outputs = model(input_ids=input_ids,
                            pixel_values=pixel_values,
                            labels=labels)

        loss = outputs.loss
        train_loss.append(loss.item())

        ## Backward propogation
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

    # Validating setting
    val_loss = []
    model.eval()
    for idx, batch in zip(tqdm(range(len(val_dataloader)),
                               desc=f"Validating Batch {epoch+1}"), val_dataloader):
        input_ids = batch.pop("input_ids").to(device)
        pixel_values = batch.pop("pixel_values").to(device)
        attention_masked = batch.pop('attention_mask').to(device)
        labels = batch.pop('labels').to(device)

        with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
            outputs = model(input_ids=input_ids,
                            pixel_values=pixel_values,
                            attention_mask=attention_masked,
                            labels=labels)

        loss = outputs.loss
        val_loss.append(loss.item())

        avg_train_loss = sum(train_loss)/len(train_loss)
        avg_val_loss = sum(val_loss)/len(val_loss)
        lr_per_epoch = optimizer.param_groups[0]["lr"]
    print(f"Epoch: {epoch + 1} - Training Loss: {avg_train_loss} - Eval Loss: {avg_val_loss} - LR: {lr_per_epoch}")

    scheduler.step()

    if avg_val_loss < min_val_loss:
        model.save_pretrained('./save_model', from_pt=True)
        print ("Saved model to ./save_model")
        min_eval_loss = avg_val_loss
processor.save_pretrained ("./save_model", from_pt=True )

## Predict

In [None]:
retrain_model_dir = './save_model'
processor = BlipProcessor.from_pretrained(retrain_model_dir)
model = BlipForQuestionAnswering.from_pretrained(retrain_model_dir).to(device)

In [None]:
test_data_dir = f"{data_path}/test_data/test_data"
samples = os.listdir(test_data_dir)
sample_path = os.path.join(test_data_dir, samples[0])
json_path = os.path.join(sample_path, "data.json")

with open(json_path, "r") as json_file:
    data = json.load(json_file)
    question = data["question"]
    image_id = data["id"]

In [None]:
question

In [None]:
image_path = os.path.join(test_data_dir, f"{image_id}", "image.png")
image = Image.open(image_path).convert("RGB")
image

In [None]:
encoding = processor(image,
                     question,
                     return_tensors="pt").to(device, torch.float16)

outputs = model.generate(**encoding)

generated_text = processor.decode(outputs[0], skip_special_tokens=True)

generated_text