In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
import numpy as np
import sys
from os import listdir, makedirs, getcwd, remove
from os.path import isfile, join, abspath, exists, isdir, expanduser
from scipy.io import loadmat
from PIL import Image
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms


import pickle
from tqdm import tqdm
import json
import random
from collections import Counter

from transformers import get_linear_schedule_with_warmup
from transformers import AutoTokenizer, AutoModel

In [4]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.autograd as autograd
from torch.utils.data import Dataset, DataLoader
import torch.utils.data as data
import torch.nn.functional as F
import torch.optim as optim
import itertools

import torchvision
from torchvision import transforms, datasets, models
from torch import Tensor

from torch.utils.tensorboard import SummaryWriter

In [5]:
import transformers
import tokenizers
from transformers import BertTokenizer, BertModel

In [6]:
cfg = {
    'max_len': 128,
    'lr': 2e-5,
    'warmup_steps': 5,
    'epochs': 250
}

In [7]:
class medical_dataset(Dataset):
    def __init__(self, config=None, answer_map=None, image_path='./dataset/TrainingDataset/images/', qa_file="./dataset/TrainingDataset/training_qa.txt", train=True):
            
        assert answer_map!=None
        assert config!=None
        
        self.image_path = image_path
        self.qa_file = qa_file
        self.config = config
        self.train = train
        
        self.answer_map = answer_map
        self.data = pd.read_csv(qa_file, sep='|', names=['imageid', 'question', 'answer'])
           
        # print(Counter(self.data.answer.tolist()))
        
        self.transforms = transforms.Compose([
                                transforms.RandomResizedCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                            ])
        self.tokenizer = tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")

    def process_data(self, text, max_len):
        text = str(text)

        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=max_len,
            truncation=True
        )
        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]
        
        padding_length = max_len - len(ids)
        
        ids = ids + ([0] * padding_length)
        mask = mask + ([0] * padding_length)
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
        }
    
    
    def __getitem__(self, index):
        
        question = self.data.question[index]
        answer = self.data.answer[index]
        image_idx = self.data.imageid[index]
        
        ## add visual feature related steps
        visuals = Image.open(join(self.image_path, f'{image_idx}.jpg'))
        visuals = self.transforms(visuals)
        
        target = torch.from_numpy(np.array([self.answer_map[answer]])).long()
        
        tmp = self.process_data(question, self.config['max_len'])
        question_tokens = {
            'ids': tmp['ids'],
            'mask': tmp['mask'],
        }
        
        return visuals, question_tokens, target

    def __len__(self):
        return len(self.data)

In [8]:
# file_path = './dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-TrainingSet/VQAnswering_2020_Train_QA_pairs.txt'
# data = pd.read_csv(file_path, sep='|', names=['imageid', 'question', 'answer'])
# data.head()

In [9]:
# answer_map = {}
# ct = 0

# for i in data.answer.unique():
#     answer_map[i] = ct
#     ct+=1

In [10]:
# with open('./answer_map.pickle', 'wb') as handle:
#     pickle.dump(answer_map, handle, protocol=pickle.HIGHEST_PROTOCOL)
    

In [11]:
with open('./answer_map.pickle', 'rb') as handle:
    answer_map = pickle.load(handle)
    
len(answer_map)

333

In [12]:
TrainData = medical_dataset(config=cfg, answer_map=answer_map, image_path='./dataset/TrainingDataset/images/', qa_file='./dataset/TrainingDataset/training_qa.txt')
ValData = medical_dataset(config=cfg, answer_map=answer_map, image_path='./dataset/VQA-Med-2021-Tasks-1-2-NewValidationSets/ImageCLEF-2021-VQA-Med-New-Validation-Images/', qa_file='./dataset/VQA-Med-2021-Tasks-1-2-NewValidationSets/VQA-Med-2021-VQAnswering-Task1-New-ValidationSet.txt', train=False)


In [13]:
%%time
a,b,c = ValData.__getitem__(0)

CPU times: user 11.1 ms, sys: 7.9 ms, total: 19 ms
Wall time: 34.5 ms


In [14]:
TrainDataLoader = DataLoader(TrainData, batch_size=16, shuffle=True, num_workers=4)  # num_workers=0 for windows OS
ValDataLoader = DataLoader(ValData, batch_size=128, shuffle=False, num_workers=4)  # num_workers=0 for windows OS

