In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
import numpy as np
import sys
from os import listdir, makedirs, getcwd, remove
from os.path import isfile, join, abspath, exists, isdir, expanduser
from scipy.io import loadmat
from PIL import Image
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms


import pickle
from tqdm import tqdm
import json
import random
from collections import Counter

from transformers import get_linear_schedule_with_warmup
from transformers import AutoTokenizer, AutoModel

In [4]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.autograd as autograd
from torch.utils.data import Dataset, DataLoader
import torch.utils.data as data
import torch.nn.functional as F
import torch.optim as optim
import itertools

import torchvision
from torchvision import transforms, datasets, models
from torch import Tensor

from torch.utils.tensorboard import SummaryWriter

In [5]:
import transformers
import tokenizers
from transformers import BertTokenizer, BertModel

In [6]:
cfg = {
    'max_len': 128,
    'lr': 2e-5,
    'warmup_steps': 5,
    'epochs': 250
}

In [7]:
class medical_dataset(Dataset):
    def __init__(self, config=None, answer_map=None, image_path='./dataset/TrainingDataset/images/', qa_file="./dataset/TrainingDataset/training_qa.txt", train=True):
            
        assert answer_map!=None
        assert config!=None
        
        self.image_path = image_path
        self.qa_file = qa_file
        self.config = config
        self.train = train
        
        self.answer_map = answer_map
        self.data = pd.read_csv(qa_file, sep='|', names=['imageid', 'question', 'answer'])
           
        # print(Counter(self.data.answer.tolist()))
        
        self.transforms = transforms.Compose([
                                transforms.RandomResizedCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                            ])
        self.tokenizer = tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")

    def process_data(self, text, max_len):
        text = str(text)

        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=max_len,
            truncation=True
        )
        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]
        
        padding_length = max_len - len(ids)
        
        ids = ids + ([0] * padding_length)
        mask = mask + ([0] * padding_length)
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
        }
    
    
    def __getitem__(self, index):
        
        question = self.data.question[index]
        answer = self.data.answer[index]
        image_idx = self.data.imageid[index]
        
        ## add visual feature related steps
        visuals = Image.open(join(self.image_path, f'{image_idx}.jpg'))
        visuals = self.transforms(visuals)
        
        target = torch.from_numpy(np.array([self.answer_map[answer]])).long()
        
        tmp = self.process_data(question, self.config['max_len'])
        question_tokens = {
            'ids': tmp['ids'],
            'mask': tmp['mask'],
        }
        
        return visuals, question_tokens, target

    def __len__(self):
        return len(self.data)

In [8]:
# file_path = './dataset/TrainingDataset/training_qa.txt'
# data = pd.read_csv(file_path, sep='|', names=['imageid', 'question', 'answer'])
# data.head()

In [9]:
# answer_map = {}
# ct = 0

# for i in data.answer.unique():
#     answer_map[i] = ct
#     ct+=1

In [10]:
# with open('./answer_map.pickle', 'wb') as handle:
#     pickle.dump(answer_map, handle, protocol=pickle.HIGHEST_PROTOCOL)
    

In [11]:
with open('./answer_map.pickle', 'rb') as handle:
    answer_map = pickle.load(handle)
    
len(answer_map)

333

In [12]:
TrainData = medical_dataset(config=cfg, answer_map=answer_map, image_path='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-TrainingSet/VQAnswering_2020_Train_images/', qa_file='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-TrainingSet/VQAnswering_2020_Train_QA_pairs.txt')
ValData = medical_dataset(config=cfg, answer_map=answer_map, image_path='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-ValidationSet/VQAnswering_2020_Val_images/', qa_file='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-ValidationSet/VQAnswering_2020_Val_QA_Pairs.txt', train=False)

In [13]:
%%time
a,b,c = ValData.__getitem__(0)

CPU times: user 29.8 ms, sys: 0 ns, total: 29.8 ms
Wall time: 27.6 ms


In [14]:
TrainDataLoader = DataLoader(TrainData, batch_size=16, shuffle=True, num_workers=4)  # num_workers=0 for windows OS
ValDataLoader = DataLoader(ValData, batch_size=128, shuffle=False, num_workers=4)  # num_workers=0 for windows OS

