In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
import numpy as np
import sys
from os import listdir, makedirs, getcwd, remove
from os.path import isfile, join, abspath, exists, isdir, expanduser
from scipy.io import loadmat
from PIL import Image
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms


import pickle
from tqdm import tqdm
import json
import random
from collections import Counter

from transformers import get_linear_schedule_with_warmup
from transformers import AutoTokenizer, AutoModel

In [4]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.autograd as autograd
from torch.utils.data import Dataset, DataLoader
import torch.utils.data as data
import torch.nn.functional as F
import torch.optim as optim
import itertools

import torchvision
from torchvision import transforms, datasets, models
from torch import Tensor

from torch.utils.tensorboard import SummaryWriter

In [5]:
import transformers
import tokenizers
from transformers import BertTokenizer, BertModel

In [6]:
cfg = {
    'max_len': 128,
    'lr': 2e-5,
    'warmup_steps': 5,
    'epochs': 250
}

In [7]:
class medical_dataset(Dataset):
    def __init__(self, config=None, answer_map=None, image_path='./dataset/TrainingDataset/images/', qa_file="./dataset/TrainingDataset/training_qa.txt", train=True):
            
        assert answer_map!=None
        assert config!=None
        
        self.image_path = image_path
        self.qa_file = qa_file
        self.config = config
        self.train = train
        
        self.answer_map = answer_map
        self.data = pd.read_csv(qa_file, sep='|', names=['imageid', 'question', 'answer'])
           
        # print(Counter(self.data.answer.tolist()))
        
        self.transforms = transforms.Compose([
                                transforms.RandomResizedCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                            ])
        self.tokenizer = tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")

    def process_data(self, text, max_len):
        text = str(text)

        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=max_len,
            truncation=True
        )
        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]
        
        padding_length = max_len - len(ids)
        
        ids = ids + ([0] * padding_length)
        mask = mask + ([0] * padding_length)
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
        }
    
    
    def __getitem__(self, index):
        
        question = self.data.question[index]
        answer = self.data.answer[index]
        image_idx = self.data.imageid[index]
        
        ## add visual feature related steps
        visuals = Image.open(join(self.image_path, f'{image_idx}.jpg'))
        visuals = self.transforms(visuals)
        
        target = torch.from_numpy(np.array([self.answer_map[answer]])).long()
        
        tmp = self.process_data(question, self.config['max_len'])
        question_tokens = {
            'ids': tmp['ids'],
            'mask': tmp['mask'],
        }
        
        return visuals, question_tokens, target

    def __len__(self):
        return len(self.data)

In [8]:
# file_path = './dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-TrainingSet/VQAnswering_2020_Train_QA_pairs.txt'
# data = pd.read_csv(file_path, sep='|', names=['imageid', 'question', 'answer'])
# data.head()

In [9]:
# answer_map = {}
# ct = 0

# for i in data.answer.unique():
#     answer_map[i] = ct
#     ct+=1

In [10]:
# with open('./answer_map.pickle', 'wb') as handle:
#     pickle.dump(answer_map, handle, protocol=pickle.HIGHEST_PROTOCOL)
    

In [11]:
with open('./answer_map.pickle', 'rb') as handle:
    answer_map = pickle.load(handle)
    
len(answer_map)

333

In [12]:
TrainData = medical_dataset(config=cfg, answer_map=answer_map, image_path='./dataset/TrainingDataset/images/', qa_file='./dataset/TrainingDataset/training_qa.txt')
ValData = medical_dataset(config=cfg, answer_map=answer_map, image_path='./dataset/VQA-Med-2021-Tasks-1-2-NewValidationSets/ImageCLEF-2021-VQA-Med-New-Validation-Images/', qa_file='./dataset/VQA-Med-2021-Tasks-1-2-NewValidationSets/VQA-Med-2021-VQAnswering-Task1-New-ValidationSet.txt', train=False)


In [13]:
%%time
a,b,c = ValData.__getitem__(0)

CPU times: user 12 ms, sys: 9.53 ms, total: 21.6 ms
Wall time: 38 ms


