In [1]:
%load_ext autotime

import os
import sys
import math
import copy
import random
import warnings
import pandas as pd
import numpy as np

random.seed(1)
warnings.filterwarnings('ignore')
from tqdm import tqdm, trange

In [2]:
os.environ['CUDA_VISIBLE_DEVICES'] = '4, 5, 6, 7'

time: 320 µs


In [3]:
import logging

log_format = '%(levelname)s %(asctime)s - %(message)s'
logging.basicConfig(filename = '../logs/original_fsl.logs',
                    level = logging.INFO,
                    format = log_format,
                    filemode = 'w')
logger = logging.getLogger()

time: 701 µs


In [4]:
train_data = pd.read_csv('../data/train_data.csv')
print (train_data.shape)
print (train_data.columns)

(368945, 19)
Index(['subject_id', 'image_path', 'image_name', 'study_id', 'split',
       'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema',
       'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion', 'Lung Opacity',
       'No Finding', 'Pleural Effusion', 'Pleural Other', 'Pneumonia',
       'Pneumothorax', 'Support Devices'],
      dtype='object')
time: 928 ms


In [5]:
train_data.head(3)

Unnamed: 0,subject_id,image_path,image_name,study_id,split,Atelectasis,Cardiomegaly,Consolidation,Edema,Enlarged Cardiomediastinum,Fracture,Lung Lesion,Lung Opacity,No Finding,Pleural Effusion,Pleural Other,Pneumonia,Pneumothorax,Support Devices
0,s52769454,../../DataCenter/MIMIC-CXR/files/p18/p18190098...,8bf006fa-7169cd83-c30e7055-b109468a-2223477d,52769454,train,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0
1,s50754262,../../DataCenter/MIMIC-CXR/files/p18/p18190098...,48fbe534-50750d68-5afd36c2-e7aaf316-a31fcb66,50754262,train,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0
2,s51845898,../../DataCenter/MIMIC-CXR/files/p18/p18190098...,fc2e907b-c19b5434-cc3d4205-8d66ab01-a0dcc274,51845898,train,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


time: 22.2 ms


In [6]:
import cv2
import csv
import time
import random
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optin
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as tfunc
import torch.multiprocessing as mp

from torch.autograd import Variable
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn.parallel import DistributedDataParallel as DDP
from PIL import Image

from sklearn.metrics import roc_auc_score, confusion_matrix
from tqdm import tqdm, trange
import sklearn.metrics as metrics

use_gpu = torch.cuda.is_available()

time: 1.25 s


In [7]:
sample_train_data = train_data[0:10000]
sample_train_data.to_csv('../data/sample_train_data.csv', index = False)

time: 231 ms


In [8]:
pathFileSampleTrain = '../data/sample_train_data.csv'
pathFileTrain = '../data/train_data.csv'
pathFileValid = '../data/val_data.csv'

nnIsTrained = False
nnClassCount = 14

trBatchSize = 64
trMaxEpoch = 3

imgtransResize = (320, 320)
imgtransCrop = 224

class_names = ['No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 
               'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax', 
               'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices']

time: 1.53 ms


In [9]:
class CheXpertDataSet(Dataset):
    def __init__(self, image_list_file, transform=None, policy="ones"):
        """
        image_list_file: path to the file containing images with corresponding labels.
        transform: optional transform to be applied on a sample.
        policy: name the policy with regard to the uncertain labels
        """
        image_names = []
        labels = []

        with open(image_list_file, "r") as f:
            csvReader = csv.reader(f)
            next(csvReader, None)
            k=0
            for line in csvReader:
                k+=1
                image_name= line[1]
                label = line[5:]
                
                for i in range(14):
                    label[i] = float(label[i])
                
                image_names.append(image_name)
                labels.append(label)

        self.image_names = image_names
        self.labels = labels
        self.transform = transform

    def __getitem__(self, index):
        """Take the index of item and returns the image and its labels"""
        
        image_name = self.image_names[index]
        image = Image.open(image_name).convert('RGB')
        label = self.labels[index]
        if self.transform is not None:
            image = self.transform(image)
        return (image, torch.FloatTensor(label))

    def __len__(self):
        return (len(self.image_names))

time: 5.97 ms


In [10]:
#TRANSFORM DATA

normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
transformList = []
transformList.append(transforms.RandomResizedCrop(imgtransCrop))
transformList.append(transforms.RandomHorizontalFlip())
transformList.append(transforms.ToTensor())
transformList.append(normalize)      
transformSequence=transforms.Compose(transformList)

time: 2.3 ms


In [11]:
#LOAD DATASET

datasetSampleTrain = CheXpertDataSet(pathFileSampleTrain, transformSequence, policy = 'ones')
datasetTrain = CheXpertDataSet(pathFileTrain ,transformSequence, policy = 'ones')
datasetValid = CheXpertDataSet(pathFileValid, transformSequence)

print ('Size of Simple Train Dataset - {}'.format(len(datasetSampleTrain)))
print ('Size of Train Dataset - {}'.format(len(datasetTrain)))

dataLoaderSampleTrain = DataLoader(dataset=datasetSampleTrain, batch_size=trBatchSize, shuffle=True, num_workers=0, pin_memory = True)
dataLoaderTrain = DataLoader(dataset=datasetTrain, batch_size=trBatchSize, shuffle=True,  num_workers=0, pin_memory = True)
dataLoaderVal = DataLoader(dataset=datasetValid, batch_size=trBatchSize, shuffle=False, num_workers=0, pin_memory = True)

