## Import Libraries

In [None]:
import numpy as np
import pandas as pd
from PIL import Image
import torch
from torch import optim, nn
import torch.nn.functional as F
from torchvision import transforms
import torchvision
from torch.utils.data.dataset import Dataset
import matplotlib.pyplot as plt
import os.path
from os import path
from collections import OrderedDict
import time


from torch.utils.tensorboard import SummaryWriter  # to print to tensorboard


## Custom Dataset Loading Class

In [None]:
label_list = ['Cardiomegaly','Emphysema','Effusion','Hernia','Nodule','Pneumothorax','Atelectasis','Pleural_Thickening','Mass','Edema','Consolidation',
              'Infiltration','Fibrosis','Pneumonia','No Finding']

In [None]:
def resolve_full_path(img_name):
    is_found = False

    # Read 1 image file
    folder_idx_range = 13
    img_path = ''
    for folder_idx in range(folder_idx_range):
        path_prefix = path.expanduser("~/data/kaggle/nih-chest-xrays/data/images_")
        path_suffix = "images/"
        cur_img_dir = path_prefix +str(folder_idx).zfill(3) +'/'
        img_folder_path = path.join(cur_img_dir, path_suffix)
        img_path = os.path.join(img_folder_path, img_name)

        if(path.exists(img_path)):
            is_found = True
            break
    if(not is_found):
        
        raise Exception('Couldn\'t find: {} last:{}'.format(img_name, img_path))
    return img_path
        
class DatasetFromCSV(Dataset):
    def __init__(self, csv_path=None, data_frame=None, transform=None):
        if(csv_path is not None):
            self.data = pd.read_csv(csv_path)
        elif data_frame is not None:
            self.data = data_frame
        else:
            raise Exception('No csv path or data frame provided')

        self.data_len = len(self.data.index)            # csv data length
        
        self.image_names = np.array(self.data.iloc[:,0])  # image names
        self.heights = np.asarray(self.data.iloc[:,8])    # heights are at 8th column 
        self.widths =  np.asarray(self.data.iloc[:,7])    # widths are at  7th column
        
        # createa a tensor to store labels
        self.labels = torch.zeros(self.data_len, 15)
        labels = self.data.loc[:,'Finding Labels'].map(lambda x: x.split('|'))
        self.multi_hot_encoding_label(labels)
    
        self.transform = transform
        
    def __len__(self):
        return self.data_len
    
    def __getitem__(self, index):
        # Read 1 image name
        img_name = self.image_names[index]
        img_path = resolve_full_path(img_name)
        img_as_img = Image.open(img_path)

        img_as_img = img_as_img.convert("RGB")
        # Transform image to tensor
        img_as_tensor = self.transform(img_as_img)

        # Read 1 label:
        image_label = self.labels[index]
        #print('image label: ', img_as_tensor[:10])
        #print('nonzero: ', img_name, np.nonzero(img_as_tensor))

        return img_as_tensor, image_label
    
    def multi_hot_encoding_label(self, labels):
            for i,label in enumerate(labels):
                for idx in range(len(label_list)):
                    if label_list[idx] in label:
                        self.labels[i][idx] = 1

## Define the model with Transfer learning

In [None]:
from torchvision import models

# Use GPU if it's available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ",device)

def create_model():
    
    # Import pre-trained resnet
    model = models.resnet50(pretrained=True)

    # Freeze parameters so we don't backprop through them
#     for param in model.parameters():
#         param.requires_grad = False
    print('----- STAGE 1 -----') # only training 'layer2', 'layer3', 'layer4' and 'fc'
    for name, param in model.named_parameters(): # all requires_grad by default, are True initially
        if ('layer2' in name) or ('layer3' in name) or ('layer4' in name) or ('fc' in name):
            param.requires_grad = True 
        else:
            param.requires_grad = False

 
    # Change output to classfiy 14 conditioins + nothing.
    # Change a new classifier
    model.fc = nn.Sequential(
        nn.Linear(model.fc.in_features, 15),
#         nn.ReLU(),
#         nn.Dropout(0.2),
#         nn.Linear(256, 15)
    )
    
    return model
model = create_model()
# print(model)

In [None]:
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def calc_roc(y_test, y_score):
    #Compute ROC curve and ROC area for each class 
    fprs = [] 
    tprs = []
    roc = dict()
    roc_auc = dict() 

    n_classes = y_score.shape[1]
    for i in range(n_classes):
        fpr, tpr, _ = roc_curve(y_test[:, i], y_score[:, i],pos_label=1) 
        fprs.append(fpr)
        tprs.append(tpr)

        roc_auc[i] = auc(fpr, tpr)        
    return fprs, tprs, roc_auc

## Calculate ROC on the trained model

In [None]:


from torch.autograd import Variable
from sklearn.metrics import roc_curve, auc
batch_size_ = 10

