In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
import numpy as np
import sys
from os import listdir, makedirs, getcwd, remove
from os.path import isfile, join, abspath, exists, isdir, expanduser
from scipy.io import loadmat
from PIL import Image
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms


import pickle
from tqdm import tqdm
import json
import random
from collections import Counter

from transformers import get_linear_schedule_with_warmup
from transformers import AutoTokenizer, AutoModel

In [4]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.autograd as autograd
from torch.utils.data import Dataset, DataLoader
import torch.utils.data as data
import torch.nn.functional as F
import torch.optim as optim
import itertools

import torchvision
from torchvision import transforms, datasets, models
from torch import Tensor

from torch.utils.tensorboard import SummaryWriter

In [5]:
import transformers
import tokenizers
from transformers import BertTokenizer, BertModel

In [19]:
cfg = {
    'max_len': 128,
    'lr': 2e-5,
    'warmup_steps': 5,
    'epochs': 250
}

In [7]:
class medical_dataset(Dataset):
    def __init__(self, config=None, answer_map=None, image_path='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-TrainingSet/VQAnswering_2020_Train_images/', qa_file="./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-TrainingSet/VQAnswering_2020_Train_QA_pairs.txt", train=True):
            
        assert answer_map!=None
        assert config!=None
        
        self.image_path = image_path
        self.qa_file = qa_file
        self.config = config
        self.train = train
        
        self.answer_map = answer_map
        self.data = pd.read_csv(qa_file, sep='|', names=['imageid', 'question', 'answer'])
           
        # print(Counter(self.data.answer.tolist()))
        
        self.transforms = transforms.Compose([
                                transforms.RandomResizedCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                            ])
        self.tokenizer = tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")

    def process_data(self, text, max_len):
        text = str(text)

        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=max_len,
            truncation=True
        )
        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]
        
        padding_length = max_len - len(ids)
        
        ids = ids + ([0] * padding_length)
        mask = mask + ([0] * padding_length)
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
        }
    
    
    def __getitem__(self, index):
        
        question = self.data.question[index]
        answer = self.data.answer[index]
        image_idx = self.data.imageid[index]
        
        target = torch.from_numpy(np.array([self.answer_map[answer]])).long()
        
        tmp = self.process_data(question, self.config['max_len'])
        question_tokens = {
            'ids': tmp['ids'],
            'mask': tmp['mask'],
        }
        
        return question_tokens, target

    def __len__(self):
        return len(self.data)

In [8]:
# file_path = './dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-TrainingSet/VQAnswering_2020_Train_QA_pairs.txt'
# data = pd.read_csv(file_path, sep='|', names=['imageid', 'question', 'answer'])
# data.head()

In [9]:
# answer_map = {}
# ct = 0

# for i in data.answer.unique():
#     answer_map[i] = ct
#     ct+=1

In [10]:
# with open('./answer_map.pickle', 'wb') as handle:
#     pickle.dump(answer_map, handle, protocol=pickle.HIGHEST_PROTOCOL)
    

In [11]:
with open('./answer_map.pickle', 'rb') as handle:
    answer_map = pickle.load(handle)
    
len(answer_map)

332

In [12]:
TrainData = medical_dataset(config=cfg, answer_map=answer_map, image_path='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-TrainingSet/VQAnswering_2020_Train_images/', qa_file='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-TrainingSet/VQAnswering_2020_Train_QA_pairs.txt')
ValData = medical_dataset(config=cfg, answer_map=answer_map, image_path='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-ValidationSet/VQAnswering_2020_Val_images/', qa_file='./dataset/VQA-Med-2020-Task1-VQAnswering-TrainVal-Sets/VQAMed2020-VQAnswering-ValidationSet/VQAnswering_2020_Val_QA_Pairs.txt', train=False)


In [13]:
%%time
a,b = ValData.__getitem__(0)

