In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
import numpy as np
import sys
from os import listdir, makedirs, getcwd, remove
from os.path import isfile, join, abspath, exists, isdir, expanduser
from scipy.io import loadmat
from PIL import Image
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms


import pickle
from tqdm import tqdm
import json
import random
from collections import Counter

from transformers import get_linear_schedule_with_warmup
from transformers import AutoTokenizer, AutoModel

In [4]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.autograd as autograd
from torch.utils.data import Dataset, DataLoader
import torch.utils.data as data
import torch.nn.functional as F
import torch.optim as optim
import itertools

import torchvision
from torchvision import transforms, datasets, models
from torch import Tensor

from torch.utils.tensorboard import SummaryWriter

In [5]:
import transformers
import tokenizers
from transformers import BertTokenizer, BertModel

In [6]:
cfg = {
    'max_len': 128,
    'lr': 2e-5,
    'warmup_steps': 5,
    'epochs': 250
}

In [7]:
class medical_dataset(Dataset):
    def __init__(self, config=None, answer_map=None, image_path='./dataset/TrainingDataset/images/', qa_file="./dataset/TrainingDataset/training_qa.txt", train=True):
            
        assert answer_map!=None
        assert config!=None
        
        self.image_path = image_path
        self.qa_file = qa_file
        self.config = config
        self.train = train
        
        self.answer_map = answer_map
        self.data = pd.read_csv(qa_file, sep='|', names=['imageid', 'question', 'answer'])
           
        # print(Counter(self.data.answer.tolist()))
        
        self.transforms = transforms.Compose([
                                transforms.RandomResizedCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                            ])
        self.tokenizer = tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")

    def process_data(self, text, max_len):
        text = str(text)

        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=max_len,
            truncation=True
        )
        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]
        
        padding_length = max_len - len(ids)
        
        ids = ids + ([0] * padding_length)
        mask = mask + ([0] * padding_length)
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
        }
    
    
    def __getitem__(self, index):
        
        question = self.data.question[index]
        answer = self.data.answer[index]
        image_idx = self.data.imageid[index]
        
        ## add visual feature related steps
        visuals = Image.open(join(self.image_path, f'{image_idx}.jpg'))
        visuals = self.transforms(visuals)
        
        target = torch.from_numpy(np.array([self.answer_map[answer]])).long()
        
        tmp = self.process_data(question, self.config['max_len'])
        question_tokens = {
            'ids': tmp['ids'],
            'mask': tmp['mask'],
        }
        
        return visuals, question_tokens, target

    def __len__(self):
        return len(self.data)

In [8]:
# file_path = './dataset/TrainingDataset/training_qa.txt'
# data = pd.read_csv(file_path, sep='|', names=['imageid', 'question', 'answer'])
# data.head()

In [9]:
# answer_map = {}
# ct = 0

# for i in data.answer.unique():
#     answer_map[i] = ct
#     ct+=1

In [10]:
# with open('./answer_map.pickle', 'wb') as handle:
#     pickle.dump(answer_map, handle, protocol=pickle.HIGHEST_PROTOCOL)
    

In [11]:
with open('./answer_map.pickle', 'rb') as handle:
    answer_map = pickle.load(handle)
    
len(answer_map)

333

In [12]:
TrainData = medical_dataset(config=cfg, answer_map=answer_map, image_path='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-TrainingSet/VQAnswering_2020_Train_images/', qa_file='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-TrainingSet/VQAnswering_2020_Train_QA_pairs.txt')
ValData = medical_dataset(config=cfg, answer_map=answer_map, image_path='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-ValidationSet/VQAnswering_2020_Val_images/', qa_file='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-ValidationSet/VQAnswering_2020_Val_QA_Pairs.txt', train=False)

In [13]:
%%time
a,b,c = ValData.__getitem__(0)

CPU times: user 33.4 ms, sys: 0 ns, total: 33.4 ms
Wall time: 25.9 ms