In [14]:
TrainDataLoader = DataLoader(TrainData, batch_size=16, shuffle=True, num_workers=4)  # num_workers=0 for windows OS
ValDataLoader = DataLoader(ValData, batch_size=128, shuffle=False, num_workers=4)  # num_workers=0 for windows OS

In [15]:
#3rd Fusion Module
class ElementsumFusion(nn.Module):

    def __init__(self, dim_v=None, dim_q=None, dim_h=512):
        super(ElementsumFusion, self).__init__()
        # Modules
        self.dim_v = dim_v
        self.dim_q = dim_q
        self.dim_h = dim_h
        self.dropout = 0
        self.activation = 'relu'
        
        if dim_v:
            self.linear_v = nn.Linear(dim_v, dim_h)
        else:
            print('Warning fusion.py: no visual embedding before fusion')

        if dim_q:
            self.linear_q = nn.Linear(dim_q, dim_h)
        else:
            print('Warning fusion.py: no question embedding before fusion')
        
    def forward(self, input_v, input_q):
        # visual (cnn features)
        if self.dim_v:
            x_v = F.dropout(input_v, p=self.dropout, training=self.training)
            x_v = self.linear_v(x_v)
            x_v = getattr(F, self.activation)(x_v)
        else:
            x_v = input_v
        # question (rnn features)
        if self.dim_q:
            x_q = F.dropout(input_q, p=self.dropout, training=self.training)
            x_q = self.linear_q(x_q)
            x_q = getattr(F, self.activation)(x_q)
        else:
            x_q = input_q
        # hadamard product
        x_mm = torch.add(x_v, 1, x_q)
        return x_mm

In [16]:
class mednet(nn.Module):
    def __init__(self, config, max_labels):
        super(mednet, self).__init__()
        self.vision = models.resnet18(pretrained=True)
        num_ftrs = self.vision.fc.in_features
        
        self.vision = nn.Sequential(*list(self.vision.children())[:-1])
        self.bert = AutoModel.from_pretrained("dmis-lab/biobert-v1.1")
        
        print(f"Number of vision filters: {num_ftrs}")
        self.fc1 = nn.Linear(512, 128)
        self.fc2 = nn.Linear(128, max_labels)
        
        self.fusion = ElementsumFusion(dim_v=num_ftrs, dim_q=768, dim_h=512)
        

    def forward(self, visual=None, ids=None, mask=None):
        vision = self.vision(visual).view((ids.shape[0], -1))
        bert_out = self.bert(ids, mask)
        
        h = self.fusion(vision, bert_out.last_hidden_state[:,0])
        h = F.relu(self.fc1(h))
        h = self.fc2(h)
        return h


