In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
import numpy as np
import sys
from os import listdir, makedirs, getcwd, remove
from os.path import isfile, join, abspath, exists, isdir, expanduser
from scipy.io import loadmat
from PIL import Image
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms


import pickle
from tqdm import tqdm
import json
import random
from collections import Counter

from transformers import get_linear_schedule_with_warmup
from transformers import AutoTokenizer, AutoModel

In [4]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.autograd as autograd
from torch.utils.data import Dataset, DataLoader
import torch.utils.data as data
import torch.nn.functional as F
import torch.optim as optim
import itertools

import torchvision
from torchvision import transforms, datasets, models
from torch import Tensor

from torch.utils.tensorboard import SummaryWriter

In [5]:
import transformers
import tokenizers
from transformers import BertTokenizer, BertModel

In [6]:
cfg = {
    'max_len': 128,
    'lr': 2e-5,
    'warmup_steps': 5,
    'epochs': 250
}

In [7]:
class medical_dataset(Dataset):
    def __init__(self, config=None, answer_map=None, image_path='./dataset/TrainingDataset/images/', qa_file="./dataset/TrainingDataset/training_qa.txt", train=True):
            
        assert answer_map!=None
        assert config!=None
        
        self.image_path = image_path
        self.qa_file = qa_file
        self.config = config
        self.train = train
        
        self.answer_map = answer_map
        self.data = pd.read_csv(qa_file, sep='|', names=['imageid', 'question', 'answer'])
           
        # print(Counter(self.data.answer.tolist()))
        
        self.transforms = transforms.Compose([
                                transforms.RandomResizedCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                            ])
        self.tokenizer = tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")

    def process_data(self, text, max_len):
        text = str(text)

        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=max_len,
            truncation=True
        )
        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]
        
        padding_length = max_len - len(ids)
        
        ids = ids + ([0] * padding_length)
        mask = mask + ([0] * padding_length)
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
        }
    
    
    def __getitem__(self, index):
        
        question = self.data.question[index]
        answer = self.data.answer[index]
        image_idx = self.data.imageid[index]
        
        ## add visual feature related steps
        visuals = Image.open(join(self.image_path, f'{image_idx}.jpg'))
        visuals = self.transforms(visuals)
        
        target = torch.from_numpy(np.array([self.answer_map[answer]])).long()
        
        tmp = self.process_data(question, self.config['max_len'])
        question_tokens = {
            'ids': tmp['ids'],
            'mask': tmp['mask'],
        }
        
        return visuals, question_tokens, target

    def __len__(self):
        return len(self.data)

In [8]:
# file_path = './dataset/TrainingDataset/training_qa.txt'
# data = pd.read_csv(file_path, sep='|', names=['imageid', 'question', 'answer'])
# data.head()

In [9]:
# answer_map = {}
# ct = 0

# for i in data.answer.unique():
#     answer_map[i] = ct
#     ct+=1

In [10]:
# with open('./answer_map.pickle', 'wb') as handle:
#     pickle.dump(answer_map, handle, protocol=pickle.HIGHEST_PROTOCOL)
    

In [11]:
with open('./answer_map.pickle', 'rb') as handle:
    answer_map = pickle.load(handle)
    
len(answer_map)

333

In [27]:
TrainData = medical_dataset(config=cfg, answer_map=answer_map, image_path='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-TrainingSet/VQAnswering_2020_Train_images/', qa_file='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-TrainingSet/VQAnswering_2020_Train_QA_pairs.txt')
ValData = medical_dataset(config=cfg, answer_map=answer_map, image_path='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-ValidationSet/VQAnswering_2020_Val_images/', qa_file='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-ValidationSet/VQAnswering_2020_Val_QA_Pairs.txt', train=False)

In [28]:
%%time
a,b,c = ValData.__getitem__(0)

CPU times: user 4.01 ms, sys: 0 ns, total: 4.01 ms
Wall time: 3.06 ms


In [29]:
TrainDataLoader = DataLoader(TrainData, batch_size=16, shuffle=True, num_workers=4)  # num_workers=0 for windows OS
ValDataLoader = DataLoader(ValData, batch_size=128, shuffle=False, num_workers=4)  # num_workers=0 for windows OS

