In [1]:
import os, random
import ntpath
import SimpleITK
from matplotlib import pyplot as plt
import matplotlib
import pandas as pd
import numpy as np
import pickle
import math
from sklearn.model_selection import StratifiedShuffleSplit
from tqdm import tqdm_notebook as tqdm

# Load Data

In [3]:
# Loading data from pickle:
data = pd.read_pickle("pickle files/train-data-filelist.pkl")
data.head()

Unnamed: 0,fissuremask,image,label,lungmask,name,completeness
0,/projects/0/ismi2018/FINALPROJECTS/CHESTCT_FIS...,/projects/0/ismi2018/FINALPROJECTS/CHESTCT_FIS...,a,/projects/0/ismi2018/FINALPROJECTS/CHESTCT_FIS...,a06,52.041
1,/projects/0/ismi2018/FINALPROJECTS/CHESTCT_FIS...,/projects/0/ismi2018/FINALPROJECTS/CHESTCT_FIS...,a,/projects/0/ismi2018/FINALPROJECTS/CHESTCT_FIS...,a13,69.0858
2,/projects/0/ismi2018/FINALPROJECTS/CHESTCT_FIS...,/projects/0/ismi2018/FINALPROJECTS/CHESTCT_FIS...,a,/projects/0/ismi2018/FINALPROJECTS/CHESTCT_FIS...,a00,20.27
3,/projects/0/ismi2018/FINALPROJECTS/CHESTCT_FIS...,/projects/0/ismi2018/FINALPROJECTS/CHESTCT_FIS...,a,/projects/0/ismi2018/FINALPROJECTS/CHESTCT_FIS...,a15,70.1147
4,/projects/0/ismi2018/FINALPROJECTS/CHESTCT_FIS...,/projects/0/ismi2018/FINALPROJECTS/CHESTCT_FIS...,a,/projects/0/ismi2018/FINALPROJECTS/CHESTCT_FIS...,a21,76.2946


In [4]:
patch_indices = pickle.load(open("patch_indices.p","rb"))

In [5]:
# fc_indices are the fissure complete label indices
# fi_indices are the fissure incomplete label indices
fc_indices, fi_indices = patch_indices

In [10]:
fc_indices[0]