In [17]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self,
                 patience=7,
                 verbose=False,
                 delta=0,
                 path='./models/resnet18_fusion-elementsum/checkpoint.pt',
                 trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(
                f'EarlyStopping counter: {self.counter} out of {self.patience}'
            )
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(
                f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...'
            )
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [18]:
class Trainer:
    def __init__(self,
                 trainloader,
                 vallaoder,
                 model_ft,
                 writer=None,
                 testloader=None,
                 checkpoint_path=None,
                 patience=10,
                 feature_extract=True,
                 print_itr=50,
                 config=None):
        self.trainloader = trainloader
        self.valloader = vallaoder
        self.testloader = testloader
        
        self.config=config

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        print("==" * 10)
        print("Training will be done on ", self.device)
        print("==" * 10)

        self.model = model_ft        
        if torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            self.model = nn.DataParallel(self.model, device_ids=[0,1])
        
        self.model = self.model.to(self.device)
        
        
        # Observe that all parameters are being optimized
        self.optimizer = optim.RAdam(self.model.parameters(), lr=self.config['lr'])
        
        self.scheduler = get_linear_schedule_with_warmup(self.optimizer, 
                                            num_warmup_steps = len(self.trainloader)*self.config['warmup_steps'], # Default value in run_glue.py
                                            num_training_steps = len(self.trainloader)*self.config['epochs'])

        
        self.criterion = nn.CrossEntropyLoss()
        self.early_stopping = EarlyStopping(patience=patience, verbose=True)
        self.writer = writer
        self.print_itr = print_itr

    def train(self, ep):
        self.model.train()

        running_loss = 0.0

        for en, (visuals, question, target) in tqdm(enumerate(self.trainloader)):
            self.optimizer.zero_grad()
            
            visuals = visuals.to(self.device)
            y = target.squeeze().to(self.device)
            
            ids = question['ids'].to(self.device)
            mask = question['mask'].to(self.device)

            outputs = self.model(visuals, ids, mask)
            loss = self.criterion(outputs, y)
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()

            running_loss += loss.item()
            if self.writer:
                self.writer.add_scalar('Train Loss', running_loss, ep*len(self.trainloader) + en)
            running_loss = 0
            

    def validate(self, ep):
        self.model.eval()
        
        running_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for en, (visuals, question, target) in tqdm(enumerate(self.valloader)):
                visuals = visuals.to(self.device)
                y = target.squeeze().to(self.device)
                
                ids = question['ids'].to(self.device)
                mask = question['mask'].to(self.device)
                
                outputs = self.model(visuals, ids, mask)
                loss = self.criterion(outputs, y)

                y_pred_softmax = torch.log_softmax(outputs, dim = 1)
                _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)
                # self.tmp_sv_ = y_pred_tags
                
                correct += (y_pred_tags.detach().cpu().data.numpy() == y.detach().cpu().data.numpy()).sum()
                total += y_pred_tags.shape[0]
                
                # print statistics
                running_loss += loss.item()
        
        
        return running_loss / len(self.valloader), correct*100/total

    def perform_training(self, total_epoch):
        val_loss, acc = self.validate(0)

        print("[Initial Validation results] Loss: {} \t Acc: {}".format(
            val_loss, acc))

        for i in range(total_epoch):
            self.train(i + 1)
            val_loss, acc = self.validate(i + 1)
            print('[{}/{}] Loss: {} \t Acc: {}'.format(i+1, total_epoch, val_loss, acc))

            if self.writer:
                self.writer.add_scalar('Validation Loss', val_loss, (i + 1))
                self.writer.add_scalar('Validation Acc', acc, (i + 1))

            self.early_stopping(-acc, self.model)

            if self.early_stopping.early_stop:
                print("Early stopping")
                break

        print("=" * 20)
        print("Training finished !!")
        print("=" * 20)


In [20]:
model_ft = mednet(config=cfg, max_labels=len(answer_map))
writer = SummaryWriter('runs/resnet18_fusion-elementsum_run_1')
trainer = Trainer(TrainDataLoader, ValDataLoader, model_ft, writer=writer, config=cfg)

Number of vision filters: 512
Training will be done on  cuda
Let's use 2 GPUs!


In [21]:
trainer.perform_training(cfg['epochs'])

4it [00:04,  1.04s/it]

[Initial Validation results] Loss: 5.8221728801727295 	 Acc: 0.0



282it [00:52,  5.33it/s]
4it [00:01,  2.53it/s]


[1/250] Loss: 5.8135682344436646 	 Acc: 0.2
Validation loss decreased (inf --> -0.200000).  Saving model ...


282it [00:53,  5.32it/s]
4it [00:01,  2.54it/s]


[2/250] Loss: 5.791120648384094 	 Acc: 0.4
Validation loss decreased (-0.200000 --> -0.400000).  Saving model ...


282it [00:52,  5.33it/s]
4it [00:01,  2.56it/s]


[3/250] Loss: 5.766025900840759 	 Acc: 1.0
Validation loss decreased (-0.400000 --> -1.000000).  Saving model ...


282it [00:53,  5.32it/s]
4it [00:01,  2.54it/s]


[4/250] Loss: 5.720999240875244 	 Acc: 1.2
Validation loss decreased (-1.000000 --> -1.200000).  Saving model ...


282it [00:53,  5.32it/s]
4it [00:01,  2.53it/s]


[5/250] Loss: 5.644358992576599 	 Acc: 2.4
Validation loss decreased (-1.200000 --> -2.400000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.55it/s]


[6/250] Loss: 5.536285042762756 	 Acc: 4.0
Validation loss decreased (-2.400000 --> -4.000000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.59it/s]


