In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
import numpy as np
import sys
from os import listdir, makedirs, getcwd, remove
from os.path import isfile, join, abspath, exists, isdir, expanduser
from scipy.io import loadmat
from PIL import Image
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms


import pickle
from tqdm import tqdm
import json
import random
from collections import Counter

from transformers import get_linear_schedule_with_warmup
from transformers import AutoTokenizer, AutoModel

In [4]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.autograd as autograd
from torch.utils.data import Dataset, DataLoader
import torch.utils.data as data
import torch.nn.functional as F
import torch.optim as optim
import itertools

import torchvision
from torchvision import transforms, datasets, models
from torch import Tensor

from torch.utils.tensorboard import SummaryWriter

In [5]:
import transformers
import tokenizers
from transformers import BertTokenizer, BertModel

In [6]:
cfg = {
    'max_len': 128,
    'lr': 2e-5,
    'warmup_steps': 5,
    'epochs': 250
}

In [7]:
class medical_dataset(Dataset):
    def __init__(self, config=None, answer_map=None, image_path='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-TrainingSet/VQAnswering_2020_Train_images/', qa_file="./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-TrainingSet/VQAnswering_2020_Train_QA_pairs.txt", train=True):
            
        assert answer_map!=None
        assert config!=None
        
        self.image_path = image_path
        self.qa_file = qa_file
        self.config = config
        self.train = train
        
        self.answer_map = answer_map
        self.data = pd.read_csv(qa_file, sep='|', names=['imageid', 'question', 'answer'])
           
        # print(Counter(self.data.answer.tolist()))
        
        self.transforms = transforms.Compose([
                                transforms.RandomResizedCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                            ])
        self.tokenizer = tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")

    def process_data(self, text, max_len):
        text = str(text)

        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=max_len,
            truncation=True
        )
        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]
        
        padding_length = max_len - len(ids)
        
        ids = ids + ([0] * padding_length)
        mask = mask + ([0] * padding_length)
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
        }
    
    
    def __getitem__(self, index):
        
        question = self.data.question[index]
        answer = self.data.answer[index]
        image_idx = self.data.imageid[index]
        
        ## add visual feature related steps
        visuals = Image.open(join(self.image_path, f'{image_idx}.jpg'))
        visuals = self.transforms(visuals)
        
        target = torch.from_numpy(np.array([self.answer_map[answer]])).long()
        
        tmp = self.process_data(question, self.config['max_len'])
        question_tokens = {
            'ids': tmp['ids'],
            'mask': tmp['mask'],
        }
        
        return visuals, question_tokens, target

    def __len__(self):
        return len(self.data)

In [8]:
# file_path = './dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-TrainingSet/VQAnswering_2020_Train_QA_pairs.txt'
# data = pd.read_csv(file_path, sep='|', names=['imageid', 'question', 'answer'])
# data.head()

In [9]:
# answer_map = {}
# ct = 0

# for i in data.answer.unique():
#     answer_map[i] = ct
#     ct+=1

In [10]:
# with open('./answer_map.pickle', 'wb') as handle:
#     pickle.dump(answer_map, handle, protocol=pickle.HIGHEST_PROTOCOL)
    

In [11]:
with open('./answer_map.pickle', 'rb') as handle:
    answer_map = pickle.load(handle)
    
len(answer_map)

332

In [12]:
TrainData = medical_dataset(config=cfg, answer_map=answer_map, image_path='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-TrainingSet/VQAnswering_2020_Train_images/', qa_file='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-TrainingSet/VQAnswering_2020_Train_QA_pairs.txt')
ValData = medical_dataset(config=cfg, answer_map=answer_map, image_path='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-ValidationSet/VQAnswering_2020_Val_images/', qa_file='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-ValidationSet/VQAnswering_2020_Val_QA_Pairs.txt', train=False)


In [13]:
%%time
a,b,c = ValData.__getitem__(0)

CPU times: user 31.2 ms, sys: 0 ns, total: 31.2 ms
Wall time: 27.3 ms


In [14]:
TrainDataLoader = DataLoader(TrainData, batch_size=16, shuffle=True, num_workers=4)  # num_workers=0 for windows OS
ValDataLoader = DataLoader(ValData, batch_size=128, shuffle=False, num_workers=4)  # num_workers=0 for windows OS

