In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
import numpy as np
import sys
from os import listdir, makedirs, getcwd, remove
from os.path import isfile, join, abspath, exists, isdir, expanduser
from scipy.io import loadmat
from PIL import Image
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms


import pickle
from tqdm import tqdm
import json
import random
from collections import Counter

from transformers import get_linear_schedule_with_warmup
from transformers import AutoTokenizer, AutoModel

In [4]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.autograd as autograd
from torch.utils.data import Dataset, DataLoader
import torch.utils.data as data
import torch.nn.functional as F
import torch.optim as optim
import itertools

import torchvision
from torchvision import transforms, datasets, models
from torch import Tensor

from torch.utils.tensorboard import SummaryWriter

In [5]:
import transformers
import tokenizers
from transformers import BertTokenizer, BertModel

In [6]:
cfg = {
    'max_len': 128,
    'lr': 2e-5,
    'warmup_steps': 5,
    'epochs': 250
}

In [7]:
class medical_dataset(Dataset):
    def __init__(self, config=None, answer_map=None, image_path='./dataset/TrainingDataset/images/', qa_file="./dataset/TrainingDataset/training_qa.txt", train=True):
            
        assert answer_map!=None
        assert config!=None
        
        self.image_path = image_path
        self.qa_file = qa_file
        self.config = config
        self.train = train
        
        self.answer_map = answer_map
        self.data = pd.read_csv(qa_file, sep='|', names=['imageid', 'question', 'answer'])
           
        # print(Counter(self.data.answer.tolist()))
        
        self.transforms = transforms.Compose([
                                transforms.RandomResizedCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                            ])
        self.tokenizer = tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")

    def process_data(self, text, max_len):
        text = str(text)

        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=max_len,
            truncation=True
        )
        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]
        
        padding_length = max_len - len(ids)
        
        ids = ids + ([0] * padding_length)
        mask = mask + ([0] * padding_length)
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
        }
    
    
    def __getitem__(self, index):
        
        question = self.data.question[index]
        answer = self.data.answer[index]
        image_idx = self.data.imageid[index]
        
        ## add visual feature related steps
        visuals = Image.open(join(self.image_path, f'{image_idx}.jpg'))
        visuals = self.transforms(visuals)
        
        target = torch.from_numpy(np.array([self.answer_map[answer]])).long()
        
        tmp = self.process_data(question, self.config['max_len'])
        question_tokens = {
            'ids': tmp['ids'],
            'mask': tmp['mask'],
        }
        
        return visuals, question_tokens, target

    def __len__(self):
        return len(self.data)

In [8]:
# file_path = './dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-TrainingSet/VQAnswering_2020_Train_QA_pairs.txt'
# data = pd.read_csv(file_path, sep='|', names=['imageid', 'question', 'answer'])
# data.head()

In [9]:
# answer_map = {}
# ct = 0

# for i in data.answer.unique():
#     answer_map[i] = ct
#     ct+=1

In [10]:
# with open('./answer_map.pickle', 'wb') as handle:
#     pickle.dump(answer_map, handle, protocol=pickle.HIGHEST_PROTOCOL)
    

In [11]:
with open('./answer_map.pickle', 'rb') as handle:
    answer_map = pickle.load(handle)
    
len(answer_map)

333

In [12]:
TrainData = medical_dataset(config=cfg, answer_map=answer_map, image_path='./dataset/TrainingDataset/images/', qa_file='./dataset/TrainingDataset/training_qa.txt')
ValData = medical_dataset(config=cfg, answer_map=answer_map, image_path='./dataset/VQA-Med-2021-Tasks-1-2-NewValidationSets/ImageCLEF-2021-VQA-Med-New-Validation-Images/', qa_file='./dataset/VQA-Med-2021-Tasks-1-2-NewValidationSets/VQA-Med-2021-VQAnswering-Task1-New-ValidationSet.txt', train=False)


In [13]:
%%time
a,b,c = ValData.__getitem__(0)

CPU times: user 34.8 ms, sys: 0 ns, total: 34.8 ms
Wall time: 31.5 ms


In [14]:
TrainDataLoader = DataLoader(TrainData, batch_size=16, shuffle=True, num_workers=4)  # num_workers=0 for windows OS
ValDataLoader = DataLoader(ValData, batch_size=128, shuffle=False, num_workers=4)  # num_workers=0 for windows OS