[7/250] Loss: 5.421822905540466 	 Acc: 6.0
Validation loss decreased (-4.000000 --> -6.000000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.54it/s]


[8/250] Loss: 5.314312219619751 	 Acc: 7.2
Validation loss decreased (-6.000000 --> -7.200000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.59it/s]


[9/250] Loss: 5.1628721952438354 	 Acc: 9.2
Validation loss decreased (-7.200000 --> -9.200000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.59it/s]

[10/250] Loss: 5.092556953430176 	 Acc: 7.6
EarlyStopping counter: 1 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.58it/s]


[11/250] Loss: 4.985626459121704 	 Acc: 9.4
Validation loss decreased (-9.200000 --> -9.400000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.53it/s]

[12/250] Loss: 4.868507266044617 	 Acc: 8.6
EarlyStopping counter: 1 out of 10



282it [00:53,  5.32it/s]
4it [00:01,  2.55it/s]


[13/250] Loss: 4.771825432777405 	 Acc: 9.4
Validation loss decreased (-9.400000 --> -9.400000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.57it/s]


[14/250] Loss: 4.683165669441223 	 Acc: 11.6
Validation loss decreased (-9.400000 --> -11.600000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.54it/s]


[15/250] Loss: 4.578542709350586 	 Acc: 13.4
Validation loss decreased (-11.600000 --> -13.400000).  Saving model ...


282it [00:53,  5.30it/s]
4it [00:01,  2.52it/s]

[16/250] Loss: 4.501201868057251 	 Acc: 13.2
EarlyStopping counter: 1 out of 10



282it [00:53,  5.32it/s]
4it [00:01,  2.55it/s]


[17/250] Loss: 4.440851211547852 	 Acc: 14.4
Validation loss decreased (-13.400000 --> -14.400000).  Saving model ...


282it [00:53,  5.32it/s]
4it [00:01,  2.52it/s]

[18/250] Loss: 4.38092827796936 	 Acc: 13.2
EarlyStopping counter: 1 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.53it/s]


[19/250] Loss: 4.237519264221191 	 Acc: 16.2
Validation loss decreased (-14.400000 --> -16.200000).  Saving model ...


282it [00:53,  5.32it/s]
4it [00:01,  2.50it/s]


[20/250] Loss: 4.214712858200073 	 Acc: 16.4
Validation loss decreased (-16.200000 --> -16.400000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.59it/s]

[21/250] Loss: 4.189817190170288 	 Acc: 16.2
EarlyStopping counter: 1 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.57it/s]

[22/250] Loss: 4.146097540855408 	 Acc: 16.2
EarlyStopping counter: 2 out of 10



282it [00:53,  5.32it/s]
4it [00:01,  2.57it/s]


[23/250] Loss: 4.154565811157227 	 Acc: 18.4
Validation loss decreased (-16.400000 --> -18.400000).  Saving model ...


282it [00:53,  5.32it/s]
4it [00:01,  2.57it/s]

[24/250] Loss: 3.9955588579177856 	 Acc: 18.0
EarlyStopping counter: 1 out of 10



282it [00:52,  5.32it/s]
4it [00:01,  2.59it/s]

[25/250] Loss: 3.997759521007538 	 Acc: 18.2
EarlyStopping counter: 2 out of 10



282it [00:53,  5.32it/s]
4it [00:01,  2.50it/s]


[26/250] Loss: 3.9107438921928406 	 Acc: 19.0
Validation loss decreased (-18.400000 --> -19.000000).  Saving model ...


282it [00:53,  5.32it/s]
4it [00:01,  2.50it/s]


[27/250] Loss: 3.89912086725235 	 Acc: 19.2
Validation loss decreased (-19.000000 --> -19.200000).  Saving model ...


282it [00:53,  5.32it/s]
4it [00:01,  2.49it/s]


[28/250] Loss: 3.9460229873657227 	 Acc: 20.2
Validation loss decreased (-19.200000 --> -20.200000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.60it/s]


[29/250] Loss: 3.8530113101005554 	 Acc: 21.8
Validation loss decreased (-20.200000 --> -21.800000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.54it/s]