In [15]:
class mednet(nn.Module):
    def __init__(self, config, max_labels):
        super(mednet, self).__init__()
        self.vision = models.resnet18(pretrained=True)
        num_ftrs = self.vision.fc.in_features
        
        self.vision = nn.Sequential(*list(self.vision.children())[:-1])
        self.bert = AutoModel.from_pretrained("dmis-lab/biobert-v1.1")
        
        
        self.fc1 = nn.Linear(num_ftrs+768, 128)
        self.fc2 = nn.Linear(128, max_labels)
        

    def forward(self, visual=None, ids=None, mask=None):
        vision = self.vision(visual).view((ids.shape[0], -1))
        bert_out = self.bert(ids, mask)
        
        h = torch.cat((vision, bert_out.last_hidden_state[:,0]), dim=1)
        
        h = F.relu(self.fc1(h))
        h = self.fc2(h)
        return h


In [16]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self,
                 patience=7,
                 verbose=False,
                 delta=0,
                 path='./models/no_fusion/checkpoint.pt',
                 trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(
                f'EarlyStopping counter: {self.counter} out of {self.patience}'
            )
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(
                f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...'
            )
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [17]:
class Trainer:
    def __init__(self,
                 trainloader,
                 vallaoder,
                 model_ft,
                 writer=None,
                 testloader=None,
                 checkpoint_path=None,
                 patience=5,
                 feature_extract=True,
                 print_itr=50,
                 config=None):
        self.trainloader = trainloader
        self.valloader = vallaoder
        self.testloader = testloader
        
        self.config=config

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        print("==" * 10)
        print("Training will be done on ", self.device)
        print("==" * 10)

        self.model = model_ft        
        if torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            self.model = nn.DataParallel(self.model, device_ids=[0,1])
        
        self.model = self.model.to(self.device)
        
        
        # Observe that all parameters are being optimized
        self.optimizer = optim.RAdam(self.model.parameters(), lr=self.config['lr'])
        
        self.scheduler = get_linear_schedule_with_warmup(self.optimizer, 
                                            num_warmup_steps = len(self.trainloader)*self.config['warmup_steps'], # Default value in run_glue.py
                                            num_training_steps = len(self.trainloader)*self.config['epochs'])

        
        self.criterion = nn.CrossEntropyLoss()
        self.early_stopping = EarlyStopping(patience=patience, verbose=True)
        self.writer = writer
        self.print_itr = print_itr

    def train(self, ep):
        self.model.train()

        running_loss = 0.0

        for en, (visuals, question, target) in tqdm(enumerate(self.trainloader)):
            self.optimizer.zero_grad()
            
            visuals = visuals.to(self.device)
            y = target.squeeze().to(self.device)
            
            ids = question['ids'].to(self.device)
            mask = question['mask'].to(self.device)

            outputs = self.model(visuals, ids, mask)
            loss = self.criterion(outputs, y)
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()

            running_loss += loss.item()
            if self.writer:
                self.writer.add_scalar('Train Loss', running_loss, ep*len(self.trainloader) + en)
            running_loss = 0
            

    def validate(self, ep):
        self.model.eval()
        
        running_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for en, (visuals, question, target) in tqdm(enumerate(self.valloader)):
                visuals = visuals.to(self.device)
                y = target.squeeze().to(self.device)
                
                ids = question['ids'].to(self.device)
                mask = question['mask'].to(self.device)
                
                outputs = self.model(visuals, ids, mask)
                loss = self.criterion(outputs, y)

                y_pred_softmax = torch.log_softmax(outputs, dim = 1)
                _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)
                # self.tmp_sv_ = y_pred_tags
                
                correct += (y_pred_tags.detach().cpu().data.numpy() == y.detach().cpu().data.numpy()).sum()
                total += y_pred_tags.shape[0]
                
                # print statistics
                running_loss += loss.item()
        
        
        return running_loss / len(self.valloader), correct*100/total

    def perform_training(self, total_epoch):
        val_loss, acc = self.validate(0)

        print("[Initial Validation results] Loss: {} \t Acc: {}".format(
            val_loss, acc))

        for i in range(total_epoch):
            self.train(i + 1)
            val_loss, acc = self.validate(i + 1)
            print('[{}/{}] Loss: {} \t Acc: {}'.format(i+1, total_epoch, val_loss, acc))

            if self.writer:
                self.writer.add_scalar('Validation Loss', val_loss, (i + 1))
                self.writer.add_scalar('Validation Acc', acc, (i + 1))

            self.early_stopping(val_loss, self.model)

            # if self.early_stopping.early_stop:
            #     print("Early stopping")
            #     break

        print("=" * 20)
        print("Training finished !!")
        print("=" * 20)


