Importing Necessary Libraries

In [2]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import numpy as np
import pandas as pd
import csv
import cv2

import matplotlib.pyplot as plt

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data.dataset import Dataset
import os.path
from os import path

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Importing data set and splitting it into training, test and validation data sets

In [9]:
#reference : https://aws.amazon.com/blogs/machine-learning/classifying-high-resolution-chest-x-ray-medical-images-with-amazon-sagemaker/
data_dir = "../../../../../../data/kaggle/nih-chest-xrays/data/"
#splitting thr data set into training, validation and testing data sets
#70% training data
trainper = 0.7
#10% validation data
valper = 0.1
file_name = data_dir + 'Data_Entry_2017.csv'

a = pd.read_csv(file_name)
patient_ids = a['Patient ID']
uniq_pids = np.unique(patient_ids)
np.random.shuffle(uniq_pids)
total_ids = len(uniq_pids)

trainset = int(trainper*total_ids)
valset = trainset+int(valper*total_ids)
#remaining data is used as a test set
testset = trainset+valset

train = uniq_pids[:trainset]
val = uniq_pids[trainset+1:valset]
test = uniq_pids[valset+1:]
print('Number of patient ids: training: %d, validation: %d, testing: %d'%(len(train), len(val), len(test)))

traindata = a.loc[a['Patient ID'].isin(train)]
valdata = a.loc[a['Patient ID'].isin(val)]
testdata = a.loc[a['Patient ID'].isin(test)]

traindata.to_csv('traindata.csv', sep=',', header=False, index=False)
valdata.to_csv('valdata.csv', sep=',', header=False, index=False)
testdata.to_csv('testdata.csv', sep=',', header=False, index=False)

Number of patient ids: training: 21563, validation: 3079, testing: 6161


Loading training data into the data loader 