[30/250] Loss: 3.8598002195358276 	 Acc: 21.8
Validation loss decreased (-21.800000 --> -21.800000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.51it/s]


[31/250] Loss: 3.812596082687378 	 Acc: 22.2
Validation loss decreased (-21.800000 --> -22.200000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.54it/s]

[32/250] Loss: 3.751449763774872 	 Acc: 21.4
EarlyStopping counter: 1 out of 10



282it [00:52,  5.32it/s]
4it [00:01,  2.56it/s]

[33/250] Loss: 3.8390117287635803 	 Acc: 21.8
EarlyStopping counter: 2 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.53it/s]


[34/250] Loss: 3.734971582889557 	 Acc: 24.4
Validation loss decreased (-22.200000 --> -24.400000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.60it/s]


[35/250] Loss: 3.7073583006858826 	 Acc: 25.0
Validation loss decreased (-24.400000 --> -25.000000).  Saving model ...


282it [00:53,  5.32it/s]
4it [00:01,  2.53it/s]


[36/250] Loss: 3.6851829290390015 	 Acc: 26.0
Validation loss decreased (-25.000000 --> -26.000000).  Saving model ...


282it [00:53,  5.32it/s]
4it [00:01,  2.54it/s]

[37/250] Loss: 3.63015353679657 	 Acc: 23.6
EarlyStopping counter: 1 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.53it/s]

[38/250] Loss: 3.630078971385956 	 Acc: 24.4
EarlyStopping counter: 2 out of 10



282it [00:53,  5.32it/s]
4it [00:01,  2.51it/s]

[39/250] Loss: 3.607324182987213 	 Acc: 22.8
EarlyStopping counter: 3 out of 10



282it [00:53,  5.32it/s]
4it [00:01,  2.60it/s]

[40/250] Loss: 3.6329208612442017 	 Acc: 24.8
EarlyStopping counter: 4 out of 10



282it [00:53,  5.32it/s]
4it [00:01,  2.56it/s]

[41/250] Loss: 3.7104918360710144 	 Acc: 23.6
EarlyStopping counter: 5 out of 10



282it [00:53,  5.32it/s]
4it [00:01,  2.51it/s]

[42/250] Loss: 3.610133469104767 	 Acc: 24.6
EarlyStopping counter: 6 out of 10



282it [00:53,  5.32it/s]
4it [00:01,  2.59it/s]

[43/250] Loss: 3.676137864589691 	 Acc: 24.4
EarlyStopping counter: 7 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.56it/s]


[44/250] Loss: 3.7011982202529907 	 Acc: 26.0
Validation loss decreased (-26.000000 --> -26.000000).  Saving model ...


282it [00:53,  5.30it/s]
4it [00:01,  2.55it/s]


[45/250] Loss: 3.5649701952934265 	 Acc: 28.4
Validation loss decreased (-26.000000 --> -28.400000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.54it/s]


[46/250] Loss: 3.5545979142189026 	 Acc: 29.8
Validation loss decreased (-28.400000 --> -29.800000).  Saving model ...


282it [00:53,  5.30it/s]
4it [00:01,  2.53it/s]


[47/250] Loss: 3.5117136240005493 	 Acc: 30.0
Validation loss decreased (-29.800000 --> -30.000000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.50it/s]


[48/250] Loss: 3.4580332040786743 	 Acc: 30.6
Validation loss decreased (-30.000000 --> -30.600000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.51it/s]

[49/250] Loss: 3.5396506786346436 	 Acc: 27.4
EarlyStopping counter: 1 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.57it/s]

[50/250] Loss: 3.624606192111969 	 Acc: 27.8
EarlyStopping counter: 2 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.55it/s]

[51/250] Loss: 3.517681062221527 	 Acc: 27.8
EarlyStopping counter: 3 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.55it/s]


[52/250] Loss: 3.4448540806770325 	 Acc: 30.8
Validation loss decreased (-30.600000 --> -30.800000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.56it/s]

[53/250] Loss: 3.5209118127822876 	 Acc: 28.4
EarlyStopping counter: 1 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.50it/s]


[54/250] Loss: 3.5797603130340576 	 Acc: 30.8
Validation loss decreased (-30.800000 --> -30.800000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.56it/s]