In [20]:
#3rd Fusion Module
class MLBFusion(nn.Module):

    def __init__(self, dim_v=None, dim_q=None, dim_h=512):
        super(MLBFusion, self).__init__()
        # Modules
        self.dim_v = dim_v
        self.dim_q = dim_q
        self.dim_h = dim_h
        self.dropout = 0
        self.activation = 'relu'
        
        if dim_v:
            self.linear_v = nn.Linear(dim_v, dim_h)
        else:
            print('Warning fusion.py: no visual embedding before fusion')

        if dim_q:
            self.linear_q = nn.Linear(dim_q, dim_h)
        else:
            print('Warning fusion.py: no question embedding before fusion')
        
    def forward(self, input_v, input_q):
        # visual (cnn features)
        if self.dim_v:
            x_v = F.dropout(input_v, p=self.dropout, training=self.training)
            x_v = self.linear_v(x_v)
            x_v = getattr(F, self.activation)(x_v)
        else:
            x_v = input_v
        # question (rnn features)
        if self.dim_q:
            x_q = F.dropout(input_q, p=self.dropout, training=self.training)
            x_q = self.linear_q(x_q)
            x_q = getattr(F, self.activation)(x_q)
        else:
            x_q = input_q
        # hadamard product
        x_mm = torch.mul(x_q, x_v)
        return x_mm

In [21]:
class mednet(nn.Module):
    def __init__(self, config, max_labels):
        super(mednet, self).__init__()
        self.vision = models.resnet18(pretrained=True)
        num_ftrs = self.vision.fc.in_features
        
        self.vision = nn.Sequential(*list(self.vision.children())[:-1])
        self.bert = AutoModel.from_pretrained("dmis-lab/biobert-v1.1")
        
        print(f"Number of vision filters: {num_ftrs}")
        self.fc1 = nn.Linear(512, 128)
        self.fc2 = nn.Linear(128, max_labels)
        
        self.fusion = MLBFusion(dim_v=num_ftrs, dim_q=768, dim_h=512)
        

    def forward(self, visual=None, ids=None, mask=None):
        vision = self.vision(visual).view((ids.shape[0], -1))
        bert_out = self.bert(ids, mask)
        
        h = self.fusion(vision, bert_out.last_hidden_state[:,0])
        h = F.relu(self.fc1(h))
        h = self.fc2(h)
        return h


In [22]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self,
                 patience=7,
                 verbose=False,
                 delta=0,
                 path='./models/checkpoint.pt',
                 trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(
                f'EarlyStopping counter: {self.counter} out of {self.patience}'
            )
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(
                f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...'
            )
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [23]:
class Trainer:
    def __init__(self,
                 trainloader,
                 vallaoder,
                 model_ft,
                 writer=None,
                 testloader=None,
                 checkpoint_path=None,
                 patience=10,
                 feature_extract=True,
                 print_itr=50,
                 config=None):
        self.trainloader = trainloader
        self.valloader = vallaoder
        self.testloader = testloader
        
        self.config=config

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        print("==" * 10)
        print("Training will be done on ", self.device)
        print("==" * 10)

        self.model = model_ft        
        if torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            self.model = nn.DataParallel(self.model, device_ids=[0,1])
        
        self.model = self.model.to(self.device)
        
        
        # Observe that all parameters are being optimized
        self.optimizer = optim.RAdam(self.model.parameters(), lr=self.config['lr'])
        
        self.scheduler = get_linear_schedule_with_warmup(self.optimizer, 
                                            num_warmup_steps = len(self.trainloader)*self.config['warmup_steps'], # Default value in run_glue.py
                                            num_training_steps = len(self.trainloader)*self.config['epochs'])

        
        self.criterion = nn.CrossEntropyLoss()
        self.early_stopping = EarlyStopping(patience=patience, verbose=True)
        self.writer = writer
        self.print_itr = print_itr

    def train(self, ep):
        self.model.train()

        running_loss = 0.0

        for en, (visuals, question, target) in tqdm(enumerate(self.trainloader)):
            self.optimizer.zero_grad()
            
            visuals = visuals.to(self.device)
            y = target.squeeze().to(self.device)
            
            ids = question['ids'].to(self.device)
            mask = question['mask'].to(self.device)

            outputs = self.model(visuals, ids, mask)
            loss = self.criterion(outputs, y)
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()

            running_loss += loss.item()
            if self.writer:
                self.writer.add_scalar('Train Loss', running_loss, ep*len(self.trainloader) + en)
            running_loss = 0
            

    def validate(self, ep):
        self.model.eval()
        
        running_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for en, (visuals, question, target) in tqdm(enumerate(self.valloader)):
                visuals = visuals.to(self.device)
                y = target.squeeze().to(self.device)
                
                ids = question['ids'].to(self.device)
                mask = question['mask'].to(self.device)
                
                outputs = self.model(visuals, ids, mask)
                loss = self.criterion(outputs, y)

                y_pred_softmax = torch.log_softmax(outputs, dim = 1)
                _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)
                # self.tmp_sv_ = y_pred_tags
                
                correct += (y_pred_tags.detach().cpu().data.numpy() == y.detach().cpu().data.numpy()).sum()
                total += y_pred_tags.shape[0]
                
                # print statistics
                running_loss += loss.item()
        
        
        return running_loss / len(self.valloader), correct*100/total

    def perform_training(self, total_epoch):
        val_loss, acc = self.validate(0)

        print("[Initial Validation results] Loss: {} \t Acc: {}".format(
            val_loss, acc))

        for i in range(total_epoch):
            self.train(i + 1)
            val_loss, acc = self.validate(i + 1)
            print('[{}/{}] Loss: {} \t Acc: {}'.format(i+1, total_epoch, val_loss, acc))

            if self.writer:
                self.writer.add_scalar('Validation Loss', val_loss, (i + 1))
                self.writer.add_scalar('Validation Acc', acc, (i + 1))

            self.early_stopping(-acc, self.model)

            if self.early_stopping.early_stop:
                print("Early stopping")
                break

        print("=" * 20)
        print("Training finished !!")
        print("=" * 20)