In [15]:
#3rd Fusion Module
class MLBFusion(nn.Module):

    def __init__(self, dim_v=None, dim_q=None, dim_h=512):
        super(MLBFusion, self).__init__()
        # Modules
        self.dim_v = dim_v
        self.dim_q = dim_q
        self.dim_h = dim_h
        self.dropout = 0
        self.activation = 'relu'
        
        if dim_v:
            self.linear_v = nn.Linear(dim_v, dim_h)
        else:
            print('Warning fusion.py: no visual embedding before fusion')

        if dim_q:
            self.linear_q = nn.Linear(dim_q, dim_h)
        else:
            print('Warning fusion.py: no question embedding before fusion')
        
    def forward(self, input_v, input_q):
        # visual (cnn features)
        if self.dim_v:
            x_v = F.dropout(input_v, p=self.dropout, training=self.training)
            x_v = self.linear_v(x_v)
            x_v = getattr(F, self.activation)(x_v)
        else:
            x_v = input_v
        # question (rnn features)
        if self.dim_q:
            x_q = F.dropout(input_q, p=self.dropout, training=self.training)
            x_q = self.linear_q(x_q)
            x_q = getattr(F, self.activation)(x_q)
        else:
            x_q = input_q
        # hadamard product
        x_mm = torch.mul(x_q, x_v)
        return x_mm

In [16]:
class mednet(nn.Module):
    def __init__(self, config, max_labels):
        super(mednet, self).__init__()
        self.vision = models.resnet18(pretrained=True)
        num_ftrs = self.vision.fc.in_features
        
        self.vision = nn.Sequential(*list(self.vision.children())[:-1])
        self.bert = AutoModel.from_pretrained("dmis-lab/biobert-v1.1")
        
        print(f"Number of vision filters: {num_ftrs}")
        self.fc1 = nn.Linear(512, 128)
        self.fc2 = nn.Linear(128, max_labels)
        
        self.fusion = MLBFusion(dim_v=num_ftrs, dim_q=768, dim_h=512)
        

    def forward(self, visual=None, ids=None, mask=None):
        vision = self.vision(visual).view((ids.shape[0], -1))
        bert_out = self.bert(ids, mask)
        
        h = self.fusion(vision, bert_out.last_hidden_state[:,0])
        h = F.relu(self.fc1(h))
        h = self.fc2(h)
        return h


In [17]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self,
                 patience=7,
                 verbose=False,
                 delta=0,
                 path='./models/vgg16_fusion-MLB/checkpoint.pt',
                 trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(
                f'EarlyStopping counter: {self.counter} out of {self.patience}'
            )
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(
                f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...'
            )
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [18]:
class Trainer:
    def __init__(self,
                 trainloader,
                 vallaoder,
                 model_ft,
                 writer=None,
                 testloader=None,
                 checkpoint_path=None,
                 patience=10,
                 feature_extract=True,
                 print_itr=50,
                 config=None):
        self.trainloader = trainloader
        self.valloader = vallaoder
        self.testloader = testloader
        
        self.config=config

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        print("==" * 10)
        print("Training will be done on ", self.device)
        print("==" * 10)

        self.model = model_ft        
        if torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            self.model = nn.DataParallel(self.model, device_ids=[0,1])
        
        self.model = self.model.to(self.device)
        
        
        # Observe that all parameters are being optimized
        self.optimizer = optim.RAdam(self.model.parameters(), lr=self.config['lr'])
        
        self.scheduler = get_linear_schedule_with_warmup(self.optimizer, 
                                            num_warmup_steps = len(self.trainloader)*self.config['warmup_steps'], # Default value in run_glue.py
                                            num_training_steps = len(self.trainloader)*self.config['epochs'])

        
        self.criterion = nn.CrossEntropyLoss()
        self.early_stopping = EarlyStopping(patience=patience, verbose=True)
        self.writer = writer
        self.print_itr = print_itr

    def train(self, ep):
        self.model.train()

        running_loss = 0.0

        for en, (visuals, question, target) in tqdm(enumerate(self.trainloader)):
            self.optimizer.zero_grad()
            
            visuals = visuals.to(self.device)
            y = target.squeeze().to(self.device)
            
            ids = question['ids'].to(self.device)
            mask = question['mask'].to(self.device)

            outputs = self.model(visuals, ids, mask)
            loss = self.criterion(outputs, y)
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()

            running_loss += loss.item()
            if self.writer:
                self.writer.add_scalar('Train Loss', running_loss, ep*len(self.trainloader) + en)
            running_loss = 0
            

    def validate(self, ep):
        self.model.eval()
        
        running_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for en, (visuals, question, target) in tqdm(enumerate(self.valloader)):
                visuals = visuals.to(self.device)
                y = target.squeeze().to(self.device)
                
                ids = question['ids'].to(self.device)
                mask = question['mask'].to(self.device)
                
                outputs = self.model(visuals, ids, mask)
                loss = self.criterion(outputs, y)

                y_pred_softmax = torch.log_softmax(outputs, dim = 1)
                _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)
                # self.tmp_sv_ = y_pred_tags
                
                correct += (y_pred_tags.detach().cpu().data.numpy() == y.detach().cpu().data.numpy()).sum()
                total += y_pred_tags.shape[0]
                
                # print statistics
                running_loss += loss.item()
        
        
        return running_loss / len(self.valloader), correct*100/total

    def perform_training(self, total_epoch):
        val_loss, acc = self.validate(0)

        print("[Initial Validation results] Loss: {} \t Acc: {}".format(
            val_loss, acc))

        for i in range(total_epoch):
            self.train(i + 1)
            val_loss, acc = self.validate(i + 1)
            print('[{}/{}] Loss: {} \t Acc: {}'.format(i+1, total_epoch, val_loss, acc))

            if self.writer:
                self.writer.add_scalar('Validation Loss', val_loss, (i + 1))
                self.writer.add_scalar('Validation Acc', acc, (i + 1))

            self.early_stopping(-acc, self.model)

            if self.early_stopping.early_stop:
                print("Early stopping")
                break

        print("=" * 20)
        print("Training finished !!")
        print("=" * 20)