In [15]:
class mednet(nn.Module):
    def __init__(self, config, max_labels):
        super(mednet, self).__init__()
        self.vision = models.resnet18(pretrained=True)
        num_ftrs = self.vision.fc.in_features
        
        self.vision = nn.Sequential(*list(self.vision.children())[:-1])
        self.bert = AutoModel.from_pretrained("dmis-lab/biobert-v1.1")
        
        
        self.fc1 = nn.Linear(num_ftrs+768, 128)
        self.fc2 = nn.Linear(128, max_labels)
        

    def forward(self, visual=None, ids=None, mask=None):
        vision = self.vision(visual).view((ids.shape[0], -1))
        bert_out = self.bert(ids, mask)
        
        h = torch.cat((vision, bert_out.last_hidden_state[:,0]), dim=1)
        
        h = F.relu(self.fc1(h))
        h = self.fc2(h)
        return h


In [16]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self,
                 patience=7,
                 verbose=False,
                 delta=0,
                 path='./models/no_fusion/checkpoint.pt',
                 trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(
                f'EarlyStopping counter: {self.counter} out of {self.patience}'
            )
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(
                f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...'
            )
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [17]:
class Trainer:
    def __init__(self,
                 trainloader,
                 vallaoder,
                 model_ft,
                 writer=None,
                 testloader=None,
                 checkpoint_path=None,
                 patience=10,
                 feature_extract=True,
                 print_itr=50,
                 config=None):
        self.trainloader = trainloader
        self.valloader = vallaoder
        self.testloader = testloader
        
        self.config=config

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        print("==" * 10)
        print("Training will be done on ", self.device)
        print("==" * 10)

        self.model = model_ft        
        if torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            self.model = nn.DataParallel(self.model, device_ids=[i for i in range(torch.cuda.device_count())])
        
        self.model = self.model.to(self.device)
        
        
        # Observe that all parameters are being optimized
        self.optimizer = optim.RAdam(self.model.parameters(), lr=self.config['lr'])
        
        self.scheduler = get_linear_schedule_with_warmup(self.optimizer, 
                                            num_warmup_steps = len(self.trainloader)*self.config['warmup_steps'], # Default value in run_glue.py
                                            num_training_steps = len(self.trainloader)*self.config['epochs'])

        
        self.criterion = nn.CrossEntropyLoss()
        self.early_stopping = EarlyStopping(patience=patience, verbose=True)
        self.writer = writer
        self.print_itr = print_itr

    def train(self, ep):
        self.model.train()

        running_loss = 0.0

        for en, (visuals, question, target) in tqdm(enumerate(self.trainloader)):
            self.optimizer.zero_grad()
            
            visuals = visuals.to(self.device)
            y = target.squeeze().to(self.device)
            
            ids = question['ids'].to(self.device)
            mask = question['mask'].to(self.device)

            outputs = self.model(visuals, ids, mask)
            loss = self.criterion(outputs, y)
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()

            running_loss += loss.item()
            if self.writer:
                self.writer.add_scalar('Train Loss', running_loss, ep*len(self.trainloader) + en)
            running_loss = 0
            

    def validate(self, ep):
        self.model.eval()
        
        running_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for en, (visuals, question, target) in tqdm(enumerate(self.valloader)):
                visuals = visuals.to(self.device)
                y = target.squeeze().to(self.device)
                
                ids = question['ids'].to(self.device)
                mask = question['mask'].to(self.device)
                
                outputs = self.model(visuals, ids, mask)
                loss = self.criterion(outputs, y)

                y_pred_softmax = torch.log_softmax(outputs, dim = 1)
                _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)
                # self.tmp_sv_ = y_pred_tags
                
                correct += (y_pred_tags.detach().cpu().data.numpy() == y.detach().cpu().data.numpy()).sum()
                total += y_pred_tags.shape[0]
                
                # print statistics
                running_loss += loss.item()
        
        
        return running_loss / len(self.valloader), correct*100/total

    def perform_training(self, total_epoch):
        val_loss, acc = self.validate(0)

        print("[Initial Validation results] Loss: {} \t Acc: {}".format(
            val_loss, acc))

        for i in range(total_epoch):
            self.train(i + 1)
            val_loss, acc = self.validate(i + 1)
            print('[{}/{}] Loss: {} \t Acc: {}'.format(i+1, total_epoch, val_loss, acc))

            if self.writer:
                self.writer.add_scalar('Validation Loss', val_loss, (i + 1))
                self.writer.add_scalar('Validation Acc', acc, (i + 1))

            self.early_stopping(-acc, self.model)

            if self.early_stopping.early_stop:
                print("Early stopping")
                break

        print("=" * 20)
        print("Training finished !!")
        print("=" * 20)