In [18]:
model_ft = mednet(config=cfg, max_labels=len(answer_map))
writer = SummaryWriter('runs/no_fusion_run_2')
trainer = Trainer(TrainDataLoader, ValDataLoader, model_ft, writer=writer, config=cfg)

Training will be done on  cuda
Let's use 2 GPUs!


In [19]:
trainer.perform_training(cfg['epochs'])

4it [00:04,  1.05s/it]

[Initial Validation results] Loss: 5.837615251541138 	 Acc: 0.2



250it [00:46,  5.34it/s]
4it [00:01,  2.70it/s]


[1/250] Loss: 5.836101531982422 	 Acc: 0.2
Validation loss decreased (inf --> 5.836102).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.61it/s]


[2/250] Loss: 5.797498106956482 	 Acc: 0.6
Validation loss decreased (5.836102 --> 5.797498).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.62it/s]


[3/250] Loss: 5.7608020305633545 	 Acc: 1.4
Validation loss decreased (5.797498 --> 5.760802).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.60it/s]


[4/250] Loss: 5.663802623748779 	 Acc: 5.4
Validation loss decreased (5.760802 --> 5.663803).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.67it/s]


[5/250] Loss: 5.524681925773621 	 Acc: 5.4
Validation loss decreased (5.663803 --> 5.524682).  Saving model ...


250it [00:46,  5.34it/s]
4it [00:01,  2.68it/s]


[6/250] Loss: 5.364600658416748 	 Acc: 6.8
Validation loss decreased (5.524682 --> 5.364601).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.61it/s]


[7/250] Loss: 5.231878638267517 	 Acc: 6.8
Validation loss decreased (5.364601 --> 5.231879).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.69it/s]


[8/250] Loss: 5.118140816688538 	 Acc: 9.8
Validation loss decreased (5.231879 --> 5.118141).  Saving model ...


250it [00:46,  5.34it/s]
4it [00:01,  2.61it/s]


[9/250] Loss: 5.0242778062820435 	 Acc: 8.8
Validation loss decreased (5.118141 --> 5.024278).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.63it/s]


[10/250] Loss: 4.892484068870544 	 Acc: 11.2
Validation loss decreased (5.024278 --> 4.892484).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.62it/s]


[11/250] Loss: 4.81072461605072 	 Acc: 10.0
Validation loss decreased (4.892484 --> 4.810725).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.65it/s]


[12/250] Loss: 4.740372657775879 	 Acc: 11.4
Validation loss decreased (4.810725 --> 4.740373).  Saving model ...


250it [00:46,  5.34it/s]
4it [00:01,  2.62it/s]


[13/250] Loss: 4.640388250350952 	 Acc: 12.6
Validation loss decreased (4.740373 --> 4.640388).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.63it/s]


[14/250] Loss: 4.580736994743347 	 Acc: 11.6
Validation loss decreased (4.640388 --> 4.580737).  Saving model ...


250it [00:46,  5.34it/s]
4it [00:01,  2.61it/s]


[15/250] Loss: 4.495085120201111 	 Acc: 14.0
Validation loss decreased (4.580737 --> 4.495085).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.63it/s]


[16/250] Loss: 4.457839906215668 	 Acc: 13.6
Validation loss decreased (4.495085 --> 4.457840).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.66it/s]


[17/250] Loss: 4.4025352001190186 	 Acc: 13.8
Validation loss decreased (4.457840 --> 4.402535).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.58it/s]


[18/250] Loss: 4.331705212593079 	 Acc: 14.8
Validation loss decreased (4.402535 --> 4.331705).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.61it/s]


[19/250] Loss: 4.2975914478302 	 Acc: 15.2
Validation loss decreased (4.331705 --> 4.297591).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.60it/s]


[20/250] Loss: 4.215506553649902 	 Acc: 14.2
Validation loss decreased (4.297591 --> 4.215507).  Saving model ...


250it [00:46,  5.34it/s]
4it [00:01,  2.56it/s]


[21/250] Loss: 4.161027014255524 	 Acc: 16.2
Validation loss decreased (4.215507 --> 4.161027).  Saving model ...


250it [00:46,  5.34it/s]
4it [00:01,  2.64it/s]

