In [1]:
import warnings
warnings.filterwarnings('ignore')
import torch
from torchvision import transforms, models, datasets
import numpy as np
import torch.optim as optim
import torch.nn as nn
import pandas as pd
import os
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, Dataset
from torch.nn import functional as F
from collections import OrderedDict
import timm
import matplotlib.pyplot as plt
from PIL import Image
from collections import OrderedDict
from torch import nn

In [61]:
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score, auc

In [25]:
data_dir = {'train': "/scratch/scratch6/akansh12/DeepEXrays/data/data_256/train/",
           'test': "/scratch/scratch6/akansh12/DeepEXrays/data/data_256/test/"}
labels = ['Aortic enlargement', 'Atelectasis','Calcification', 'Cardiomegaly', 'Consolidation','ILD', 'Infiltration','Lung Opacity','Nodule/Mass','Other lesion', 'Pleural effusion', 'Pleural thickening', 'Pneumothorax',
       'Pulmonary fibrosis', 'No finding']

In [3]:
class fourteen_class(Dataset):
    def __init__(self, label_loc, img_location, transform,  data_type= 'train'):
        label_dataframe = pd.read_csv(label_loc)
        label_dataframe.set_index("image_id", inplace = True)
        filenames = label_dataframe.index.values
        self.full_filename = [os.path.join(img_location,i+'.png') for i in filenames]
        self.labels = label_dataframe.iloc[:].values
        self.transform = transform
    def __len__(self):
        return len(self.full_filename)
    
    def __getitem__(self, idx):
        
        image = Image.open(self.full_filename[idx])
        image = self.transform(image)
        return image, self.labels[idx]        

data_transforms = { 
    "train": transforms.Compose([
        transforms.RandomHorizontalFlip(p = 0.5), 
        transforms.RandomPerspective(distortion_scale=0.3),
        transforms.RandomRotation((-30,30)),
        transforms.ToTensor(),
        transforms.Normalize(mean =  [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
    ]),
    
    "test": transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean =  [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])        
    ])
    
}

In [74]:
train_data = fourteen_class("/storage/home/akansh12/Vin-ChestXR-Abnormality-detection/Notebooks/14_class/labels/exp_5.csv",
                                       img_location = "/scratch/scratch6/akansh12/DeepEXrays/data/data_256/train/", transform =data_transforms['train'])
test_data = fourteen_class("/storage/home/akansh12/Vin-ChestXR-Abnormality-detection/Notebooks/14_class/labels/test.csv",
                                       img_location = "/scratch/scratch6/akansh12/DeepEXrays/data/data_256/test/", transform =data_transforms['test'])


In [77]:
trainloader = DataLoader(train_data,batch_size = 16,shuffle = False, num_workers = 2)
testloader = DataLoader(test_data,batch_size = 16,shuffle = False, num_workers = 2)

In [7]:
def exp_model(path):
    model = timm.models.efficientnet_b0(pretrained=False)
    model.classifier = nn.Sequential(OrderedDict([
        ('fcl1', nn.Linear(1280,15)),
        ('out', nn.Sigmoid()),
    ]))    
    state_dict = torch.load(path, map_location = 'cpu')['state_dict']
    for keyA, keyB in zip(state_dict, model.state_dict()):
        state_dict = OrderedDict((keyB if k == keyA else k, v) for k, v in state_dict.items())

    model.load_state_dict(state_dict)
    
    return model

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [29]:
exp_5_model = exp_model("/scratch/scratch6/akansh12/DeepEXrays/radiologist_selection/exp_5/exp_5_eff_b00.168909_1_.pth")

### Threashold selection

In [53]:
def calc_plot_auc_curve(model):
    model.eval()
    predicted = []
    for images,labels in tqdm(trainloader):
        ps = model(images)
        predicted.extend(ps.tolist())
    plt.figure(figsize=(15,15))
    lw = 2
    labels = ['Aortic enlargement', 'Atelectasis','Calcification', 'Cardiomegaly', 'Consolidation','ILD', 'Infiltration','Lung Opacity','Nodule/Mass','Other lesion', 'Pleural effusion', 'Pleural thickening', 'Pneumothorax',
       'Pulmonary fibrosis', 'No finding']
    AUC = []
    for i,j in enumerate(labels):
        fpr, tpr, thresholds =  roc_curve(np.array(trainloader.dataset.labels)[:,i], np.array(predicted)[:,i])
        plt.plot(fpr,tpr,lw=lw,label=f"{labels[i]} (area = {auc(fpr,tpr)})")
        AUC.append(auc(fpr,tpr))
    plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("Receiver operating characteristic")
    plt.legend(loc="lower right")
    plt.show()
    
    return AUC

In [81]:
exp_5_model.eval()
predicted = []
for images,labels in tqdm(testloader):
    ps = exp_5_model(images)
    predicted.extend(ps.tolist())
    
labels = ['Aortic enlargement', 'Atelectasis','Calcification', 'Cardiomegaly', 'Consolidation','ILD', 'Infiltration','Lung Opacity','Nodule/Mass','Other lesion', 'Pleural effusion', 'Pleural thickening', 'Pneumothorax',
       'Pulmonary fibrosis', 'No finding']



  0%|          | 0/188 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f130c412dd0>
Traceback (most recent call last):
  File "/scratch/scratch6/akansh12/env/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/scratch/scratch6/akansh12/env/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/tools/anaconda3/envs/torch37/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f130c412dd0>