In [15]:
class mednet(nn.Module):
    def __init__(self, config, max_labels):
        super(mednet, self).__init__()
        self.vision = models.resnet18(pretrained=True)
        num_ftrs = self.vision.fc.in_features
        
        self.vision = nn.Sequential(*list(self.vision.children())[:-1])        
        
        self.fc1 = nn.Linear(num_ftrs, 128)
        self.fc2 = nn.Linear(128, max_labels)
        

    def forward(self, visual=None, ids=None, mask=None):
        h = self.vision(visual).view((visual.shape[0], -1))        
        h = F.relu(self.fc1(h))
        h = self.fc2(h)
        return h


In [16]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self,
                 patience=7,
                 verbose=False,
                 delta=0,
                 path='./models/checkpoint.pt',
                 trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(
                f'EarlyStopping counter: {self.counter} out of {self.patience}'
            )
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(
                f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...'
            )
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [17]:
class Trainer:
    def __init__(self,
                 trainloader,
                 vallaoder,
                 model_ft,
                 writer=None,
                 testloader=None,
                 checkpoint_path=None,
                 patience=10,
                 feature_extract=True,
                 print_itr=50,
                 config=None):
        self.trainloader = trainloader
        self.valloader = vallaoder
        self.testloader = testloader
        
        self.config=config

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        print("==" * 10)
        print("Training will be done on ", self.device)
        print("==" * 10)

        self.model = model_ft        
        if torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            self.model = nn.DataParallel(self.model, device_ids=[i for i in range(torch.cuda.device_count())])
        
        self.model = self.model.to(self.device)
        
        
        # Observe that all parameters are being optimized
        self.optimizer = optim.RAdam(self.model.parameters(), lr=self.config['lr'])
        
        self.scheduler = get_linear_schedule_with_warmup(self.optimizer, 
                                            num_warmup_steps = len(self.trainloader)*self.config['warmup_steps'], # Default value in run_glue.py
                                            num_training_steps = len(self.trainloader)*self.config['epochs'])

        
        self.criterion = nn.CrossEntropyLoss()
        self.early_stopping = EarlyStopping(patience=patience, verbose=True)
        self.writer = writer
        self.print_itr = print_itr

    def train(self, ep):
        self.model.train()

        running_loss = 0.0

        for en, (visuals, question, target) in tqdm(enumerate(self.trainloader)):
            self.optimizer.zero_grad()
            
            visuals = visuals.to(self.device)
            y = target.squeeze().to(self.device)
            
            outputs = self.model(visuals)
            loss = self.criterion(outputs, y)
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()

            running_loss += loss.item()
            if self.writer:
                self.writer.add_scalar('Train Loss', running_loss, ep*len(self.trainloader) + en)
            running_loss = 0
            

    def validate(self, ep):
        self.model.eval()
        
        running_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for en, (visuals, question, target) in tqdm(enumerate(self.valloader)):
                visuals = visuals.to(self.device)
                y = target.squeeze().to(self.device)
                
                outputs = self.model(visuals)
                loss = self.criterion(outputs, y)

                y_pred_softmax = torch.log_softmax(outputs, dim = 1)
                _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)
                # self.tmp_sv_ = y_pred_tags
                
                correct += (y_pred_tags.detach().cpu().data.numpy() == y.detach().cpu().data.numpy()).sum()
                total += y_pred_tags.shape[0]
                
                # print statistics
                running_loss += loss.item()
        
        
        return running_loss / len(self.valloader), correct*100/total

    def perform_training(self, total_epoch):
        val_loss, acc = self.validate(0)

        print("[Initial Validation results] Loss: {} \t Acc: {}".format(
            val_loss, acc))

        for i in range(total_epoch):
            self.train(i + 1)
            val_loss, acc = self.validate(i + 1)
            print('[{}/{}] Loss: {} \t Acc: {}'.format(i+1, total_epoch, val_loss, acc))

            if self.writer:
                self.writer.add_scalar('Validation Loss', val_loss, (i + 1))
                self.writer.add_scalar('Validation Acc', acc, (i + 1))

            self.early_stopping(-acc, self.model)

            if self.early_stopping.early_stop:
                print("Early stopping")
                break

        print("=" * 20)
        print("Training finished !!")
        print("=" * 20)