In [19]:
model_ft = mednet(config=cfg, max_labels=len(answer_map))
writer = SummaryWriter('runs/vgg16_fusion-MLB_run_2')
trainer = Trainer(TrainDataLoader, ValDataLoader, model_ft, writer=writer, config=cfg)

Number of vision filters: 512
Training will be done on  cuda
Let's use 2 GPUs!


In [20]:
trainer.perform_training(cfg['epochs'])

4it [00:04,  1.04s/it]

[Initial Validation results] Loss: 5.811639308929443 	 Acc: 0.2



282it [00:52,  5.33it/s]
4it [00:01,  2.64it/s]


[1/250] Loss: 5.809322834014893 	 Acc: 0.4
Validation loss decreased (inf --> -0.400000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.65it/s]

[2/250] Loss: 5.801398038864136 	 Acc: 0.2
EarlyStopping counter: 1 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.66it/s]


[3/250] Loss: 5.759962797164917 	 Acc: 1.8
Validation loss decreased (-0.400000 --> -1.800000).  Saving model ...


282it [00:52,  5.32it/s]
4it [00:01,  2.63it/s]

[4/250] Loss: 5.681699633598328 	 Acc: 1.6
EarlyStopping counter: 1 out of 10



282it [00:53,  5.32it/s]
4it [00:01,  2.64it/s]


[5/250] Loss: 5.554576277732849 	 Acc: 2.8
Validation loss decreased (-1.800000 --> -2.800000).  Saving model ...


282it [00:53,  5.30it/s]
4it [00:01,  2.62it/s]


[6/250] Loss: 5.412758827209473 	 Acc: 4.2
Validation loss decreased (-2.800000 --> -4.200000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.63it/s]


[7/250] Loss: 5.264324307441711 	 Acc: 6.4
Validation loss decreased (-4.200000 --> -6.400000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.69it/s]


[8/250] Loss: 5.035277724266052 	 Acc: 8.0
Validation loss decreased (-6.400000 --> -8.000000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.66it/s]


[9/250] Loss: 4.916660904884338 	 Acc: 9.2
Validation loss decreased (-8.000000 --> -9.200000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.63it/s]


[10/250] Loss: 4.71006453037262 	 Acc: 11.2
Validation loss decreased (-9.200000 --> -11.200000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.65it/s]

[11/250] Loss: 4.602700710296631 	 Acc: 10.8
EarlyStopping counter: 1 out of 10



282it [00:53,  5.32it/s]
4it [00:01,  2.63it/s]


[12/250] Loss: 4.424753904342651 	 Acc: 11.8
Validation loss decreased (-11.200000 --> -11.800000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.55it/s]


[13/250] Loss: 4.3568501472473145 	 Acc: 14.2
Validation loss decreased (-11.800000 --> -14.200000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.65it/s]


[14/250] Loss: 4.247000098228455 	 Acc: 15.4
Validation loss decreased (-14.200000 --> -15.400000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.61it/s]


[15/250] Loss: 4.10185307264328 	 Acc: 17.0
Validation loss decreased (-15.400000 --> -17.000000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.63it/s]