[22/250] Loss: 4.17311704158783 	 Acc: 15.6
EarlyStopping counter: 1 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.67it/s]


[23/250] Loss: 4.107385694980621 	 Acc: 18.0
Validation loss decreased (4.161027 --> 4.107386).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.64it/s]


[24/250] Loss: 4.055293619632721 	 Acc: 18.2
Validation loss decreased (4.107386 --> 4.055294).  Saving model ...


250it [00:46,  5.34it/s]
4it [00:01,  2.61it/s]

[25/250] Loss: 4.061221420764923 	 Acc: 16.4
EarlyStopping counter: 1 out of 5



250it [00:46,  5.34it/s]
4it [00:01,  2.62it/s]


[26/250] Loss: 3.969112277030945 	 Acc: 17.8
Validation loss decreased (4.055294 --> 3.969112).  Saving model ...


250it [00:46,  5.34it/s]
4it [00:01,  2.63it/s]


[27/250] Loss: 3.9382843375205994 	 Acc: 18.6
Validation loss decreased (3.969112 --> 3.938284).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.70it/s]


[28/250] Loss: 3.927073657512665 	 Acc: 19.0
Validation loss decreased (3.938284 --> 3.927074).  Saving model ...


250it [00:46,  5.34it/s]
4it [00:01,  2.60it/s]


[29/250] Loss: 3.898852825164795 	 Acc: 18.0
Validation loss decreased (3.927074 --> 3.898853).  Saving model ...


250it [00:46,  5.34it/s]
4it [00:01,  2.60it/s]


[30/250] Loss: 3.866658926010132 	 Acc: 18.6
Validation loss decreased (3.898853 --> 3.866659).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.62it/s]


[31/250] Loss: 3.861752688884735 	 Acc: 21.6
Validation loss decreased (3.866659 --> 3.861753).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.57it/s]


[32/250] Loss: 3.7956483364105225 	 Acc: 21.4
Validation loss decreased (3.861753 --> 3.795648).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.58it/s]

[33/250] Loss: 3.7963501811027527 	 Acc: 19.8
EarlyStopping counter: 1 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.60it/s]


[34/250] Loss: 3.7830586433410645 	 Acc: 20.8
Validation loss decreased (3.795648 --> 3.783059).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.60it/s]


[35/250] Loss: 3.7026565074920654 	 Acc: 21.8
Validation loss decreased (3.783059 --> 3.702657).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.60it/s]

[36/250] Loss: 3.7331173419952393 	 Acc: 21.2
EarlyStopping counter: 1 out of 5



250it [00:46,  5.34it/s]
4it [00:01,  2.67it/s]

[37/250] Loss: 3.7678977251052856 	 Acc: 19.8
EarlyStopping counter: 2 out of 5



250it [00:46,  5.35it/s]
4it [00:01,  2.61it/s]


[38/250] Loss: 3.6464545130729675 	 Acc: 22.2
Validation loss decreased (3.702657 --> 3.646455).  Saving model ...


250it [00:46,  5.32it/s]
4it [00:01,  2.63it/s]

[39/250] Loss: 3.6909760236740112 	 Acc: 23.4
EarlyStopping counter: 1 out of 5



250it [00:46,  5.34it/s]
4it [00:01,  2.59it/s]


[40/250] Loss: 3.620068073272705 	 Acc: 23.4
Validation loss decreased (3.646455 --> 3.620068).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.59it/s]

[41/250] Loss: 3.6352208256721497 	 Acc: 21.6
EarlyStopping counter: 1 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.62it/s]


[42/250] Loss: 3.6112032532691956 	 Acc: 24.4
Validation loss decreased (3.620068 --> 3.611203).  Saving model ...


250it [00:46,  5.34it/s]
4it [00:01,  2.62it/s]


[43/250] Loss: 3.5258325934410095 	 Acc: 25.4
Validation loss decreased (3.611203 --> 3.525833).  Saving model ...


250it [00:46,  5.34it/s]
4it [00:01,  2.66it/s]

[44/250] Loss: 3.592952787876129 	 Acc: 25.0
EarlyStopping counter: 1 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.61it/s]


[45/250] Loss: 3.5132226943969727 	 Acc: 24.2
Validation loss decreased (3.525833 --> 3.513223).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.61it/s]

[46/250] Loss: 3.52799254655838 	 Acc: 24.8
EarlyStopping counter: 1 out of 5



250it [00:46,  5.34it/s]
4it [00:01,  2.64it/s]

