In [None]:
import os
import torch
from torch import nn
import numpy as nn
from PIL import Image
from transformers import BlipConfig , BlipProcessor , BlipForQuestionAnswering , BlipImageProcessor , AutoProcessor
from datasets import load_dataset
from torch.utils.data import DataLoader , Dataset
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

BATCH_SIZE = 32
NUM_WORKERS = os.cpu_count()

device = "cuda" if torch.cuda.is_available() else "cpu"
dataset = load_dataset("flaviagiammarino/path-vqa")

In [None]:
config = BlipConfig.from_pretrained("Salesforce/blip-vqa-base")
train_data = dataset["train"].select(range(10000))
val_data = dataset["validation"].select(range(1000))

class VQA(Dataset):
    def __init__ (self , data , segment , text_processor , image_processor):
        self.data = data
        self.questions = data["question"]
        self.answers = data["answer"]
        self.text_processor = text_processor
        self.image_processor = image_processor
        self.max_length = 32
        self.image_height = 224
        self.image_width = 224
    def __len__(self):
        return len(self.data)
    def __getitem__(self , idx):
        answers = self.answers[idx]
        questions = self.questions[idx]
        image = self.data[idx]["image"].convert('RGB')
        text = self.questions[idx]
        image_encoding = self.image_processor(image , do_resize = True , 
                                             size = (self.image_height , self.image_width),
                                             return_tensors = "pt")
        encoding = self.text_processor(None , text , padding = "max_length" , truncation = True , 
                                      max_length = self.max_length , return_tensors = "pt")
        for k,v in encoding.items():
            encoding[k] = v.squeeze()
        encoding["pixel_values"] = image_encoding["pixel_values"][0]
        labels = self.text_processor.tokenizer.encode(answers , max_length = self.max_length,
                                                     padding = "max_length",
                                                     truncation = True,
                                                     return_tensors = "pt")[0]
        encoding["labels"] = labels
        return encoding

In [None]:
text_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
image_processor = BlipImageProcessor.from_pretrained("Salesforce/blip-vqa-base")

train_vqa = VQA(data = train_data , segment = "train" ,
               text_processor = text_processor , 
               image_processor = image_processor)
test_vqa = VQA(data = val_data , segment = "validation" ,
              text_processor = text_processor,
              image_processor = image_processor)

train_dataloader = DataLoader(train_vqa , batch_size = BATCH_SIZE ,
                             num_workers = NUM_WORKERS,
                             shuffle = False)
test_dataloader = DataLoader(test_vqa , batch_size = BATCH_SIZE , 
                            num_workers = NUM_WORKERS , shuffle = False)

In [None]:
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
model.to(device)

In [None]:
optimizer = torch.optim.AdamW(model.parameters() , lr = 3e-4)

model.train()
for epoch in tqdm(range(10)):
    print(f"Epoch : {epoch + 1}")
    total_loss = []
    for batch in tqdm(train_dataloader):
        batch = {k:v.to(device) for k , v in batch.items()}
        optimizer.zero_grad()
        outputs = model(**batch)
        loss = outputs.loss
        total_loss.append(loss.item())
        loss.backward()
        optimizer.step()
    print("Loss", sum(total_loss))