In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
import numpy as np
import sys
from os import listdir, makedirs, getcwd, remove
from os.path import isfile, join, abspath, exists, isdir, expanduser
from scipy.io import loadmat
from PIL import Image
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms


import pickle
from tqdm import tqdm
import json
import random
from collections import Counter

from transformers import get_linear_schedule_with_warmup
from transformers import AutoTokenizer, AutoModel

In [4]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.autograd as autograd
from torch.utils.data import Dataset, DataLoader
import torch.utils.data as data
import torch.nn.functional as F
import torch.optim as optim
import itertools

import torchvision
from torchvision import transforms, datasets, models
from torch import Tensor

from torch.utils.tensorboard import SummaryWriter

In [5]:
import transformers
import tokenizers
from transformers import BertTokenizer, BertModel

In [6]:
cfg = {
    'max_len': 128,
    'lr': 2e-5,
    'warmup_steps': 5,
    'epochs': 250
}

In [7]:
class medical_dataset(Dataset):
    def __init__(self, config=None, answer_map=None, image_path='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-TrainingSet/VQAnswering_2020_Train_images/', qa_file="./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-TrainingSet/VQAnswering_2020_Train_QA_pairs.txt", train=True):
            
        assert answer_map!=None
        assert config!=None
        
        self.image_path = image_path
        self.qa_file = qa_file
        self.config = config
        self.train = train
        
        self.answer_map = answer_map
        self.data = pd.read_csv(qa_file, sep='|', names=['imageid', 'question', 'answer'])
           
        # print(Counter(self.data.answer.tolist()))
        
        self.transforms = transforms.Compose([
                                transforms.RandomResizedCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                            ])
        self.tokenizer = tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")

    def process_data(self, text, max_len):
        text = str(text)

        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=max_len,
            truncation=True
        )
        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]
        
        padding_length = max_len - len(ids)
        
        ids = ids + ([0] * padding_length)
        mask = mask + ([0] * padding_length)
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
        }
    
    
    def __getitem__(self, index):
        
        question = self.data.question[index]
        answer = self.data.answer[index]
        image_idx = self.data.imageid[index]
        
        ## add visual feature related steps
        visuals = Image.open(join(self.image_path, f'{image_idx}.jpg'))
        visuals = self.transforms(visuals)
        
        target = torch.from_numpy(np.array([self.answer_map[answer]])).long()
        
        tmp = self.process_data(question, self.config['max_len'])
        question_tokens = {
            'ids': tmp['ids'],
            'mask': tmp['mask'],
        }
        
        return visuals, question_tokens, target

    def __len__(self):
        return len(self.data)

In [8]:
# file_path = './dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-TrainingSet/VQAnswering_2020_Train_QA_pairs.txt'
# data = pd.read_csv(file_path, sep='|', names=['imageid', 'question', 'answer'])
# data.head()

In [9]:
# answer_map = {}
# ct = 0

# for i in data.answer.unique():
#     answer_map[i] = ct
#     ct+=1

In [10]:
# with open('./answer_map.pickle', 'wb') as handle:
#     pickle.dump(answer_map, handle, protocol=pickle.HIGHEST_PROTOCOL)
    

In [11]:
with open('./answer_map.pickle', 'rb') as handle:
    answer_map = pickle.load(handle)
    
len(answer_map)

332

In [12]:
TrainData = medical_dataset(config=cfg, answer_map=answer_map, image_path='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-TrainingSet/VQAnswering_2020_Train_images/', qa_file='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-TrainingSet/VQAnswering_2020_Train_QA_pairs.txt')
ValData = medical_dataset(config=cfg, answer_map=answer_map, image_path='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-ValidationSet/VQAnswering_2020_Val_images/', qa_file='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-ValidationSet/VQAnswering_2020_Val_QA_Pairs.txt', train=False)


In [13]:
%%time
a,b,c = ValData.__getitem__(0)

CPU times: user 30.7 ms, sys: 0 ns, total: 30.7 ms
Wall time: 27.4 ms


In [14]:
TrainDataLoader = DataLoader(TrainData, batch_size=16, shuffle=True, num_workers=4)  # num_workers=0 for windows OS
ValDataLoader = DataLoader(ValData, batch_size=128, shuffle=False, num_workers=4)  # num_workers=0 for windows OS