# Define transforms
transform = transforms.Compose([transforms.Resize(256),
                                transforms.RandomResizedCrop(224),
                                transforms.RandomHorizontalFlip(), # randomly flip and rotate
                                transforms.RandomRotation(10),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# Define custom data loader
test_dataset = DatasetFromCSV('~/group/donut/medical_ip/Multi_Label_Dataloader_and_Classifier/testdata_paul.csv',transform=transform)

valid_dataset = DatasetFromCSV('~/group/donut/medical_ip/Multi_Label_Dataloader_and_Classifier/valdata_paul.csv',transform=transform)
test_df = pd.read_csv('~/group/donut/medical_ip/Multi_Label_Dataloader_and_Classifier/testdata_paul.csv')

valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset,
                                                    batch_size=batch_size_,
                                                    num_workers=6,
                                                    shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                                batch_size=batch_size_,
                                                num_workers=6,
                                                shuffle=True)


In [None]:
# Plot some training images
import torchvision.utils as vutils

iterat  = iter(valid_loader)
real_batch = next(iterat)
plt.figure(figsize=(10,10))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0][:5], padding=2, normalize=True).cpu(),(1,2,0)))

In [None]:
real_batch = next(iterat)
plt.figure(figsize=(10,10))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0][:5], padding=2, normalize=True).cpu(),(1,2,0)))

In [None]:
models_dir = os.path.expanduser('~/group/donut/medical_ip/trained_models')
model_name = 'model_resenet18_multi_class_multi_label_k_fold_0pt001_lr.pt'
models_dir = os.path.expanduser('~/group/donut/medical_ip/trained_models')
resnet_model_name= 'model_resenet18_multi_class_multi_label_k_fold.pt'
dc_gan_model_name = 'resnet18_dcgan_kfold.pt'
model_name=  dc_gan_model_name

model_name = 'resnet18_dcgan_kfold.pt'
model_path = os.path.join(models_dir, model_name)
print('Model save/load location: {}'.format(model_path))
model = create_model()
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.to(device)
model.eval()

outputs = np.zeros((len(test_loader)*batch_size_, 15))
labels = np.zeros((len(test_loader)*batch_size_, 15))

model.eval()
i = 0
print(len(valid_loader))
for batch_idx, (data, label) in enumerate(test_loader):
    data = data.to(device)
    # forward pass: compute predicted outputs by passing inputs to the model
    output = model(data).to('cpu')
    output = output.to('cpu').detach().numpy()
    labelnp = label.to('cpu').numpy()
    outputs[i:i+output.shape[0], :] = sigmoid(output)


    
    labels[i:i+output.shape[0], :] = labelnp
    i+=+output.shape[0]
    

print('loaded predictions')
fprs, tprs, roc_auc = calc_roc(labels, outputs)
print(roc_auc)
print('fpr: ', fprs[0])

import matplotlib.pyplot as plt

n_classes = 15

fig, ax = plt.subplots(nrows=n_classes, ncols=2, figsize=(10,40))
i = 0
for row in ax:
    for col in row:
        if(i< n_classes):

            col.plot(fprs[i],tprs[i],label="data 1, auc="+str(roc_auc[i]))
            col.legend(loc=4)
    i+=1

plt.show()
plt.plot(fprs[0],tprs[0],label="data 1, auc="+str(roc_auc[0]))
plt.legend(loc=4)
plt.show()


In [None]:
!pwd

In [None]:
print('loaded predictions')
fprs, tprs, roc_auc = calc_roc(labels, outputs)
print(roc_auc)

import matplotlib.pyplot as plt


n_classes = 15

fig, ax = plt.subplots(nrows=n_classes, ncols=1, figsize=(10,80))
i = 0
for row in ax:
    #for col in row:
    if(i< n_classes):

        ax = row.plot(fprs[i],tprs[i], label="auc="+str(roc_auc[i]))
        row.title.set_text("ROC For class {} of {}".format(str(i), model_name.split('.')[0]))
        row.legend(loc=4)
        extent = row.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
        fig.savefig('../../mlmip_trained_model/{}_roc_class_{}.png'.format(model_name, i), bbox_inches=extent.expanded(1.1, 1.2))
        i+=1


plt.show()



In [None]:

# Doing inference on cpu as it doesn't take much effort feel free to change.
# Had some trouble loading it on GPU. 
iterat  = iter(valid_loader)
image, label = next(iterat)
image, label = next(iterat)


model = create_model()
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.eval()

data = image.to('cpu')
# forward pass: compute predicted outputs by passing inputs to the model
output = model(data)
output = output.to('cpu').detach().numpy()
target = label.to('cpu').detach().numpy()
cur = 7
print(np.shape(target))
print(np.shape(output))

print("Predicted: {}".format(output[cur]))
print("Predicted sigmoid: {}".format(sigmoid(output[cur])))

print("Actual: {}".format(target[cur]))

print("Predicted Max : {}".format(output[cur].max()))
print("Actual Max : {}".format(target[cur].max()))

print("Predicted Sigmoid Arg Max : {}".format(sigmoid(output[cur].argmax())))
print("Actual Arg Max : {}".format(target[cur].argmax()))


print(output[cur].max())
print("\nimage batch shape: ", image.shape)
print("single image shape: ", image[cur].shape)




# 1 channel image
img_1_channel = image.numpy()[cur][1]
print("img_1channel shape: ", img_1_channel.shape)
plt.figure()
plt.imshow(img_1_channel)

# 3 channel image
plt.figure()
img_3_channel = image[cur].permute(1, 2, 0)
plt.imshow(img_3_channel, cmap='cool')
print("img_3channel shape:", img_3_channel.shape)

# print label
print("labels:",label)


# Test classficaton on a single image.