[47/250] Loss: 3.5883543491363525 	 Acc: 24.4
EarlyStopping counter: 2 out of 5



250it [00:46,  5.34it/s]
4it [00:01,  2.64it/s]


[48/250] Loss: 3.4854727387428284 	 Acc: 24.2
Validation loss decreased (3.513223 --> 3.485473).  Saving model ...


250it [00:46,  5.34it/s]
4it [00:01,  2.64it/s]

[49/250] Loss: 3.516741156578064 	 Acc: 26.0
EarlyStopping counter: 1 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.62it/s]

[50/250] Loss: 3.5203317999839783 	 Acc: 25.2
EarlyStopping counter: 2 out of 5



250it [00:46,  5.34it/s]
4it [00:01,  2.65it/s]


[51/250] Loss: 3.434432327747345 	 Acc: 25.6
Validation loss decreased (3.485473 --> 3.434432).  Saving model ...


250it [00:46,  5.34it/s]
4it [00:01,  2.60it/s]


[52/250] Loss: 3.408302664756775 	 Acc: 27.4
Validation loss decreased (3.434432 --> 3.408303).  Saving model ...


250it [00:46,  5.34it/s]
4it [00:01,  2.59it/s]


[53/250] Loss: 3.406317949295044 	 Acc: 26.4
Validation loss decreased (3.408303 --> 3.406318).  Saving model ...


250it [00:46,  5.34it/s]
4it [00:01,  2.62it/s]


[54/250] Loss: 3.3542001843452454 	 Acc: 27.8
Validation loss decreased (3.406318 --> 3.354200).  Saving model ...


250it [00:46,  5.34it/s]
4it [00:01,  2.65it/s]

[55/250] Loss: 3.3700140714645386 	 Acc: 29.0
EarlyStopping counter: 1 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.60it/s]

[56/250] Loss: 3.3648563623428345 	 Acc: 29.4
EarlyStopping counter: 2 out of 5



250it [00:46,  5.34it/s]
4it [00:01,  2.57it/s]

[57/250] Loss: 3.4806090593338013 	 Acc: 27.2
EarlyStopping counter: 3 out of 5



250it [00:46,  5.34it/s]
4it [00:01,  2.66it/s]

[58/250] Loss: 3.3637436628341675 	 Acc: 28.6
EarlyStopping counter: 4 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.64it/s]


[59/250] Loss: 3.3100754022598267 	 Acc: 29.2
Validation loss decreased (3.354200 --> 3.310075).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.63it/s]

[60/250] Loss: 3.36012464761734 	 Acc: 29.8
EarlyStopping counter: 1 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.61it/s]

[61/250] Loss: 3.3694770336151123 	 Acc: 30.4
EarlyStopping counter: 2 out of 5



250it [00:46,  5.34it/s]
4it [00:01,  2.63it/s]


[62/250] Loss: 3.2813875675201416 	 Acc: 30.2
Validation loss decreased (3.310075 --> 3.281388).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.62it/s]

[63/250] Loss: 3.382315993309021 	 Acc: 29.0
EarlyStopping counter: 1 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.61it/s]

[64/250] Loss: 3.343909740447998 	 Acc: 28.6
EarlyStopping counter: 2 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.66it/s]

[65/250] Loss: 3.359048366546631 	 Acc: 29.8
EarlyStopping counter: 3 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.62it/s]

[66/250] Loss: 3.3229116797447205 	 Acc: 29.2
EarlyStopping counter: 4 out of 5



250it [00:46,  5.34it/s]
4it [00:01,  2.61it/s]

[67/250] Loss: 3.2831071615219116 	 Acc: 31.4
EarlyStopping counter: 5 out of 5



250it [00:46,  5.34it/s]
4it [00:01,  2.64it/s]


[68/250] Loss: 3.268081247806549 	 Acc: 32.4
Validation loss decreased (3.281388 --> 3.268081).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.60it/s]

[69/250] Loss: 3.2716123461723328 	 Acc: 35.2
EarlyStopping counter: 1 out of 5



250it [00:46,  5.34it/s]
4it [00:01,  2.64it/s]

[70/250] Loss: 3.2979135513305664 	 Acc: 30.6
EarlyStopping counter: 2 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.61it/s]

[71/250] Loss: 3.2898788452148438 	 Acc: 31.4
EarlyStopping counter: 3 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.68it/s]