In [14]:
TrainDataLoader = DataLoader(TrainData, batch_size=16, shuffle=True, num_workers=4)  # num_workers=0 for windows OS
ValDataLoader = DataLoader(ValData, batch_size=128, shuffle=False, num_workers=4)  # num_workers=0 for windows OS

In [15]:
class mednet(nn.Module):
    def __init__(self, config, max_labels):
        super(mednet, self).__init__()
        self.vision = models.vgg16(pretrained=True)
        num_ftrs = self.vision.classifier[-1].in_features
        
        self.vision.classifier = nn.Sequential(*list(self.vision.classifier.children())[:-1])
        self.bert = AutoModel.from_pretrained("dmis-lab/biobert-v1.1")
        
        
        self.fc1 = nn.Linear(num_ftrs+768, 128)
        self.fc2 = nn.Linear(128, max_labels)
        

    def forward(self, visual=None, ids=None, mask=None):
        vision = self.vision(visual).view((ids.shape[0], -1))
        bert_out = self.bert(ids, mask)
        
        h = torch.cat((vision, bert_out.last_hidden_state[:,0]), dim=1)
        
        h = F.relu(self.fc1(h))
        h = self.fc2(h)
        return h


In [16]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self,
                 patience=7,
                 verbose=False,
                 delta=0,
                 path='./models/vgg16_no_fusion/checkpoint.pt',
                 trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(
                f'EarlyStopping counter: {self.counter} out of {self.patience}'
            )
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(
                f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...'
            )
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [17]:
class Trainer:
    def __init__(self,
                 trainloader,
                 vallaoder,
                 model_ft,
                 writer=None,
                 testloader=None,
                 checkpoint_path=None,
                 patience=10,
                 feature_extract=True,
                 print_itr=50,
                 config=None):
        self.trainloader = trainloader
        self.valloader = vallaoder
        self.testloader = testloader
        
        self.config=config

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        print("==" * 10)
        print("Training will be done on ", self.device)
        print("==" * 10)

        self.model = model_ft        
        if torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            self.model = nn.DataParallel(self.model, device_ids=[i for i in range(torch.cuda.device_count())])
        
        self.model = self.model.to(self.device)
        
        
        # Observe that all parameters are being optimized
        self.optimizer = optim.RAdam(self.model.parameters(), lr=self.config['lr'])
        
        self.scheduler = get_linear_schedule_with_warmup(self.optimizer, 
                                            num_warmup_steps = len(self.trainloader)*self.config['warmup_steps'], # Default value in run_glue.py
                                            num_training_steps = len(self.trainloader)*self.config['epochs'])

        
        self.criterion = nn.CrossEntropyLoss()
        self.early_stopping = EarlyStopping(patience=patience, verbose=True)
        self.writer = writer
        self.print_itr = print_itr

    def train(self, ep):
        self.model.train()

        running_loss = 0.0

        for en, (visuals, question, target) in tqdm(enumerate(self.trainloader)):
            self.optimizer.zero_grad()
            
            visuals = visuals.to(self.device)
            y = target.squeeze().to(self.device)
            
            ids = question['ids'].to(self.device)
            mask = question['mask'].to(self.device)

            outputs = self.model(visuals, ids, mask)
            loss = self.criterion(outputs, y)
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()

            running_loss += loss.item()
            if self.writer:
                self.writer.add_scalar('Train Loss', running_loss, ep*len(self.trainloader) + en)
            running_loss = 0
            

    def validate(self, ep):
        self.model.eval()
        
        running_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for en, (visuals, question, target) in tqdm(enumerate(self.valloader)):
                visuals = visuals.to(self.device)
                y = target.squeeze().to(self.device)
                
                ids = question['ids'].to(self.device)
                mask = question['mask'].to(self.device)
                
                outputs = self.model(visuals, ids, mask)
                loss = self.criterion(outputs, y)

                y_pred_softmax = torch.log_softmax(outputs, dim = 1)
                _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)
                # self.tmp_sv_ = y_pred_tags
                
                correct += (y_pred_tags.detach().cpu().data.numpy() == y.detach().cpu().data.numpy()).sum()
                total += y_pred_tags.shape[0]
                
                # print statistics
                running_loss += loss.item()
        
        
        return running_loss / len(self.valloader), correct*100/total

    def perform_training(self, total_epoch):
        val_loss, acc = self.validate(0)

        print("[Initial Validation results] Loss: {} \t Acc: {}".format(
            val_loss, acc))

        for i in range(total_epoch):
            self.train(i + 1)
            val_loss, acc = self.validate(i + 1)
            print('[{}/{}] Loss: {} \t Acc: {}'.format(i+1, total_epoch, val_loss, acc))

            if self.writer:
                self.writer.add_scalar('Validation Loss', val_loss, (i + 1))
                self.writer.add_scalar('Validation Acc', acc, (i + 1))

            self.early_stopping(-acc, self.model)

            if self.early_stopping.early_stop:
                print("Early stopping")
                break

        print("=" * 20)
        print("Training finished !!")
        print("=" * 20)