[16/250] Loss: 4.098338961601257 	 Acc: 17.4
Validation loss decreased (-17.000000 --> -17.400000).  Saving model ...


282it [00:53,  5.30it/s]
4it [00:01,  2.63it/s]


[17/250] Loss: 3.98421847820282 	 Acc: 20.8
Validation loss decreased (-17.400000 --> -20.800000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.63it/s]


[18/250] Loss: 3.9383188486099243 	 Acc: 23.4
Validation loss decreased (-20.800000 --> -23.400000).  Saving model ...


282it [00:53,  5.30it/s]
4it [00:01,  2.60it/s]

[19/250] Loss: 3.9216880798339844 	 Acc: 22.8
EarlyStopping counter: 1 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.65it/s]


[20/250] Loss: 3.795299708843231 	 Acc: 23.4
Validation loss decreased (-23.400000 --> -23.400000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.61it/s]


[21/250] Loss: 3.8401907682418823 	 Acc: 24.0
Validation loss decreased (-23.400000 --> -24.000000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.65it/s]


[22/250] Loss: 3.641884982585907 	 Acc: 26.2
Validation loss decreased (-24.000000 --> -26.200000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.61it/s]

[23/250] Loss: 3.8403669595718384 	 Acc: 23.6
EarlyStopping counter: 1 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.62it/s]

[24/250] Loss: 3.771163046360016 	 Acc: 26.0
EarlyStopping counter: 2 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.61it/s]


[25/250] Loss: 3.7951148748397827 	 Acc: 26.6
Validation loss decreased (-26.200000 --> -26.600000).  Saving model ...


282it [00:53,  5.30it/s]
4it [00:01,  2.61it/s]


[26/250] Loss: 3.710276782512665 	 Acc: 27.8
Validation loss decreased (-26.600000 --> -27.800000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.62it/s]


[27/250] Loss: 3.606103241443634 	 Acc: 28.0
Validation loss decreased (-27.800000 --> -28.000000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.63it/s]


[28/250] Loss: 3.505462884902954 	 Acc: 30.8
Validation loss decreased (-28.000000 --> -30.800000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.61it/s]

[29/250] Loss: 3.555931568145752 	 Acc: 29.6
EarlyStopping counter: 1 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.63it/s]

[30/250] Loss: 3.6270081400871277 	 Acc: 28.4
EarlyStopping counter: 2 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.61it/s]

[31/250] Loss: 3.726452112197876 	 Acc: 29.2
EarlyStopping counter: 3 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.61it/s]

[32/250] Loss: 3.661178410053253 	 Acc: 29.8
EarlyStopping counter: 4 out of 10



282it [00:53,  5.27it/s]
4it [00:01,  2.59it/s]


[33/250] Loss: 3.5833767652511597 	 Acc: 30.8
Validation loss decreased (-30.800000 --> -30.800000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.61it/s]


[34/250] Loss: 3.6845592856407166 	 Acc: 30.8
Validation loss decreased (-30.800000 --> -30.800000).  Saving model ...


282it [00:53,  5.30it/s]
4it [00:01,  2.61it/s]


[35/250] Loss: 3.5860692858695984 	 Acc: 31.8
Validation loss decreased (-30.800000 --> -31.800000).  Saving model ...


282it [00:53,  5.29it/s]
4it [00:01,  2.61it/s]


[36/250] Loss: 3.641350507736206 	 Acc: 32.2
Validation loss decreased (-31.800000 --> -32.200000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.65it/s]


[37/250] Loss: 3.622658669948578 	 Acc: 32.8
Validation loss decreased (-32.200000 --> -32.800000).  Saving model ...


282it [00:53,  5.30it/s]
4it [00:01,  2.64it/s]

[38/250] Loss: 3.666078507900238 	 Acc: 32.0
EarlyStopping counter: 1 out of 10



282it [00:53,  5.32it/s]
4it [00:01,  2.62it/s]


[39/250] Loss: 3.5475263595581055 	 Acc: 34.2
Validation loss decreased (-32.800000 --> -34.200000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.61it/s]


[40/250] Loss: 3.5096710324287415 	 Acc: 34.8
Validation loss decreased (-34.200000 --> -34.800000).  Saving model ...


282it [00:53,  5.32it/s]
4it [00:01,  2.62it/s]

[41/250] Loss: 3.6001172065734863 	 Acc: 33.0
EarlyStopping counter: 1 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.61it/s]


[42/250] Loss: 3.6327435970306396 	 Acc: 35.8
Validation loss decreased (-34.800000 --> -35.800000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.63it/s]

