In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
import numpy as np
import sys
from os import listdir, makedirs, getcwd, remove
from os.path import isfile, join, abspath, exists, isdir, expanduser
from scipy.io import loadmat
from PIL import Image
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms


import pickle
from tqdm import tqdm
import json
import random
from collections import Counter

from transformers import get_linear_schedule_with_warmup
from transformers import AutoTokenizer, AutoModel

In [4]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.autograd as autograd
from torch.utils.data import Dataset, DataLoader
import torch.utils.data as data
import torch.nn.functional as F
import torch.optim as optim
import itertools

import torchvision
from torchvision import transforms, datasets, models
from torch import Tensor

from torch.utils.tensorboard import SummaryWriter

In [5]:
import transformers
import tokenizers
from transformers import BertTokenizer, BertModel

In [6]:
cfg = {
    'max_len': 128,
    'lr': 2e-5,
    'warmup_steps': 5,
    'epochs': 250
}

In [7]:
class medical_dataset(Dataset):
    def __init__(self, config=None, answer_map=None, image_path='./dataset/TrainingDataset/images/', qa_file="./dataset/TrainingDataset/training_qa.txt", train=True):
            
        assert answer_map!=None
        assert config!=None
        
        self.image_path = image_path
        self.qa_file = qa_file
        self.config = config
        self.train = train
        
        self.answer_map = answer_map
        self.data = pd.read_csv(qa_file, sep='|', names=['imageid', 'question', 'answer'])
           
        # print(Counter(self.data.answer.tolist()))
        
        self.transforms = transforms.Compose([
                                transforms.RandomResizedCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                            ])
        self.tokenizer = tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")

    def process_data(self, text, max_len):
        text = str(text)

        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=max_len,
            truncation=True
        )
        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]
        
        padding_length = max_len - len(ids)
        
        ids = ids + ([0] * padding_length)
        mask = mask + ([0] * padding_length)
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
        }
    
    
    def __getitem__(self, index):
        
        question = self.data.question[index]
        answer = self.data.answer[index]
        image_idx = self.data.imageid[index]
        
        ## add visual feature related steps
        visuals = Image.open(join(self.image_path, f'{image_idx}.jpg'))
        visuals = self.transforms(visuals)
        
        target = torch.from_numpy(np.array([self.answer_map[answer]])).long()
        
        tmp = self.process_data(question, self.config['max_len'])
        question_tokens = {
            'ids': tmp['ids'],
            'mask': tmp['mask'],
        }
        
        return visuals, question_tokens, target

    def __len__(self):
        return len(self.data)

In [8]:
# file_path = './dataset/TrainingDataset/training_qa.txt'
# data = pd.read_csv(file_path, sep='|', names=['imageid', 'question', 'answer'])
# data.head()

In [9]:
# answer_map = {}
# ct = 0

# for i in data.answer.unique():
#     answer_map[i] = ct
#     ct+=1

In [10]:
# with open('./answer_map.pickle', 'wb') as handle:
#     pickle.dump(answer_map, handle, protocol=pickle.HIGHEST_PROTOCOL)
    

In [11]:
with open('./answer_map.pickle', 'rb') as handle:
    answer_map = pickle.load(handle)
    
len(answer_map)

333

In [12]:
TrainData = medical_dataset(config=cfg, answer_map=answer_map, image_path='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-TrainingSet/VQAnswering_2020_Train_images/', qa_file='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-TrainingSet/VQAnswering_2020_Train_QA_pairs.txt')
ValData = medical_dataset(config=cfg, answer_map=answer_map, image_path='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-ValidationSet/VQAnswering_2020_Val_images/', qa_file='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-ValidationSet/VQAnswering_2020_Val_QA_Pairs.txt', train=False)

In [13]:
%%time
a,b,c = ValData.__getitem__(0)

CPU times: user 30.7 ms, sys: 0 ns, total: 30.7 ms
Wall time: 27.3 ms


In [14]:
TrainDataLoader = DataLoader(TrainData, batch_size=16, shuffle=True, num_workers=4)  # num_workers=0 for windows OS
ValDataLoader = DataLoader(ValData, batch_size=128, shuffle=False, num_workers=4)  # num_workers=0 for windows OS