In [18]:
model_ft = mednet(config=cfg, max_labels=len(answer_map))
writer = SummaryWriter('runs/vgg16_no_fusion_run_2')
trainer = Trainer(TrainDataLoader, ValDataLoader, model_ft, writer=writer, config=cfg)

Training will be done on  cuda
Let's use 2 GPUs!


In [19]:
trainer.perform_training(cfg['epochs'])

4it [00:04,  1.17s/it]

[Initial Validation results] Loss: 5.821583151817322 	 Acc: 0.4



250it [01:26,  2.91it/s]
4it [00:02,  1.93it/s]


[1/250] Loss: 5.8128708600997925 	 Acc: 0.2
Validation loss decreased (inf --> -0.200000).  Saving model ...


250it [01:26,  2.90it/s]
4it [00:02,  1.95it/s]

[2/250] Loss: 5.794021725654602 	 Acc: 0.0
EarlyStopping counter: 1 out of 10



250it [01:26,  2.90it/s]
4it [00:02,  1.98it/s]


[3/250] Loss: 5.756929397583008 	 Acc: 0.4
Validation loss decreased (-0.200000 --> -0.400000).  Saving model ...


250it [01:26,  2.90it/s]
4it [00:02,  1.98it/s]


[4/250] Loss: 5.643176436424255 	 Acc: 4.4
Validation loss decreased (-0.400000 --> -4.400000).  Saving model ...


250it [01:26,  2.90it/s]
4it [00:02,  1.96it/s]


[5/250] Loss: 5.428176522254944 	 Acc: 8.2
Validation loss decreased (-4.400000 --> -8.200000).  Saving model ...


250it [01:26,  2.90it/s]
4it [00:02,  2.00it/s]


[6/250] Loss: 5.172154068946838 	 Acc: 10.4
Validation loss decreased (-8.200000 --> -10.400000).  Saving model ...


250it [01:26,  2.90it/s]
4it [00:02,  1.99it/s]


[7/250] Loss: 4.991268873214722 	 Acc: 11.2
Validation loss decreased (-10.400000 --> -11.200000).  Saving model ...


250it [01:26,  2.90it/s]
4it [00:01,  2.01it/s]


[8/250] Loss: 4.78477156162262 	 Acc: 12.8
Validation loss decreased (-11.200000 --> -12.800000).  Saving model ...


250it [01:26,  2.90it/s]
4it [00:02,  1.96it/s]


[9/250] Loss: 4.552987217903137 	 Acc: 14.8
Validation loss decreased (-12.800000 --> -14.800000).  Saving model ...


250it [01:26,  2.90it/s]
4it [00:01,  2.00it/s]


[10/250] Loss: 4.433855891227722 	 Acc: 17.0
Validation loss decreased (-14.800000 --> -17.000000).  Saving model ...


250it [01:26,  2.90it/s]
4it [00:02,  1.99it/s]