[43/250] Loss: 3.661168873310089 	 Acc: 35.4
EarlyStopping counter: 1 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.62it/s]

[44/250] Loss: 3.725752294063568 	 Acc: 35.0
EarlyStopping counter: 2 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.66it/s]

[45/250] Loss: 3.693035840988159 	 Acc: 33.8
EarlyStopping counter: 3 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.61it/s]

[46/250] Loss: 3.701202869415283 	 Acc: 35.4
EarlyStopping counter: 4 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.63it/s]

[47/250] Loss: 3.591647744178772 	 Acc: 35.4
EarlyStopping counter: 5 out of 10



282it [00:53,  5.32it/s]
4it [00:01,  2.61it/s]

[48/250] Loss: 3.579277455806732 	 Acc: 35.0
EarlyStopping counter: 6 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.61it/s]

[49/250] Loss: 3.6533467173576355 	 Acc: 35.4
EarlyStopping counter: 7 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.64it/s]

[50/250] Loss: 3.703054130077362 	 Acc: 33.4
EarlyStopping counter: 8 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.55it/s]

[51/250] Loss: 3.575757682323456 	 Acc: 35.2
EarlyStopping counter: 9 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.64it/s]

[52/250] Loss: 4.020175993442535 	 Acc: 32.4
EarlyStopping counter: 10 out of 10
Early stopping
Training finished !!





In [21]:
class medical_dataset_test(Dataset):
    def __init__(self, config=None, answer_map=None, image_path='./dataset/Task1-VQA-2021-TestSet-w-GroundTruth/VQA-500-Images/', qa_file="./dataset/Task1-VQA-2021-TestSet-w-GroundTruth/Task1-VQA-2021-TestSet-Questions.txt", train=True):
            
        assert answer_map!=None
        assert config!=None
        
        self.image_path = image_path
        self.qa_file = qa_file
        self.config = config
        self.train = train
        
        self.answer_map = answer_map
        self.data = pd.read_csv(qa_file, sep='|', names=['imageid', 'question'])
           
        # print(Counter(self.data.answer.tolist()))
        
        self.transforms = transforms.Compose([
                                transforms.RandomResizedCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                            ])
        self.tokenizer = tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")

    def process_data(self, text, max_len):
        text = str(text)

        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=max_len,
            truncation=True
        )
        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]
        
        padding_length = max_len - len(ids)
        
        ids = ids + ([0] * padding_length)
        mask = mask + ([0] * padding_length)
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
        }
    
    
    def __getitem__(self, index):
        
        question = self.data.question[index]
        image_idx = self.data.imageid[index]
        
        ## add visual feature related steps
        visuals = Image.open(join(self.image_path, f'{image_idx}.jpg'))
        visuals = self.transforms(visuals)
        
        
        tmp = self.process_data(question, self.config['max_len'])
        question_tokens = {
            'ids': tmp['ids'],
            'mask': tmp['mask'],
        }
        
        return visuals, question_tokens, image_idx

    def __len__(self):
        return len(self.data)

In [22]:
TestData = medical_dataset_test(config=cfg, answer_map=answer_map)
TestDataLoader = DataLoader(TestData, batch_size=128, shuffle=False, num_workers=4)  # num_workers=0 for windows OS

In [23]:
inv_map = {v: k for k, v in answer_map.items()}

In [24]:
trainer.model.load_state_dict(torch.load('./models/vgg16_fusion-MLB/checkpoint.pt'))
trainer.model.eval()
        
imageids = []
answers = []

running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
    for en, (visuals, question, target) in tqdm(enumerate(TestDataLoader)):
        visuals = visuals.to(trainer.device)

        ids = question['ids'].to(trainer.device)
        mask = question['mask'].to(trainer.device)

        outputs = trainer.model(visuals, ids, mask)

        y_pred_softmax = torch.log_softmax(outputs, dim = 1)
        _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)
        # self.tmp_sv_ = y_pred_tags
        
        for i in range(visuals.shape[0]):
            imageids.append(target[i])
            answers.append(inv_map[int(y_pred_tags[i])])


4it [00:01,  2.48it/s]


In [25]:
answers[:5]

['simple bone cyst',
 'pulmonary embolus',
 'rickets',
 'carotid artery dissection',
 'bucket handle meniscal tear of the knee']

In [26]:
pd.DataFrame({'imageids':imageids, 'answers':answers}).to_csv('vgg16_fusion-MLB.txt', sep='|', index=False, header=False)