In [15]:
class MinhmulFusion(nn.Module):

    def __init__(self, dim_v=None, dim_q=None, dim_h=512):
        super(MinhmulFusion, self).__init__()
        # Modules
        self.dim_v = dim_v
        self.dim_q = dim_q
        self.dim_h = dim_h
        self.dropout = 0
        self.activation = 'relu'
        
        if dim_v:
            self.linear_v = nn.Linear(dim_v, dim_h)
        else:
            print('Warning fusion.py: no visual embedding before fusion')

        if dim_q:
            self.linear_q = nn.Linear(dim_q, dim_h)
        else:
            print('Warning fusion.py: no question embedding before fusion')
        
    def forward(self, input_v, input_q):
        # visual (cnn features)
        if self.dim_v:
            x_v = F.dropout(input_v, p=self.dropout, training=self.training)
            x_v = self.linear_v(x_v)
            x_v = getattr(F, self.activation)(x_v)
        else:
            x_v = input_v
        # question (rnn features)
        if self.dim_q:
            x_q = F.dropout(input_q, p=self.dropout, training=self.training)
            x_q = self.linear_q(x_q)
            x_q = getattr(F, self.activation)(x_q)
        else:
            x_q = input_q
        # hadamard product
        # x_mm = torch.mul(x_q, x_q)
        x_q = torch.pow(x_q, 2)
        x_v = torch.pow(x_v, 1)

        x_mm = torch.mul(x_q, x_v)
        # x_mm = torch.mul(x_q, x_q)
        # x_mm = torch.mul(x_mm, x_v)
        return x_mm


In [16]:
class mednet(nn.Module):
    def __init__(self, config, max_labels):
        super(mednet, self).__init__()
        self.vision = models.resnet18(pretrained=True)
        num_ftrs = self.vision.fc.in_features
        
        self.vision = nn.Sequential(*list(self.vision.children())[:-1])
        self.bert = AutoModel.from_pretrained("dmis-lab/biobert-v1.1")
        
        print(f"Number of vision filters: {num_ftrs}")
        self.fc1 = nn.Linear(512, 128)
        self.fc2 = nn.Linear(128, max_labels)
        
        self.fusion = MinhmulFusion(dim_v=num_ftrs, dim_q=768, dim_h=512)
        

    def forward(self, visual=None, ids=None, mask=None):
        vision = self.vision(visual).view((ids.shape[0], -1))
        bert_out = self.bert(ids, mask)
        
        h = self.fusion(vision, bert_out.last_hidden_state[:,0])
        h = F.relu(self.fc1(h))
        h = self.fc2(h)
        return h