In [24]:
model_ft = mednet(config=cfg, max_labels=len(answer_map))
writer = SummaryWriter('runs/resnet18_fusion-MLB_run_1')
trainer = Trainer(TrainDataLoader, ValDataLoader, model_ft, writer=writer, config=cfg)

Number of vision filters: 512
Training will be done on  cuda
Let's use 2 GPUs!


In [25]:
trainer.perform_training(cfg['epochs'])

4it [00:04,  1.01s/it]

[Initial Validation results] Loss: 5.812502026557922 	 Acc: 0.0



282it [00:52,  5.33it/s]
4it [00:01,  2.56it/s]


[1/250] Loss: 5.811530590057373 	 Acc: 0.0
Validation loss decreased (inf --> -0.000000).  Saving model ...


282it [00:53,  5.32it/s]
4it [00:01,  2.63it/s]


[2/250] Loss: 5.801392555236816 	 Acc: 0.2
Validation loss decreased (-0.000000 --> -0.200000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.54it/s]

[3/250] Loss: 5.770436763763428 	 Acc: 0.0
EarlyStopping counter: 1 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.52it/s]


[4/250] Loss: 5.693037986755371 	 Acc: 1.8
Validation loss decreased (-0.200000 --> -1.800000).  Saving model ...


282it [00:52,  5.32it/s]
4it [00:01,  2.50it/s]


[5/250] Loss: 5.548788070678711 	 Acc: 4.2
Validation loss decreased (-1.800000 --> -4.200000).  Saving model ...


282it [00:53,  5.32it/s]
4it [00:01,  2.61it/s]


[6/250] Loss: 5.393372893333435 	 Acc: 4.6
Validation loss decreased (-4.200000 --> -4.600000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.58it/s]


[7/250] Loss: 5.202782034873962 	 Acc: 7.4
Validation loss decreased (-4.600000 --> -7.400000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.55it/s]


[8/250] Loss: 4.987839221954346 	 Acc: 10.4
Validation loss decreased (-7.400000 --> -10.400000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.62it/s]


[9/250] Loss: 4.813996434211731 	 Acc: 10.4
Validation loss decreased (-10.400000 --> -10.400000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.56it/s]


[10/250] Loss: 4.713526368141174 	 Acc: 11.2
Validation loss decreased (-10.400000 --> -11.200000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.54it/s]


[11/250] Loss: 4.552898406982422 	 Acc: 13.8
Validation loss decreased (-11.200000 --> -13.800000).  Saving model ...


282it [00:53,  5.32it/s]
4it [00:01,  2.57it/s]