In [18]:
model_ft = mednet(config=cfg, max_labels=len(answer_map))
writer = SummaryWriter('runs/no_fusion_run_3')
trainer = Trainer(TrainDataLoader, ValDataLoader, model_ft, writer=writer, config=cfg)

Training will be done on  cuda
Let's use 2 GPUs!


In [19]:
trainer.perform_training(cfg['epochs'])

4it [00:04,  1.05s/it]

[Initial Validation results] Loss: 5.825301289558411 	 Acc: 0.0



282it [00:52,  5.37it/s]
4it [00:01,  2.61it/s]


[1/250] Loss: 5.821021914482117 	 Acc: 0.2
Validation loss decreased (inf --> -0.200000).  Saving model ...


282it [00:52,  5.36it/s]
4it [00:01,  2.69it/s]


[2/250] Loss: 5.7818523645401 	 Acc: 0.6
Validation loss decreased (-0.200000 --> -0.600000).  Saving model ...


282it [00:52,  5.35it/s]
4it [00:01,  2.61it/s]


[3/250] Loss: 5.720300793647766 	 Acc: 0.8
Validation loss decreased (-0.600000 --> -0.800000).  Saving model ...


282it [00:52,  5.36it/s]
4it [00:01,  2.60it/s]


[4/250] Loss: 5.648903012275696 	 Acc: 3.4
Validation loss decreased (-0.800000 --> -3.400000).  Saving model ...


282it [00:52,  5.35it/s]
4it [00:01,  2.60it/s]


[5/250] Loss: 5.542846083641052 	 Acc: 3.8
Validation loss decreased (-3.400000 --> -3.800000).  Saving model ...


282it [00:52,  5.36it/s]
4it [00:01,  2.59it/s]


[6/250] Loss: 5.421558856964111 	 Acc: 5.6
Validation loss decreased (-3.800000 --> -5.600000).  Saving model ...


282it [00:52,  5.36it/s]
4it [00:01,  2.63it/s]


[7/250] Loss: 5.2917104959487915 	 Acc: 5.6
Validation loss decreased (-5.600000 --> -5.600000).  Saving model ...


282it [00:52,  5.35it/s]
4it [00:01,  2.63it/s]


[8/250] Loss: 5.177066087722778 	 Acc: 6.0
Validation loss decreased (-5.600000 --> -6.000000).  Saving model ...


282it [00:52,  5.35it/s]
4it [00:01,  2.66it/s]


[9/250] Loss: 5.072428822517395 	 Acc: 6.6
Validation loss decreased (-6.000000 --> -6.600000).  Saving model ...


282it [00:52,  5.35it/s]
4it [00:01,  2.63it/s]


[10/250] Loss: 4.976704478263855 	 Acc: 7.2
Validation loss decreased (-6.600000 --> -7.200000).  Saving model ...


282it [00:52,  5.35it/s]
4it [00:01,  2.68it/s]


[11/250] Loss: 4.883864402770996 	 Acc: 7.6
Validation loss decreased (-7.200000 --> -7.600000).  Saving model ...


282it [00:52,  5.36it/s]
4it [00:01,  2.67it/s]


[12/250] Loss: 4.822434306144714 	 Acc: 8.4
Validation loss decreased (-7.600000 --> -8.400000).  Saving model ...


282it [00:52,  5.36it/s]
4it [00:01,  2.57it/s]