In [17]:
class mednet(nn.Module):
    def __init__(self, config, max_labels):
        super(mednet, self).__init__()
        self.vision = models.vgg16(pretrained=True)
        num_ftrs = self.vision.classifier[-1].in_features
        
        self.vision.classifier = nn.Sequential(*list(self.vision.classifier.children())[:-1])
        
        self.fc1 = nn.Linear(num_ftrs, 128)
        self.fc2 = nn.Linear(128, max_labels)
        

    def forward(self, visual=None, ids=None, mask=None):
        h = self.vision(visual).view((visual.shape[0], -1))
        h = F.relu(self.fc1(h))
        h = self.fc2(h)
        return h



In [18]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self,
                 patience=7,
                 verbose=False,
                 delta=0,
                 path='./models/vgg16/checkpoint.pt',
                 trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(
                f'EarlyStopping counter: {self.counter} out of {self.patience}'
            )
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(
                f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...'
            )
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [19]:
class Trainer:
    def __init__(self,
                 trainloader,
                 vallaoder,
                 model_ft,
                 writer=None,
                 testloader=None,
                 checkpoint_path=None,
                 patience=10,
                 feature_extract=True,
                 print_itr=50,
                 config=None):
        self.trainloader = trainloader
        self.valloader = vallaoder
        self.testloader = testloader
        
        self.config=config

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        print("==" * 10)
        print("Training will be done on ", self.device)
        print("==" * 10)

        self.model = model_ft        
        if torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            self.model = nn.DataParallel(self.model, device_ids=[i for i in range(torch.cuda.device_count())])
        
        self.model = self.model.to(self.device)
        
        
        # Observe that all parameters are being optimized
        self.optimizer = optim.RAdam(self.model.parameters(), lr=self.config['lr'])
        
        self.scheduler = get_linear_schedule_with_warmup(self.optimizer, 
                                            num_warmup_steps = len(self.trainloader)*self.config['warmup_steps'], # Default value in run_glue.py
                                            num_training_steps = len(self.trainloader)*self.config['epochs'])

        
        self.criterion = nn.CrossEntropyLoss()
        self.early_stopping = EarlyStopping(patience=patience, verbose=True)
        self.writer = writer
        self.print_itr = print_itr

    def train(self, ep):
        self.model.train()

        running_loss = 0.0

        for en, (visuals, question, target) in tqdm(enumerate(self.trainloader)):
            self.optimizer.zero_grad()
            
            visuals = visuals.to(self.device)
            y = target.squeeze().to(self.device)
            
            outputs = self.model(visuals)
            loss = self.criterion(outputs, y)
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()

            running_loss += loss.item()
            if self.writer:
                self.writer.add_scalar('Train Loss', running_loss, ep*len(self.trainloader) + en)
            running_loss = 0
            

    def validate(self, ep):
        self.model.eval()
        
        running_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for en, (visuals, question, target) in tqdm(enumerate(self.valloader)):
                visuals = visuals.to(self.device)
                y = target.squeeze().to(self.device)
                
                outputs = self.model(visuals)
                loss = self.criterion(outputs, y)

                y_pred_softmax = torch.log_softmax(outputs, dim = 1)
                _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)
                # self.tmp_sv_ = y_pred_tags
                
                correct += (y_pred_tags.detach().cpu().data.numpy() == y.detach().cpu().data.numpy()).sum()
                total += y_pred_tags.shape[0]
                
                # print statistics
                running_loss += loss.item()
        
        
        return running_loss / len(self.valloader), correct*100/total

    def perform_training(self, total_epoch):
        val_loss, acc = self.validate(0)

        print("[Initial Validation results] Loss: {} \t Acc: {}".format(
            val_loss, acc))

        for i in range(total_epoch):
            self.train(i + 1)
            val_loss, acc = self.validate(i + 1)
            print('[{}/{}] Loss: {} \t Acc: {}'.format(i+1, total_epoch, val_loss, acc))

            if self.writer:
                self.writer.add_scalar('Validation Loss', val_loss, (i + 1))
                self.writer.add_scalar('Validation Acc', acc, (i + 1))

            self.early_stopping(-acc, self.model)

            if self.early_stopping.early_stop:
                print("Early stopping")
                break

        print("=" * 20)
        print("Training finished !!")
        print("=" * 20)