[11/250] Loss: 4.328107595443726 	 Acc: 19.6
Validation loss decreased (-17.000000 --> -19.600000).  Saving model ...


250it [01:26,  2.90it/s]
4it [00:01,  2.00it/s]


[12/250] Loss: 4.1230034828186035 	 Acc: 20.0
Validation loss decreased (-19.600000 --> -20.000000).  Saving model ...


250it [01:26,  2.90it/s]
4it [00:02,  1.94it/s]


[13/250] Loss: 4.033986508846283 	 Acc: 21.0
Validation loss decreased (-20.000000 --> -21.000000).  Saving model ...


250it [01:26,  2.90it/s]
4it [00:02,  1.99it/s]


[14/250] Loss: 3.919830322265625 	 Acc: 24.2
Validation loss decreased (-21.000000 --> -24.200000).  Saving model ...


250it [01:26,  2.90it/s]
4it [00:02,  1.94it/s]


[15/250] Loss: 3.914784848690033 	 Acc: 24.4
Validation loss decreased (-24.200000 --> -24.400000).  Saving model ...


250it [01:26,  2.90it/s]
4it [00:01,  2.02it/s]

[16/250] Loss: 3.843086361885071 	 Acc: 23.4
EarlyStopping counter: 1 out of 10



250it [01:26,  2.90it/s]
4it [00:01,  2.01it/s]


[17/250] Loss: 3.7368473410606384 	 Acc: 26.0
Validation loss decreased (-24.400000 --> -26.000000).  Saving model ...


250it [01:26,  2.90it/s]
4it [00:02,  1.95it/s]


[18/250] Loss: 3.765804409980774 	 Acc: 26.8
Validation loss decreased (-26.000000 --> -26.800000).  Saving model ...


250it [01:26,  2.90it/s]
4it [00:01,  2.01it/s]


[19/250] Loss: 3.5827478170394897 	 Acc: 28.8
Validation loss decreased (-26.800000 --> -28.800000).  Saving model ...


250it [01:26,  2.90it/s]
4it [00:02,  1.98it/s]

[20/250] Loss: 3.5478389263153076 	 Acc: 28.0
EarlyStopping counter: 1 out of 10



250it [01:26,  2.90it/s]
4it [00:02,  1.99it/s]


[21/250] Loss: 3.537476062774658 	 Acc: 29.4
Validation loss decreased (-28.800000 --> -29.400000).  Saving model ...


250it [01:26,  2.90it/s]
4it [00:02,  1.98it/s]

[22/250] Loss: 3.5179251432418823 	 Acc: 28.6
EarlyStopping counter: 1 out of 10



250it [01:26,  2.90it/s]
4it [00:02,  1.95it/s]

[23/250] Loss: 3.5731672048568726 	 Acc: 28.8
EarlyStopping counter: 2 out of 10



250it [01:26,  2.90it/s]
4it [00:02,  1.95it/s]


[24/250] Loss: 3.633940041065216 	 Acc: 30.0
Validation loss decreased (-29.400000 --> -30.000000).  Saving model ...


250it [01:26,  2.91it/s]
4it [00:02,  1.95it/s]


[25/250] Loss: 3.591951608657837 	 Acc: 31.0
Validation loss decreased (-30.000000 --> -31.000000).  Saving model ...


250it [01:26,  2.91it/s]
4it [00:02,  1.95it/s]


[26/250] Loss: 3.5046640038490295 	 Acc: 32.4
Validation loss decreased (-31.000000 --> -32.400000).  Saving model ...


250it [01:26,  2.90it/s]
4it [00:01,  2.01it/s]

[27/250] Loss: 3.4443984031677246 	 Acc: 31.6
EarlyStopping counter: 1 out of 10



250it [01:26,  2.90it/s]
4it [00:02,  1.98it/s]

[28/250] Loss: 3.4721547961235046 	 Acc: 31.2
EarlyStopping counter: 2 out of 10