[13/250] Loss: 4.750892519950867 	 Acc: 10.2
Validation loss decreased (-8.400000 --> -10.200000).  Saving model ...


282it [00:52,  5.35it/s]
4it [00:01,  2.61it/s]

[14/250] Loss: 4.647339105606079 	 Acc: 9.2
EarlyStopping counter: 1 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.58it/s]

[15/250] Loss: 4.587748289108276 	 Acc: 8.4
EarlyStopping counter: 2 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.58it/s]


[16/250] Loss: 4.499471426010132 	 Acc: 12.8
Validation loss decreased (-10.200000 --> -12.800000).  Saving model ...


282it [00:52,  5.35it/s]
4it [00:01,  2.66it/s]


[17/250] Loss: 4.435078263282776 	 Acc: 13.2
Validation loss decreased (-12.800000 --> -13.200000).  Saving model ...


282it [00:52,  5.36it/s]
4it [00:01,  2.67it/s]


[18/250] Loss: 4.368112087249756 	 Acc: 13.8
Validation loss decreased (-13.200000 --> -13.800000).  Saving model ...


282it [00:52,  5.34it/s]
4it [00:01,  2.57it/s]


[19/250] Loss: 4.35517144203186 	 Acc: 13.8
Validation loss decreased (-13.800000 --> -13.800000).  Saving model ...


282it [00:52,  5.35it/s]
4it [00:01,  2.65it/s]


[20/250] Loss: 4.327334761619568 	 Acc: 14.4
Validation loss decreased (-13.800000 --> -14.400000).  Saving model ...


282it [00:52,  5.35it/s]
4it [00:01,  2.66it/s]


[21/250] Loss: 4.2474119663238525 	 Acc: 14.4
Validation loss decreased (-14.400000 --> -14.400000).  Saving model ...


282it [00:52,  5.35it/s]
4it [00:01,  2.67it/s]


[22/250] Loss: 4.218833565711975 	 Acc: 14.8
Validation loss decreased (-14.400000 --> -14.800000).  Saving model ...


282it [00:52,  5.35it/s]
4it [00:01,  2.62it/s]


[23/250] Loss: 4.1437376737594604 	 Acc: 15.2
Validation loss decreased (-14.800000 --> -15.200000).  Saving model ...


282it [00:52,  5.34it/s]
4it [00:01,  2.62it/s]


[24/250] Loss: 4.070925951004028 	 Acc: 18.0
Validation loss decreased (-15.200000 --> -18.000000).  Saving model ...


282it [00:52,  5.35it/s]
4it [00:01,  2.60it/s]

[25/250] Loss: 4.05003559589386 	 Acc: 17.0
EarlyStopping counter: 1 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.59it/s]

[26/250] Loss: 3.9940956830978394 	 Acc: 17.0
EarlyStopping counter: 2 out of 10



282it [00:52,  5.35it/s]
4it [00:01,  2.68it/s]


[27/250] Loss: 3.9950817823410034 	 Acc: 18.4
Validation loss decreased (-18.000000 --> -18.400000).  Saving model ...


282it [00:52,  5.35it/s]
4it [00:01,  2.56it/s]


[28/250] Loss: 3.94563227891922 	 Acc: 20.2
Validation loss decreased (-18.400000 --> -20.200000).  Saving model ...


282it [00:52,  5.35it/s]
4it [00:01,  2.60it/s]


[29/250] Loss: 3.8327447175979614 	 Acc: 20.4
Validation loss decreased (-20.200000 --> -20.400000).  Saving model ...


282it [00:52,  5.35it/s]
4it [00:01,  2.67it/s]


[30/250] Loss: 3.8184351325035095 	 Acc: 21.0
Validation loss decreased (-20.400000 --> -21.000000).  Saving model ...


282it [00:52,  5.34it/s]
4it [00:01,  2.62it/s]


[31/250] Loss: 3.7767664194107056 	 Acc: 21.0
Validation loss decreased (-21.000000 --> -21.000000).  Saving model ...


282it [00:52,  5.34it/s]
4it [00:01,  2.63it/s]