In [20]:
model_ft = mednet(config=cfg, max_labels=len(answer_map))
writer = SummaryWriter('runs/vgg16_run_1')
trainer = Trainer(TrainDataLoader, ValDataLoader, model_ft, writer=writer, config=cfg)

Training will be done on  cuda
Let's use 2 GPUs!


In [21]:
trainer.perform_training(cfg['epochs'])

4it [00:03,  1.00it/s]


[Initial Validation results] Loss: 5.801280856132507 	 Acc: 0.6


250it [00:46,  5.40it/s]
4it [00:01,  2.59it/s]


[1/250] Loss: 5.795615077018738 	 Acc: 0.0
Validation loss decreased (inf --> -0.000000).  Saving model ...


250it [00:46,  5.39it/s]
4it [00:01,  2.71it/s]


[2/250] Loss: 5.782869935035706 	 Acc: 0.6
Validation loss decreased (-0.000000 --> -0.600000).  Saving model ...


250it [00:46,  5.40it/s]
4it [00:01,  2.74it/s]


[3/250] Loss: 5.723097085952759 	 Acc: 2.2
Validation loss decreased (-0.600000 --> -2.200000).  Saving model ...


250it [00:46,  5.40it/s]
4it [00:01,  2.75it/s]


[4/250] Loss: 5.604703307151794 	 Acc: 3.2
Validation loss decreased (-2.200000 --> -3.200000).  Saving model ...


250it [00:46,  5.40it/s]
4it [00:01,  2.73it/s]


[5/250] Loss: 5.459057927131653 	 Acc: 5.2
Validation loss decreased (-3.200000 --> -5.200000).  Saving model ...


250it [00:46,  5.40it/s]
4it [00:01,  2.72it/s]


[6/250] Loss: 5.256566643714905 	 Acc: 7.8
Validation loss decreased (-5.200000 --> -7.800000).  Saving model ...


250it [00:46,  5.41it/s]
4it [00:01,  2.71it/s]


[7/250] Loss: 5.047289133071899 	 Acc: 9.6
Validation loss decreased (-7.800000 --> -9.600000).  Saving model ...


250it [00:46,  5.41it/s]
4it [00:01,  2.75it/s]


[8/250] Loss: 4.845716714859009 	 Acc: 11.8
Validation loss decreased (-9.600000 --> -11.800000).  Saving model ...


250it [00:46,  5.40it/s]
4it [00:01,  2.60it/s]


[9/250] Loss: 4.697169661521912 	 Acc: 12.8
Validation loss decreased (-11.800000 --> -12.800000).  Saving model ...


250it [00:46,  5.41it/s]
4it [00:01,  2.77it/s]


[10/250] Loss: 4.569134950637817 	 Acc: 13.8
Validation loss decreased (-12.800000 --> -13.800000).  Saving model ...


250it [00:46,  5.41it/s]
4it [00:01,  2.77it/s]


[11/250] Loss: 4.4537094831466675 	 Acc: 15.8
Validation loss decreased (-13.800000 --> -15.800000).  Saving model ...


250it [00:46,  5.41it/s]
4it [00:01,  2.67it/s]


[12/250] Loss: 4.390807271003723 	 Acc: 17.2
Validation loss decreased (-15.800000 --> -17.200000).  Saving model ...


250it [00:46,  5.41it/s]
4it [00:01,  2.72it/s]


[13/250] Loss: 4.309332489967346 	 Acc: 17.8
Validation loss decreased (-17.200000 --> -17.800000).  Saving model ...


250it [00:46,  5.41it/s]
4it [00:01,  2.72it/s]


[14/250] Loss: 4.163159549236298 	 Acc: 20.4
Validation loss decreased (-17.800000 --> -20.400000).  Saving model ...