Traceback (most recent call last):
  File "/scratch/scratch6/akansh12/env/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/scratch/scratch6/akansh12/env/lib/python3.7/site-packages/tor

In [83]:
AUC = []
for i,j in enumerate(labels): 
    fpr, tpr, thresholds =  roc_curve(np.array(testloader.dataset.labels)[:,i], np.array(predicted)[:,i])
    gmeans = np.sqrt(tpr * (1-fpr))
    ix = np.argmax(gmeans)
    print('Best Threshold=%f, G-Mean=%.3f' % (thresholds[ix], gmeans[ix]))
    AUC.append(auc(fpr,tpr))


Best Threshold=0.125292, G-Mean=0.814
Best Threshold=0.002689, G-Mean=0.791
Best Threshold=0.020505, G-Mean=0.733
Best Threshold=0.083901, G-Mean=0.854
Best Threshold=0.008590, G-Mean=0.846
Best Threshold=0.030781, G-Mean=0.783
Best Threshold=0.052850, G-Mean=0.834
Best Threshold=0.018769, G-Mean=0.773
Best Threshold=0.032329, G-Mean=0.780
Best Threshold=0.024871, G-Mean=0.775
Best Threshold=0.168697, G-Mean=0.900
Best Threshold=0.076900, G-Mean=0.808
Best Threshold=0.005970, G-Mean=0.829
Best Threshold=0.086214, G-Mean=0.789
Best Threshold=0.774439, G-Mean=0.846


In [94]:
B = [0.125292, 0.002689, 0.020505, 0.083901, 0.00859, 0.030781, 0.05285, 0.018769, 0.032329, 0.024871, 0.168697, 0.0769,
 0.00597, 0.086214, 0.774439]

In [71]:
gmeans = np.sqrt(tpr * (1-fpr))
ix = np.argmax(gmeans)
print('Best Threshold=%f, G-Mean=%.3f' % (thresholds[ix], gmeans[ix]))

Best Threshold=1.693290, G-Mean=nan


In [10]:
test_label_csv = pd.read_csv("/storage/home/akansh12/Vin-ChestXR-Abnormality-detection/Notebooks/14_class/labels/test.csv")

In [11]:
test_label_csv

Unnamed: 0,image_id,Aortic enlargement,Atelectasis,Calcification,Cardiomegaly,Consolidation,ILD,Infiltration,Lung Opacity,Nodule/Mass,Other lesion,Pleural effusion,Pleural thickening,Pneumothorax,Pulmonary fibrosis,No finding
0,e0dc2e79105ad93532484e956ef8a71a,0,1,1,1,0,1,0,0,0,0,1,0,1,0,0
1,0aed23e64ebdea798486056b4f174424,0,0,0,0,1,0,1,0,0,0,1,0,0,0,0
2,aa15cfcfca7605465ca0513902738b95,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0
3,665c4a6d2693dc0286d65ab479c9b169,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0
4,42da2c134b53cb5594774d3d29faac59,1,0,1,1,0,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2995,a039af299f86007d0d77da077a6def9a,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
2996,aba3d1f5b1c04236f52a8980929b2cfa,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
2997,6d3d6b53f358a983b486e9e03144eb62,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
2998,d6678cb7ae39f575d35ab9da6d7cb171,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1


In [30]:
exp_5_model.eval()

EfficientNet(
  (conv_stem): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): SiLU(inplace=True)
  (blocks): Sequential(
    (0): Sequential(
      (0): DepthwiseSeparableConv(
        (conv_dw): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): SiLU(inplace=True)
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
          (act1): SiLU(inplace=True)
          (conv_expand): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
          (gate): Sigmoid()
        )
        (conv_pw): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): Identity()
   

In [43]:
with torch.no_grad():
    exp_5_model.eval()
    img_path = os.path.join('/scratch/scratch6/akansh12/DeepEXrays/data/data_256/test/','e0dc2e79105ad93532484e956ef8a71a' + '.png')
    print(exp_5_model(torch.unsqueeze(data_transforms['test'](Image.open(img_path)), dim = 0)))
    print(exp_5_model(torch.unsqueeze(data_transforms['test'](Image.open(img_path)), dim = 0)) > 0.5)

tensor([[0.3119, 0.0994, 0.0327, 0.5195, 0.2425, 0.1442, 0.2721, 0.7003, 0.0727,
         0.1539, 0.9409, 0.4394, 0.0285, 0.3474, 0.0223]])
tensor([[False, False, False,  True, False, False, False,  True, False, False,
          True, False, False, False, False]])


In [44]:
with torch.no_grad():
    exp_5_model.eval()
    img_path = os.path.join('/scratch/scratch6/akansh12/DeepEXrays/data/data_256/test/','42da2c134b53cb5594774d3d29faac59' + '.png')
    print(exp_5_model(torch.unsqueeze(data_transforms['test'](Image.open(img_path)), dim = 0)))
    print(exp_5_model(torch.unsqueeze(data_transforms['test'](Image.open(img_path)), dim = 0)) > 0.5)

tensor([[0.6930, 0.0013, 0.0089, 0.8981, 0.0046, 0.0043, 0.0060, 0.0471, 0.0353,
         0.0549, 0.0031, 0.0373, 0.0013, 0.0269, 0.0564]])
tensor([[ True, False, False,  True, False, False, False, False, False, False,
         False, False, False, False, False]])