In [17]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self,
                 patience=7,
                 verbose=False,
                 delta=0,
                 path='./models/fusion/checkpoint.pt',
                 trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(
                f'EarlyStopping counter: {self.counter} out of {self.patience}'
            )
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(
                f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...'
            )
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [18]:
class Trainer:
    def __init__(self,
                 trainloader,
                 vallaoder,
                 model_ft,
                 writer=None,
                 testloader=None,
                 checkpoint_path=None,
                 patience=5,
                 feature_extract=True,
                 print_itr=50,
                 config=None):
        self.trainloader = trainloader
        self.valloader = vallaoder
        self.testloader = testloader
        
        self.config=config

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        print("==" * 10)
        print("Training will be done on ", self.device)
        print("==" * 10)

        self.model = model_ft        
        if torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            self.model = nn.DataParallel(self.model, device_ids=[0,1])
        
        self.model = self.model.to(self.device)
        
        
        # Observe that all parameters are being optimized
        self.optimizer = optim.RAdam(self.model.parameters(), lr=self.config['lr'])
        
        self.scheduler = get_linear_schedule_with_warmup(self.optimizer, 
                                            num_warmup_steps = len(self.trainloader)*self.config['warmup_steps'], # Default value in run_glue.py
                                            num_training_steps = len(self.trainloader)*self.config['epochs'])

        
        self.criterion = nn.CrossEntropyLoss()
        self.early_stopping = EarlyStopping(patience=patience, verbose=True)
        self.writer = writer
        self.print_itr = print_itr

    def train(self, ep):
        self.model.train()

        running_loss = 0.0

        for en, (visuals, question, target) in tqdm(enumerate(self.trainloader)):
            self.optimizer.zero_grad()
            
            visuals = visuals.to(self.device)
            y = target.squeeze().to(self.device)
            
            ids = question['ids'].to(self.device)
            mask = question['mask'].to(self.device)

            outputs = self.model(visuals, ids, mask)
            loss = self.criterion(outputs, y)
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()

            running_loss += loss.item()
            if self.writer:
                self.writer.add_scalar('Train Loss', running_loss, ep*len(self.trainloader) + en)
            running_loss = 0
            

    def validate(self, ep):
        self.model.eval()
        
        running_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for en, (visuals, question, target) in tqdm(enumerate(self.valloader)):
                visuals = visuals.to(self.device)
                y = target.squeeze().to(self.device)
                
                ids = question['ids'].to(self.device)
                mask = question['mask'].to(self.device)
                
                outputs = self.model(visuals, ids, mask)
                loss = self.criterion(outputs, y)

                y_pred_softmax = torch.log_softmax(outputs, dim = 1)
                _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)
                # self.tmp_sv_ = y_pred_tags
                
                correct += (y_pred_tags.detach().cpu().data.numpy() == y.detach().cpu().data.numpy()).sum()
                total += y_pred_tags.shape[0]
                
                # print statistics
                running_loss += loss.item()
        
        
        return running_loss / len(self.valloader), correct*100/total

    def perform_training(self, total_epoch):
        val_loss, acc = self.validate(0)

        print("[Initial Validation results] Loss: {} \t Acc: {}".format(
            val_loss, acc))

        for i in range(total_epoch):
            self.train(i + 1)
            val_loss, acc = self.validate(i + 1)
            print('[{}/{}] Loss: {} \t Acc: {}'.format(i+1, total_epoch, val_loss, acc))

            if self.writer:
                self.writer.add_scalar('Validation Loss', val_loss, (i + 1))
                self.writer.add_scalar('Validation Acc', acc, (i + 1))

            self.early_stopping(val_loss, self.model)

            # if self.early_stopping.early_stop:
            #     print("Early stopping")
            #     break

        print("=" * 20)
        print("Training finished !!")
        print("=" * 20)


In [19]:
model_ft = mednet(config=cfg, max_labels=len(answer_map))
writer = SummaryWriter('runs/fusion_run_1')
trainer = Trainer(TrainDataLoader, ValDataLoader, model_ft, writer=writer, config=cfg)

Number of vision filters: 512
Training will be done on  cuda
Let's use 2 GPUs!


In [20]:
trainer.perform_training(cfg['epochs'])

4it [00:04,  1.07s/it]

[Initial Validation results] Loss: 5.804837584495544 	 Acc: 0.6



250it [00:46,  5.34it/s]
4it [00:01,  2.66it/s]


[1/250] Loss: 5.804639458656311 	 Acc: 0.4
Validation loss decreased (inf --> 5.804639).  Saving model ...


250it [00:46,  5.36it/s]
4it [00:01,  2.67it/s]


[2/250] Loss: 5.7992247343063354 	 Acc: 0.0
Validation loss decreased (5.804639 --> 5.799225).  Saving model ...


250it [00:46,  5.34it/s]
4it [00:01,  2.65it/s]


[3/250] Loss: 5.758097171783447 	 Acc: 0.2
Validation loss decreased (5.799225 --> 5.758097).  Saving model ...


250it [00:46,  5.34it/s]
4it [00:01,  2.63it/s]


[4/250] Loss: 5.689926624298096 	 Acc: 1.6
Validation loss decreased (5.758097 --> 5.689927).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.62it/s]


[5/250] Loss: 5.375650405883789 	 Acc: 6.0
Validation loss decreased (5.689927 --> 5.375650).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.64it/s]


[6/250] Loss: 5.245293974876404 	 Acc: 7.6
Validation loss decreased (5.375650 --> 5.245294).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.64it/s]


[7/250] Loss: 5.125733017921448 	 Acc: 7.8
Validation loss decreased (5.245294 --> 5.125733).  Saving model ...