[12/250] Loss: 4.415682315826416 	 Acc: 15.4
Validation loss decreased (-13.800000 --> -15.400000).  Saving model ...


282it [00:53,  5.32it/s]
4it [00:01,  2.57it/s]

[13/250] Loss: 4.307988882064819 	 Acc: 12.8
EarlyStopping counter: 1 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.60it/s]


[14/250] Loss: 4.242518663406372 	 Acc: 15.4
Validation loss decreased (-15.400000 --> -15.400000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.56it/s]


[15/250] Loss: 4.198721766471863 	 Acc: 17.4
Validation loss decreased (-15.400000 --> -17.400000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.61it/s]


[16/250] Loss: 4.040081977844238 	 Acc: 21.0
Validation loss decreased (-17.400000 --> -21.000000).  Saving model ...


282it [00:53,  5.32it/s]
4it [00:01,  2.57it/s]

[17/250] Loss: 4.051970601081848 	 Acc: 19.2
EarlyStopping counter: 1 out of 10



282it [00:52,  5.32it/s]
4it [00:01,  2.54it/s]


[18/250] Loss: 3.875373899936676 	 Acc: 23.8
Validation loss decreased (-21.000000 --> -23.800000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.55it/s]

[19/250] Loss: 3.9861786365509033 	 Acc: 19.0
EarlyStopping counter: 1 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.52it/s]

[20/250] Loss: 3.8456026315689087 	 Acc: 21.8
EarlyStopping counter: 2 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.56it/s]

[21/250] Loss: 3.7877159118652344 	 Acc: 22.8
EarlyStopping counter: 3 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.54it/s]

[22/250] Loss: 3.8751221895217896 	 Acc: 22.8
EarlyStopping counter: 4 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.60it/s]


[23/250] Loss: 3.7947794795036316 	 Acc: 23.8
Validation loss decreased (-23.800000 --> -23.800000).  Saving model ...


282it [00:53,  5.32it/s]
4it [00:01,  2.53it/s]

[24/250] Loss: 3.720829427242279 	 Acc: 23.4
EarlyStopping counter: 1 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.52it/s]


[25/250] Loss: 3.686511814594269 	 Acc: 27.6
Validation loss decreased (-23.800000 --> -27.600000).  Saving model ...


282it [00:53,  5.32it/s]
4it [00:01,  2.58it/s]

[26/250] Loss: 3.749365210533142 	 Acc: 25.2
EarlyStopping counter: 1 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.58it/s]

[27/250] Loss: 3.691597104072571 	 Acc: 27.0
EarlyStopping counter: 2 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.60it/s]


[28/250] Loss: 3.6353889107704163 	 Acc: 30.2
Validation loss decreased (-27.600000 --> -30.200000).  Saving model ...


282it [00:53,  5.32it/s]
4it [00:01,  2.60it/s]

[29/250] Loss: 3.584589421749115 	 Acc: 28.4
EarlyStopping counter: 1 out of 10



282it [00:53,  5.32it/s]
4it [00:01,  2.57it/s]

[30/250] Loss: 3.6709401607513428 	 Acc: 28.2
EarlyStopping counter: 2 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.55it/s]


[31/250] Loss: 3.5841389298439026 	 Acc: 32.4
Validation loss decreased (-30.200000 --> -32.400000).  Saving model ...


282it [00:53,  5.32it/s]
4it [00:01,  2.56it/s]

[32/250] Loss: 3.809987485408783 	 Acc: 30.0
EarlyStopping counter: 1 out of 10



282it [00:53,  5.32it/s]
4it [00:01,  2.55it/s]

[33/250] Loss: 3.6738269329071045 	 Acc: 31.4
EarlyStopping counter: 2 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.57it/s]

[34/250] Loss: 3.709216356277466 	 Acc: 28.6
EarlyStopping counter: 3 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.56it/s]

[35/250] Loss: 3.7636728286743164 	 Acc: 32.0
EarlyStopping counter: 4 out of 10



282it [00:52,  5.32it/s]
4it [00:01,  2.55it/s]


[36/250] Loss: 3.6823253631591797 	 Acc: 33.2
Validation loss decreased (-32.400000 --> -33.200000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.53it/s]


[37/250] Loss: 3.5620210766792297 	 Acc: 33.4
Validation loss decreased (-33.200000 --> -33.400000).  Saving model ...