[32/250] Loss: 3.7114609479904175 	 Acc: 21.8
Validation loss decreased (-21.000000 --> -21.800000).  Saving model ...


282it [00:52,  5.34it/s]
4it [00:01,  2.66it/s]


[33/250] Loss: 3.7511988282203674 	 Acc: 21.8
Validation loss decreased (-21.800000 --> -21.800000).  Saving model ...


282it [00:52,  5.34it/s]
4it [00:01,  2.60it/s]


[34/250] Loss: 3.643182933330536 	 Acc: 24.0
Validation loss decreased (-21.800000 --> -24.000000).  Saving model ...


282it [00:52,  5.34it/s]
4it [00:01,  2.66it/s]


[35/250] Loss: 3.6692989468574524 	 Acc: 24.4
Validation loss decreased (-24.000000 --> -24.400000).  Saving model ...


282it [00:52,  5.34it/s]
4it [00:01,  2.60it/s]

[36/250] Loss: 3.572270154953003 	 Acc: 22.0
EarlyStopping counter: 1 out of 10



282it [00:52,  5.35it/s]
4it [00:01,  2.64it/s]

[37/250] Loss: 3.572862207889557 	 Acc: 23.0
EarlyStopping counter: 2 out of 10



282it [00:52,  5.35it/s]
4it [00:01,  2.56it/s]

[38/250] Loss: 3.5794169306755066 	 Acc: 22.6
EarlyStopping counter: 3 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.61it/s]

[39/250] Loss: 3.5962395071983337 	 Acc: 21.2
EarlyStopping counter: 4 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.58it/s]


[40/250] Loss: 3.476062595844269 	 Acc: 24.4
Validation loss decreased (-24.400000 --> -24.400000).  Saving model ...


282it [00:52,  5.34it/s]
4it [00:01,  2.56it/s]


[41/250] Loss: 3.4114068746566772 	 Acc: 27.0
Validation loss decreased (-24.400000 --> -27.000000).  Saving model ...


282it [00:52,  5.35it/s]
4it [00:01,  2.59it/s]

[42/250] Loss: 3.452574372291565 	 Acc: 26.0
EarlyStopping counter: 1 out of 10



282it [00:52,  5.35it/s]
4it [00:01,  2.58it/s]

[43/250] Loss: 3.4338163137435913 	 Acc: 26.2
EarlyStopping counter: 2 out of 10



282it [00:52,  5.35it/s]
4it [00:01,  2.63it/s]

[44/250] Loss: 3.4227702021598816 	 Acc: 25.6
EarlyStopping counter: 3 out of 10



282it [00:52,  5.35it/s]
4it [00:01,  2.69it/s]


[45/250] Loss: 3.3769617080688477 	 Acc: 27.6
Validation loss decreased (-27.000000 --> -27.600000).  Saving model ...


282it [00:52,  5.34it/s]
4it [00:01,  2.58it/s]


[46/250] Loss: 3.3631256222724915 	 Acc: 29.2
Validation loss decreased (-27.600000 --> -29.200000).  Saving model ...


282it [00:52,  5.34it/s]
4it [00:01,  2.66it/s]

[47/250] Loss: 3.3749396800994873 	 Acc: 27.2
EarlyStopping counter: 1 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.58it/s]

[48/250] Loss: 3.374020040035248 	 Acc: 27.6
EarlyStopping counter: 2 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.67it/s]


[49/250] Loss: 3.35868376493454 	 Acc: 29.2
Validation loss decreased (-29.200000 --> -29.200000).  Saving model ...


282it [00:52,  5.35it/s]
4it [00:01,  2.59it/s]

[50/250] Loss: 3.3768895268440247 	 Acc: 26.2
EarlyStopping counter: 1 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.61it/s]

[51/250] Loss: 3.2904568314552307 	 Acc: 29.0
EarlyStopping counter: 2 out of 10



282it [00:52,  5.35it/s]
4it [00:01,  2.64it/s]

[52/250] Loss: 3.325326144695282 	 Acc: 28.4
EarlyStopping counter: 3 out of 10



282it [00:52,  5.35it/s]
4it [00:01,  2.57it/s]