CPU times: user 0 ns, sys: 1.35 ms, total: 1.35 ms
Wall time: 1.21 ms


In [14]:
TrainDataLoader = DataLoader(TrainData, batch_size=16, shuffle=True, num_workers=4)  # num_workers=0 for windows OS
ValDataLoader = DataLoader(ValData, batch_size=128, shuffle=False, num_workers=4)  # num_workers=0 for windows OS

In [15]:
class mednet(nn.Module):
    def __init__(self, config, max_labels):
        super(mednet, self).__init__()
        
        self.bert = AutoModel.from_pretrained("dmis-lab/biobert-v1.1")

        self.fc1 = nn.Linear(768, 128)
        self.fc2 = nn.Linear(128, max_labels)
        

    def forward(self, ids=None, mask=None):
        bert_out = self.bert(ids, mask)
        h = bert_out.last_hidden_state[:,0]
        h = F.relu(self.fc1(h))
        h = self.fc2(h)
        return h


In [16]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self,
                 patience=7,
                 verbose=False,
                 delta=0,
                 path='./models/checkpoint.pt',
                 trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(
                f'EarlyStopping counter: {self.counter} out of {self.patience}'
            )
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(
                f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...'
            )
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [17]:
class Trainer:
    def __init__(self,
                 trainloader,
                 vallaoder,
                 model_ft,
                 writer=None,
                 testloader=None,
                 checkpoint_path=None,
                 patience=10,
                 feature_extract=True,
                 print_itr=50,
                 config=None):
        self.trainloader = trainloader
        self.valloader = vallaoder
        self.testloader = testloader
        
        self.config=config

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        print("==" * 10)
        print("Training will be done on ", self.device)
        print("==" * 10)

        self.model = model_ft        
        if torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            self.model = nn.DataParallel(self.model, device_ids=[0,1])
        
        self.model = self.model.to(self.device)
        
        
        # Observe that all parameters are being optimized
        self.optimizer = optim.RAdam(self.model.parameters(), lr=self.config['lr'])
        
        self.scheduler = get_linear_schedule_with_warmup(self.optimizer, 
                                            num_warmup_steps = len(self.trainloader)*self.config['warmup_steps'], # Default value in run_glue.py
                                            num_training_steps = len(self.trainloader)*self.config['epochs'])

        
        self.criterion = nn.CrossEntropyLoss()
        self.early_stopping = EarlyStopping(patience=patience, verbose=True)
        self.writer = writer
        self.print_itr = print_itr

    def train(self, ep):
        self.model.train()

        running_loss = 0.0

        for en, (question, target) in tqdm(enumerate(self.trainloader)):
            self.optimizer.zero_grad()
            
            y = target.squeeze().to(self.device)
            
            ids = question['ids'].to(self.device)
            mask = question['mask'].to(self.device)

            outputs = self.model(ids, mask)
            loss = self.criterion(outputs, y)
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()

            running_loss += loss.item()
            if self.writer:
                self.writer.add_scalar('Train Loss', running_loss, ep*len(self.trainloader) + en)
            running_loss = 0
            

    def validate(self, ep):
        self.model.eval()
        
        running_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for en, (question, target) in tqdm(enumerate(self.valloader)):
                y = target.squeeze().to(self.device)
                
                ids = question['ids'].to(self.device)
                mask = question['mask'].to(self.device)
                
                outputs = self.model(ids, mask)
                loss = self.criterion(outputs, y)

                y_pred_softmax = torch.log_softmax(outputs, dim = 1)
                _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)
                # self.tmp_sv_ = y_pred_tags
                
                correct += (y_pred_tags.detach().cpu().data.numpy() == y.detach().cpu().data.numpy()).sum()
                total += y_pred_tags.shape[0]
                
                # print statistics
                running_loss += loss.item()
        
        
        return running_loss / len(self.valloader), correct*100/total

    def perform_training(self, total_epoch):
        val_loss, acc = self.validate(0)

        print("[Initial Validation results] Loss: {} \t Acc: {}".format(
            val_loss, acc))

        for i in range(total_epoch):
            self.train(i + 1)
            val_loss, acc = self.validate(i + 1)
            print('[{}/{}] Loss: {} \t Acc: {}'.format(i+1, total_epoch, val_loss, acc))

            if self.writer:
                self.writer.add_scalar('Validation Loss', val_loss, (i + 1))
                self.writer.add_scalar('Validation Acc', acc, (i + 1))

            self.early_stopping(val_loss, self.model)

            # if self.early_stopping.early_stop:
            #     print("Early stopping")
            #     break

        print("=" * 20)
        print("Training finished !!")
        print("=" * 20)