Size of Simple Train Dataset - 10000
Size of Train Dataset - 368945
time: 2.92 s


In [12]:
class CheXpertTrainer():

    def train (model, dataLoaderTrain, dataLoaderVal, nnClassCount, trMaxEpoch, launchTimestamp, checkpoint):
        
        #SETTINGS: OPTIMIZER & SCHEDULER
        optimizer = optim.Adam (model.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
                
        #SETTINGS: LOSS
        loss = torch.nn.BCELoss(size_average = True)
        
        #LOAD CHECKPOINT 
        if checkpoint != None and use_gpu:
            modelCheckpoint = torch.load(checkpoint)
            model.load_state_dict(modelCheckpoint['state_dict'])
            optimizer.load_state_dict(modelCheckpoint['optimizer'])

        
        #TRAIN THE NETWORK
        lossMIN = 100000
        
        for epochID in range(0, trMaxEpoch):
            
            timestampTime = time.strftime("%H%M%S")
            timestampDate = time.strftime("%d%m%Y")
            timestampSTART = timestampDate + '-' + timestampTime
            
            batchs, losst, losse = CheXpertTrainer.epochTrain(model, dataLoaderTrain, optimizer, trMaxEpoch, nnClassCount, loss)
            lossVal = CheXpertTrainer.epochVal(model, dataLoaderVal, optimizer, trMaxEpoch, nnClassCount, loss)


            timestampTime = time.strftime("%H%M%S")
            timestampDate = time.strftime("%d%m%Y")
            timestampEND = timestampDate + '-' + timestampTime
            
            if lossVal < lossMIN:
                lossMIN = lossVal    
                torch.save({'epoch': epochID + 1, 'state_dict': model.state_dict(), 'best_loss': lossMIN, 'optimizer' : optimizer.state_dict()}, 'm-epoch'+str(epochID)+'-' + launchTimestamp + '.pth.tar')
                print ('Epoch [' + str(epochID + 1) + '] [save] [' + timestampEND + '] loss= ' + str(lossVal))
            else:
                print ('Epoch [' + str(epochID + 1) + '] [----] [' + timestampEND + '] loss= ' + str(lossVal))
        
        return (batchs, losst, losse)
    #-------------------------------------------------------------------------------- 
       
    def epochTrain(model, dataLoader, optimizer, epochMax, classCount, loss):
        
        batch = []
        losstrain = []
        losseval = []
        
        model.train()

        for batchID, (varInput, target) in enumerate(dataLoaderTrain):
            
            varTarget = target.cuda(non_blocking = True)
            
            #varTarget = target.cuda()         
            
            print (varInput.shape)
            varOutput = model(varInput)
            lossvalue = loss(varOutput, varTarget)
                       
            optimizer.zero_grad()
            lossvalue.backward()
            optimizer.step()
            
            l = lossvalue.item()
            losstrain.append(l)
            
            if batchID%35==0:
                logger.info('Batches Computed - {}'.format(batchID//35))
                print(batchID//35, "% batches computed")
                #Fill three arrays to see the evolution of the loss


                batch.append(batchID)
                
                le = CheXpertTrainer.epochVal(model, dataLoaderVal, optimizer, trMaxEpoch, nnClassCount, loss).item()
                losseval.append(le)
                
                logger.info('Batch ID - {}'.format(batchID))
                logger.info('Training Loss - {}'.format(l))
                logger.info('Validation Loss - {}'.format(le))
                
                print('Batch ID - {}'.format(batchID))
                print('Training Loss - {}'.format(l))
                print('Validation Loss - {}'.format(le))
                
        return (batch, losstrain, losseval)
    
    #-------------------------------------------------------------------------------- 
    
    def epochVal(model, dataLoader, optimizer, epochMax, classCount, loss):
        
        model.eval()
        
        lossVal = 0
        lossValNorm = 0

        with torch.no_grad():
            for i, (varInput, target) in enumerate(dataLoaderVal):
                
                target = target.cuda(non_blocking = True)
                varOutput = model(varInput)
                
                losstensor = loss(varOutput, target)
                lossVal += losstensor
                lossValNorm += 1
                
        outLoss = lossVal / lossValNorm
        return (outLoss)
    
    
    #--------------------------------------------------------------------------------     
     
    #---- Computes area under ROC curve 
    #---- dataGT - ground truth data
    #---- dataPRED - predicted data
    #---- classCount - number of classes
    
    def computeAUROC (dataGT, dataPRED, classCount):
        
        outAUROC = []
        
        datanpGT = dataGT.cpu().numpy()
        datanpPRED = dataPRED.cpu().numpy()
        
        for i in range(classCount):
            try:
                outAUROC.append(roc_auc_score(datanpGT[:, i], datanpPRED[:, i]))
            except ValueError:
                pass
        return (outAUROC)
        
        
    #-------------------------------------------------------------------------------- 
    
    
    def test(model, dataLoaderTest, nnClassCount, checkpoint, class_names):   
        
        cudnn.benchmark = True
        
        if checkpoint != None and use_gpu:
            modelCheckpoint = torch.load(checkpoint)
            model.load_state_dict(modelCheckpoint['state_dict'])

        if use_gpu:
            outGT = torch.FloatTensor().cuda()
            outPRED = torch.FloatTensor().cuda()
        else:
            outGT = torch.FloatTensor()
            outPRED = torch.FloatTensor()
       
        model.eval()
        
        with torch.no_grad():
            for i, (input, target) in enumerate(dataLoaderTest):

                target = target.cuda()
                outGT = torch.cat((outGT, target), 0).cuda()
                
                bs, c, h, w = input.size()
                varInput = input.view(-1, c, h, w)
            
                out = model(varInput)
                outPRED = torch.cat((outPRED, out), 0)
        aurocIndividual = CheXpertTrainer.computeAUROC(outGT, outPRED, nnClassCount)
        aurocMean = np.array(aurocIndividual).mean()
        
        logger.info('AUROC mean ', aurocMean)
        print ('AUROC mean ', aurocMean)
        
        for i in range (0, len(aurocIndividual)):
            print (class_names[i], ' ', aurocIndividual[i])
        
        return (outGT, outPRED)

time: 6.97 ms


In [13]:
class DenseNet121(nn.Module):
    """Model modified.
    The architecture of our model is the same as standard DenseNet121
    except the classifier layer which has an additional sigmoid function.
    """
    def __init__(self, out_size):
        super(DenseNet121, self).__init__()
        self.densenet121 = torchvision.models.densenet121(pretrained=True)
        num_ftrs = self.densenet121.classifier.in_features
        self.densenet121.classifier = nn.Sequential(
            nn.Linear(num_ftrs, out_size),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.densenet121(x)
        return (x)

time: 1.51 ms


In [14]:
# initialize and load the model

model = DenseNet121(nnClassCount).cuda()
model = torch.nn.DataParallel(model).cuda()
logger.info('Initial model is loaded on to GPU.')

time: 2.47 s


In [15]:
pretrained_model = torch.load('../model/model_ones_3epoch_densenet.tar')
model_state_dict = pretrained_model['state_dict']
torch.save(model_state_dict, '../model/model_state_dict.pth')

time: 161 ms


In [16]:
model.load_state_dict(torch.load('../model/model_state_dict.pth'))
model.eval()

DataParallel(
  (module): DenseNet121(
    (densenet121): DenseNet(
      (features): Sequential(
        (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu0): ReLU(inplace=True)
        (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        (denseblock1): _DenseBlock(
          (denselayer1): _DenseLayer(
            (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu1): ReLU(inplace=True)
            (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu2): ReLU(inplace=True)
            (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          )
          (denselayer2): _Dens

time: 96.4 ms


In [17]:
outGT, outPRED = CheXpertTrainer.test(model = model, dataLoaderTest = dataLoaderSampleTrain, nnClassCount = nnClassCount, class_names = class_names, checkpoint = None)

AUROC mean  0.5288434268265898
No Finding   0.2638449532763988
Enlarged Cardiomediastinum   0.4880634397311175
Cardiomegaly   0.5929519331243469
Lung Opacity   0.7889307467665185
Lung Lesion   0.41107512509907096
Edema   0.4936832092934497
Consolidation   0.5983564084521852
Pneumonia   0.6994563493421155
Atelectasis   0.22334262878895578
Pneumothorax   0.4560934485868078
Pleural Effusion   0.6071574696458462
Pleural Other   0.4993409835527319
Fracture   0.42284137185648385
Support Devices   0.8586699080562272
time: 10min 25s


--- Logging error ---
Traceback (most recent call last):
  File "/opt/tljh/user/lib/python3.7/logging/__init__.py", line 1034, in emit
    msg = self.format(record)
  File "/opt/tljh/user/lib/python3.7/logging/__init__.py", line 880, in format
    return fmt.format(record)
  File "/opt/tljh/user/lib/python3.7/logging/__init__.py", line 619, in format
    record.message = record.getMessage()
  File "/opt/tljh/user/lib/python3.7/logging/__init__.py", line 380, in getMessage
    msg = msg % self.args
TypeError: not all arguments converted during string formatting
Call stack:
  File "/opt/tljh/user/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/opt/tljh/user/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/opt/tljh/user/lib/python3.7/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/opt/tljh/user/lib/python3.7/site-packages/traitlets/config/application.py", line 

In [18]:
outPRED = outPRED.cpu().detach().numpy()
outPRED = np.array(outPRED)

sample_train_data['pred_No Finding'] = np.round(outPRED[:,0])
sample_train_data['pred_Enlarged Cardiomediastinum'] = np.round(outPRED[:,1])
sample_train_data['pred_Cardiomegaly'] = np.round(outPRED[:,2])
sample_train_data['pred_Lung Opacity'] = np.round(outPRED[:,3])
sample_train_data['pred_Lung Lesion'] = np.round(outPRED[:,4])
sample_train_data['pred_Edema'] = np.round(outPRED[:,5])
sample_train_data['pred_Consolidation'] = np.round(outPRED[:,6])
sample_train_data['pred_Pneumonia'] = np.round(outPRED[:,7])
sample_train_data['pred_Atelectasis'] = np.round(outPRED[:,8])
sample_train_data['pred_Pneumothorax'] = np.round(outPRED[:,9])
sample_train_data['pred_Pleural Effusion'] = np.round(outPRED[:,10])
sample_train_data['pred_Pleural Other'] = np.round(outPRED[:,11])
sample_train_data['pred_Fracture'] = np.round(outPRED[:,12])
sample_train_data['pred_Support Devices'] = np.round(outPRED[:,13])

time: 18.8 ms


In [19]:
len(sample_train_data)

10000

time: 2.69 ms


In [20]:
'pred_No Finding' in sample_train_data.columns

True

time: 3.22 ms


## Todo List

- Train few shot learning model on every class with 100 images.
    - Incrementally build the few shot learning model on all the 14 conditions.
- Validate the results on the all images (except the 10 images).
- Compare the results with the prior results.

In [21]:
args = {
    'train_size' : 100,
    'test_size' : 100,
    'out_size' : 128, # should be less than 512
    'loss_margin' : 0.2,
    'loss_p' : 2,
    'batch_size' : 4 # do not change this
}

time: 750 µs


In [22]:
class DenseNet121ToDense(nn.Module):
    '''
    The architecture of the densenet121 is copied and the final classifier layer is removed.
    In the place of final classifier layer a dense layer is attached with PReLU activation.
    '''
    def __init__(self, out_size):
        super(DenseNet121ToDense, self).__init__()
        self.densenet121todense = model.module.densenet121
        num_ftrs = self.densenet121todense.classifier[0].in_features
        self.densenet121todense.classifier = nn.Sequential(
            nn.Linear(num_ftrs, 512),
            nn.PReLU(),
            nn.Linear(512, out_size)
        )
        
    def forward(self, x_1, x_2, x_3):
        x_1 = self.densenet121todense(x_1)
        x_2 = self.densenet121todense(x_2)
        x_3 = self.densenet121todense(x_3)
        return (x_1, x_2, x_3)
    
ref_model = DenseNet121ToDense(args['out_size']).cuda()
ref_model = torch.nn.DataParallel(ref_model).cuda()
logger.info('FSL model is loaded on to GPU.')

time: 20.9 ms


In [23]:
from torch.utils.data import DataLoader, RandomSampler, TensorDataset
import torch.optim as optim
from sklearn.metrics import roc_auc_score

def make_train_triplets(pathology, sample_train_data, n):
    '''
    Returns 'n' triplet images for training data for a given pathology.
    First Image - Any random image (with atleast 40 images which has the pathology)
    Second Image - An image which is positive for the pathology.
    Third Image - An image which is negative for the pathology.
    '''
    
    actual_path = pathology
    pred_path = 'pred_'+pathology
    false_negative = sample_train_data[(sample_train_data[actual_path] == 1.0) & (sample_train_data[pred_path] == 0.0)]
    false_positive = sample_train_data[(sample_train_data[actual_path] == 0.0) & (sample_train_data[pred_path] == 1.0)]
    print ('False Negatives - {} False Positives - {}'.format(len(false_negative), len(false_positive)))

    positive = sample_train_data[sample_train_data[actual_path] == 1.0]
    negative = sample_train_data[sample_train_data[actual_path] == 0.0]
    
    images = []
    checking_label = []
    
    for i in range(n):
        if (len(false_negative.index) > 0 and len(false_positive.index) > 0):
            choose = random.choice(['a','b'])    
            if (choose == 'a'):
                first_image_index = random.choice(false_negative.index)
                checking_label.append(-1)
            elif (choose == 'b'):
                first_image_index = random.choice(false_positive.index)
                checking_label.append(1)
        else:
            if (len(false_negative.index) > 0):
                first_image_index = random.choice(false_negative.index)
                checking_label.append(-1)
            elif (len(false_positive.index) > 0):
                first_image_index = random.choice(false_positive.index)
                checking_label.append(1)
            
        second_image_index = random.choice(positive.index)
        third_image_index = random.choice(negative.index)

        # Convert the indices to images
        first_image = sample_train_data['image_path'][first_image_index]
        first_image = Image.open(first_image).convert('RGB')
        first_image = transformSequence(first_image)
        
        second_image = sample_train_data['image_path'][second_image_index]
        second_image = Image.open(second_image).convert('RGB')
        second_image = transformSequence(second_image)
        
        third_image = sample_train_data['image_path'][third_image_index]
        third_image = Image.open(third_image).convert('RGB')
        third_image = transformSequence(third_image)
        
        images += [[first_image, second_image, third_image]]
        
    return (images, checking_label)

time: 5.09 ms


In [24]:
def make_test_triplets(pathology, sample_train_data):
    '''
    - Returns three images (so the name triplets)
    - First Image : An Inference failed image. (Inference failures can appear in the form of false negatives and 
    false postives)
    - Second Image: An image which is positive for the pathology.
    - Third Image: An image which is negative for the pathology.
    '''
    
    actual_path = pathology
    pred_path = 'pred_'+pathology
    false_negative = sample_train_data[(sample_train_data[actual_path] == 1.0) & (sample_train_data[pred_path] == 0.0)]
    false_positive = sample_train_data[(sample_train_data[actual_path] == 0.0) & (sample_train_data[pred_path] == 1.0)]
    print ('False Negatives - {} False Positives - {}'.format(len(false_negative), len(false_positive)))

    positive = sample_train_data[sample_train_data[actual_path] == 1.0]
    negative = sample_train_data[sample_train_data[actual_path] == 0.0]
    
    images = []
    checking_label = []
    
    false_inference_index = false_negative.index.union(false_positive.index)
    for index in false_inference_index:
        if index in false_negative.index:
            checking_label.append(-1)
        elif index in false_positive.index:
            checking_label.append(1)
            
        first_image_index  = index    
        second_image_index = random.choice(positive.index)
        third_image_index = random.choice(negative.index)

        # Convert the indices to images
        first_image = sample_train_data['image_path'][first_image_index]
        first_image = Image.open(first_image).convert('RGB')
        first_image = transformSequence(first_image)

        second_image = sample_train_data['image_path'][second_image_index]
        second_image = Image.open(second_image).convert('RGB')
        second_image = transformSequence(second_image)

        third_image = sample_train_data['image_path'][third_image_index]
        third_image = Image.open(third_image).convert('RGB')
        third_image = transformSequence(third_image)

        images += [[first_image, second_image, third_image]]
        
    return (images, checking_label, false_inference_index)

time: 3.96 ms


In [25]:
def ref_train_image_triplets(images, checking_label):
    first_images = [i[0] for i in images]
    second_images = [i[1] for i in images]
    third_images = [i[2] for i in images]

    first_images = torch.stack(first_images)
    second_images = torch.stack(second_images)
    third_images = torch.stack(third_images)
    
    target = torch.Tensor(checking_label).int()    
    return (first_images, second_images, third_images, target)


def ref_test_image_triplets(test_images, test_checking_label):
    test_first_images = [i[0] for i in test_images]
    test_second_images = [i[1] for i in test_images]
    test_third_images = [i[2] for i in test_images]
    
    test_first_images = torch.stack(test_first_images)
    test_second_images = torch.stack(test_second_images)
    test_third_images = torch.stack(test_third_images)
    
    test_target = torch.Tensor(test_checking_label).int()
    return (test_first_images, test_second_images, test_third_images, test_target)

time: 2.68 ms


In [26]:
def triplet_distance(anchor, positive, negative):
    '''
    The function calculates the distance between anchor, positive and anchor, negative images. The difference 
    between the pos_distance, neg_distance was calculated and tuned according to the checking_label.
    '''
    pos_distance = F.pairwise_distance(anchor, positive, 2)
    neg_distance = F.pairwise_distance(anchor, negative, 2)
    return (pos_distance, neg_distance)

def predictive_value(conf_matrix):
    '''
    This function returns the positive predictive value and negative predictive value of a classifier.
    Parameters:
    
    1. conf_matrix : A Confusion matrix from the result of a classifier.
    '''
    ppv = (conf_matrix[0][0]/(conf_matrix[0][0] + conf_matrix[0][1]))*100
    npv = (conf_matrix[1][0]/(conf_matrix[1][0] + conf_matrix[1][1]))*100
    return (ppv, npv)

def train_and_eval(ref_model, train_dataloader, test_dataloader):
    for epoch in trange(5):
    
        tr_loss = 0
        nb_tr_steps = 0
        ref_model.train()
        for step, batch in enumerate(train_dataloader):
            anchor, positive, negative, target = batch[0], batch[1], batch[2], batch[3]
            anchor, positive, negative, target = Variable(anchor), Variable(positive), Variable(negative), Variable(target)

            bs, c, h, w = anchor.size()
            anchor_input = anchor.view(-1, c, h, w)
            bs, c, h, w = positive.size()
            positive_input = positive.view(-1, c, h, w)
            bs, c, h, w = negative.size()
            negative_input = negative.view(-1, c, h, w)

            E1, E2, E3 = ref_model(anchor_input, positive_input, negative_input)
            dist_E1_E2, dist_E1_E3 = triplet_distance(E1, E2, E3)

            target = target.cuda()
            loss = criterion(dist_E1_E2, dist_E1_E3, target)
            tr_loss += loss
            nb_tr_steps += 1

            torch.autograd.set_detect_anomaly(True)
            optimizer.zero_grad()
            loss.backward(retain_graph = True)
            optimizer.step()

        print("Train loss: {}".format(tr_loss/nb_tr_steps))

        pred_list = []
        with torch.no_grad():
            ref_model.eval()
            for step, batch in enumerate(test_dataloader):
                anchor, positive, negative, target = batch[0], batch[1], batch[2], batch[3]
                anchor, positive, negative, target = Variable(anchor), Variable(positive), Variable(negative), Variable(target)

                bs, c, h, w = anchor.size()
                anchor_input = anchor.view(-1, c, h, w)
                bs, c, h, w = positive.size()
                positive_input = positive.view(-1, c, h, w)
                bs, c, h, w = negative.size()
                negative_input = negative.view(-1, c, h, w)

                E1, E2, E3 = ref_model(anchor_input, positive_input, negative_input)
                dist_E1_E2, dist_E1_E3 = triplet_distance(E1, E2, E3)

                for i in range(len(dist_E1_E2)):
                    if (dist_E1_E2[i] > dist_E1_E3[i]):
                        pred_list.append(1)
                    else:
                        pred_list.append(0)

        mod_test_target = []
        for i in test_target:
            if i == -1:
                mod_test_target.append(1)
            else:
                mod_test_target.append(0)
    
    return (mod_test_target, pred_list)

time: 7.55 ms


In [27]:
for _class in class_names:
    
    actual_class = _class
    pred_class = 'pred_'+_class
    print ('Class Name : {}'.format(actual_class))
    logger.info('Class Name : {}'.format(actual_class))
    images, checking_label = make_train_triplets(_class, sample_train_data, 150)
    test_images, test_checking_label, false_inference_index = make_test_triplets(_class, sample_train_data)
    
    first_images, second_images, third_images, target = ref_train_image_triplets(images, checking_label)
    test_first_images, test_second_images, test_third_images, test_target = ref_test_image_triplets(test_images, test_checking_label)
    
    ref_train_data = TensorDataset(first_images, second_images, third_images, target)
    train_dataloader = DataLoader(ref_train_data, batch_size = args['batch_size'], shuffle = True, num_workers = 8, pin_memory = True)

    test_data = TensorDataset(test_first_images, test_second_images, test_third_images, test_target)
    test_dataloader = DataLoader(test_data, batch_size = args['batch_size'], shuffle = False, num_workers = 8, pin_memory = True)

    criterion = torch.nn.MarginRankingLoss(margin = args['loss_margin'])
    optimizer = optim.Adam (model.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
    
    mod_test_target, pred_list = train_and_eval(ref_model, train_dataloader, test_dataloader)
    
    # Calculating the ROC-AUC Scores before and after the few shot learning algorithm.
    sample_train_data_copy_pred_class = copy.deepcopy(sample_train_data[pred_class])
    i = 0
    for index in false_inference_index:
        sample_train_data_copy_pred_class[index] = pred_list[i]
        i += 1
        
    sample_train_data['FSL_pred_'+_class] = sample_train_data_copy_pred_class
    
    print ('Before FSL - ') 
    logger.info('Before FSL - ')
    ppv, npv = predictive_value(confusion_matrix(sample_train_data[actual_class], sample_train_data[pred_class]))
    print ('Positive Predictive Value : {}'.format(ppv))
    logger.info('Positive Predictive Value : {}'.format(ppv))
    print ('Negative Predictive Value : {}'.format(npv))
    logger.info('Negative Predictive Value : {}'.format(npv))
    
    print ('After FSL - ')
    logger.info('After FSL - ')
    ppv, npv = predictive_value(confusion_matrix(sample_train_data[actual_class], sample_train_data_copy_pred_class))
    print ('Positive Predictive Value : {}'.format(ppv))
    logger.info('Positive Predictive Value : {}'.format(ppv))
    print ('Negative Predictive Value : {}'.format(npv))
    logger.info('Negative Predictive Value : {}'.format(npv))
    print ('--------------------------------------')
    logger.info('--------------------------------------')

Class Name : No Finding
False Negatives - 3797 False Positives - 607
False Negatives - 3797 False Positives - 607


  0%|          | 0/5 [00:00<?, ?it/s]

Train loss: 0.1981429159641266


 20%|██        | 1/5 [07:07<28:30, 427.55s/it]

Train loss: 0.10084587335586548


 40%|████      | 2/5 [14:15<21:22, 427.55s/it]

Train loss: 0.018712429329752922


 60%|██████    | 3/5 [21:15<14:11, 425.50s/it]

Train loss: 0.007266358472406864


 80%|████████  | 4/5 [28:13<07:03, 423.06s/it]

Train loss: 0.0023711128160357475


100%|██████████| 5/5 [35:08<00:00, 420.84s/it]


Before FSL - 
Positive Predictive Value : 89.49463482173763
Negative Predictive Value : 89.9336807200379
After FSL - 
Positive Predictive Value : 93.907926618207
Negative Predictive Value : 50.33159639981052
--------------------------------------
Class Name : Enlarged Cardiomediastinum
False Negatives - 565 False Positives - 0
False Negatives - 565 False Positives - 0


  0%|          | 0/5 [00:00<?, ?it/s]

Train loss: 0.17425817251205444


 20%|██        | 1/5 [02:27<09:48, 147.05s/it]

Train loss: 0.027121998369693756


 40%|████      | 2/5 [04:55<07:22, 147.42s/it]

Train loss: 0.0049117980524897575


 60%|██████    | 3/5 [07:24<04:56, 148.07s/it]

Train loss: 0.001371822552755475


 80%|████████  | 4/5 [09:52<02:28, 148.04s/it]

Train loss: 0.0010520974174141884


100%|██████████| 5/5 [12:22<00:00, 148.49s/it]


Before FSL - 
Positive Predictive Value : 100.0
Negative Predictive Value : 100.0
After FSL - 
Positive Predictive Value : 100.0
Negative Predictive Value : 58.93805309734513
--------------------------------------
Class Name : Cardiomegaly
False Negatives - 1676 False Positives - 911
False Negatives - 1676 False Positives - 911


  0%|          | 0/5 [00:00<?, ?it/s]

Train loss: 0.1989714801311493


 20%|██        | 1/5 [04:50<19:21, 290.41s/it]

Train loss: 0.026790782809257507


 40%|████      | 2/5 [09:40<14:31, 290.43s/it]

Train loss: 0.004064400680363178


 60%|██████    | 3/5 [14:33<09:42, 291.11s/it]

Train loss: 0.0013526572147384286


 80%|████████  | 4/5 [19:22<04:50, 290.34s/it]

Train loss: 4.8074871301651e-06


100%|██████████| 5/5 [24:14<00:00, 290.96s/it]


Before FSL - 
Positive Predictive Value : 88.7752587481518
Negative Predictive Value : 88.95966029723992
After FSL - 
Positive Predictive Value : 94.12272055199605
Negative Predictive Value : 53.07855626326964
--------------------------------------
Class Name : Lung Opacity
False Negatives - 1511 False Positives - 2279
False Negatives - 1511 False Positives - 2279


  0%|          | 0/5 [00:00<?, ?it/s]

Train loss: 0.20865677297115326


 20%|██        | 1/5 [06:11<24:47, 371.98s/it]

Train loss: 0.03487783297896385


 40%|████      | 2/5 [12:26<18:37, 372.64s/it]

Train loss: 0.003476259298622608


 60%|██████    | 3/5 [18:43<12:28, 374.17s/it]

Train loss: 0.0007229932816699147


 80%|████████  | 4/5 [24:55<06:13, 373.47s/it]

Train loss: 0.0012889286736026406


100%|██████████| 5/5 [30:38<00:00, 364.14s/it]


Before FSL - 
Positive Predictive Value : 71.1628495508035
Negative Predictive Value : 72.0553171196948
After FSL - 
Positive Predictive Value : 86.03062128305706
Negative Predictive Value : 39.77110157367668
--------------------------------------
Class Name : Lung Lesion
False Negatives - 299 False Positives - 6
False Negatives - 299 False Positives - 6


  0%|          | 0/5 [00:00<?, ?it/s]

Train loss: 0.20720255374908447


 20%|██        | 1/5 [02:07<08:30, 127.70s/it]

Train loss: 0.04196728393435478


 40%|████      | 2/5 [04:18<06:25, 128.61s/it]

Train loss: 0.003958416171371937


 60%|██████    | 3/5 [06:29<04:18, 129.24s/it]

Train loss: 0.0011955222580581903


 80%|████████  | 4/5 [08:39<02:09, 129.53s/it]

Train loss: 0.0007965067634359002


100%|██████████| 5/5 [10:50<00:00, 129.89s/it]


Before FSL - 
Positive Predictive Value : 99.93815070611276
Negative Predictive Value : 100.0
After FSL - 
Positive Predictive Value : 99.9690753530564
Negative Predictive Value : 50.50167224080268
--------------------------------------
Class Name : Edema
False Negatives - 1158 False Positives - 1720
False Negatives - 1158 False Positives - 1720


  0%|          | 0/5 [00:00<?, ?it/s]

Train loss: 0.1989484280347824


 20%|██        | 1/5 [04:48<19:15, 288.83s/it]

Train loss: 0.024662181735038757


 40%|████      | 2/5 [09:35<14:24, 288.09s/it]

Train loss: 0.002335724188014865


 60%|██████    | 3/5 [14:46<09:49, 294.95s/it]

Train loss: 0.003430748824030161


 80%|████████  | 4/5 [19:47<04:56, 296.76s/it]

Train loss: 0.0012464869068935513


100%|██████████| 5/5 [24:47<00:00, 297.85s/it]


Before FSL - 
Positive Predictive Value : 79.94402985074626
Negative Predictive Value : 81.32022471910112
After FSL - 
Positive Predictive Value : 89.50559701492537
Negative Predictive Value : 50.28089887640449
--------------------------------------
Class Name : Consolidation
False Negatives - 430 False Positives - 1
False Negatives - 430 False Positives - 1


  0%|          | 0/5 [00:00<?, ?it/s]

Train loss: 0.19818906486034393


 20%|██        | 1/5 [02:15<09:01, 135.49s/it]

Train loss: 0.020759744569659233


 40%|████      | 2/5 [04:27<06:43, 134.49s/it]

Train loss: 0.0076851570047438145


 60%|██████    | 3/5 [06:39<04:27, 133.61s/it]

Train loss: 0.005488424561917782


 80%|████████  | 4/5 [08:51<02:13, 133.24s/it]

Train loss: 0.0027426297310739756


100%|██████████| 5/5 [11:07<00:00, 133.90s/it]


Before FSL - 
Positive Predictive Value : 99.98955067920585
Negative Predictive Value : 100.0
After FSL - 
Positive Predictive Value : 99.98955067920585
Negative Predictive Value : 57.906976744186046
--------------------------------------
Class Name : Pneumonia
False Negatives - 1376 False Positives - 52
False Negatives - 1376 False Positives - 52


  0%|          | 0/5 [00:00<?, ?it/s]

Train loss: 0.2002343088388443


 20%|██        | 1/5 [03:20<13:21, 200.33s/it]

Train loss: 0.026331735774874687


 40%|████      | 2/5 [06:39<10:00, 200.04s/it]

Train loss: 0.0032423229422420263


 60%|██████    | 3/5 [09:53<06:36, 198.04s/it]

Train loss: 0.0


 80%|████████  | 4/5 [13:16<03:19, 199.70s/it]

Train loss: 0.0002341357758268714


100%|██████████| 5/5 [16:41<00:00, 201.25s/it]


Before FSL - 
Positive Predictive Value : 99.39633155328535
Negative Predictive Value : 99.27849927849928
After FSL - 
Positive Predictive Value : 99.686556768052
Negative Predictive Value : 52.308802308802306
--------------------------------------
Class Name : Atelectasis
False Negatives - 1956 False Positives - 417
False Negatives - 1956 False Positives - 417


  0%|          | 0/5 [00:00<?, ?it/s]

Train loss: 0.22322429716587067


 20%|██        | 1/5 [04:20<17:20, 260.17s/it]

Train loss: 0.0227182786911726


 40%|████      | 2/5 [08:50<13:10, 263.37s/it]

Train loss: 0.0029674137476831675


 60%|██████    | 3/5 [13:20<08:50, 265.17s/it]

Train loss: 0.004232614301145077


 80%|████████  | 4/5 [17:53<04:27, 267.63s/it]

Train loss: 0.00038035301258787513


100%|██████████| 5/5 [22:24<00:00, 268.69s/it]


Before FSL - 
Positive Predictive Value : 94.7401614530777
Negative Predictive Value : 94.4015444015444
After FSL - 
Positive Predictive Value : 97.46468213925328
Negative Predictive Value : 55.067567567567565
--------------------------------------
Class Name : Pneumothorax
False Negatives - 405 False Positives - 12
False Negatives - 405 False Positives - 12


  0%|          | 0/5 [00:00<?, ?it/s]

Train loss: 0.2052149474620819


 20%|██        | 1/5 [02:14<08:59, 134.76s/it]

Train loss: 0.02107487991452217


 40%|████      | 2/5 [04:32<06:47, 135.80s/it]

Train loss: 0.003960717469453812


 60%|██████    | 3/5 [06:52<04:33, 136.97s/it]

Train loss: 0.002325511770322919


 80%|████████  | 4/5 [09:11<02:17, 137.40s/it]

Train loss: 0.0


100%|██████████| 5/5 [11:24<00:00, 136.13s/it]


Before FSL - 
Positive Predictive Value : 99.87493486190724
Negative Predictive Value : 100.0
After FSL - 
Positive Predictive Value : 99.92704533611257
Negative Predictive Value : 50.123456790123456
--------------------------------------
Class Name : Pleural Effusion
False Negatives - 1755 False Positives - 1541
False Negatives - 1755 False Positives - 1541


  0%|          | 0/5 [00:00<?, ?it/s]

Train loss: 0.1847667694091797


 20%|██        | 1/5 [05:28<21:53, 328.27s/it]

Train loss: 0.025365864858031273


 40%|████      | 2/5 [10:54<16:23, 327.69s/it]

Train loss: 0.005649204831570387


 60%|██████    | 3/5 [16:28<10:58, 329.50s/it]

Train loss: 0.0009084025514312088


 80%|████████  | 4/5 [22:02<05:30, 330.91s/it]

Train loss: 0.002496852306649089


100%|██████████| 5/5 [27:25<00:00, 328.58s/it]


Before FSL - 
Positive Predictive Value : 80.23852269812772
Negative Predictive Value : 79.70027247956402
After FSL - 
Positive Predictive Value : 89.06129776865863
Negative Predictive Value : 50.45413260672116
--------------------------------------
Class Name : Pleural Other
False Negatives - 95 False Positives - 0
False Negatives - 95 False Positives - 0


  0%|          | 0/5 [00:00<?, ?it/s]

Train loss: 0.21752911806106567


 20%|██        | 1/5 [01:56<07:45, 116.46s/it]

Train loss: 0.040139179676771164


 40%|████      | 2/5 [03:53<05:49, 116.48s/it]

Train loss: 0.0030097211711108685


 60%|██████    | 3/5 [05:50<03:53, 116.77s/it]

Train loss: 0.0035380502231419086


 80%|████████  | 4/5 [07:48<01:57, 117.14s/it]

Train loss: 0.0


100%|██████████| 5/5 [09:46<00:00, 117.27s/it]


Before FSL - 
Positive Predictive Value : 100.0
Negative Predictive Value : 100.0
After FSL - 
Positive Predictive Value : 100.0
Negative Predictive Value : 44.21052631578947
--------------------------------------
Class Name : Fracture
False Negatives - 136 False Positives - 0
False Negatives - 136 False Positives - 0


  0%|          | 0/5 [00:00<?, ?it/s]

Train loss: 0.1877029985189438


 20%|██        | 1/5 [01:55<07:42, 115.63s/it]

Train loss: 0.022608136758208275


 40%|████      | 2/5 [03:52<05:47, 115.92s/it]

Train loss: 0.0032189879566431046


 60%|██████    | 3/5 [05:50<03:52, 116.50s/it]

Train loss: 0.00013588133151642978


 80%|████████  | 4/5 [07:50<01:57, 117.53s/it]

Train loss: 7.381957090046853e-08


100%|██████████| 5/5 [09:50<00:00, 118.45s/it]


Before FSL - 
Positive Predictive Value : 100.0
Negative Predictive Value : 100.0
After FSL - 
Positive Predictive Value : 100.0
Negative Predictive Value : 48.529411764705884
--------------------------------------
Class Name : Support Devices
False Negatives - 1676 False Positives - 1919
False Negatives - 1676 False Positives - 1919


  0%|          | 0/5 [00:00<?, ?it/s]

Train loss: 0.21427179872989655


 20%|██        | 1/5 [05:46<23:07, 346.89s/it]

Train loss: 0.026318073272705078


 40%|████      | 2/5 [11:41<17:27, 349.26s/it]

Train loss: 0.002913641044870019


 60%|██████    | 3/5 [17:30<11:38, 349.12s/it]

Train loss: 0.0016794210532680154


 80%|████████  | 4/5 [23:25<05:50, 350.95s/it]

Train loss: 0.00016746150504332036


100%|██████████| 5/5 [29:19<00:00, 351.78s/it]


Before FSL - 
Positive Predictive Value : 75.31515307435039
Negative Predictive Value : 75.29200359389039
After FSL - 
Positive Predictive Value : 86.66066375096474
Negative Predictive Value : 41.01527403414196
--------------------------------------
time: 5h 57min 40s