282it [00:53,  5.32it/s]
4it [00:01,  2.61it/s]

[38/250] Loss: 3.7807061672210693 	 Acc: 31.6
EarlyStopping counter: 1 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.59it/s]


[39/250] Loss: 3.6551392674446106 	 Acc: 34.2
Validation loss decreased (-33.400000 --> -34.200000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.56it/s]


[40/250] Loss: 3.7373404502868652 	 Acc: 34.8
Validation loss decreased (-34.200000 --> -34.800000).  Saving model ...


282it [00:53,  5.30it/s]
4it [00:01,  2.58it/s]

[41/250] Loss: 3.8876466751098633 	 Acc: 31.2
EarlyStopping counter: 1 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.56it/s]

[42/250] Loss: 3.796594500541687 	 Acc: 31.2
EarlyStopping counter: 2 out of 10



282it [00:53,  5.32it/s]
4it [00:01,  2.55it/s]

[43/250] Loss: 3.8053457736968994 	 Acc: 31.8
EarlyStopping counter: 3 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.58it/s]

[44/250] Loss: 3.856473743915558 	 Acc: 32.8
EarlyStopping counter: 4 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.53it/s]

[45/250] Loss: 3.6343867778778076 	 Acc: 32.4
EarlyStopping counter: 5 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.55it/s]

[46/250] Loss: 3.7829599380493164 	 Acc: 33.0
EarlyStopping counter: 6 out of 10



282it [00:53,  5.32it/s]
4it [00:01,  2.58it/s]


[47/250] Loss: 3.7185477018356323 	 Acc: 35.0
Validation loss decreased (-34.800000 --> -35.000000).  Saving model ...


282it [00:53,  5.32it/s]
4it [00:01,  2.59it/s]

[48/250] Loss: 3.895490348339081 	 Acc: 31.4
EarlyStopping counter: 1 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.57it/s]


[49/250] Loss: 3.728913366794586 	 Acc: 36.6
Validation loss decreased (-35.000000 --> -36.600000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.59it/s]

[50/250] Loss: 3.7310253381729126 	 Acc: 34.2
EarlyStopping counter: 1 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.62it/s]

[51/250] Loss: 3.775515556335449 	 Acc: 35.0
EarlyStopping counter: 2 out of 10



282it [00:53,  5.32it/s]
4it [00:01,  2.55it/s]

[52/250] Loss: 3.929898262023926 	 Acc: 34.8
EarlyStopping counter: 3 out of 10



282it [00:53,  5.32it/s]
4it [00:01,  2.52it/s]

[53/250] Loss: 3.8607380390167236 	 Acc: 35.2
EarlyStopping counter: 4 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.58it/s]

[54/250] Loss: 4.006890952587128 	 Acc: 32.2
EarlyStopping counter: 5 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.59it/s]


[55/250] Loss: 3.9344884753227234 	 Acc: 36.8
Validation loss decreased (-36.600000 --> -36.800000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.54it/s]


[56/250] Loss: 3.8711518049240112 	 Acc: 38.0
Validation loss decreased (-36.800000 --> -38.000000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.58it/s]

[57/250] Loss: 3.9980027079582214 	 Acc: 35.6
EarlyStopping counter: 1 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.56it/s]

[58/250] Loss: 3.6581377387046814 	 Acc: 37.6
EarlyStopping counter: 2 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.54it/s]

[59/250] Loss: 3.86177796125412 	 Acc: 34.8
EarlyStopping counter: 3 out of 10



282it [00:53,  5.32it/s]
4it [00:01,  2.61it/s]

[60/250] Loss: 3.834803819656372 	 Acc: 34.0
EarlyStopping counter: 4 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.55it/s]

[61/250] Loss: 3.8515923619270325 	 Acc: 36.4
EarlyStopping counter: 5 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.60it/s]

[62/250] Loss: 3.8175195455551147 	 Acc: 37.4
EarlyStopping counter: 6 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.57it/s]

[63/250] Loss: 3.763061761856079 	 Acc: 36.4
EarlyStopping counter: 7 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.60it/s]


[64/250] Loss: 3.6883451342582703 	 Acc: 40.6
Validation loss decreased (-38.000000 --> -40.600000).  Saving model ...


282it [00:53,  5.30it/s]
4it [00:01,  2.55it/s]