250it [01:26,  2.91it/s]
4it [00:02,  1.97it/s]


[29/250] Loss: 3.4993982911109924 	 Acc: 34.2
Validation loss decreased (-32.400000 --> -34.200000).  Saving model ...


250it [01:26,  2.91it/s]
4it [00:02,  1.99it/s]

[30/250] Loss: 3.5990793108940125 	 Acc: 31.6
EarlyStopping counter: 1 out of 10



250it [01:26,  2.91it/s]
4it [00:02,  1.96it/s]

[31/250] Loss: 3.6945667266845703 	 Acc: 34.0
EarlyStopping counter: 2 out of 10



250it [01:26,  2.91it/s]
4it [00:02,  1.99it/s]


[32/250] Loss: 3.4540164470672607 	 Acc: 34.8
Validation loss decreased (-34.200000 --> -34.800000).  Saving model ...


250it [01:25,  2.91it/s]
4it [00:02,  1.99it/s]

[33/250] Loss: 3.463267147541046 	 Acc: 34.6
EarlyStopping counter: 1 out of 10



250it [01:26,  2.91it/s]
4it [00:02,  1.99it/s]


[34/250] Loss: 3.5164600610733032 	 Acc: 35.0
Validation loss decreased (-34.800000 --> -35.000000).  Saving model ...


250it [01:25,  2.91it/s]
4it [00:02,  1.96it/s]


[35/250] Loss: 3.6119596362113953 	 Acc: 36.8
Validation loss decreased (-35.000000 --> -36.800000).  Saving model ...


250it [01:25,  2.91it/s]
4it [00:02,  1.97it/s]

[36/250] Loss: 3.5207518935203552 	 Acc: 35.6
EarlyStopping counter: 1 out of 10



250it [01:26,  2.91it/s]
4it [00:02,  1.95it/s]


[37/250] Loss: 3.525190591812134 	 Acc: 38.0
Validation loss decreased (-36.800000 --> -38.000000).  Saving model ...


250it [01:25,  2.91it/s]
4it [00:02,  1.96it/s]

[38/250] Loss: 3.437589645385742 	 Acc: 35.2
EarlyStopping counter: 1 out of 10



250it [01:26,  2.91it/s]
4it [00:02,  2.00it/s]

[39/250] Loss: 3.6659525632858276 	 Acc: 34.2
EarlyStopping counter: 2 out of 10



250it [01:26,  2.91it/s]
4it [00:02,  1.98it/s]

[40/250] Loss: 3.5529873967170715 	 Acc: 37.6
EarlyStopping counter: 3 out of 10



250it [01:26,  2.90it/s]
4it [00:02,  1.96it/s]

[41/250] Loss: 3.3842508792877197 	 Acc: 37.6
EarlyStopping counter: 4 out of 10



250it [01:26,  2.91it/s]
4it [00:02,  1.95it/s]

[42/250] Loss: 3.3607234358787537 	 Acc: 37.4
EarlyStopping counter: 5 out of 10



250it [01:26,  2.91it/s]
4it [00:02,  1.98it/s]


[43/250] Loss: 3.495938241481781 	 Acc: 38.0
Validation loss decreased (-38.000000 --> -38.000000).  Saving model ...


250it [01:26,  2.91it/s]
4it [00:01,  2.02it/s]

[44/250] Loss: 3.4855316281318665 	 Acc: 36.6
EarlyStopping counter: 1 out of 10



250it [01:25,  2.91it/s]
4it [00:02,  1.97it/s]

[45/250] Loss: 3.3168144822120667 	 Acc: 36.8
EarlyStopping counter: 2 out of 10



250it [01:25,  2.91it/s]
4it [00:01,  2.00it/s]

[46/250] Loss: 3.6242886781692505 	 Acc: 35.6
EarlyStopping counter: 3 out of 10



250it [01:26,  2.91it/s]
4it [00:02,  1.96it/s]