In [18]:
model_ft = mednet(config=cfg, max_labels=len(answer_map))
writer = SummaryWriter('runs/resnet18_run_1')
trainer = Trainer(TrainDataLoader, ValDataLoader, model_ft, writer=writer, config=cfg)

Training will be done on  cuda
Let's use 2 GPUs!


In [19]:
trainer.perform_training(cfg['epochs'])

4it [00:03,  1.13it/s]

[Initial Validation results] Loss: 5.813371181488037 	 Acc: 0.6



250it [00:08, 30.68it/s]
4it [00:01,  3.63it/s]


[1/250] Loss: 5.801670432090759 	 Acc: 0.6
Validation loss decreased (inf --> -0.600000).  Saving model ...


250it [00:07, 31.72it/s]
4it [00:01,  3.77it/s]

[2/250] Loss: 5.7813321352005005 	 Acc: 0.2
EarlyStopping counter: 1 out of 10



250it [00:07, 31.39it/s]
4it [00:01,  3.67it/s]


[3/250] Loss: 5.748850226402283 	 Acc: 0.6
Validation loss decreased (-0.600000 --> -0.600000).  Saving model ...


250it [00:08, 31.22it/s]
4it [00:01,  3.67it/s]


[4/250] Loss: 5.704672455787659 	 Acc: 2.0
Validation loss decreased (-0.600000 --> -2.000000).  Saving model ...


250it [00:07, 31.53it/s]
4it [00:01,  3.77it/s]


[5/250] Loss: 5.635384559631348 	 Acc: 3.4
Validation loss decreased (-2.000000 --> -3.400000).  Saving model ...


250it [00:07, 31.39it/s]
4it [00:01,  3.63it/s]


[6/250] Loss: 5.528140306472778 	 Acc: 5.2
Validation loss decreased (-3.400000 --> -5.200000).  Saving model ...


250it [00:07, 31.82it/s]
4it [00:01,  3.75it/s]


[7/250] Loss: 5.426574110984802 	 Acc: 5.6
Validation loss decreased (-5.200000 --> -5.600000).  Saving model ...


250it [00:07, 31.32it/s]
4it [00:01,  3.64it/s]


[8/250] Loss: 5.252291202545166 	 Acc: 6.2
Validation loss decreased (-5.600000 --> -6.200000).  Saving model ...


250it [00:08, 31.11it/s]
4it [00:01,  3.81it/s]

[9/250] Loss: 5.149539589881897 	 Acc: 7.6
Validation loss decreased (-6.200000 --> -7.600000).  Saving model ...



250it [00:08, 31.21it/s]
4it [00:01,  3.78it/s]

[10/250] Loss: 5.029486179351807 	 Acc: 9.0
Validation loss decreased (-7.600000 --> -9.000000).  Saving model ...



250it [00:07, 31.26it/s]
4it [00:01,  3.64it/s]


[11/250] Loss: 4.9453946352005005 	 Acc: 9.2
Validation loss decreased (-9.000000 --> -9.200000).  Saving model ...


250it [00:08, 31.23it/s]
4it [00:01,  3.76it/s]


[12/250] Loss: 4.851119756698608 	 Acc: 9.6
Validation loss decreased (-9.200000 --> -9.600000).  Saving model ...


250it [00:07, 31.41it/s]
4it [00:01,  3.74it/s]

[13/250] Loss: 4.751638174057007 	 Acc: 12.8
Validation loss decreased (-9.600000 --> -12.800000).  Saving model ...



250it [00:07, 31.70it/s]
4it [00:01,  3.77it/s]

[14/250] Loss: 4.726839303970337 	 Acc: 11.6
EarlyStopping counter: 1 out of 10



250it [00:08, 31.15it/s]
4it [00:01,  3.70it/s]

[15/250] Loss: 4.624054193496704 	 Acc: 11.0
EarlyStopping counter: 2 out of 10