[55/250] Loss: 3.5360729098320007 	 Acc: 32.6
Validation loss decreased (-30.800000 --> -32.600000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.58it/s]

[56/250] Loss: 3.510233998298645 	 Acc: 29.6
EarlyStopping counter: 1 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.57it/s]

[57/250] Loss: 3.51551616191864 	 Acc: 29.8
EarlyStopping counter: 2 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.58it/s]

[58/250] Loss: 3.5220212936401367 	 Acc: 32.4
EarlyStopping counter: 3 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.54it/s]

[59/250] Loss: 3.6147786378860474 	 Acc: 30.4
EarlyStopping counter: 4 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.55it/s]

[60/250] Loss: 3.5430420637130737 	 Acc: 28.6
EarlyStopping counter: 5 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.52it/s]

[61/250] Loss: 3.644425928592682 	 Acc: 30.6
EarlyStopping counter: 6 out of 10



282it [00:53,  5.32it/s]
4it [00:01,  2.58it/s]

[62/250] Loss: 3.4798313975334167 	 Acc: 32.0
EarlyStopping counter: 7 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.51it/s]


[63/250] Loss: 3.455235242843628 	 Acc: 33.6
Validation loss decreased (-32.600000 --> -33.600000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.52it/s]

[64/250] Loss: 3.5509286522865295 	 Acc: 31.4
EarlyStopping counter: 1 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.55it/s]

[65/250] Loss: 3.4899246096611023 	 Acc: 32.2
EarlyStopping counter: 2 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.56it/s]

[66/250] Loss: 3.522518217563629 	 Acc: 32.8
EarlyStopping counter: 3 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.56it/s]


[67/250] Loss: 3.4503977298736572 	 Acc: 33.8
Validation loss decreased (-33.600000 --> -33.800000).  Saving model ...


282it [00:53,  5.30it/s]
4it [00:01,  2.51it/s]

[68/250] Loss: 3.4587578177452087 	 Acc: 32.2
EarlyStopping counter: 1 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.53it/s]

[69/250] Loss: 3.5794461965560913 	 Acc: 30.4
EarlyStopping counter: 2 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.53it/s]

[70/250] Loss: 3.5695130825042725 	 Acc: 32.4
EarlyStopping counter: 3 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.60it/s]

[71/250] Loss: 3.4550111889839172 	 Acc: 33.6
EarlyStopping counter: 4 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.58it/s]

[72/250] Loss: 3.4877008199691772 	 Acc: 32.2
EarlyStopping counter: 5 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.59it/s]


[73/250] Loss: 3.5596903562545776 	 Acc: 35.2
Validation loss decreased (-33.800000 --> -35.200000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.57it/s]

[74/250] Loss: 3.715454876422882 	 Acc: 29.0
EarlyStopping counter: 1 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.61it/s]


[75/250] Loss: 3.5673781633377075 	 Acc: 35.4
Validation loss decreased (-35.200000 --> -35.400000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.55it/s]

[76/250] Loss: 3.5713347792625427 	 Acc: 33.2
EarlyStopping counter: 1 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.50it/s]

[77/250] Loss: 3.5321478247642517 	 Acc: 34.0
EarlyStopping counter: 2 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.54it/s]

[78/250] Loss: 3.4876806139945984 	 Acc: 35.2
EarlyStopping counter: 3 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.58it/s]

[79/250] Loss: 3.7247402667999268 	 Acc: 32.2
EarlyStopping counter: 4 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.58it/s]


[80/250] Loss: 3.4828875064849854 	 Acc: 36.4
Validation loss decreased (-35.400000 --> -36.400000).  Saving model ...


282it [00:53,  5.31it/s]
4it [00:01,  2.53it/s]

[81/250] Loss: 3.4792022109031677 	 Acc: 35.4
EarlyStopping counter: 1 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.57it/s]

[82/250] Loss: 3.508804440498352 	 Acc: 35.0
EarlyStopping counter: 2 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.56it/s]

[83/250] Loss: 3.519869029521942 	 Acc: 33.8
EarlyStopping counter: 3 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.51it/s]