[47/250] Loss: 3.701004445552826 	 Acc: 36.4
EarlyStopping counter: 4 out of 10



250it [01:26,  2.91it/s]
4it [00:02,  1.99it/s]

[48/250] Loss: 3.5695695281028748 	 Acc: 37.0
EarlyStopping counter: 5 out of 10



250it [01:26,  2.91it/s]
4it [00:02,  1.99it/s]

[49/250] Loss: 3.5949472188949585 	 Acc: 35.8
EarlyStopping counter: 6 out of 10



250it [01:26,  2.91it/s]
4it [00:02,  2.00it/s]

[50/250] Loss: 3.6537726521492004 	 Acc: 35.8
EarlyStopping counter: 7 out of 10



250it [01:26,  2.91it/s]
4it [00:01,  2.01it/s]


[51/250] Loss: 3.5838794112205505 	 Acc: 39.0
Validation loss decreased (-38.000000 --> -39.000000).  Saving model ...


250it [01:25,  2.91it/s]
4it [00:02,  2.00it/s]

[52/250] Loss: 3.5274888277053833 	 Acc: 38.8
EarlyStopping counter: 1 out of 10



250it [01:25,  2.91it/s]
4it [00:01,  2.02it/s]


[53/250] Loss: 3.5711973309516907 	 Acc: 39.4
Validation loss decreased (-39.000000 --> -39.400000).  Saving model ...


250it [01:25,  2.91it/s]
4it [00:02,  1.99it/s]

[54/250] Loss: 3.7550923824310303 	 Acc: 36.2
EarlyStopping counter: 1 out of 10



250it [01:25,  2.91it/s]
4it [00:02,  2.00it/s]

[55/250] Loss: 3.58802992105484 	 Acc: 38.8
EarlyStopping counter: 2 out of 10



250it [01:25,  2.91it/s]
4it [00:02,  1.97it/s]

[56/250] Loss: 3.4929869174957275 	 Acc: 38.6
EarlyStopping counter: 3 out of 10



250it [01:25,  2.91it/s]
4it [00:01,  2.00it/s]


[57/250] Loss: 3.698054254055023 	 Acc: 39.6
Validation loss decreased (-39.400000 --> -39.600000).  Saving model ...


250it [01:26,  2.90it/s]
4it [00:01,  2.02it/s]


[58/250] Loss: 3.5017515420913696 	 Acc: 41.2
Validation loss decreased (-39.600000 --> -41.200000).  Saving model ...


250it [01:26,  2.91it/s]
4it [00:02,  1.99it/s]

[59/250] Loss: 3.5573434233665466 	 Acc: 38.6
EarlyStopping counter: 1 out of 10



250it [01:26,  2.91it/s]
4it [00:02,  1.94it/s]

[60/250] Loss: 3.7585140466690063 	 Acc: 38.0
EarlyStopping counter: 2 out of 10



250it [01:26,  2.90it/s]
4it [00:01,  2.00it/s]

[61/250] Loss: 3.775320529937744 	 Acc: 37.8
EarlyStopping counter: 3 out of 10



250it [01:25,  2.91it/s]
4it [00:01,  2.01it/s]


[62/250] Loss: 3.767213761806488 	 Acc: 42.6
Validation loss decreased (-41.200000 --> -42.600000).  Saving model ...


250it [01:26,  2.90it/s]
4it [00:02,  1.95it/s]

[63/250] Loss: 3.5945538878440857 	 Acc: 39.8
EarlyStopping counter: 1 out of 10



250it [01:26,  2.90it/s]
4it [00:01,  2.01it/s]

[64/250] Loss: 3.6432072520256042 	 Acc: 39.0
EarlyStopping counter: 2 out of 10



250it [01:25,  2.91it/s]
4it [00:02,  1.98it/s]

[65/250] Loss: 3.5696439146995544 	 Acc: 38.8
EarlyStopping counter: 3 out of 10



250it [01:26,  2.91it/s]
4it [00:02,  1.97it/s]