250it [00:07, 31.39it/s]
4it [00:01,  3.57it/s]


[16/250] Loss: 4.481132984161377 	 Acc: 13.8
Validation loss decreased (-12.800000 --> -13.800000).  Saving model ...


250it [00:08, 31.09it/s]
4it [00:01,  3.73it/s]


[17/250] Loss: 4.495366096496582 	 Acc: 13.8
Validation loss decreased (-13.800000 --> -13.800000).  Saving model ...


250it [00:08, 30.80it/s]
4it [00:01,  3.56it/s]

[18/250] Loss: 4.421104192733765 	 Acc: 14.6
Validation loss decreased (-13.800000 --> -14.600000).  Saving model ...



250it [00:08, 30.75it/s]
4it [00:01,  3.61it/s]


[19/250] Loss: 4.384242057800293 	 Acc: 15.8
Validation loss decreased (-14.600000 --> -15.800000).  Saving model ...


250it [00:08, 30.84it/s]
4it [00:01,  3.74it/s]

[20/250] Loss: 4.301176428794861 	 Acc: 15.0
EarlyStopping counter: 1 out of 10



250it [00:07, 31.28it/s]
4it [00:01,  3.77it/s]

[21/250] Loss: 4.241750478744507 	 Acc: 17.2
Validation loss decreased (-15.800000 --> -17.200000).  Saving model ...



250it [00:07, 31.46it/s]
4it [00:01,  3.55it/s]


[22/250] Loss: 4.185140073299408 	 Acc: 17.2
Validation loss decreased (-17.200000 --> -17.200000).  Saving model ...


250it [00:07, 31.36it/s]
4it [00:01,  3.75it/s]

[23/250] Loss: 4.143294036388397 	 Acc: 18.4
Validation loss decreased (-17.200000 --> -18.400000).  Saving model ...



250it [00:07, 31.37it/s]
4it [00:01,  3.70it/s]


[24/250] Loss: 4.088014543056488 	 Acc: 19.0
Validation loss decreased (-18.400000 --> -19.000000).  Saving model ...


250it [00:08, 31.16it/s]
4it [00:01,  3.74it/s]

[25/250] Loss: 4.070830702781677 	 Acc: 19.6
Validation loss decreased (-19.000000 --> -19.600000).  Saving model ...



250it [00:07, 31.59it/s]
4it [00:01,  3.77it/s]

[26/250] Loss: 4.036144554615021 	 Acc: 19.4
EarlyStopping counter: 1 out of 10



250it [00:08, 31.18it/s]
4it [00:01,  3.73it/s]

[27/250] Loss: 4.0021931529045105 	 Acc: 19.2
EarlyStopping counter: 2 out of 10



250it [00:08, 31.02it/s]
4it [00:01,  3.76it/s]

[28/250] Loss: 3.9284478425979614 	 Acc: 19.6
Validation loss decreased (-19.600000 --> -19.600000).  Saving model ...



250it [00:08, 30.91it/s]
4it [00:01,  3.60it/s]


[29/250] Loss: 3.9365466237068176 	 Acc: 20.6
Validation loss decreased (-19.600000 --> -20.600000).  Saving model ...


250it [00:07, 31.39it/s]
4it [00:01,  3.82it/s]

[30/250] Loss: 3.9243555068969727 	 Acc: 20.4
EarlyStopping counter: 1 out of 10



250it [00:08, 31.19it/s]
4it [00:01,  3.24it/s]

[31/250] Loss: 3.781720757484436 	 Acc: 22.2
Validation loss decreased (-20.600000 --> -22.200000).  Saving model ...



250it [00:08, 31.18it/s]
4it [00:01,  3.59it/s]

[32/250] Loss: 3.8512572646141052 	 Acc: 19.8
EarlyStopping counter: 1 out of 10



250it [00:07, 31.58it/s]
4it [00:01,  3.60it/s]


[33/250] Loss: 3.7770437002182007 	 Acc: 23.6
Validation loss decreased (-22.200000 --> -23.600000).  Saving model ...