250it [00:46,  5.34it/s]
4it [00:01,  2.61it/s]


[8/250] Loss: 4.959262013435364 	 Acc: 8.6
Validation loss decreased (5.125733 --> 4.959262).  Saving model ...


250it [00:46,  5.34it/s]
4it [00:01,  2.66it/s]


[9/250] Loss: 4.871185541152954 	 Acc: 10.6
Validation loss decreased (4.959262 --> 4.871186).  Saving model ...


250it [00:46,  5.32it/s]
4it [00:01,  2.63it/s]


[10/250] Loss: 4.774176836013794 	 Acc: 11.0
Validation loss decreased (4.871186 --> 4.774177).  Saving model ...


250it [00:46,  5.34it/s]
4it [00:01,  2.63it/s]


[11/250] Loss: 4.593083143234253 	 Acc: 13.6
Validation loss decreased (4.774177 --> 4.593083).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.63it/s]


[12/250] Loss: 4.5334166288375854 	 Acc: 13.4
Validation loss decreased (4.593083 --> 4.533417).  Saving model ...


250it [00:46,  5.34it/s]
4it [00:01,  2.61it/s]


[13/250] Loss: 4.498239278793335 	 Acc: 13.8
Validation loss decreased (4.533417 --> 4.498239).  Saving model ...


250it [00:47,  5.32it/s]
4it [00:01,  2.61it/s]


[14/250] Loss: 4.308689832687378 	 Acc: 16.2
Validation loss decreased (4.498239 --> 4.308690).  Saving model ...


250it [00:46,  5.32it/s]
4it [00:01,  2.63it/s]


[15/250] Loss: 4.198996126651764 	 Acc: 17.6
Validation loss decreased (4.308690 --> 4.198996).  Saving model ...


250it [00:46,  5.33it/s]
4it [00:01,  2.64it/s]

[16/250] Loss: 4.3623206615448 	 Acc: 14.8
EarlyStopping counter: 1 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.60it/s]


[17/250] Loss: 4.186459302902222 	 Acc: 19.2
Validation loss decreased (4.198996 --> 4.186459).  Saving model ...


250it [00:46,  5.32it/s]
4it [00:01,  2.63it/s]

[18/250] Loss: 4.194683253765106 	 Acc: 21.6
EarlyStopping counter: 1 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.61it/s]


[19/250] Loss: 4.04733681678772 	 Acc: 22.4
Validation loss decreased (4.186459 --> 4.047337).  Saving model ...


250it [00:46,  5.32it/s]
4it [00:01,  2.62it/s]

[20/250] Loss: 4.150575578212738 	 Acc: 19.6
EarlyStopping counter: 1 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.61it/s]

[21/250] Loss: 4.203032195568085 	 Acc: 21.8
EarlyStopping counter: 2 out of 5



250it [00:47,  5.32it/s]
4it [00:01,  2.62it/s]

[22/250] Loss: 4.050514578819275 	 Acc: 24.0
EarlyStopping counter: 3 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.64it/s]

[23/250] Loss: 4.078871369361877 	 Acc: 21.6
EarlyStopping counter: 4 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.62it/s]

[24/250] Loss: 4.0815399289131165 	 Acc: 24.8
EarlyStopping counter: 5 out of 5



250it [00:47,  5.32it/s]
4it [00:01,  2.62it/s]

[25/250] Loss: 4.065466821193695 	 Acc: 23.4
EarlyStopping counter: 6 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.63it/s]

[26/250] Loss: 4.355287432670593 	 Acc: 22.8
EarlyStopping counter: 7 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.56it/s]

[27/250] Loss: 4.158205926418304 	 Acc: 26.6
EarlyStopping counter: 8 out of 5



250it [00:46,  5.34it/s]
4it [00:01,  2.64it/s]

[28/250] Loss: 4.113175272941589 	 Acc: 25.4
EarlyStopping counter: 9 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.62it/s]

[29/250] Loss: 4.065544247627258 	 Acc: 26.2
EarlyStopping counter: 10 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.60it/s]

[30/250] Loss: 4.133573353290558 	 Acc: 25.2
EarlyStopping counter: 11 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.60it/s]