[72/250] Loss: 3.244566261768341 	 Acc: 34.0
Validation loss decreased (3.268081 --> 3.244566).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.64it/s]

[73/250] Loss: 3.3067516684532166 	 Acc: 32.0
EarlyStopping counter: 1 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.64it/s]

[74/250] Loss: 3.3184532523155212 	 Acc: 31.0
EarlyStopping counter: 2 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.61it/s]


[75/250] Loss: 3.2438436150550842 	 Acc: 33.6
Validation loss decreased (3.244566 --> 3.243844).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.59it/s]

[76/250] Loss: 3.308763861656189 	 Acc: 32.0
EarlyStopping counter: 1 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.58it/s]


[77/250] Loss: 3.210522174835205 	 Acc: 33.8
Validation loss decreased (3.243844 --> 3.210522).  Saving model ...


250it [00:46,  5.34it/s]
4it [00:01,  2.62it/s]

[78/250] Loss: 3.3308228850364685 	 Acc: 31.8
EarlyStopping counter: 1 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.61it/s]

[79/250] Loss: 3.2909353971481323 	 Acc: 34.2
EarlyStopping counter: 2 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.57it/s]

[80/250] Loss: 3.3019753098487854 	 Acc: 31.4
EarlyStopping counter: 3 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.59it/s]

[81/250] Loss: 3.292394995689392 	 Acc: 34.8
EarlyStopping counter: 4 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.60it/s]

[82/250] Loss: 3.312225580215454 	 Acc: 33.2
EarlyStopping counter: 5 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.69it/s]

[83/250] Loss: 3.3470104932785034 	 Acc: 32.8
EarlyStopping counter: 6 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.59it/s]


[84/250] Loss: 3.1846446990966797 	 Acc: 33.4
Validation loss decreased (3.210522 --> 3.184645).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.62it/s]

[85/250] Loss: 3.2632964849472046 	 Acc: 32.8
EarlyStopping counter: 1 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.59it/s]

[86/250] Loss: 3.202887177467346 	 Acc: 34.2
EarlyStopping counter: 2 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.61it/s]

[87/250] Loss: 3.3135664463043213 	 Acc: 33.8
EarlyStopping counter: 3 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.60it/s]

[88/250] Loss: 3.2068715691566467 	 Acc: 34.2
EarlyStopping counter: 4 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.66it/s]

[89/250] Loss: 3.3237234354019165 	 Acc: 31.8
EarlyStopping counter: 5 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.61it/s]


[90/250] Loss: 3.159863770008087 	 Acc: 35.6
Validation loss decreased (3.184645 --> 3.159864).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.53it/s]

[91/250] Loss: 3.301831901073456 	 Acc: 32.4
EarlyStopping counter: 1 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.60it/s]

[92/250] Loss: 3.215398907661438 	 Acc: 35.6
EarlyStopping counter: 2 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.57it/s]

[93/250] Loss: 3.2453460693359375 	 Acc: 33.8
EarlyStopping counter: 3 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.62it/s]

[94/250] Loss: 3.238484799861908 	 Acc: 35.2
EarlyStopping counter: 4 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.62it/s]

[95/250] Loss: 3.264693796634674 	 Acc: 35.0
EarlyStopping counter: 5 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.60it/s]

[96/250] Loss: 3.2440231442451477 	 Acc: 37.2
EarlyStopping counter: 6 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.65it/s]

[97/250] Loss: 3.3067097663879395 	 Acc: 34.0
EarlyStopping counter: 7 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.61it/s]

[98/250] Loss: 3.2965638041496277 	 Acc: 34.2
EarlyStopping counter: 8 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.60it/s]

[99/250] Loss: 3.284667670726776 	 Acc: 33.0
EarlyStopping counter: 9 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.61it/s]


[100/250] Loss: 3.076765477657318 	 Acc: 36.4
Validation loss decreased (3.159864 --> 3.076765).  Saving model ...


118it [00:22,  5.30it/s]


KeyboardInterrupt: 