[66/250] Loss: 3.716792583465576 	 Acc: 39.6
EarlyStopping counter: 4 out of 10



250it [01:25,  2.91it/s]
4it [00:01,  2.01it/s]

[67/250] Loss: 3.3661288619041443 	 Acc: 41.2
EarlyStopping counter: 5 out of 10



250it [01:26,  2.91it/s]
4it [00:02,  1.97it/s]

[68/250] Loss: 3.6482794880867004 	 Acc: 39.2
EarlyStopping counter: 6 out of 10



250it [01:25,  2.91it/s]
4it [00:02,  1.94it/s]

[69/250] Loss: 3.77021723985672 	 Acc: 40.4
EarlyStopping counter: 7 out of 10



250it [01:26,  2.90it/s]
4it [00:02,  1.97it/s]


[70/250] Loss: 3.5657076835632324 	 Acc: 42.6
Validation loss decreased (-42.600000 --> -42.600000).  Saving model ...


250it [01:26,  2.91it/s]
4it [00:02,  1.97it/s]

[71/250] Loss: 3.7159006595611572 	 Acc: 37.4
EarlyStopping counter: 1 out of 10



250it [01:26,  2.90it/s]
4it [00:01,  2.01it/s]

[72/250] Loss: 3.779899001121521 	 Acc: 38.8
EarlyStopping counter: 2 out of 10



250it [01:26,  2.91it/s]
4it [00:02,  1.99it/s]

[73/250] Loss: 4.025850296020508 	 Acc: 39.0
EarlyStopping counter: 3 out of 10



250it [01:26,  2.90it/s]
4it [00:01,  2.01it/s]

[74/250] Loss: 3.742052972316742 	 Acc: 40.4
EarlyStopping counter: 4 out of 10



250it [01:26,  2.91it/s]
4it [00:02,  1.95it/s]

[75/250] Loss: 3.718210995197296 	 Acc: 40.2
EarlyStopping counter: 5 out of 10



250it [01:26,  2.91it/s]
4it [00:01,  2.00it/s]

[76/250] Loss: 3.6407808661460876 	 Acc: 41.2
EarlyStopping counter: 6 out of 10



250it [01:26,  2.91it/s]
4it [00:02,  1.99it/s]

[77/250] Loss: 3.603325128555298 	 Acc: 40.6
EarlyStopping counter: 7 out of 10



250it [01:26,  2.91it/s]
4it [00:02,  1.93it/s]


[78/250] Loss: 3.8527262806892395 	 Acc: 42.8
Validation loss decreased (-42.600000 --> -42.800000).  Saving model ...


250it [01:26,  2.91it/s]
4it [00:02,  1.97it/s]

[79/250] Loss: 3.7266843914985657 	 Acc: 40.6
EarlyStopping counter: 1 out of 10



250it [01:26,  2.91it/s]
4it [00:02,  1.99it/s]

[80/250] Loss: 3.543361246585846 	 Acc: 42.2
EarlyStopping counter: 2 out of 10



250it [01:26,  2.91it/s]
4it [00:02,  1.99it/s]

[81/250] Loss: 3.8218082785606384 	 Acc: 40.4
EarlyStopping counter: 3 out of 10



250it [01:26,  2.90it/s]
4it [00:02,  2.00it/s]

[82/250] Loss: 3.6882490515708923 	 Acc: 39.4
EarlyStopping counter: 4 out of 10



250it [01:26,  2.90it/s]
4it [00:02,  1.98it/s]

[83/250] Loss: 3.966365337371826 	 Acc: 39.0
EarlyStopping counter: 5 out of 10



250it [01:26,  2.91it/s]
4it [00:02,  1.95it/s]

[84/250] Loss: 3.8190112709999084 	 Acc: 39.8
EarlyStopping counter: 6 out of 10



250it [01:26,  2.91it/s]
4it [00:02,  1.94it/s]

[85/250] Loss: 3.725303292274475 	 Acc: 38.4
EarlyStopping counter: 7 out of 10