[31/250] Loss: 4.226832985877991 	 Acc: 25.6
EarlyStopping counter: 12 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.62it/s]

[32/250] Loss: 4.255941927433014 	 Acc: 26.4
EarlyStopping counter: 13 out of 5



250it [00:47,  5.32it/s]
4it [00:01,  2.62it/s]

[33/250] Loss: 4.229967951774597 	 Acc: 25.6
EarlyStopping counter: 14 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.60it/s]

[34/250] Loss: 4.305858731269836 	 Acc: 28.4
EarlyStopping counter: 15 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.60it/s]

[35/250] Loss: 4.085360646247864 	 Acc: 26.8
EarlyStopping counter: 16 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.60it/s]

[36/250] Loss: 4.284018397331238 	 Acc: 30.2
EarlyStopping counter: 17 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.59it/s]

[37/250] Loss: 4.16567599773407 	 Acc: 31.0
EarlyStopping counter: 18 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.56it/s]

[38/250] Loss: 4.146484315395355 	 Acc: 30.6
EarlyStopping counter: 19 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.62it/s]

[39/250] Loss: 4.189395427703857 	 Acc: 29.2
EarlyStopping counter: 20 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.63it/s]

[40/250] Loss: 4.115708529949188 	 Acc: 29.4
EarlyStopping counter: 21 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.63it/s]

[41/250] Loss: 4.518850803375244 	 Acc: 27.8
EarlyStopping counter: 22 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.61it/s]

[42/250] Loss: 4.3720585107803345 	 Acc: 28.0
EarlyStopping counter: 23 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.62it/s]

[43/250] Loss: 4.2547866106033325 	 Acc: 29.8
EarlyStopping counter: 24 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.61it/s]

[44/250] Loss: 4.101820886135101 	 Acc: 30.8
EarlyStopping counter: 25 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.62it/s]

[45/250] Loss: 4.163883566856384 	 Acc: 31.6
EarlyStopping counter: 26 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.65it/s]

[46/250] Loss: 4.255007863044739 	 Acc: 29.4
EarlyStopping counter: 27 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.58it/s]

[47/250] Loss: 4.335599422454834 	 Acc: 31.8
EarlyStopping counter: 28 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.65it/s]

[48/250] Loss: 4.274009108543396 	 Acc: 31.8
EarlyStopping counter: 29 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.59it/s]

[49/250] Loss: 4.594709992408752 	 Acc: 30.4
EarlyStopping counter: 30 out of 5



250it [00:47,  5.32it/s]
4it [00:01,  2.63it/s]

[50/250] Loss: 4.155714631080627 	 Acc: 33.2
EarlyStopping counter: 31 out of 5



250it [00:47,  5.32it/s]
4it [00:01,  2.62it/s]

[51/250] Loss: 4.38055145740509 	 Acc: 30.8
EarlyStopping counter: 32 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.63it/s]

[52/250] Loss: 4.613361358642578 	 Acc: 30.0
EarlyStopping counter: 33 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.62it/s]

[53/250] Loss: 4.625774264335632 	 Acc: 31.2
EarlyStopping counter: 34 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.62it/s]

[54/250] Loss: 4.5729042291641235 	 Acc: 33.8
EarlyStopping counter: 35 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.63it/s]

[55/250] Loss: 4.390235185623169 	 Acc: 31.4
EarlyStopping counter: 36 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.64it/s]

[56/250] Loss: 4.449095249176025 	 Acc: 32.8
EarlyStopping counter: 37 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.63it/s]

[57/250] Loss: 4.291743755340576 	 Acc: 35.2
EarlyStopping counter: 38 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.59it/s]

[58/250] Loss: 4.443251371383667 	 Acc: 32.8
EarlyStopping counter: 39 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.60it/s]

[59/250] Loss: 4.419125437736511 	 Acc: 31.8
EarlyStopping counter: 40 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.60it/s]

[60/250] Loss: 4.4474098682403564 	 Acc: 30.8
EarlyStopping counter: 41 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.60it/s]

[61/250] Loss: 4.376547813415527 	 Acc: 33.4
EarlyStopping counter: 42 out of 5



250it [00:47,  5.32it/s]
4it [00:01,  2.63it/s]