In [41]:
class medical_dataset_test(Dataset):
    def __init__(self, config=None, answer_map=None, image_path='./dataset/Task1-VQA-2021-TestSet-w-GroundTruth/VQA-500-Images/', qa_file="./dataset/Task1-VQA-2021-TestSet-w-GroundTruth/Task1-VQA-2021-TestSet-Questions.txt", train=True):
            
        assert answer_map!=None
        assert config!=None
        
        self.image_path = image_path
        self.qa_file = qa_file
        self.config = config
        self.train = train
        
        self.answer_map = answer_map
        self.data = pd.read_csv(qa_file, sep='|', names=['imageid', 'question'])
           
        # print(Counter(self.data.answer.tolist()))
        
        self.transforms = transforms.Compose([
                                transforms.RandomResizedCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                            ])
        self.tokenizer = tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")

    def process_data(self, text, max_len):
        text = str(text)

        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=max_len,
            truncation=True
        )
        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]
        
        padding_length = max_len - len(ids)
        
        ids = ids + ([0] * padding_length)
        mask = mask + ([0] * padding_length)
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
        }
    
    
    def __getitem__(self, index):
        
        question = self.data.question[index]
        image_idx = self.data.imageid[index]
        
        ## add visual feature related steps
        visuals = Image.open(join(self.image_path, f'{image_idx}.jpg'))
        visuals = self.transforms(visuals)
        
        
        tmp = self.process_data(question, self.config['max_len'])
        question_tokens = {
            'ids': tmp['ids'],
            'mask': tmp['mask'],
        }
        
        return visuals, question_tokens, image_idx

    def __len__(self):
        return len(self.data)

In [42]:
TestData = medical_dataset_test(config=cfg, answer_map=answer_map)
TestDataLoader = DataLoader(TestData, batch_size=128, shuffle=False, num_workers=4)  # num_workers=0 for windows OS

In [43]:
inv_map = {v: k for k, v in answer_map.items()}

In [46]:
trainer.model.eval()
        
imageids = []
answers = []

running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
    for en, (visuals, question, target) in tqdm(enumerate(TestDataLoader)):
        visuals = visuals.to(trainer.device)

        ids = question['ids'].to(trainer.device)
        mask = question['mask'].to(trainer.device)

        outputs = trainer.model(visuals, ids, mask)

        y_pred_softmax = torch.log_softmax(outputs, dim = 1)
        _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)
        # self.tmp_sv_ = y_pred_tags
        
        for i in range(visuals.shape[0]):
            imageids.append(target[i])
            answers.append(inv_map[int(y_pred_tags[i])])


4it [00:01,  2.45it/s]


In [61]:
answers[:5]

['femoral neck stress fractures',
 'pulmonary embolism',
 'aortic dissection',
 'tophaceous gout',
 'traumatic transient lateral patellar dislocation.']

In [48]:
# pd.DataFrame({'imageids':imageids, 'answers':answers}).to_csv('no_fusion.txt', sep='|', index=False, header=False)

In [52]:
reference_data = pd.read_csv('./VQA-Med-2021/Task1-VQA-2021-TestSet-w-GroundTruth/Task1-VQA-2021-TestSet-ReferenceAnswers.txt',sep='|', names=['imageid', 'ans1', 'ans2', 'ans3'])

In [53]:
reference_data.head()

Unnamed: 0,imageid,ans1,ans2,ans3
0,synpic42072,avascular necrosis,extensive degenerative changes on the bilatera...,Avascular Necrosis of femoral heads bilaterall...
1,synpic37231,pulmonary embolism,,
2,synpic51484,osteopoikilosis,osteosclerotic lesions,
3,synpic15699,takayasu arteritis,multifocal stenosis of the great vessels and b...,
4,synpic33852,acl injury,,


In [72]:
correct = 0
for en, idx in enumerate(imageids):
    tmp_ = reference_data[reference_data.imageid==idx].reset_index()
    assert len(tmp_)==1
    
    if answers[en].strip()==tmp_.ans1[0] or answers[en]==tmp_.ans2[0] or answers[en]==tmp_.ans3[0]:
        correct+=1

print(correct*100/len(imageids))

20.4


In [71]:
bleu_score = []

from nltk.translate.bleu_score import sentence_bleu
from nltk.tokenize import word_tokenize

for en, idx in tqdm(enumerate(imageids)):
    tmp_ = reference_data[reference_data.imageid==idx].reset_index()
    assert len(tmp_)==1
    
    candidate = word_tokenize(str(answers[en]).strip().lower())
    w_ = [tmp_.ans1[0], tmp_.ans2[0], tmp_.ans3[0]]
    
    reference = [word_tokenize(str(w).strip().lower()) for w in w_]
    bleu_score.append(sentence_bleu(reference, candidate))
        
print(np.mean(bleu_score))

500it [00:00, 1726.64it/s]

0.05659260141742961



