Importing Necessary Libraries

In [2]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import numpy as np
import pandas as pd
import csv
import cv2

import matplotlib.pyplot as plt

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data.dataset import Dataset
import os.path
from os import path

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Importing data set and splitting it into training, test and validation data sets

In [3]:
#reference : https://aws.amazon.com/blogs/machine-learning/classifying-high-resolution-chest-x-ray-medical-images-with-amazon-sagemaker/
data_dir = "../../../../../../data/kaggle/nih-chest-xrays/data/"
#splitting thr data set into training, validation and testing data sets
#70% training data
trainper = 0.7
#10% validation data
valper = 0.1
file_name = data_dir + 'Data_Entry_2017.csv'

a = pd.read_csv(file_name)
patient_ids = a['Patient ID']
uniq_pids = np.unique(patient_ids)
np.random.shuffle(uniq_pids)
total_ids = len(uniq_pids)

trainset = int(trainper*total_ids)
valset = trainset+int(valper*total_ids)
#remaining data is used as a test set
testset = trainset+valset

train = uniq_pids[:trainset]
val = uniq_pids[trainset+1:valset]
test = uniq_pids[valset+1:]
print('Number of patient ids: training: %d, validation: %d, testing: %d'%(len(train), len(val), len(test)))

traindata = a.loc[a['Patient ID'].isin(train)]
valdata = a.loc[a['Patient ID'].isin(val)]
testdata = a.loc[a['Patient ID'].isin(test)]

traindata.to_csv('traindata.csv', sep=',', header=False, index=False)
valdata.to_csv('valdata.csv', sep=',', header=False, index=False)
testdata.to_csv('testdata.csv', sep=',', header=False, index=False)

Number of patient ids: training: 21563, validation: 3079, testing: 6161


Loading training data into the data loader 