{14: [(76, 118), (76, 119)],
 15: [(76, 118), (76, 119)],
 16: [(76, 118), (76, 119)],
 17: [(76, 118), (76, 119)],
 18: [(74, 120), (75, 120), (76, 118), (76, 119)],
 19: [(74, 120), (75, 120), (76, 118), (76, 119)],
 20: [(74, 120), (74, 121), (75, 120), (75, 121), (76, 118), (76, 119)],
 21: [(76, 116), (76, 117), (76, 118), (76, 119), (76, 120), (76, 121)],
 22: [(76, 116), (76, 117), (76, 118), (76, 119), (76, 120), (76, 121)],
 23: [(76, 116), (76, 117), (76, 118), (76, 119), (76, 120), (76, 121)],
 24: [(76, 116), (76, 117), (76, 118), (76, 119), (76, 120), (76, 121)],
 25: [(76, 116), (76, 117), (76, 118), (76, 119), (76, 120), (76, 121)],
 26: [(76, 116), (76, 117), (76, 118), (76, 119), (76, 120), (76, 121)],
 27: [(76, 116), (76, 117), (76, 118), (76, 119), (76, 120), (76, 121)],
 28: [(76, 116), (76, 117), (76, 118), (76, 119), (76, 120), (76, 121)],
 29: [(76, 116), (76, 117), (76, 118), (76, 119), (76, 120), (76, 121)],
 30: [(76, 116), (76, 117), (76, 118), (76, 119), (7

# Split Data

In [5]:
labels = data['label'].values

In [6]:
splitter = StratifiedShuffleSplit(1,test_size=0.1)

In [7]:
for train_index, test_index in splitter.split(data, data['label'].values):
    train_set = data.loc[train_index]
    validation_set = data.loc[test_index]

# Patch Generator

In [8]:
'''
PatchExtractor: class used to extract and possibly augment patches from images.
'''

class PatchExtractor:
    
    def __init__(self, patch_size, output_shape):
        self.patch_size = patch_size
        self.output_shape = output_shape
        
    def get_patch(self, image, location, isOutput):
        '''
        image: a numpy array representing the input image
        location: a tuple with an z, y, and x coordinate
        
        return a 3D patch from the image at 'location', representing the center of the patch
        '''
        
        z, y, x = location
        c, h, w = self.patch_size
        patch = np.zeros(self.patch_size + (1,))
        if isOutput:
            patch = np.zeros(self.output_shape + (1,))
            c, h, w = self.output_shape
        try:
            #patch = image[int(z-(c/2)):int(z+(c/2)),int(y-(h/2)):int(y+(h/2)),int(x-(w/2)):int(x+(w/2))]
            patch[:,:,:,0] = image[int(z-(c/2)):int(z+(c/2)),int(y-(h/2)):int(y+(h/2)),int(x-(w/2)):int(x+(w/2))]
        except:
            print("Patch out of boundary, please make sure that the patch location is not out of boundary.")
        return patch

# Additional functions

In [29]:
def computeBackgroundSamples(img_array,lbl_array,patch_size,output_shape):
    # unpack different sizes
    img_z, img_y, img_x = img_array.shape
    patch_z, patch_y, patch_x = patch_size
    output_z, output_y, output_x = output_shape
    
    # compute minimum and maximum per dimension
    min_z = int(0+(patch_z/2))
    max_z = int(img_z-(patch_z/2))
    
    min_y = int(0+(patch_y/2))
    max_y = int(img_y-(patch_y/2))
    
    min_x = int(0+(patch_x/2))
    max_x = int(img_x-(patch_x/2))
    
    # compute step sizes based on output size
    # this is important as due to network the output is smaller than input
    # this way overlap is garanteed
    z_step_size = int(output_z/2)
    y_step_size = int(output_y/2)
    x_step_size = int(output_x/2)
    
    samples = []
    
    # collect different coords unless coord is not background
    for z in range(min_z,max_z,z_step_size):
        for y in range(min_y,max_y,y_step_size):
            for x in range(min_x,max_x,x_step_size):
                if not lbl_array[z,y,x] == 0:
                    continue
                else:
                    samples.append((z,y,x))
                    
    return samples

In [30]:
def getBackgroundSamples(patch_size,output_shape,dataset):
    # obtain image indices to obtain img_path and lbl_path
    dataset_indices = dataset.index.values.tolist()
    
    # initialize dict to store samples per image
    samplesDict = {}
    
    # go over the image indices
    for i in range(len(dataset_indices)):
        # obtain img path and array
        img_index = dataset_indices[i]
        img_path = dataset.iloc[dataset.index.values.tolist().index(img_index)]['image']
        img_array = readImg(img_path)
        
        # obtain lbl path and array
        lbl_path = dataset.iloc[dataset.index.values.tolist().index(img_index)]['fissuremask']
        lbl_array = readImg(lbl_path)
        
        # collect sample for specific image
        samples = computeBackgroundSamples(img_array,lbl_array,patch_size,output_shape)
        
        # store samples with img_index as key
        samplesDict[img_index] = samples
        
    return samplesDict

In [31]:
def getBoundaries(img_array, patch_size):
    # unpack different sizes
    img_z, img_y, img_x = img_array.shape
    patch_z, patch_y, patch_x = patch_size
    
    # compute minimum and maximum sizes
    min_z = int(0+(patch_z/2))
    max_z = int(img_z-(patch_z/2))
    
    min_y = int(0+(patch_y/2))
    max_y = int(img_y-(patch_y/2))
    
    min_x = int(0+(patch_x/2))
    max_x = int(img_x-(patch_x/2))
    
    return ((min_z,max_z),(min_y,max_y),(min_x,max_x))

In [12]:
def getFissureSamples(patch_size,indices,dataset):
    samplesDict = {}
    # go over all images in the dataset
    for i in range(len(indices)):
        
        # obtain image array
        img_path = dataset.iloc[dataset.index.values.tolist().index(i)]['image']
        img_array = readImg(img_path)
        
        # use image array to determine boundaries
        (min_z,max_z),(min_y,max_y),(min_x,max_x) = getBoundaries(img_array,patch_size)
        
        # get fissure (in)complete indices correpsonding to image
        img_indices = indices[i]
        
        # get slices with fissure (in)complete parts
        z_slices = list(img_indices.keys())
        
        # initialize list to put coords in
        samples = []
        
        # go over z slices
        for z in z_slices:
            # if z out of boundary, skip it
            if (z < min_z) or (z > max_z):
                continue
            else:
                # get all y-x coords related to the CT slice
                yx_coords = img_indices[z]
                for (y,x) in yx_coords:
                    if (y >= min_y) and (y <= max_y):
                        if (x >= min_x) and (x <= max_x):
                            # only append if not out of boundary
                            samples.append((z,y,x))
                            
        # store samples with key = image index
        samplesDict[i] = samples
        
    return samplesDict

In [13]:
def readImg(img_path):
    img = SimpleITK.ReadImage(img_path)
    img_array = SimpleITK.GetArrayFromImage(img)
    return img_array

# Batch Creator

In [65]:
class BatchCreator:
    
    def __init__(self,patch_extractor,dataset,sampleLocations,batch_division, nr_samples):
        self.patch_extractor = patch_extractor
        self.patch_size = self.patch_extractor.patch_size
        
        self.dataset = dataset
        
        self.img_list = dataset['image'].values
        self.lbl_list = dataset['fissuremask'].values
        self.msk_list = dataset['lungmask'].values
        
        self.img_indices = dataset.index.values.tolist()
        
        self.bSamples, self.fcSamples, self.fiSamples = sampleLocations
        
        self.batch_division = batch_division
        
        self.nr_samples = nr_samples
        
        self.counter = 0
        
        self.examined_images = []
        
    def create_batch(self, batch_size, img_index):
        
        if len(self.examined_images) == len(self.img_indices):
            self.examined_images = []
        
        x_data, y_data, fissure_data = self.initializeOutputArrays(batch_size)
        
        b_samples = self.bSamples[img_index]
        fc_samples = self.fcSamples[img_index]
        fi_samples = self.fiSamples[img_index]
        
        img_array, lbl_array, msk_array = self.img2array(img_index)
        
        fc_nr, fi_nr = self.checkEmpty(b_samples,fc_samples,fi_samples,self.batch_division)
        
        for i in range(batch_size):
            if i < fc_nr:
                (z,y,x) = random.choice(fc_samples)
                x_data[i] = self.patch_extractor.get_patch(img_array,(z,y,x),False)
                y_data[i,0,0,0,2] = 1
                fissure_data[i] = self.patch_extractor.get_patch(lbl_array,(z,y,x),True)
            elif ((i >= fc_nr) and (i < (fc_nr + fi_nr))):
                (z,y,x) = random.choice(fi_samples)
                x_data[i] = self.patch_extractor.get_patch(img_array,(z,y,x),False)
                y_data[i,0,0,0,1] = 1
                fissure_data[i] = self.patch_extractor.get_patch(lbl_array,(z,y,x),True)
            else:
                (z,y,x) = random.choice(b_samples)
                x_data[i] = self.patch_extractor.get_patch(img_array,(z,y,x),False)
                y_data[i,0,0,0,0] = 1
                fissure_data[i] = self.patch_extractor.get_patch(lbl_array,(z,y,x),True)
        
        self.updateCounter(batch_size)
        
        self.examined_images.append(img_index)
        
        return x_data, fissure_data
        
    def pickImage(self):
        img_index = self.img_indices[len(self.examined_images)]
        self.examined_images.append(img_index)
        return img_index
    
    def initializeOutputArrays(self, batch_size):
        # patch array
        x_data = np.zeros((batch_size, *self.patch_extractor.patch_size,1))
        # label array (one-hot structure)
        y_data = np.zeros((batch_size, 1, 1, 1, 3))
        # fissure mask patch array
        fissure_data = np.zeros((batch_size, *self.patch_extractor.output_shape,1))
        
        return x_data, y_data, fissure_data
    
    def img2array(self, img_index):
        # compute numpy array from image
        img_path = self.dataset.iloc[self.img_indices.index(img_index)]['image']
        img_array = readImg(img_path)
        
        # compute numpy array from fissure mask
        lbl_path = self.dataset.iloc[self.img_indices.index(img_index)]['fissuremask']
        lbl_array = readImg(lbl_path)
        
        # compute numpy array from lung mask
        msk_path = self.dataset.iloc[self.img_indices.index(img_index)]['lungmask']
        msk_array = readImg(msk_path)
        return img_array, lbl_array, msk_array
    
    def checkEmpty(self,b_samples,fc_samples,fi_samples,batch_division):
        fc_nr, fi_nr = batch_division
        
        if len(fc_samples) == 0:
            if len(fi_samples) == 0:
                fc_nr = 0
                fi_nr = 0
            else:
                fi_nr = fi_nr + fc_nr
                fc_nr = 0
        else:
            if len(fi_samples) == 0:
                fc_nr = fc_nr + fi_nr
                fi_nr = 0
        
        return fc_nr, fi_nr
    
    def updateCounter(self,batch_size):
        self.counter += batch_size
        if self.counter > self.nr_samples:
            self.counter = 0
            
    def counterReset(self):
        if self.counter == 0:
            return True
        else:
            return False
            
    def get_generator(self, batch_size):
        '''returns a generator that will yield batches infinitely'''
        img_index = self.pickImage()
        while True:
            if self.counterReset:
                img_index = self.pickImage()
            print(img_index)
            yield self.create_batch(batch_size,img_index)

In [54]:
patch_size = (132,132,132)
output_shape = (44,44,44)
patch_extractor = PatchExtractor(patch_size,output_shape)
batch_size = 4
batch_division = (1,1)
nr_samples = 16

# Notes

After deciding on patch_size and output_size, collect sample locations for background, fissure complete and incomplete. The dicts have as key the image index and as value the different (z,y,x) coordinates from which a patch can be computed. This way we do not have to compute if a patch is in boundary and we can collect multiple samples per image.

In [27]:
samplesBackgroundDict = getBackgroundSamples(patch_size,output_shape,data)

In [28]:
nr_samples = []
for index, samples in samplesBackgroundDict.items():
    nr_samples.append(len(samples))

print(np.unique(nr_samples))
print("The maximum amount of samples of background per image is: %s"%(np.min(nr_samples)))

[  48   59   75   96  103  104  111  148  168  197  209  214  216  226  228
  251  256  271  278  285  286  291  305  306  312  313  324  334  345  354
  357  359  363  375  388  390  402  412  414  416  420  427  438  449  470
  487  488  499  501  502  506  509  527  528  529  530  538  539  541  571
  572  573  574  579  585  586  605  628  635  638  639  641  661  670  678
  698  714  717  718  750  757  758  808  820  837  852  899  906  915  945
 1048 1129 1435 1448]
The maximum amount of samples of background per image is: 48


In [24]:
samplesFissureCompleteDict = getFissureSamples(patch_size,fc_indices,data)

In [32]:
sampelsFissureIncompleteDict = getFissureSamples(patch_size,fi_indices,data)

In [34]:
sampleLocations = (samplesBackgroundDict,samplesFissureCompleteDict,sampelsFissureIncompleteDict)

# Updated Batch_Creator

In [66]:
batch_creator = BatchCreator(patch_extractor, train_set, sampleLocations, batch_division, nr_samples)