In [18]:
model_ft = mednet(config=cfg, max_labels=len(answer_map))
writer = SummaryWriter('runs/biobert_run_1')
trainer = Trainer(TrainDataLoader, ValDataLoader, model_ft, writer=writer, config=cfg)

Training will be done on  cuda
Let's use 2 GPUs!


In [20]:
trainer.perform_training(cfg['epochs'])

4it [00:03,  1.25it/s]

[Initial Validation results] Loss: 5.810192346572876 	 Acc: 0.0



250it [00:40,  6.22it/s]
4it [00:00,  6.69it/s]


[1/35] Loss: 5.803069353103638 	 Acc: 0.0
Validation loss decreased (inf --> 5.803069).  Saving model ...


250it [00:40,  6.19it/s]
4it [00:00,  6.51it/s]


[2/35] Loss: 5.77190375328064 	 Acc: 0.8
Validation loss decreased (5.803069 --> 5.771904).  Saving model ...


250it [00:40,  6.16it/s]
4it [00:00,  6.73it/s]


[3/35] Loss: 5.719835162162781 	 Acc: 1.2
Validation loss decreased (5.771904 --> 5.719835).  Saving model ...


250it [00:40,  6.19it/s]
4it [00:00,  6.75it/s]


[4/35] Loss: 5.690270900726318 	 Acc: 2.6
Validation loss decreased (5.719835 --> 5.690271).  Saving model ...


250it [00:40,  6.17it/s]
4it [00:00,  6.75it/s]


[5/35] Loss: 5.58155632019043 	 Acc: 4.0
Validation loss decreased (5.690271 --> 5.581556).  Saving model ...


250it [00:40,  6.16it/s]
4it [00:00,  6.78it/s]


[6/35] Loss: 5.491603851318359 	 Acc: 5.6
Validation loss decreased (5.581556 --> 5.491604).  Saving model ...


250it [00:40,  6.19it/s]
4it [00:00,  6.35it/s]


[7/35] Loss: 5.432030916213989 	 Acc: 5.2
Validation loss decreased (5.491604 --> 5.432031).  Saving model ...


250it [00:40,  6.18it/s]
4it [00:00,  7.16it/s]


[8/35] Loss: 5.374152421951294 	 Acc: 5.8
Validation loss decreased (5.432031 --> 5.374152).  Saving model ...


250it [00:40,  6.16it/s]
4it [00:00,  6.36it/s]


[9/35] Loss: 5.33577287197113 	 Acc: 5.4
Validation loss decreased (5.374152 --> 5.335773).  Saving model ...


250it [00:40,  6.18it/s]
4it [00:00,  6.44it/s]


[10/35] Loss: 5.315036416053772 	 Acc: 3.6
Validation loss decreased (5.335773 --> 5.315036).  Saving model ...


250it [00:40,  6.19it/s]
4it [00:00,  6.82it/s]


[11/35] Loss: 5.283109426498413 	 Acc: 5.6
Validation loss decreased (5.315036 --> 5.283109).  Saving model ...


250it [00:40,  6.15it/s]
4it [00:00,  6.94it/s]