250it [00:08, 30.84it/s]
4it [00:01,  3.61it/s]

[34/250] Loss: 3.7708261013031006 	 Acc: 22.8
EarlyStopping counter: 1 out of 10



250it [00:08, 31.16it/s]
4it [00:01,  3.70it/s]

[35/250] Loss: 3.706676661968231 	 Acc: 23.4
EarlyStopping counter: 2 out of 10



250it [00:08, 31.13it/s]
4it [00:01,  3.76it/s]

[36/250] Loss: 3.7423490285873413 	 Acc: 24.0
Validation loss decreased (-23.600000 --> -24.000000).  Saving model ...



250it [00:08, 31.21it/s]
4it [00:01,  3.61it/s]


[37/250] Loss: 3.6279128789901733 	 Acc: 25.2
Validation loss decreased (-24.000000 --> -25.200000).  Saving model ...


250it [00:07, 31.38it/s]
4it [00:01,  3.66it/s]

[38/250] Loss: 3.641744613647461 	 Acc: 24.2
EarlyStopping counter: 1 out of 10



250it [00:07, 31.59it/s]
4it [00:01,  3.56it/s]


[39/250] Loss: 3.6555274724960327 	 Acc: 23.6
EarlyStopping counter: 2 out of 10


250it [00:08, 31.22it/s]
4it [00:01,  3.77it/s]


[40/250] Loss: 3.626066744327545 	 Acc: 25.2
Validation loss decreased (-25.200000 --> -25.200000).  Saving model ...


250it [00:08, 31.07it/s]
4it [00:01,  3.55it/s]


[41/250] Loss: 3.5805962681770325 	 Acc: 26.4
Validation loss decreased (-25.200000 --> -26.400000).  Saving model ...


250it [00:07, 31.26it/s]
4it [00:01,  3.72it/s]

[42/250] Loss: 3.642753303050995 	 Acc: 25.4
EarlyStopping counter: 1 out of 10



250it [00:08, 30.93it/s]
4it [00:01,  3.74it/s]

[43/250] Loss: 3.640055477619171 	 Acc: 26.0
EarlyStopping counter: 2 out of 10



250it [00:07, 31.27it/s]
4it [00:01,  3.59it/s]

[44/250] Loss: 3.5360246896743774 	 Acc: 26.2
EarlyStopping counter: 3 out of 10



250it [00:08, 31.14it/s]
4it [00:01,  3.75it/s]

[45/250] Loss: 3.5841487646102905 	 Acc: 25.4
EarlyStopping counter: 4 out of 10



250it [00:07, 31.46it/s]
4it [00:01,  3.69it/s]

[46/250] Loss: 3.557334005832672 	 Acc: 27.6
Validation loss decreased (-26.400000 --> -27.600000).  Saving model ...



250it [00:08, 30.98it/s]
4it [00:01,  3.54it/s]


[47/250] Loss: 3.497935116291046 	 Acc: 28.0
Validation loss decreased (-27.600000 --> -28.000000).  Saving model ...


250it [00:08, 30.85it/s]
4it [00:01,  3.60it/s]


[48/250] Loss: 3.4093093872070312 	 Acc: 31.4
Validation loss decreased (-28.000000 --> -31.400000).  Saving model ...


250it [00:08, 30.86it/s]
4it [00:01,  3.55it/s]


[49/250] Loss: 3.440618872642517 	 Acc: 29.2
EarlyStopping counter: 1 out of 10


250it [00:08, 31.25it/s]
4it [00:01,  3.64it/s]

[50/250] Loss: 3.423649489879608 	 Acc: 28.6
EarlyStopping counter: 2 out of 10



250it [00:08, 30.77it/s]
4it [00:01,  3.55it/s]

[51/250] Loss: 3.4509515166282654 	 Acc: 29.6
EarlyStopping counter: 3 out of 10



250it [00:07, 31.43it/s]
4it [00:01,  3.71it/s]

[52/250] Loss: 3.441505551338196 	 Acc: 28.4
EarlyStopping counter: 4 out of 10



250it [00:07, 31.44it/s]
4it [00:01,  3.59it/s]