250it [00:46,  5.41it/s]
4it [00:01,  2.67it/s]


[15/250] Loss: 4.089449048042297 	 Acc: 21.4
Validation loss decreased (-20.400000 --> -21.400000).  Saving model ...


250it [00:46,  5.41it/s]
4it [00:01,  2.73it/s]


[16/250] Loss: 4.010577499866486 	 Acc: 22.0
Validation loss decreased (-21.400000 --> -22.000000).  Saving model ...


250it [00:46,  5.41it/s]
4it [00:01,  2.61it/s]

[17/250] Loss: 3.9328706860542297 	 Acc: 21.8
EarlyStopping counter: 1 out of 10



250it [00:46,  5.41it/s]
4it [00:01,  2.69it/s]


[18/250] Loss: 3.997576951980591 	 Acc: 22.4
Validation loss decreased (-22.000000 --> -22.400000).  Saving model ...


250it [00:46,  5.42it/s]
4it [00:01,  2.73it/s]


[19/250] Loss: 3.9676690101623535 	 Acc: 24.2
Validation loss decreased (-22.400000 --> -24.200000).  Saving model ...


250it [00:46,  5.41it/s]
4it [00:01,  2.72it/s]

[20/250] Loss: 3.861342966556549 	 Acc: 23.4
EarlyStopping counter: 1 out of 10



250it [00:46,  5.41it/s]
4it [00:01,  2.70it/s]


[21/250] Loss: 3.860013961791992 	 Acc: 25.8
Validation loss decreased (-24.200000 --> -25.800000).  Saving model ...


250it [00:46,  5.41it/s]
4it [00:01,  2.72it/s]


[22/250] Loss: 3.704376995563507 	 Acc: 27.2
Validation loss decreased (-25.800000 --> -27.200000).  Saving model ...


250it [00:46,  5.41it/s]
4it [00:01,  2.73it/s]


[23/250] Loss: 3.6215683817863464 	 Acc: 27.8
Validation loss decreased (-27.200000 --> -27.800000).  Saving model ...


250it [00:46,  5.41it/s]
4it [00:01,  2.74it/s]

[24/250] Loss: 3.734498083591461 	 Acc: 27.4
EarlyStopping counter: 1 out of 10



250it [00:46,  5.42it/s]
4it [00:01,  2.66it/s]


[25/250] Loss: 3.7959346175193787 	 Acc: 29.0
Validation loss decreased (-27.800000 --> -29.000000).  Saving model ...


250it [00:46,  5.42it/s]
4it [00:01,  2.74it/s]


[26/250] Loss: 3.66642165184021 	 Acc: 30.6
Validation loss decreased (-29.000000 --> -30.600000).  Saving model ...


250it [00:46,  5.42it/s]
4it [00:01,  2.71it/s]

[27/250] Loss: 3.7510868906974792 	 Acc: 27.6
EarlyStopping counter: 1 out of 10



250it [00:46,  5.41it/s]
4it [00:01,  2.72it/s]

[28/250] Loss: 3.7046894431114197 	 Acc: 28.0
EarlyStopping counter: 2 out of 10



250it [00:46,  5.42it/s]
4it [00:01,  2.61it/s]


[29/250] Loss: 3.6873714923858643 	 Acc: 32.8
Validation loss decreased (-30.600000 --> -32.800000).  Saving model ...


250it [00:46,  5.41it/s]
4it [00:01,  2.64it/s]

[30/250] Loss: 3.742626190185547 	 Acc: 29.4
EarlyStopping counter: 1 out of 10



250it [00:46,  5.42it/s]
4it [00:01,  2.71it/s]

[31/250] Loss: 3.7021802067756653 	 Acc: 31.4
EarlyStopping counter: 2 out of 10



250it [00:46,  5.42it/s]
4it [00:01,  2.71it/s]


[32/250] Loss: 3.6712191700935364 	 Acc: 33.2
Validation loss decreased (-32.800000 --> -33.200000).  Saving model ...


250it [00:46,  5.41it/s]
4it [00:01,  2.73it/s]