[12/35] Loss: 5.2554720640182495 	 Acc: 5.6
Validation loss decreased (5.283109 --> 5.255472).  Saving model ...


250it [00:40,  6.17it/s]
4it [00:00,  6.85it/s]


[13/35] Loss: 5.243133306503296 	 Acc: 5.0
Validation loss decreased (5.255472 --> 5.243133).  Saving model ...


250it [00:40,  6.19it/s]
4it [00:00,  6.75it/s]


[14/35] Loss: 5.232126712799072 	 Acc: 5.6
Validation loss decreased (5.243133 --> 5.232127).  Saving model ...


250it [00:40,  6.17it/s]
4it [00:00,  6.83it/s]


[15/35] Loss: 5.221943020820618 	 Acc: 4.2
Validation loss decreased (5.232127 --> 5.221943).  Saving model ...


250it [00:40,  6.15it/s]
4it [00:00,  7.06it/s]


[16/35] Loss: 5.212420582771301 	 Acc: 4.8
Validation loss decreased (5.221943 --> 5.212421).  Saving model ...


250it [00:40,  6.19it/s]
4it [00:00,  6.56it/s]

[17/35] Loss: 5.217351794242859 	 Acc: 4.4
EarlyStopping counter: 1 out of 5



250it [00:40,  6.17it/s]
4it [00:00,  6.66it/s]


[18/35] Loss: 5.198567986488342 	 Acc: 5.2
Validation loss decreased (5.212421 --> 5.198568).  Saving model ...


250it [00:40,  6.15it/s]
4it [00:00,  6.61it/s]

[19/35] Loss: 5.201871275901794 	 Acc: 4.4
EarlyStopping counter: 1 out of 5



250it [00:40,  6.18it/s]
4it [00:00,  6.56it/s]

[20/35] Loss: 5.199782013893127 	 Acc: 4.8
EarlyStopping counter: 2 out of 5



250it [00:40,  6.18it/s]
4it [00:00,  6.58it/s]


[21/35] Loss: 5.194813370704651 	 Acc: 4.8
Validation loss decreased (5.198568 --> 5.194813).  Saving model ...


250it [00:40,  6.16it/s]
4it [00:00,  6.53it/s]

[22/35] Loss: 5.199813485145569 	 Acc: 4.6
EarlyStopping counter: 1 out of 5



250it [00:40,  6.17it/s]
4it [00:00,  6.67it/s]


[23/35] Loss: 5.190837860107422 	 Acc: 4.8
Validation loss decreased (5.194813 --> 5.190838).  Saving model ...


250it [00:40,  6.18it/s]
4it [00:00,  6.99it/s]


[24/35] Loss: 5.186531186103821 	 Acc: 4.8
Validation loss decreased (5.190838 --> 5.186531).  Saving model ...


250it [00:40,  6.16it/s]
4it [00:00,  6.53it/s]

[25/35] Loss: 5.19906222820282 	 Acc: 4.2
EarlyStopping counter: 1 out of 5



250it [00:40,  6.16it/s]
4it [00:00,  6.37it/s]

[26/35] Loss: 5.189875364303589 	 Acc: 4.2
EarlyStopping counter: 2 out of 5



250it [00:40,  6.17it/s]
4it [00:00,  6.43it/s]

[27/35] Loss: 5.18710470199585 	 Acc: 5.0
EarlyStopping counter: 3 out of 5



250it [00:40,  6.17it/s]
4it [00:00,  6.86it/s]

[28/35] Loss: 5.201194405555725 	 Acc: 4.0
EarlyStopping counter: 4 out of 5



250it [00:40,  6.16it/s]
4it [00:00,  6.95it/s]

[29/35] Loss: 5.186747789382935 	 Acc: 5.4
EarlyStopping counter: 5 out of 5



250it [00:40,  6.17it/s]
4it [00:00,  6.54it/s]