[53/250] Loss: 3.461735725402832 	 Acc: 28.6
EarlyStopping counter: 5 out of 10


250it [00:07, 31.39it/s]
4it [00:01,  3.54it/s]


[54/250] Loss: 3.4368790984153748 	 Acc: 29.0
EarlyStopping counter: 6 out of 10


250it [00:08, 30.49it/s]
4it [00:01,  3.67it/s]

[55/250] Loss: 3.433984637260437 	 Acc: 29.8
EarlyStopping counter: 7 out of 10



250it [00:07, 31.72it/s]
4it [00:01,  3.59it/s]

[56/250] Loss: 3.3580992221832275 	 Acc: 31.0
EarlyStopping counter: 8 out of 10



250it [00:07, 31.38it/s]
4it [00:01,  3.70it/s]

[57/250] Loss: 3.4080103039741516 	 Acc: 30.8
EarlyStopping counter: 9 out of 10



250it [00:08, 31.05it/s]
4it [00:01,  3.67it/s]

[58/250] Loss: 3.3122188448905945 	 Acc: 31.2
EarlyStopping counter: 10 out of 10
Early stopping
Training finished !!





In [20]:
class medical_dataset_test(Dataset):
    def __init__(self, config=None, answer_map=None, image_path='./dataset/Task1-VQA-2021-TestSet-w-GroundTruth/VQA-500-Images/', qa_file="./dataset/Task1-VQA-2021-TestSet-w-GroundTruth/Task1-VQA-2021-TestSet-Questions.txt", train=True):
            
        assert answer_map!=None
        assert config!=None
        
        self.image_path = image_path
        self.qa_file = qa_file
        self.config = config
        self.train = train
        
        self.answer_map = answer_map
        self.data = pd.read_csv(qa_file, sep='|', names=['imageid', 'question'])
           
        # print(Counter(self.data.answer.tolist()))
        
        self.transforms = transforms.Compose([
                                transforms.RandomResizedCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                            ])
        self.tokenizer = tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")

    def process_data(self, text, max_len):
        text = str(text)

        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=max_len,
            truncation=True
        )
        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]
        
        padding_length = max_len - len(ids)
        
        ids = ids + ([0] * padding_length)
        mask = mask + ([0] * padding_length)
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
        }
    
    
    def __getitem__(self, index):
        
        question = self.data.question[index]
        image_idx = self.data.imageid[index]
        
        ## add visual feature related steps
        visuals = Image.open(join(self.image_path, f'{image_idx}.jpg'))
        visuals = self.transforms(visuals)
        
        
        tmp = self.process_data(question, self.config['max_len'])
        question_tokens = {
            'ids': tmp['ids'],
            'mask': tmp['mask'],
        }
        
        return visuals, question_tokens, image_idx

    def __len__(self):
        return len(self.data)

In [21]:
TestData = medical_dataset_test(config=cfg, answer_map=answer_map)
TestDataLoader = DataLoader(TestData, batch_size=128, shuffle=False, num_workers=4)  # num_workers=0 for windows OS

In [22]:
inv_map = {v: k for k, v in answer_map.items()}

In [23]:
trainer.model.load_state_dict(torch.load('./models/checkpoint.pt'))

<All keys matched successfully>

In [24]:
trainer.model.eval()
        
imageids = []
answers = []

running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
    for en, (visuals, question, target) in tqdm(enumerate(TestDataLoader)):
        visuals = visuals.to(trainer.device)

        outputs = trainer.model(visuals)

        y_pred_softmax = torch.log_softmax(outputs, dim = 1)
        _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)
        # self.tmp_sv_ = y_pred_tags
        
        for i in range(visuals.shape[0]):
            imageids.append(target[i])
            answers.append(inv_map[int(y_pred_tags[i])])


4it [00:01,  3.39it/s]


In [25]:
answers[:5]

['avascular necrosis of the femoral head',
 'pulmonary embolism',
 'rickets',
 'achalasia',
 'tibial plateau fracture']

In [26]:
pd.DataFrame({'imageids':imageids, 'answers':answers}).to_csv('resnet18.txt', sep='|', index=False, header=False)