In [None]:
import os

from transformers import BlipProcessor, BlipForQuestionAnswering
import torch
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
import pickle
import json
from torchmetrics.text import BLEUScore
from statistics import mean

In [None]:
# model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")
# processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
# model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
model = BlipForQuestionAnswering.from_pretrained("/home/jovyan/vqa_project/baselines/finetuning/blip_vqa_base_tune")
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model.to(device)

torch.cuda.empty_cache()
torch.manual_seed(42)

In [None]:
class VQADataset(torch.utils.data.Dataset):
    def __init__(self, dataset_path, processor, imagespath_split):
        # self.dataset = dataset
        with open(dataset_path, 'r') as f:
            self.dataset = json.loads(list(f)[0])
        self.processor = processor
        self.imagespath_split = imagespath_split
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # get image + text
        question = self.dataset[idx]['question']
        answer = self.dataset[idx]['answer']
        if ("val" in self.imagespath_split):
            image_path = self.imagespath_split + self.dataset[idx]['image_id'].replace("train", "val") + ".jpg"
        else:
            image_path = self.imagespath_split + self.dataset[idx]['image_id'] + ".jpg"
        # image_id = self.dataset[idx]['pid']
        # image_path = f"Data/train_fill_in_blank/{image_id}/image.png"
        image = Image.open(image_path).convert("RGB")
        # image = image.resize((224, 224))
        text = question
        
        encoding = self.processor(image, text, padding="max_length", truncation=True, return_tensors="pt",  max_length=60)
        labels = self.processor.tokenizer.encode(
            answer, padding='max_length', truncation=True, max_length=20, pad_to_max_length=True, return_tensors='pt'
        )
        # labels = self.processor.tokenizer(answer, padding='max_length', truncation=True, max_length=8, return_tensors='pt')['input_ids']

        encoding["labels"] = labels
        # remove batch dimension
        for k,v in encoding.items():  encoding[k] = v.squeeze()
        return encoding

In [None]:
train_dataset = VQADataset(dataset_path="/home/jovyan/vqa_project/baselines/VQAv2_train_translation.jsonl",
                          processor=processor,
                          imagespath_split="/home/jovyan/vqa_project/baselines/trainvqa/train2014/")
valid_dataset = VQADataset(dataset_path="/home/jovyan/vqa_project/baselines/VQAv2_val_translation.jsonl",
                          processor=processor,
                          imagespath_split="/home/jovyan/vqa_project/baselines/valvqa/val2014/")

In [None]:
batch_size = 36
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=20)
valid_dataloader = DataLoader(valid_dataset, batch_size=24, shuffle=False, pin_memory=True, num_workers=20)

In [None]:
import wandb
wandb.login(key="278590c2621521efe866317352d7f3e13fef885f")
wandb.init(project="blip_finetuning", sync_tensorboard=True, name="")

In [None]:
with open('/home/jovyan/vqa_project/baselines/tracking_information.pkl', 'rb') as f:
    tracking = pickle.load(f)

In [None]:
epoch_loss, eval_loss, lr = tracking[-1]
eval_loss = min(i[1] for i in tracking)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=3.24e-05)#lr=4e-5)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9, last_epoch=-1, verbose=False)
bleu_scorers = [BLEUScore(n_gram=i) for i in [1, 2, 3]]

num_epochs = 50
patience = 10
min_eval_loss = 0.18760925092543795*len(valid_dataloader) # float("inf")
early_stopping_hook = 0
tracking_information = []
scaler = torch.cuda.amp.GradScaler()

bl1 = []
bl2 = []
bl3 = []

In [None]:
for epoch in range(num_epochs):
    epoch_loss = 0
    
    model.train()
    for idx, batch in zip(tqdm(range(len(train_dataloader)), desc='Training batch: ...'), train_dataloader):
        input_ids = batch['input_ids'].to(device)
        pixel_values = batch['pixel_values'].to(device)
        attention_masked = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
            outputs = model(input_ids=input_ids,
                        pixel_values=pixel_values,
                        attention_mask=attention_masked,
                        labels=labels)
            
        loss = outputs.loss
        epoch_loss += loss.item()
        # loss.backward()
        # optimizer.step()
        optimizer.zero_grad()
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        wandb.log({"loss": loss.item()})
    
    model.eval()
    eval_loss = 0

    labels = 0
    input_ids = 0
    pixel_values = 0
    attention_masked = 0
    # with torch.no_grad():
    for idx, batch in zip(tqdm(range(len(valid_dataloader)), desc='Validating batch: ...'), valid_dataloader):
        input_ids = batch.pop('input_ids').to(device)
        pixel_values = batch.pop('pixel_values').to(device)
        attention_masked = batch.pop('attention_mask').to(device)
        labels = batch.pop('labels').to(device)

        with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
            outputs = model(input_ids=input_ids,
                        pixel_values=pixel_values,
                        attention_mask=attention_masked,
                        labels=labels)
        
        loss = outputs.loss
        eval_loss += loss.item()
        wandb.log({"val_loss": eval_loss})
    
    real = processor.batch_decode(labels, skip_special_tokens=True)    
    out = model.generate(input_ids, pixel_values, attention_masked)
    pred = processor.batch_decode(out, skip_special_tokens=True) 

    bl1.append(bleu_scorers[0](pred, real))
    bl2.append(bleu_scorers[1](pred, real))
    bl3.append(bleu_scorers[2](pred, real))

    wandb.log({
        "bleu_1": mean([tensor.item() for tensor in bl1]),
        "bleu_2": mean([tensor.item() for tensor in bl2]),
        "bleu_3": mean([tensor.item() for tensor in bl3])
    })
    print(real[0], pred[0])
    tracking_information.append((epoch_loss/len(train_dataloader), eval_loss/len(valid_dataloader), optimizer.param_groups[0]["lr"]))
    print("Epoch: {} - Training loss: {} - Eval Loss: {} - LR: {}".format(epoch+1, epoch_loss/len(train_dataloader), eval_loss/len(valid_dataloader), optimizer.param_groups[0]["lr"]))
    scheduler.step()
    if eval_loss < min_eval_loss:
        model.save_pretrained("/home/jovyan/vqa_project/baselines/finetuning/blip_vqa_base_tune", from_pt=True) 
        print("/home/jovyan/vqa_project/baselines/finetuning/blip_vqa_base_tune")
        min_eval_loss = eval_loss
        early_stopping_hook = 0
    else:
        early_stopping_hook += 1
        if early_stopping_hook > patience:
            break
    
    
pickle.dump(tracking_information, open("tracking_information.pkl", "wb"))
print("The finetuning process has done!")

In [None]:
model.save_pretrained("/home/jovyan/vqa_project/baselines/saved_models/finetune_blip2/another", from_pt=True) 

In [None]:
for idx, batch in zip(tqdm(range(len(train_dataloader)), desc='Training batch: ...'), train_dataloader):
    input_ids = batch.pop('input_ids').to(device)
    pixel_values = batch.pop('pixel_values').to(device)
    attention_masked = batch.pop('attention_mask').to(device)
    labels = batch.pop('labels').to(device)
    
    with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
        outputs = model(input_ids=input_ids,
                    pixel_values=pixel_values,
                    # attention_mask=attention_masked,
                    labels=labels)
            

In [None]:
with open("/home/jovyan/vqa_project/baselines/VQAv2_train_translation.jsonl", 'r') as f:
    infdataset = json.loads(list(f)[0])
question = infdataset[1]['question']
image_path = "/home/jovyan/vqa_project/baselines/trainvqa/train2014/" + infdataset[1]['image_id'] + ".jpg"