[65/250] Loss: 3.918591022491455 	 Acc: 35.6
EarlyStopping counter: 1 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.59it/s]

[66/250] Loss: 3.950996696949005 	 Acc: 34.8
EarlyStopping counter: 2 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.52it/s]

[67/250] Loss: 3.828134775161743 	 Acc: 36.4
EarlyStopping counter: 3 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.57it/s]

[68/250] Loss: 3.7998949885368347 	 Acc: 37.2
EarlyStopping counter: 4 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.57it/s]

[69/250] Loss: 3.8688319325447083 	 Acc: 39.4
EarlyStopping counter: 5 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.55it/s]

[70/250] Loss: 3.921475887298584 	 Acc: 36.2
EarlyStopping counter: 6 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.54it/s]

[71/250] Loss: 3.9739432334899902 	 Acc: 36.8
EarlyStopping counter: 7 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.55it/s]

[72/250] Loss: 3.9036930799484253 	 Acc: 36.6
EarlyStopping counter: 8 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.54it/s]

[73/250] Loss: 3.884511351585388 	 Acc: 38.6
EarlyStopping counter: 9 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.57it/s]

[74/250] Loss: 4.0045188665390015 	 Acc: 35.4
EarlyStopping counter: 10 out of 10
Early stopping
Training finished !!





In [26]:
class medical_dataset_test(Dataset):
    def __init__(self, config=None, answer_map=None, image_path='./dataset/Task1-VQA-2021-TestSet-w-GroundTruth/VQA-500-Images/', qa_file="./dataset/Task1-VQA-2021-TestSet-w-GroundTruth/Task1-VQA-2021-TestSet-Questions.txt", train=True):
            
        assert answer_map!=None
        assert config!=None
        
        self.image_path = image_path
        self.qa_file = qa_file
        self.config = config
        self.train = train
        
        self.answer_map = answer_map
        self.data = pd.read_csv(qa_file, sep='|', names=['imageid', 'question'])
           
        # print(Counter(self.data.answer.tolist()))
        
        self.transforms = transforms.Compose([
                                transforms.RandomResizedCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                            ])
        self.tokenizer = tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")

    def process_data(self, text, max_len):
        text = str(text)

        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=max_len,
            truncation=True
        )
        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]
        
        padding_length = max_len - len(ids)
        
        ids = ids + ([0] * padding_length)
        mask = mask + ([0] * padding_length)
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
        }
    
    
    def __getitem__(self, index):
        
        question = self.data.question[index]
        image_idx = self.data.imageid[index]
        
        ## add visual feature related steps
        visuals = Image.open(join(self.image_path, f'{image_idx}.jpg'))
        visuals = self.transforms(visuals)
        
        
        tmp = self.process_data(question, self.config['max_len'])
        question_tokens = {
            'ids': tmp['ids'],
            'mask': tmp['mask'],
        }
        
        return visuals, question_tokens, image_idx

    def __len__(self):
        return len(self.data)

In [27]:
TestData = medical_dataset_test(config=cfg, answer_map=answer_map)
TestDataLoader = DataLoader(TestData, batch_size=128, shuffle=False, num_workers=4)  # num_workers=0 for windows OS

In [28]:
inv_map = {v: k for k, v in answer_map.items()}

In [29]:
trainer.model.load_state_dict(torch.load('./models/checkpoint.pt'))
trainer.model.eval()
        
imageids = []
answers = []

running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
    for en, (visuals, question, target) in tqdm(enumerate(TestDataLoader)):
        visuals = visuals.to(trainer.device)

        ids = question['ids'].to(trainer.device)
        mask = question['mask'].to(trainer.device)

        outputs = trainer.model(visuals, ids, mask)

        y_pred_softmax = torch.log_softmax(outputs, dim = 1)
        _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)
        # self.tmp_sv_ = y_pred_tags
        
        for i in range(visuals.shape[0]):
            imageids.append(target[i])
            answers.append(inv_map[int(y_pred_tags[i])])


4it [00:01,  2.42it/s]


In [30]:
answers[:5]

['sickle cell disease',
 'pulmonary embolism',
 'nonossifying fibroma',
 'choledocholithiasis',
 'lateral femoral notch sign']

In [31]:
pd.DataFrame({'imageids':imageids, 'answers':answers}).to_csv('resnet18_fusion-MLB.txt', sep='|', index=False, header=False)