250it [01:26,  2.91it/s]
4it [00:01,  2.00it/s]

[86/250] Loss: 3.888101577758789 	 Acc: 38.2
EarlyStopping counter: 8 out of 10



250it [01:26,  2.91it/s]
4it [00:02,  2.00it/s]

[87/250] Loss: 3.968602478504181 	 Acc: 40.4
EarlyStopping counter: 9 out of 10



250it [01:26,  2.90it/s]
4it [00:02,  2.00it/s]

[88/250] Loss: 3.9875511527061462 	 Acc: 39.0
EarlyStopping counter: 10 out of 10
Early stopping
Training finished !!





In [20]:
class medical_dataset_test(Dataset):
    def __init__(self, config=None, answer_map=None, image_path='./dataset/Task1-VQA-2021-TestSet-w-GroundTruth/VQA-500-Images/', qa_file="./dataset/Task1-VQA-2021-TestSet-w-GroundTruth/Task1-VQA-2021-TestSet-Questions.txt", train=True):
            
        assert answer_map!=None
        assert config!=None
        
        self.image_path = image_path
        self.qa_file = qa_file
        self.config = config
        self.train = train
        
        self.answer_map = answer_map
        self.data = pd.read_csv(qa_file, sep='|', names=['imageid', 'question'])
           
        # print(Counter(self.data.answer.tolist()))
        
        self.transforms = transforms.Compose([
                                transforms.RandomResizedCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                            ])
        self.tokenizer = tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")

    def process_data(self, text, max_len):
        text = str(text)

        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=max_len,
            truncation=True
        )
        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]
        
        padding_length = max_len - len(ids)
        
        ids = ids + ([0] * padding_length)
        mask = mask + ([0] * padding_length)
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
        }
    
    
    def __getitem__(self, index):
        
        question = self.data.question[index]
        image_idx = self.data.imageid[index]
        
        ## add visual feature related steps
        visuals = Image.open(join(self.image_path, f'{image_idx}.jpg'))
        visuals = self.transforms(visuals)
        
        
        tmp = self.process_data(question, self.config['max_len'])
        question_tokens = {
            'ids': tmp['ids'],
            'mask': tmp['mask'],
        }
        
        return visuals, question_tokens, image_idx

    def __len__(self):
        return len(self.data)

In [21]:
TestData = medical_dataset_test(config=cfg, answer_map=answer_map)
TestDataLoader = DataLoader(TestData, batch_size=128, shuffle=False, num_workers=4)  # num_workers=0 for windows OS

In [22]:
inv_map = {v: k for k, v in answer_map.items()}

In [23]:
trainer.model.load_state_dict(torch.load('./models/vgg16_no_fusion/checkpoint.pt'))

<All keys matched successfully>

In [24]:
trainer.model.eval()
        
imageids = []
answers = []

running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
    for en, (visuals, question, target) in tqdm(enumerate(TestDataLoader)):
        visuals = visuals.to(trainer.device)

        ids = question['ids'].to(trainer.device)
        mask = question['mask'].to(trainer.device)

        outputs = trainer.model(visuals, ids, mask)

        y_pred_softmax = torch.log_softmax(outputs, dim = 1)
        _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)
        # self.tmp_sv_ = y_pred_tags
        
        for i in range(visuals.shape[0]):
            imageids.append(target[i])
            answers.append(inv_map[int(y_pred_tags[i])])


4it [00:02,  1.92it/s]


In [25]:
answers[:5]

['avascular necrosis of the femoral head',
 'pulmonary embolism',
 'torus fracture/buckle fracture of the distal radius',
 'common carotid occlusion, collateral reconstitution of ipsilateral ica from vertebral artery muscular branches',
 'traumatic transient lateral patellar dislocation.']

In [26]:
pd.DataFrame({'imageids':imageids, 'answers':answers}).to_csv('vgg16_no_fusion.txt', sep='|', index=False, header=False)