[30/35] Loss: 5.193841814994812 	 Acc: 5.0
EarlyStopping counter: 6 out of 5



250it [00:40,  6.17it/s]
4it [00:00,  6.89it/s]

[31/35] Loss: 5.192580699920654 	 Acc: 5.0
EarlyStopping counter: 7 out of 5



250it [00:40,  6.14it/s]
4it [00:00,  6.41it/s]

[32/35] Loss: 5.191875219345093 	 Acc: 4.8
EarlyStopping counter: 8 out of 5



250it [00:40,  6.17it/s]
4it [00:00,  6.39it/s]

[33/35] Loss: 5.1894800662994385 	 Acc: 4.2
EarlyStopping counter: 9 out of 5



250it [00:40,  6.18it/s]
4it [00:00,  6.51it/s]

[34/35] Loss: 5.200999855995178 	 Acc: 4.0
EarlyStopping counter: 10 out of 5



250it [00:40,  6.14it/s]
4it [00:00,  6.90it/s]

[35/35] Loss: 5.19563364982605 	 Acc: 5.0
EarlyStopping counter: 11 out of 5
Training finished !!





In [21]:
class medical_dataset_test(Dataset):
    def __init__(self, config=None, answer_map=None, image_path='./dataset/Task1-VQA-2021-TestSet-w-GroundTruth/VQA-500-Images/', qa_file="./dataset/Task1-VQA-2021-TestSet-w-GroundTruth/Task1-VQA-2021-TestSet-Questions.txt", train=True):
            
        assert answer_map!=None
        assert config!=None
        
        self.image_path = image_path
        self.qa_file = qa_file
        self.config = config
        self.train = train
        
        self.answer_map = answer_map
        self.data = pd.read_csv(qa_file, sep='|', names=['imageid', 'question'])
           
        # print(Counter(self.data.answer.tolist()))
        
        self.transforms = transforms.Compose([
                                transforms.RandomResizedCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                            ])
        self.tokenizer = tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")

    def process_data(self, text, max_len):
        text = str(text)

        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=max_len,
            truncation=True
        )
        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]
        
        padding_length = max_len - len(ids)
        
        ids = ids + ([0] * padding_length)
        mask = mask + ([0] * padding_length)
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
        }
    
    
    def __getitem__(self, index):
        
        question = self.data.question[index]
        image_idx = self.data.imageid[index]
        
        
        tmp = self.process_data(question, self.config['max_len'])
        question_tokens = {
            'ids': tmp['ids'],
            'mask': tmp['mask'],
        }
        
        return question_tokens, image_idx

    def __len__(self):
        return len(self.data)

In [22]:
TestData = medical_dataset_test(config=cfg, answer_map=answer_map)
TestDataLoader = DataLoader(TestData, batch_size=128, shuffle=False, num_workers=4)  # num_workers=0 for windows OS

In [23]:
inv_map = {v: k for k, v in answer_map.items()}

In [25]:
trainer.model.load_state_dict(torch.load('./models/checkpoint.pt'))
trainer.model.eval()
        
imageids = []
answers = []

running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
    for en, (question, target) in tqdm(enumerate(TestDataLoader)):

        ids = question['ids'].to(trainer.device)
        mask = question['mask'].to(trainer.device)

        outputs = trainer.model(ids, mask)

        y_pred_softmax = torch.log_softmax(outputs, dim = 1)
        _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)
        # self.tmp_sv_ = y_pred_tags
        
        for i in range(ids.shape[0]):
            imageids.append(target[i])
            answers.append(inv_map[int(y_pred_tags[i])])


4it [00:00,  6.31it/s]


In [26]:
answers[:5]

['pulmonary embolism',
 'acute appendicitis',
 'acute appendicitis',
 'pulmonary embolism',
 'osteomyelitis']

In [27]:
pd.DataFrame({'imageids':imageids, 'answers':answers}).to_csv('biobert.txt', sep='|', index=False, header=False)