[53/250] Loss: 3.2904970049858093 	 Acc: 29.2
Validation loss decreased (-29.200000 --> -29.200000).  Saving model ...


282it [00:52,  5.35it/s]
4it [00:01,  2.66it/s]


[54/250] Loss: 3.2374303340911865 	 Acc: 29.4
Validation loss decreased (-29.200000 --> -29.400000).  Saving model ...


282it [00:52,  5.33it/s]
4it [00:01,  2.66it/s]

[55/250] Loss: 3.2869619727134705 	 Acc: 28.2
EarlyStopping counter: 1 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.57it/s]


[56/250] Loss: 3.260919153690338 	 Acc: 30.4
Validation loss decreased (-29.400000 --> -30.400000).  Saving model ...


282it [00:52,  5.35it/s]
4it [00:01,  2.68it/s]


[57/250] Loss: 3.2286633253097534 	 Acc: 30.8
Validation loss decreased (-30.400000 --> -30.800000).  Saving model ...


282it [00:52,  5.34it/s]
4it [00:01,  2.68it/s]

[58/250] Loss: 3.229787588119507 	 Acc: 29.2
EarlyStopping counter: 1 out of 10



282it [00:52,  5.35it/s]
4it [00:01,  2.63it/s]


[59/250] Loss: 3.2167298197746277 	 Acc: 30.8
Validation loss decreased (-30.800000 --> -30.800000).  Saving model ...


282it [00:52,  5.34it/s]
4it [00:01,  2.65it/s]

[60/250] Loss: 3.246437966823578 	 Acc: 30.0
EarlyStopping counter: 1 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.65it/s]

[61/250] Loss: 3.2769806385040283 	 Acc: 30.6
EarlyStopping counter: 2 out of 10



282it [00:52,  5.33it/s]
4it [00:01,  2.68it/s]


[62/250] Loss: 3.1864458322525024 	 Acc: 31.8
Validation loss decreased (-30.800000 --> -31.800000).  Saving model ...


282it [00:52,  5.34it/s]
4it [00:01,  2.65it/s]

[63/250] Loss: 3.2317771315574646 	 Acc: 31.0
EarlyStopping counter: 1 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.62it/s]


[64/250] Loss: 3.2085373401641846 	 Acc: 32.4
Validation loss decreased (-31.800000 --> -32.400000).  Saving model ...


282it [00:52,  5.34it/s]
4it [00:01,  2.68it/s]

[65/250] Loss: 3.1708869338035583 	 Acc: 32.0
EarlyStopping counter: 1 out of 10



282it [00:52,  5.33it/s]
4it [00:01,  2.63it/s]

[66/250] Loss: 3.2512751817703247 	 Acc: 31.6
EarlyStopping counter: 2 out of 10



282it [00:52,  5.35it/s]
4it [00:01,  2.66it/s]

[67/250] Loss: 3.2462050318717957 	 Acc: 31.0
EarlyStopping counter: 3 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.56it/s]

[68/250] Loss: 3.2519484162330627 	 Acc: 31.8
EarlyStopping counter: 4 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.59it/s]


[69/250] Loss: 3.1621213555336 	 Acc: 32.8
Validation loss decreased (-32.400000 --> -32.800000).  Saving model ...


282it [00:52,  5.35it/s]
4it [00:01,  2.57it/s]


[70/250] Loss: 3.2184041142463684 	 Acc: 33.4
Validation loss decreased (-32.800000 --> -33.400000).  Saving model ...


282it [00:52,  5.34it/s]
4it [00:01,  2.63it/s]


[71/250] Loss: 3.1667208671569824 	 Acc: 35.4
Validation loss decreased (-33.400000 --> -35.400000).  Saving model ...


282it [00:52,  5.35it/s]
4it [00:01,  2.62it/s]

[72/250] Loss: 3.1725669503211975 	 Acc: 33.8
EarlyStopping counter: 1 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.69it/s]

[73/250] Loss: 3.1508122086524963 	 Acc: 30.4
EarlyStopping counter: 2 out of 10



282it [00:52,  5.35it/s]
4it [00:01,  2.62it/s]