In [4]:
# Define transforms
train_transform = transforms.Compose([transforms.Resize(256),
                                        transforms.RandomResizedCrop(224),
                                        transforms.RandomHorizontalFlip(), # randomly flip and rotate
                                        transforms.RandomRotation(10),
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

test_transform = transforms.Compose([transforms.Resize(255),
                                      transforms.CenterCrop(224),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])



Creating a bag of words of labels

In [5]:
def gen_set(csvfile, outputfile):
    disease_list = ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', \
                   'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', \
                   'Hernia']
    alldiseases = {disease:i for i,disease in enumerate(disease_list)}
    with open(outputfile, 'w') as fp:
        with open(csvfile, 'r') as cfile:
            line = csv.reader(cfile, delimiter=',')
            index = 0
            for element in line:
                # the first column is the image filename, while the second
                # column has the list of diseases separated by |
                diseases = element[1].split('|')
                #fp.write('%d\t'%index)
                for d in alldiseases:
                    if ((d in diseases) and (d == 'Atelectasis')):
                        fp.write('%d\t'%1)
                    elif((d in diseases) and (d == 'Consolidation')):
                        fp.write('%d\t'%2)
                    elif((d in diseases) and (d == 'Infiltration')):
                        fp.write('%d\t'%3)
                    elif((d in diseases) and (d == 'Pneumothorax')):
                        fp.write('%d\t'%4)
                    elif((d in diseases) and (d == 'Edema')):
                        fp.write('%d\t'%5)
                    elif((d in diseases) and (d == 'Emphysema')):
                        fp.write('%d\t'%6)
                    elif((d in diseases) and (d == 'Fibrosis')):
                        fp.write('%d\t'%7)
                    elif((d in diseases) and (d == 'Effusion')):
                        fp.write('%d\t'%8)
                    elif((d in diseases) and (d == 'Pneumonia')):
                        fp.write('%d\t'%9)
                    elif((d in diseases) and (d == 'Pleural_Thickening')):
                        fp.write('%d\t'%10)
                    elif((d in diseases) and (d == 'Cardiomegaly')):
                        fp.write('%d\t'%11)
                    elif((d in diseases) and (d == 'Nodule')):
                        fp.write('%d\t'%12)
                    elif((d in diseases) and (d == 'Mass')):
                        fp.write('%d\t'%13)
                    elif((d in diseases) and (d == 'Hernia')):
                        fp.write('%d\t'%14)
                    else:
                        fp.write('%d\t'%0)
                fp.write('images/%s\n' % element[0])
                index += 1
#when used in local machine, to be commented out otherwise
#path = 'D:/EIT_AUS_TUB/SoSe2020_MLInMIP/MedicalImageProcessing/'    
#os.chdir(path)                 
gen_set('traindata.csv', 'chestxraytrain.txt')
gen_set('valdata.csv', 'chestxrayval.txt')
gen_set('testdata.csv', 'chestxraytest.txt')    

Function to identify the diagnosis label associated with an image

In [6]:
def Diagnosis(word):
    #print("word", word)
    if(word == 1):
        diag = 'Atelectasis'
    elif(word == 2):
        diag = 'Consolidation'
    elif(word == 3):
        diag = 'Infiltration'
    elif(word == 4):
        diag = 'Pneumothorax'
    elif(word == 5):
        diag = 'Edema'
    elif(word == 6):
        diag = 'Emphysema'
    elif(word == 7):
        diag = 'Fibrosis'
    elif(word == 8):
        diag = 'Effusion'
    elif(word == 9):
        diag = 'Pneumonia'
    elif(word == 10):
        diag = 'Pleural_Thickening'
    elif(word == 11):
        diag = 'Cardiomegaly'
    elif(word == 12):
        diag = 'Nodule'
    elif(word == 13):
        diag = 'Mass'
    elif(word == 14):
        diag = 'Hernia'
    else:
        diag = "Undiagnosed"

    #print(diag)
    return diag

Create a custom dataloader - __to do__
The idea is, in __getitem__:
A. to read the lst file which contains the labels associated with images
B. For each image, check which ailment is '1'
C. Append those to labels for that image
D. Return image and labels

In [7]:
class XRaysTrainDataset(Dataset):
    
    def __init__(self, csv_name, transform=None):
        file1 = open(csv_name, "r")
        self.data =  file1.readlines()           
        self.data_len = len(self.data)         
        self.transform = transform
             
    def __len__(self):
        return self.data_len
    
    def __getitem__(self, index):       
        #read labels in each line in the txt file
        cnt = 0 
        imageName = []            
        imgLab = [] 
        image_Label = []
        for word in self.data[index].split():          
            #diagnose until the last index which is the image name
            if(cnt < 14):
                #print(word)
                diag = Diagnosis(int(word))
                #print(diag)
                if(diag != 'Undiagnosed'):
                    imgLab.append(diag)
                    #image_Label.append(imgLab)
                    image_Label = imgLab.copy()

            if(cnt == 14):
                imageName.append(word)
            cnt+=1 
        
        print(imageName)  
        if not image_Label:
            image_Label.append("Undiagnosed")
            #image_Label = imgLab.copy()
                                     
        print(image_Label)
        return imageName, image_Label  
    

In [8]:
traindataLoader = XRaysTrainDataset('chestxraytrain.txt', transform = train_transform)
trainLoader = torch.utils.data.DataLoader(traindataLoader, batch_size = 5, shuffle = True)

In [9]:
imageName, label = next(iter(trainLoader))         
print("in trainloader")
print(imageName)
print(label)

['images/00013088_001.png']
['Undiagnosed']
['images/00030260_010.png']
['Emphysema']
['images/00014329_002.png']
['Undiagnosed']
['images/00022278_001.png']
['Atelectasis', 'Infiltration']
['images/00008183_000.png']
['Undiagnosed']
in trainloader
[('images/00013088_001.png', 'images/00030260_010.png', 'images/00014329_002.png', 'images/00022278_001.png', 'images/00008183_000.png')]
[('Undiagnosed', 'Emphysema', 'Undiagnosed', 'Atelectasis', 'Undiagnosed')]