[33/250] Loss: 3.676538825035095 	 Acc: 31.8
EarlyStopping counter: 1 out of 10


250it [00:46,  5.41it/s]
4it [00:01,  2.68it/s]


[34/250] Loss: 3.5738664865493774 	 Acc: 33.8
Validation loss decreased (-33.200000 --> -33.800000).  Saving model ...


250it [00:46,  5.41it/s]
4it [00:01,  2.70it/s]

[35/250] Loss: 3.856591522693634 	 Acc: 32.8
EarlyStopping counter: 1 out of 10



250it [00:46,  5.42it/s]
4it [00:01,  2.74it/s]


[36/250] Loss: 3.473080635070801 	 Acc: 34.6
Validation loss decreased (-33.800000 --> -34.600000).  Saving model ...


250it [00:46,  5.43it/s]
4it [00:01,  2.66it/s]

[37/250] Loss: 3.674593210220337 	 Acc: 33.2
EarlyStopping counter: 1 out of 10



250it [00:46,  5.42it/s]
4it [00:01,  2.73it/s]


[38/250] Loss: 3.5076833963394165 	 Acc: 36.4
Validation loss decreased (-34.600000 --> -36.400000).  Saving model ...


250it [00:46,  5.42it/s]
4it [00:01,  2.62it/s]

[39/250] Loss: 3.537635922431946 	 Acc: 35.0
EarlyStopping counter: 1 out of 10



250it [00:46,  5.42it/s]
4it [00:01,  2.66it/s]

[40/250] Loss: 3.7994144558906555 	 Acc: 34.4
EarlyStopping counter: 2 out of 10



250it [00:46,  5.42it/s]
4it [00:01,  2.71it/s]

[41/250] Loss: 3.8572824597358704 	 Acc: 32.8
EarlyStopping counter: 3 out of 10



250it [00:46,  5.43it/s]
4it [00:01,  2.72it/s]

[42/250] Loss: 3.7248722910881042 	 Acc: 34.0
EarlyStopping counter: 4 out of 10



250it [00:46,  5.43it/s]
4it [00:01,  2.76it/s]

[43/250] Loss: 3.6137933135032654 	 Acc: 35.2
EarlyStopping counter: 5 out of 10



250it [00:46,  5.43it/s]
4it [00:01,  2.70it/s]

[44/250] Loss: 3.7930745482444763 	 Acc: 35.0
EarlyStopping counter: 6 out of 10



250it [00:46,  5.43it/s]
4it [00:01,  2.72it/s]

[45/250] Loss: 3.6609623432159424 	 Acc: 35.6
EarlyStopping counter: 7 out of 10



250it [00:46,  5.42it/s]
4it [00:01,  2.69it/s]

[46/250] Loss: 3.9165297150611877 	 Acc: 34.8
EarlyStopping counter: 8 out of 10



250it [00:46,  5.42it/s]
4it [00:01,  2.70it/s]


[47/250] Loss: 3.8146116733551025 	 Acc: 38.2
Validation loss decreased (-36.400000 --> -38.200000).  Saving model ...


250it [00:46,  5.42it/s]
4it [00:01,  2.72it/s]

[48/250] Loss: 3.937797248363495 	 Acc: 33.8
EarlyStopping counter: 1 out of 10



250it [00:46,  5.43it/s]
4it [00:01,  2.72it/s]

[49/250] Loss: 3.845806658267975 	 Acc: 36.0
EarlyStopping counter: 2 out of 10



250it [00:46,  5.42it/s]
4it [00:01,  2.74it/s]

[50/250] Loss: 3.720928430557251 	 Acc: 36.2
EarlyStopping counter: 3 out of 10



250it [00:46,  5.43it/s]
4it [00:01,  2.73it/s]

[51/250] Loss: 3.8145565390586853 	 Acc: 34.8
EarlyStopping counter: 4 out of 10



250it [00:46,  5.43it/s]
4it [00:01,  2.73it/s]

[52/250] Loss: 3.7569339275360107 	 Acc: 36.0
EarlyStopping counter: 5 out of 10