[74/250] Loss: 3.1426669359207153 	 Acc: 32.4
EarlyStopping counter: 3 out of 10



282it [00:52,  5.35it/s]
4it [00:01,  2.66it/s]


[75/250] Loss: 3.154318332672119 	 Acc: 36.0
Validation loss decreased (-35.400000 --> -36.000000).  Saving model ...


282it [00:52,  5.34it/s]
4it [00:01,  2.56it/s]

[76/250] Loss: 3.1968222856521606 	 Acc: 34.6
EarlyStopping counter: 1 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.60it/s]

[77/250] Loss: 3.0797927379608154 	 Acc: 35.6
EarlyStopping counter: 2 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.63it/s]

[78/250] Loss: 3.156620144844055 	 Acc: 33.0
EarlyStopping counter: 3 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.60it/s]

[79/250] Loss: 3.165237307548523 	 Acc: 32.8
EarlyStopping counter: 4 out of 10



282it [00:52,  5.35it/s]
4it [00:01,  2.62it/s]

[80/250] Loss: 3.117622494697571 	 Acc: 35.0
EarlyStopping counter: 5 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.59it/s]


[81/250] Loss: 3.0812231302261353 	 Acc: 37.4
Validation loss decreased (-36.000000 --> -37.400000).  Saving model ...


282it [00:52,  5.34it/s]
4it [00:01,  2.60it/s]

[82/250] Loss: 3.229897975921631 	 Acc: 33.4
EarlyStopping counter: 1 out of 10



282it [00:52,  5.35it/s]
4it [00:01,  2.53it/s]

[83/250] Loss: 3.118796944618225 	 Acc: 36.0
EarlyStopping counter: 2 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.57it/s]

[84/250] Loss: 3.1215485334396362 	 Acc: 36.2
EarlyStopping counter: 3 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.55it/s]

[85/250] Loss: 3.1352834701538086 	 Acc: 35.2
EarlyStopping counter: 4 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.64it/s]

[86/250] Loss: 3.0586055517196655 	 Acc: 36.4
EarlyStopping counter: 5 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.66it/s]

[87/250] Loss: 3.175398349761963 	 Acc: 36.8
EarlyStopping counter: 6 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.65it/s]

[88/250] Loss: 3.1128631830215454 	 Acc: 35.6
EarlyStopping counter: 7 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.68it/s]


[89/250] Loss: 3.0295092463493347 	 Acc: 39.6
Validation loss decreased (-37.400000 --> -39.600000).  Saving model ...


282it [00:52,  5.34it/s]
4it [00:01,  2.69it/s]

[90/250] Loss: 3.1590051651000977 	 Acc: 34.8
EarlyStopping counter: 1 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.65it/s]

[91/250] Loss: 3.1574828028678894 	 Acc: 36.0
EarlyStopping counter: 2 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.67it/s]

[92/250] Loss: 3.1662628650665283 	 Acc: 37.2
EarlyStopping counter: 3 out of 10



282it [00:52,  5.35it/s]
4it [00:01,  2.64it/s]


[93/250] Loss: 3.059181332588196 	 Acc: 39.6
Validation loss decreased (-39.600000 --> -39.600000).  Saving model ...


282it [00:52,  5.34it/s]
4it [00:01,  2.59it/s]

[94/250] Loss: 3.0813928842544556 	 Acc: 38.4
EarlyStopping counter: 1 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.55it/s]

[95/250] Loss: 3.0703771710395813 	 Acc: 38.8
EarlyStopping counter: 2 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.62it/s]

[96/250] Loss: 3.1295506358146667 	 Acc: 34.6
EarlyStopping counter: 3 out of 10



282it [00:52,  5.34it/s]
4it [00:01,  2.63it/s]

[97/250] Loss: 3.244458794593811 	 Acc: 34.6
EarlyStopping counter: 4 out of 10



282it [00:52,  5.35it/s]
4it [00:01,  2.61it/s]

[98/250] Loss: 3.067845404148102 	 Acc: 37.6
EarlyStopping counter: 5 out of 10



282it [00:52,  5.35it/s]
4it [00:01,  2.57it/s]