[84/250] Loss: 3.5421931743621826 	 Acc: 31.8
EarlyStopping counter: 4 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.53it/s]

[85/250] Loss: 3.5046932101249695 	 Acc: 36.0
EarlyStopping counter: 5 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.51it/s]

[86/250] Loss: 3.559581220149994 	 Acc: 33.2
EarlyStopping counter: 6 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.52it/s]

[87/250] Loss: 3.570066809654236 	 Acc: 33.8
EarlyStopping counter: 7 out of 10



282it [00:53,  5.30it/s]
4it [00:01,  2.52it/s]

[88/250] Loss: 3.520843803882599 	 Acc: 34.2
EarlyStopping counter: 8 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.49it/s]

[89/250] Loss: 3.5724109411239624 	 Acc: 35.2
EarlyStopping counter: 9 out of 10



282it [00:53,  5.31it/s]
4it [00:01,  2.58it/s]

[90/250] Loss: 3.5846240520477295 	 Acc: 34.8
EarlyStopping counter: 10 out of 10
Early stopping
Training finished !!





In [25]:
class medical_dataset_test(Dataset):
    def __init__(self, config=None, answer_map=None, image_path='./dataset/Task1-VQA-2021-TestSet-w-GroundTruth/VQA-500-Images/', qa_file="./dataset/Task1-VQA-2021-TestSet-w-GroundTruth/Task1-VQA-2021-TestSet-Questions.txt", train=True):
            
        assert answer_map!=None
        assert config!=None
        
        self.image_path = image_path
        self.qa_file = qa_file
        self.config = config
        self.train = train
        
        self.answer_map = answer_map
        self.data = pd.read_csv(qa_file, sep='|', names=['imageid', 'question'])
           
        # print(Counter(self.data.answer.tolist()))
        
        self.transforms = transforms.Compose([
                                transforms.RandomResizedCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                            ])
        self.tokenizer = tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")

    def process_data(self, text, max_len):
        text = str(text)

        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=max_len,
            truncation=True
        )
        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]
        
        padding_length = max_len - len(ids)
        
        ids = ids + ([0] * padding_length)
        mask = mask + ([0] * padding_length)
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
        }
    
    
    def __getitem__(self, index):
        
        question = self.data.question[index]
        image_idx = self.data.imageid[index]
        
        ## add visual feature related steps
        visuals = Image.open(join(self.image_path, f'{image_idx}.jpg'))
        visuals = self.transforms(visuals)
        
        
        tmp = self.process_data(question, self.config['max_len'])
        question_tokens = {
            'ids': tmp['ids'],
            'mask': tmp['mask'],
        }
        
        return visuals, question_tokens, image_idx

    def __len__(self):
        return len(self.data)

In [26]:
TestData = medical_dataset_test(config=cfg, answer_map=answer_map)
TestDataLoader = DataLoader(TestData, batch_size=128, shuffle=False, num_workers=4)  # num_workers=0 for windows OS

In [27]:
inv_map = {v: k for k, v in answer_map.items()}

In [28]:
trainer.model.load_state_dict(torch.load('./models/resnet18_fusion-elementsum/checkpoint.pt'))
trainer.model.eval()
        
imageids = []
answers = []

running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
    for en, (visuals, question, target) in tqdm(enumerate(TestDataLoader)):
        visuals = visuals.to(trainer.device)

        ids = question['ids'].to(trainer.device)
        mask = question['mask'].to(trainer.device)

        outputs = trainer.model(visuals, ids, mask)

        y_pred_softmax = torch.log_softmax(outputs, dim = 1)
        _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)
        # self.tmp_sv_ = y_pred_tags
        
        for i in range(visuals.shape[0]):
            imageids.append(target[i])
            answers.append(inv_map[int(y_pred_tags[i])])


4it [00:01,  2.36it/s]


In [29]:
answers[:5]

['spine, intradural lipoma',
 'pulmonary embolism',
 'calcific tendinitis',
 'transitional cell carcinoma, bladder',
 'simple bone cyst']

In [30]:
pd.DataFrame({'imageids':imageids, 'answers':answers}).to_csv('resnet18_fusion-elementsum.txt', sep='|', index=False, header=False)