250it [00:46,  5.43it/s]
4it [00:01,  2.70it/s]

[53/250] Loss: 3.9076655507087708 	 Acc: 36.8
EarlyStopping counter: 6 out of 10



250it [00:46,  5.43it/s]
4it [00:01,  2.71it/s]

[54/250] Loss: 3.8507161140441895 	 Acc: 34.4
EarlyStopping counter: 7 out of 10



250it [00:46,  5.43it/s]
4it [00:01,  2.72it/s]

[55/250] Loss: 3.8025211095809937 	 Acc: 36.8
EarlyStopping counter: 8 out of 10



250it [00:46,  5.43it/s]
4it [00:01,  2.72it/s]

[56/250] Loss: 3.997760832309723 	 Acc: 35.8
EarlyStopping counter: 9 out of 10



250it [00:46,  5.42it/s]
4it [00:01,  2.70it/s]

[57/250] Loss: 4.10494065284729 	 Acc: 35.2
EarlyStopping counter: 10 out of 10
Early stopping
Training finished !!





In [22]:
class medical_dataset_test(Dataset):
    def __init__(self, config=None, answer_map=None, image_path='./dataset/Task1-VQA-2021-TestSet-w-GroundTruth/VQA-500-Images/', qa_file="./dataset/Task1-VQA-2021-TestSet-w-GroundTruth/Task1-VQA-2021-TestSet-Questions.txt", train=True):
            
        assert answer_map!=None
        assert config!=None
        
        self.image_path = image_path
        self.qa_file = qa_file
        self.config = config
        self.train = train
        
        self.answer_map = answer_map
        self.data = pd.read_csv(qa_file, sep='|', names=['imageid', 'question'])
           
        # print(Counter(self.data.answer.tolist()))
        
        self.transforms = transforms.Compose([
                                transforms.RandomResizedCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                            ])
        self.tokenizer = tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")

    def process_data(self, text, max_len):
        text = str(text)

        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=max_len,
            truncation=True
        )
        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]
        
        padding_length = max_len - len(ids)
        
        ids = ids + ([0] * padding_length)
        mask = mask + ([0] * padding_length)
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
        }
    
    
    def __getitem__(self, index):
        
        question = self.data.question[index]
        image_idx = self.data.imageid[index]
        
        ## add visual feature related steps
        visuals = Image.open(join(self.image_path, f'{image_idx}.jpg'))
        visuals = self.transforms(visuals)
        
        
        tmp = self.process_data(question, self.config['max_len'])
        question_tokens = {
            'ids': tmp['ids'],
            'mask': tmp['mask'],
        }
        
        return visuals, question_tokens, image_idx

    def __len__(self):
        return len(self.data)

In [23]:
TestData = medical_dataset_test(config=cfg, answer_map=answer_map)
TestDataLoader = DataLoader(TestData, batch_size=128, shuffle=False, num_workers=4)  # num_workers=0 for windows OS

In [24]:
inv_map = {v: k for k, v in answer_map.items()}

In [25]:
trainer.model.load_state_dict(torch.load('./models/vgg16/checkpoint.pt'))

<All keys matched successfully>

In [26]:
trainer.model.eval()
        
imageids = []
answers = []

running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
    for en, (visuals, question, target) in tqdm(enumerate(TestDataLoader)):
        visuals = visuals.to(trainer.device)

        outputs = trainer.model(visuals)

        y_pred_softmax = torch.log_softmax(outputs, dim = 1)
        _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)
        # self.tmp_sv_ = y_pred_tags
        
        for i in range(visuals.shape[0]):
            imageids.append(target[i])
            answers.append(inv_map[int(y_pred_tags[i])])


4it [00:01,  2.46it/s]


In [27]:
answers[:5]

['avascular necrosis of the femoral head',
 'pulmonary embolism',
 'rickets',
 'carotid artery dissection',
 'bucket handle meniscal tear of the knee']

In [28]:
pd.DataFrame({'imageids':imageids, 'answers':answers}).to_csv('vgg16.txt', sep='|', index=False, header=False)