[62/250] Loss: 4.434884071350098 	 Acc: 33.0
EarlyStopping counter: 43 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.60it/s]

[63/250] Loss: 4.304305195808411 	 Acc: 35.4
EarlyStopping counter: 44 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.63it/s]

[64/250] Loss: 4.51214873790741 	 Acc: 34.6
EarlyStopping counter: 45 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.59it/s]

[65/250] Loss: 4.448162078857422 	 Acc: 35.4
EarlyStopping counter: 46 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.61it/s]

[66/250] Loss: 4.6145758628845215 	 Acc: 35.2
EarlyStopping counter: 47 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.65it/s]

[67/250] Loss: 4.489890933036804 	 Acc: 32.6
EarlyStopping counter: 48 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.59it/s]

[68/250] Loss: 4.4329187870025635 	 Acc: 35.4
EarlyStopping counter: 49 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.60it/s]

[69/250] Loss: 4.592087626457214 	 Acc: 33.4
EarlyStopping counter: 50 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.63it/s]

[70/250] Loss: 4.370912790298462 	 Acc: 34.4
EarlyStopping counter: 51 out of 5



250it [00:47,  5.32it/s]
4it [00:01,  2.61it/s]

[71/250] Loss: 4.505208492279053 	 Acc: 31.8
EarlyStopping counter: 52 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.59it/s]

[72/250] Loss: 4.47956120967865 	 Acc: 33.6
EarlyStopping counter: 53 out of 5



250it [00:47,  5.32it/s]
4it [00:01,  2.57it/s]

[73/250] Loss: 4.189254879951477 	 Acc: 36.4
EarlyStopping counter: 54 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.65it/s]

[74/250] Loss: 4.821986198425293 	 Acc: 31.8
EarlyStopping counter: 55 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.59it/s]

[75/250] Loss: 4.639125108718872 	 Acc: 34.6
EarlyStopping counter: 56 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.60it/s]

[76/250] Loss: 4.74618923664093 	 Acc: 31.4
EarlyStopping counter: 57 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.53it/s]

[77/250] Loss: 4.5526362657547 	 Acc: 35.4
EarlyStopping counter: 58 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.62it/s]

[78/250] Loss: 4.544452548027039 	 Acc: 32.4
EarlyStopping counter: 59 out of 5



250it [00:47,  5.32it/s]
4it [00:01,  2.59it/s]

[79/250] Loss: 4.570830345153809 	 Acc: 34.6
EarlyStopping counter: 60 out of 5



250it [00:47,  5.32it/s]
4it [00:01,  2.62it/s]

[80/250] Loss: 4.651182293891907 	 Acc: 34.8
EarlyStopping counter: 61 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.59it/s]

[81/250] Loss: 4.881680130958557 	 Acc: 32.6
EarlyStopping counter: 62 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.66it/s]

[82/250] Loss: 4.573618531227112 	 Acc: 36.8
EarlyStopping counter: 63 out of 5



250it [00:47,  5.32it/s]
4it [00:01,  2.62it/s]

[83/250] Loss: 4.645077228546143 	 Acc: 36.4
EarlyStopping counter: 64 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.61it/s]

[84/250] Loss: 4.594783425331116 	 Acc: 37.2
EarlyStopping counter: 65 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.60it/s]

[85/250] Loss: 4.835582375526428 	 Acc: 35.8
EarlyStopping counter: 66 out of 5



250it [00:47,  5.32it/s]
4it [00:01,  2.63it/s]

[86/250] Loss: 4.78715443611145 	 Acc: 35.4
EarlyStopping counter: 67 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.62it/s]

[87/250] Loss: 4.465767860412598 	 Acc: 38.2
EarlyStopping counter: 68 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.61it/s]

[88/250] Loss: 4.642544746398926 	 Acc: 35.4
EarlyStopping counter: 69 out of 5



250it [00:47,  5.32it/s]
4it [00:01,  2.62it/s]

[89/250] Loss: 4.725299835205078 	 Acc: 32.6
EarlyStopping counter: 70 out of 5



250it [00:47,  5.30it/s]
4it [00:01,  2.62it/s]

