In [None]:
from torch.utils.data import Dataset, DataLoader
import glob
import os
from PIL import Image
import pandas as pd
import torch
import time
import json
from bert_score import score
from transformers import BlipProcessor, BlipForQuestionAnswering
from torchtext.data.metrics import bleu_score
from torchmetrics.text.rouge import ROUGEScore
from sklearn.metrics import accuracy_score

device = "cuda" if torch.cuda.is_available() else "cpu"

"""save execution time"""
start_time = time.time()

"""necessary paths"""
path = '/content/mvqa_dataset/'
train_path = path + 'train'
test_path_rad = path + 'test/vqa-rad'
test_path_slake = path + 'test/slake'
test_path_mvqa = path + 'test/mvqa'

each_n_epoch = 2
version = 1
im_size = (224, 224)
epochs = 5
should_save_weights = True

"""create necessary folders"""
model_path = path + "runs/" + str(version)
os.makedirs(model_path + "/results", exist_ok=True)
os.makedirs(model_path + "/weights", exist_ok=True)

"""define limit for data loading"""
sample = True
train_limit = 3000
### what about this?
test_limit = 500

dataset_batchsize = 4

train_img_path = train_path + '/images'
test_img_rad = test_path_rad + '/images'
test_img_slake = test_path_slake + '/images'
test_img_mvqa = test_path_mvqa + '/images'

existing_train_img_names = []
for root, dirs, files in os.walk(train_img_path, topdown=False):
    for name in files:
        existing_train_img_names.append(name)

existing_test_img_rad_names = []
for root, dirs, files in os.walk(test_img_rad, topdown=False):
    for name in files:
        existing_test_img_rad_names.append(name)

existing_test_img_slake_names = []
for root, dirs, files in os.walk(test_img_slake, topdown=False):
    for name in files:
        existing_test_img_slake_names.append(name)

existing_test_img_mvqa_names = []
for root, dirs, files in os.walk(test_img_mvqa, topdown=False):
    for name in files:
        existing_test_img_mvqa_names.append(name)

"""read csv files"""
train_df = (
    pd.read_csv(
        train_path + "/train_EN.csv",
        index_col="name",
    )
).filter(items=existing_train_img_names, axis=0)

if sample:
    train_df = train_df.sample(train_limit).sort_values(by=["ID"])

print(f"- read {len(train_df)} train images")

test_df_rad = (
    pd.read_csv(
        test_path_rad + "/radiologytestdata.csv",
        index_col="ID",
    )
).filter(items=existing_test_img_rad_names, axis=0)

if sample:
    test_df_rad = test_df_rad.sample(test_limit).sort_values(by=["ID"])

print(f"- read {len(test_df_rad)} rad test images")

test_df_slake = (
    pd.read_csv(
        test_path_slake + "/radiologytestdata.csv",
        index_col="ID",
    )
).filter(items=existing_test_img_slake_names, axis=0)

if sample:
    test_df_slake = test_df_slake.sample(test_limit).sort_values(by=["ID"])

print(f"- read {len(test_df_slake)} slake test images")

test_df_mvqa = (
    pd.read_csv(
        test_path_mvqa + "/radiologytestdata.csv",
        index_col="ID",
    )
).filter(items=existing_test_img_mvqa_names, axis=0)

if sample:
    test_df_mvqa = test_df_mvqa.sample(test_limit).sort_values(by=["ID"])

print(f"- read {len(test_df_mvqa)} mvqa test images")

"""Save image names as list"""
train_img_names = train_df.index.values.tolist()
test_img_rad_names = test_df_rad.index.values.tolist()
test_img_slake_names = test_df_slake.index.values.tolist()
test_img_mvqa_names = test_df_mvqa.index.values.tolist()

class VQADataset(Dataset):
    def __init__(self, folder_path, image_list, dataset_df, processor):
        self.folder_path = folder_path
        self.image_list = image_list
        self.processor = processor
        self.dataset_df = dataset_df

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        img_name = self.image_list[idx]
        img_path = self.image_path + '/' + img_name
        img = Image.open(img_path).convert("RGB")
        q = self.dataset_df.iloc[idx]['question']
        ans = self.dataset_df.iloc[idx]['answer']
        encoding = self.processor(images=img, text=[q, str(ans)], padding="max_length", return_tensors="pt")
        encoding = {k:v.squeeze() for k,v in encoding.items()}
        encoding['labels'] = encoding['input_ids'][1][:512]
        encoding['input_ids'] = encoding['input_ids'][0][:512]
        encoding['attention_mask'] = [entry[:512] for entry in encoding['attention_mask']]
        return encoding

"""Load processors and models"""
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to("cuda")

"""Generate Dataset and Dataloaders"""
train_dataset = VQADataset(folder_path=train_img_path, image_list=train_img_names, dataset_df=train_df, processor=processor)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=dataset_batchsize)