In [29]:
# Define transforms
train_transform = transforms.Compose([transforms.Resize(256),
                                        transforms.RandomResizedCrop(224),
                                        transforms.RandomHorizontalFlip(), # randomly flip and rotate
                                        transforms.RandomRotation(10),
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

test_transform = transforms.Compose([transforms.Resize(255),
                                      transforms.CenterCrop(224),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])



Creating a bag of words of labels

In [3]:
def gen_set(csvfile, outputfile):
    disease_list = ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', \
                   'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', \
                   'Hernia']
    alldiseases = {disease:i for i,disease in enumerate(disease_list)}
    with open(outputfile, 'w') as fp:
        with open(csvfile, 'r') as cfile:
            line = csv.reader(cfile, delimiter=',')
            index = 0
            for element in line:
                # the first column is the image filename, while the second
                # column has the list of diseases separated by |
                diseases = element[1].split('|')
                fp.write('%d\t'%index)
                for d in alldiseases:
                    if d in diseases:
                        fp.write('%d\t'%1)
                    else:
                        fp.write('%d\t'%0)
                fp.write('images/%s\n' % element[0])
                index += 1
                 
gen_set('traindata.csv', 'chestxraytrain.lst')
gen_set('valdata.csv', 'chestxrayval.lst')
gen_set('testdata.csv', 'chestxraytest.lst')    

Create a custom dataloader - __to do__
The idea is, in __getitem__:
A. to read the lst file which contains the labels associated with images
B. For each image, check which ailment is '1'
C. Append those to labels for that image
D. Return image and labels

In [None]:
class XRaysTrainDataset(Dataset):
    
    def __init__(self, csv_path, transform=None):
        self.data = pd.read_csv(csv_path)
        self.data_len = len(self.data.index)            # csv data length
        
        self.image_names = np.array(self.data.iloc[:,0])  # image names
        self.heights = np.asarray(self.data.iloc[:,8])    # heights are at 8th column 
        self.widths =  np.asarray(self.data.iloc[:,7])    # widths are at  7th column
        
        # Only use 'Peumonia' for single-label classification
        # createa a new col
        #self.data.loc[:,'Infiltration'] = 0
        labels = self.data.loc[:,'Finding Labels'].map(lambda x: x.split('|'))
        for i,label in enumerate(labels):
            if 'Infiltration' in label:
                self.data.loc[i,'Infiltration'] = 1
        self.labels = torch.LongTensor(self.data.loc[:,'Infiltration'])     # all labels at  1st column
        
        self.transform = transform
        
    def __len__(self):
        return self.data_len
    
    def __getitem__(self, index):
        # Read 1 image name
        img_name = self.image_names[index]
        
        # Read 1 image file
        folder_idx_range = 13
        for folder_idx in range(folder_idx_range):
            path_prefix = "/home/tu-weihengxia/data/kaggle/nih-chest-xrays/data/images_"
            path_suffix = "/images/"
            img_folder_path = path_prefix + str(folder_idx).zfill(3) + path_suffix
            img_path = os.path.join(img_folder_path, img_name)
            if(path.exists(os.path.join(img_folder_path, img_name))):
                # print("image path: ", img_path)
                img_as_img = Image.open(os.path.join(img_folder_path, img_name))
                break
        
        img_as_img = img_as_img.convert("RGB")
        # Transform image to tensor
        img_as_tensor = self.transform(img_as_img)
        
        # Read 1 label:

        image_label = self.labels[index]
        
        return img_as_tensor, image_label
    

In [49]:
traindataLoader = XRaysTrainDataset('traindata.csv', transform = train_transform)
trainLoader = torch.utils.data.DataLoader(traindataLoader, batch_size = 10, num_workers = 5, shuffle = True)

In [51]:
#image, label = next(iter(trainLoader))

#print(label)

KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/group/donut/Anupama/mlmip-2020-notebooks/pytorch-venv/lib/python3.6/site-packages/pandas/core/indexes/base.py", line 4411, in get_value
    return libindex.get_value_at(s, key)
  File "pandas/_libs/index.pyx", line 44, in pandas._libs.index.get_value_at
  File "pandas/_libs/index.pyx", line 45, in pandas._libs.index.get_value_at
  File "pandas/_libs/util.pxd", line 98, in pandas._libs.util.get_value_at
  File "pandas/_libs/util.pxd", line 83, in pandas._libs.util.validate_indexer
TypeError: 'str' object cannot be interpreted as an integer

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/group/donut/Anupama/mlmip-2020-notebooks/pytorch-venv/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/group/donut/Anupama/mlmip-2020-notebooks/pytorch-venv/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/group/donut/Anupama/mlmip-2020-notebooks/pytorch-venv/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "<ipython-input-48-9f3a40d1d291>", line 12, in __getitem__
    img = cv2.imread(row['Image Index'])
  File "/group/donut/Anupama/mlmip-2020-notebooks/pytorch-venv/lib/python3.6/site-packages/pandas/core/series.py", line 871, in __getitem__
    result = self.index.get_value(self, key)
  File "/group/donut/Anupama/mlmip-2020-notebooks/pytorch-venv/lib/python3.6/site-packages/pandas/core/indexes/base.py", line 4419, in get_value
    raise e1
  File "/group/donut/Anupama/mlmip-2020-notebooks/pytorch-venv/lib/python3.6/site-packages/pandas/core/indexes/base.py", line 4405, in get_value
    return self._engine.get_value(s, k, tz=getattr(series.dtype, "tz", None))
  File "pandas/_libs/index.pyx", line 80, in pandas._libs.index.IndexEngine.get_value
  File "pandas/_libs/index.pyx", line 90, in pandas._libs.index.IndexEngine.get_value
  File "pandas/_libs/index.pyx", line 138, in pandas._libs.index.IndexEngine.get_loc
  File "pandas/_libs/hashtable_class_helper.pxi", line 1619, in pandas._libs.hashtable.PyObjectHashTable.get_item
  File "pandas/_libs/hashtable_class_helper.pxi", line 1627, in pandas._libs.hashtable.PyObjectHashTable.get_item
KeyError: 'Image Index'