[90/250] Loss: 4.683391094207764 	 Acc: 35.2
EarlyStopping counter: 71 out of 5



250it [00:46,  5.33it/s]
4it [00:01,  2.62it/s]

[91/250] Loss: 4.713443040847778 	 Acc: 34.6
EarlyStopping counter: 72 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.58it/s]

[92/250] Loss: 4.628618121147156 	 Acc: 34.2
EarlyStopping counter: 73 out of 5



250it [00:47,  5.32it/s]
4it [00:01,  2.60it/s]

[93/250] Loss: 4.898150086402893 	 Acc: 31.8
EarlyStopping counter: 74 out of 5



250it [00:47,  5.32it/s]
4it [00:01,  2.60it/s]

[94/250] Loss: 4.809429407119751 	 Acc: 37.4
EarlyStopping counter: 75 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.64it/s]

[95/250] Loss: 4.663732051849365 	 Acc: 35.4
EarlyStopping counter: 76 out of 5



250it [00:47,  5.32it/s]
4it [00:01,  2.61it/s]

[96/250] Loss: 4.731742262840271 	 Acc: 37.6
EarlyStopping counter: 77 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.60it/s]

[97/250] Loss: 4.789201259613037 	 Acc: 33.4
EarlyStopping counter: 78 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.61it/s]

[98/250] Loss: 4.8450340032577515 	 Acc: 33.8
EarlyStopping counter: 79 out of 5



250it [00:47,  5.32it/s]
4it [00:01,  2.62it/s]

[99/250] Loss: 4.770956039428711 	 Acc: 37.4
EarlyStopping counter: 80 out of 5



250it [00:46,  5.32it/s]
4it [00:01,  2.60it/s]

[100/250] Loss: 4.572443842887878 	 Acc: 38.2
EarlyStopping counter: 81 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.60it/s]

[101/250] Loss: 4.605502724647522 	 Acc: 36.8
EarlyStopping counter: 82 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.64it/s]

[102/250] Loss: 4.653134107589722 	 Acc: 37.6
EarlyStopping counter: 83 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.60it/s]

[103/250] Loss: 4.673865795135498 	 Acc: 37.6
EarlyStopping counter: 84 out of 5



250it [00:47,  5.32it/s]
4it [00:01,  2.63it/s]

[104/250] Loss: 4.856234312057495 	 Acc: 33.6
EarlyStopping counter: 85 out of 5



250it [00:47,  5.32it/s]
4it [00:01,  2.60it/s]

[105/250] Loss: 4.5644084215164185 	 Acc: 36.2
EarlyStopping counter: 86 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.63it/s]

[106/250] Loss: 4.711540579795837 	 Acc: 38.2
EarlyStopping counter: 87 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.62it/s]

[107/250] Loss: 4.666320323944092 	 Acc: 40.6
EarlyStopping counter: 88 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.63it/s]

[108/250] Loss: 4.647049188613892 	 Acc: 36.2
EarlyStopping counter: 89 out of 5



250it [00:47,  5.32it/s]
4it [00:01,  2.62it/s]

[109/250] Loss: 4.719738721847534 	 Acc: 38.0
EarlyStopping counter: 90 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.62it/s]

[110/250] Loss: 4.799131989479065 	 Acc: 35.4
EarlyStopping counter: 91 out of 5



250it [00:47,  5.30it/s]
4it [00:01,  2.61it/s]

[111/250] Loss: 4.731251001358032 	 Acc: 35.8
EarlyStopping counter: 92 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.56it/s]

[112/250] Loss: 4.7016438245773315 	 Acc: 35.8
EarlyStopping counter: 93 out of 5



250it [00:47,  5.32it/s]
4it [00:01,  2.59it/s]

[113/250] Loss: 4.617602109909058 	 Acc: 37.4
EarlyStopping counter: 94 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.58it/s]

[114/250] Loss: 4.597397804260254 	 Acc: 38.2
EarlyStopping counter: 95 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.58it/s]

[115/250] Loss: 4.600773096084595 	 Acc: 36.8
EarlyStopping counter: 96 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.57it/s]

[116/250] Loss: 4.5728349685668945 	 Acc: 38.4
EarlyStopping counter: 97 out of 5