"""Predicting"""
def pred_vqa(image, question):
    inputs = processor(images=image, text=question, return_tensors="pt").to(device)
    pixel_values = inputs.pixel_values
    input_ids = inputs.input_ids

    generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50)
    generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_answer

def predict_test_dataframe(image_path, image_list, df, dataset_name, epoch):
    preds = []
    for idx, name in enumerate(image_list):
        im_path = image_path + "/" + name
        img = Image.open(im_path)
        q = df.iloc[idx]['question']
        ans = pred_vqa(img, q)
        preds.append(ans)

    filename = f"e-{epoch}-test-{dataset_name}-caption-v{version}.csv"
    df_pred = pd.DataFrame(
        data={
            "ID": image_list,
            "question": df['question'].tolist(),
            "answer": preds,
            "answer_type":df['answer_type'].tolist()
        }
    )

    df_pred.to_csv(
        model_path + "/results/" + filename, index=False, header=None
    )

    return df

def calculate_accuracy(true_df, pred_df):
    """all data"""
    true_ans =  [str(ans).lower() for ans in true_df['answer'].tolist()]
    pred_ans = [ans.lower() for ans in pred_df['answer'].tolist()]
    acc_all = accuracy_score(true_ans, pred_ans)
    """open questions"""
    true_ans =  [str(ans).lower() for ans in true_df.loc[true_df['answer_type'] == "OPEN"]['answer'].tolist()]
    pred_ans =  [ans.lower() for ans in pred_df.loc[pred_df['answer_type'] == "OPEN"]['answer'].tolist()]
    acc_open = accuracy_score(true_ans, pred_ans)
    """close questions"""
    true_ans =  [str(ans).lower() for ans in true_df.loc[true_df['answer_type'] == "CLOSED"]['answer'].tolist()]
    pred_ans =  [ans.lower() for ans in pred_df.loc[pred_df['answer_type'] == "CLOSED"]['answer'].tolist()]
    acc_closed = accuracy_score(true_ans, pred_ans)

    return acc_all, acc_open, acc_closed


def predict_and_save(model, epoch):
    """pred test datasets - RAD, MVQA, SLAKE"""
    pred_df_rad = predict_test_dataframe(test_img_rad, test_img_rad_names, test_df_rad, 'rad', epoch)
    pred_df_mvqa = predict_test_dataframe(test_img_mvqa, test_img_mvqa_names, test_df_mvqa, 'mvqa', epoch)
    pred_df_slake = predict_test_dataframe(test_img_slake, test_img_slake_names, test_df_slake, 'slake', epoch)

    """compute scores for test data"""
    rad_all, rad_open, rad_closed = calculate_accuracy(test_df_rad, pred_df_rad)
    slake_all, slake_open, slake_closed = calculate_accuracy(test_df_slake, pred_df_slake)
    mvqa_all, mvqa_open, mvqa_closed = calculate_accuracy(test_df_mvqa, pred_df_mvqa)
    score_df = pd.DataFrame(data={'All':[rad_all, slake_all, mvqa_all],
                                  'Open':[rad_open, slake_open, mvqa_open],
                                  "Closed":[rad_closed, slake_closed, mvqa_closed]},
                            index=['RAD', 'SLAKE', 'MVQA'])





    filename = f"e-{epoch}-test-scores-v{version}.csv"
    score_df.to_csv(
        model_path + "/results/" + filename
    )

"""Training"""
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

model.to(device)

model.train()

for epoch in range(epochs):
  print("Epoch:", epoch+1)
  train_loss = 0.0
  batch_num = 1
  for idx, batch in enumerate(train_dataloader):
    print("Batch: ", str(batch_num))
    batch_num +=1
    torch.cuda.empty_cache()
    input_ids = batch.pop("input_ids").to(device)
    pixel_values = batch.pop("pixel_values").to(device)
    labels = batch.pop("labels").to(device)

    outputs = model(input_ids=input_ids,
                    pixel_values=pixel_values,
                    labels=input_ids)

    loss = outputs.loss
    train_loss += loss.item() * input_ids.size(0)

    loss.backward()

    optimizer.step()
    optimizer.zero_grad()

  # Compute the average loss for the training dataset
  train_loss /= len(train_dataloader.dataset)

  # Print the training and validation losses for each epoch
  print('Epoch [{}/{}], Train Loss: {:.4f}'
        .format(epoch + 1, epochs, train_loss))

  if epoch == 0 or (epoch + 1) % each_n_epoch == 0:
      predict_and_save(model, epoch + 1)
      os.makedirs(model_path + "/weights/epoch" + str(epoch+1) , exist_ok=True)
      model.save_pretrained(model_path + "/weights/epoch" + str(epoch+1))




with open(model_path + "/results/running_time.json", "w+") as f:
    json.dump({
        "task": "caption",
        "total_running_time_mins": (time.time() - start_time) / 60,
        "version": version,
    }, f)