[99/250] Loss: 3.076171100139618 	 Acc: 38.2
EarlyStopping counter: 6 out of 10



282it [00:52,  5.35it/s]
4it [00:01,  2.59it/s]

[100/250] Loss: 3.1869722604751587 	 Acc: 37.2
EarlyStopping counter: 7 out of 10



282it [00:52,  5.36it/s]
4it [00:01,  2.54it/s]

[101/250] Loss: 3.157418727874756 	 Acc: 35.6
EarlyStopping counter: 8 out of 10



282it [00:52,  5.35it/s]
4it [00:01,  2.61it/s]

[102/250] Loss: 3.263918876647949 	 Acc: 35.4
EarlyStopping counter: 9 out of 10



282it [00:52,  5.35it/s]
4it [00:01,  2.69it/s]

[103/250] Loss: 3.1934134364128113 	 Acc: 37.4
EarlyStopping counter: 10 out of 10
Early stopping
Training finished !!





In [20]:
class medical_dataset_test(Dataset):
    def __init__(self, config=None, answer_map=None, image_path='./dataset/Task1-VQA-2021-TestSet-w-GroundTruth/VQA-500-Images/', qa_file="./dataset/Task1-VQA-2021-TestSet-w-GroundTruth/Task1-VQA-2021-TestSet-Questions.txt", train=True):
            
        assert answer_map!=None
        assert config!=None
        
        self.image_path = image_path
        self.qa_file = qa_file
        self.config = config
        self.train = train
        
        self.answer_map = answer_map
        self.data = pd.read_csv(qa_file, sep='|', names=['imageid', 'question'])
           
        # print(Counter(self.data.answer.tolist()))
        
        self.transforms = transforms.Compose([
                                transforms.RandomResizedCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                            ])
        self.tokenizer = tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")

    def process_data(self, text, max_len):
        text = str(text)

        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=max_len,
            truncation=True
        )
        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]
        
        padding_length = max_len - len(ids)
        
        ids = ids + ([0] * padding_length)
        mask = mask + ([0] * padding_length)
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
        }
    
    
    def __getitem__(self, index):
        
        question = self.data.question[index]
        image_idx = self.data.imageid[index]
        
        ## add visual feature related steps
        visuals = Image.open(join(self.image_path, f'{image_idx}.jpg'))
        visuals = self.transforms(visuals)
        
        
        tmp = self.process_data(question, self.config['max_len'])
        question_tokens = {
            'ids': tmp['ids'],
            'mask': tmp['mask'],
        }
        
        return visuals, question_tokens, image_idx

    def __len__(self):
        return len(self.data)

In [21]:
TestData = medical_dataset_test(config=cfg, answer_map=answer_map)
TestDataLoader = DataLoader(TestData, batch_size=128, shuffle=False, num_workers=4)  # num_workers=0 for windows OS

In [22]:
inv_map = {v: k for k, v in answer_map.items()}

In [23]:
trainer.model.load_state_dict(torch.load('./models/no_fusion/checkpoint.pt'))

<All keys matched successfully>

In [24]:
trainer.model.eval()
        
imageids = []
answers = []

running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
    for en, (visuals, question, target) in tqdm(enumerate(TestDataLoader)):
        visuals = visuals.to(trainer.device)

        ids = question['ids'].to(trainer.device)
        mask = question['mask'].to(trainer.device)

        outputs = trainer.model(visuals, ids, mask)

        y_pred_softmax = torch.log_softmax(outputs, dim = 1)
        _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)
        # self.tmp_sv_ = y_pred_tags
        
        for i in range(visuals.shape[0]):
            imageids.append(target[i])
            answers.append(inv_map[int(y_pred_tags[i])])


4it [00:01,  2.39it/s]


In [25]:
answers[:5]

['femoral neck stress fractures',
 'pulmonary embolism',
 'cavernous hemangioma',
 'ovarian dermoid (cystic teratoma)',
 'traumatic transient lateral patellar dislocation.']

In [26]:
pd.DataFrame({'imageids':imageids, 'answers':answers}).to_csv('no_fusion.txt', sep='|', index=False, header=False)