250it [00:47,  5.32it/s]
4it [00:01,  2.59it/s]

[117/250] Loss: 4.863190531730652 	 Acc: 36.4
EarlyStopping counter: 98 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.61it/s]

[118/250] Loss: 4.682114124298096 	 Acc: 37.4
EarlyStopping counter: 99 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.54it/s]

[119/250] Loss: 4.682757139205933 	 Acc: 37.4
EarlyStopping counter: 100 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.62it/s]

[120/250] Loss: 4.673215627670288 	 Acc: 37.0
EarlyStopping counter: 101 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.64it/s]

[121/250] Loss: 4.769526958465576 	 Acc: 36.2
EarlyStopping counter: 102 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.60it/s]

[122/250] Loss: 4.671425700187683 	 Acc: 36.0
EarlyStopping counter: 103 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.61it/s]

[123/250] Loss: 4.716428518295288 	 Acc: 37.8
EarlyStopping counter: 104 out of 5



250it [00:47,  5.31it/s]
4it [00:01,  2.62it/s]

[124/250] Loss: 4.693289041519165 	 Acc: 37.6
EarlyStopping counter: 105 out of 5



18it [00:03,  4.91it/s]


KeyboardInterrupt: 

In [21]:
class medical_dataset_test(Dataset):
    def __init__(self, config=None, answer_map=None, image_path='./dataset/Task1-VQA-2021-TestSet-w-GroundTruth/VQA-500-Images/', qa_file="./dataset/Task1-VQA-2021-TestSet-w-GroundTruth/Task1-VQA-2021-TestSet-Questions.txt", train=True):
            
        assert answer_map!=None
        assert config!=None
        
        self.image_path = image_path
        self.qa_file = qa_file
        self.config = config
        self.train = train
        
        self.answer_map = answer_map
        self.data = pd.read_csv(qa_file, sep='|', names=['imageid', 'question'])
           
        # print(Counter(self.data.answer.tolist()))
        
        self.transforms = transforms.Compose([
                                transforms.RandomResizedCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                            ])
        self.tokenizer = tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")

    def process_data(self, text, max_len):
        text = str(text)

        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=max_len,
            truncation=True
        )
        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]
        
        padding_length = max_len - len(ids)
        
        ids = ids + ([0] * padding_length)
        mask = mask + ([0] * padding_length)
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
        }
    
    
    def __getitem__(self, index):
        
        question = self.data.question[index]
        image_idx = self.data.imageid[index]
        
        ## add visual feature related steps
        visuals = Image.open(join(self.image_path, f'{image_idx}.jpg'))
        visuals = self.transforms(visuals)
        
        
        tmp = self.process_data(question, self.config['max_len'])
        question_tokens = {
            'ids': tmp['ids'],
            'mask': tmp['mask'],
        }
        
        return visuals, question_tokens, image_idx

    def __len__(self):
        return len(self.data)

In [22]:
TestData = medical_dataset_test(config=cfg, answer_map=answer_map)
TestDataLoader = DataLoader(TestData, batch_size=128, shuffle=False, num_workers=4)  # num_workers=0 for windows OS

In [23]:
inv_map = {v: k for k, v in answer_map.items()}

In [24]:
trainer.model.eval()
        
imageids = []
answers = []

running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
    for en, (visuals, question, target) in tqdm(enumerate(TestDataLoader)):
        visuals = visuals.to(trainer.device)

        ids = question['ids'].to(trainer.device)
        mask = question['mask'].to(trainer.device)

        outputs = trainer.model(visuals, ids, mask)

        y_pred_softmax = torch.log_softmax(outputs, dim = 1)
        _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)
        # self.tmp_sv_ = y_pred_tags
        
        for i in range(visuals.shape[0]):
            imageids.append(target[i])
            answers.append(inv_map[int(y_pred_tags[i])])


4it [00:01,  2.35it/s]


In [25]:
answers[:5]

['avascular necrosis',
 'pulmonary embolism',
 'osteopoikilosis',
 'adrenal myelolipoma',
 'osteochondritis dissecans']

In [26]:
pd.DataFrame({'imageids':imageids, 'answers':answers}).to_csv('fusion.txt', sep='|', index=False, header=False)