# Import Libraries

In [1]:
import os, random
import ntpath
import SimpleITK
from matplotlib import pyplot as plt
import matplotlib
import pandas as pd
import numpy as np
import pickle

# Load Dataset

In [2]:
# Loading data from pickle:
data = pd.read_pickle("pickle files/train-data-filelist.pkl")

# Load Label Indices

In [3]:
patch_indices = pickle.load(open("pickle files/patch_indices.p","rb"))

In [4]:
# fc_indices are the fissure complete label indices
# fi_indices are the fissure incomplete label indices
fc_indices, fi_indices = patch_indices

# Help Functions

In [5]:
def readImg(img_path):
    img = SimpleITK.ReadImage(img_path)
    img_array = SimpleITK.GetArrayFromImage(img)
    return img_array

# Patch Generator

In [6]:
'''
PatchExtractor: class used to extract and possibly augment patches from images.
'''

class PatchExtractor:
    
    def __init__(self, patch_size):
        self.patch_size = patch_size
        
    def get_patch(self, image, location):
        '''
        image: a numpy array representing the input image
        location: a tuple with an z, y, and x coordinate
        
        return a 3D patch from the image at 'location', representing the center of the patch
        '''
        
        z, y, x = location
        c, h, w = self.patch_size
        try:
            patch = image[int(z-(c/2)):int(z+(c/2)),int(y-(h/2)):int(y+(h/2)),int(x-(w/2)):int(x+(w/2))]
        except:
            print("Patch out of boundary, please make sure that the patch location is not out of boundary.")
        return patch
        
        
        

# Batch Creator

In [11]:
'''
BatchCreator is used to create batches of patches

For now it works totally random:
1) pick image at random -> generate patches from it

for every new patch:
2) pick random z (slice) from image
3) pick random (y,x) coordinates
4) generate patch

It is done this way as the (y,x) coordinates were computed per slice
if a slice did not contain any (y,x) coordinates (as the slice did not contain the label)
then this slice was skipped for the image.

What should still be implemented:
- generate patches from multiple images
- only pick background at random
- make sure to generate patches from large part of fissure
'''
class BatchCreator:
    
    def __init__(self,patch_extractor,dataset,patch_indices,batch_division):
        self.patch_extractor = patch_extractor
        self.patch_size = self.patch_extractor.patch_size
        
        self.img_list = dataset['image'].values
        self.lbl_list = dataset['fissuremask'].values
        self.msk_list = dataset['lungmask'].values
        
        self.a_indices = dataset.index[dataset['label'] == "a"].tolist()
        self.b_indices = dataset.index[dataset['label'] == "b"].tolist()
        self.c_indices = dataset.index[dataset['label'] == "c"].tolist()
        
        '''
        fc_indices and fi_indices are lists of length equal to the number of images
        the lists contain  dictionaries for every image
        the dictionary has as keys the slice indices (z coordinate) of the image
        and each key has as  value a corresponding list filled with (y,x) coordinates.
        These coordinates correspond to either fissure complete or fissure incomplete spots.
        '''
        self.fc_indices = patch_indices[0]
        self.fi_indices = patch_indices[1]
        
        self.batch_division = batch_division
        
        self.examined_images = []
        
    def create_batch(self, batch_size):
        
        # check if all images have been processed
        if len(self.examined_images) == 100:
            # start over
            self.examined_images = []
        
        # patch array
        x_data = np.zeros((batch_size, *self.patch_extractor.patch_size))
        # label array (one-hot structure)
        y_data = np.zeros((batch_size, 1, 1, 1, 3))
        # fissure mask patch array
        fissure_data = np.zeros((batch_size, *self.patch_extractor.patch_size))
        
        # pick random image to collect patches from
        img_index = np.random.choice(self.a_indices+self.b_indices+self.c_indices)
        # check if image has not already been processed
        while(self.imageProcessed(img_index)):
            img_index = np.random.choice(self.a_indices+self.b_indices+self.c_indices)
            
        # collect the different slices (z coordinate) for the image 
        fc_slices_dict = self.fc_indices[img_index]
        
        # collect the different slices (z coordinate) for the image 
        fi_slices_dict = self.fi_indices[img_index]
        
        # compute numpy array from image
        img_path = self.img_list[img_index]
        img_array = readImg(img_path)
        
        # compute numpy array from fissure mask
        lbl_path = self.lbl_list[img_index]
        lbl_array = readImg(lbl_path)
        
        # compute numpy array from lung mask
        msk_path = self.msk_list[img_index]
        msk_array = readImg(msk_path)
        
        # fc_nr is the number of total fissure complete patches
        # fi_nr is the number of total fissure incomplete patches
        (fc_nr,fi_nr) = self.batch_division
        
        if len(list(fi_slices_dict.keys())) == 0:
            fc_nr = fc_nr + fi_nr
            fi_nr = 0
        
        for i in range(batch_size):
            #collect fissure complete patches
            if i < fc_nr:
                # collect slice indices
                slice_nrs = list(fc_slices_dict.keys())
                # pick random slice
                z = slice_nrs[np.random.choice(len(slice_nrs))]
                # pick random coordinate
                (z,y,x) = self.getCoordinates(fc_slices_dict,z)
                # if patch does not exceed image continue else pick new coordinate
                while(not self.inBoundary((z,y,x),self.patch_size,img_array.shape)):
                    (z,y,x) = self.getCoordinates(fc_slices_dict,z)
                # extract patch with coordinate
                x_data[i] = self.patch_extractor.get_patch(img_array,(z, y, x))
                # store one hot encoding patch
                y_data[i,0,0,0,2] = 1
                # extract fissure mask patch
                fissure_data[i] = self.patch_extractor.get_patch(lbl_array,(z, y, x))
            #collect fissure incomplete patches
            elif ((i >= fc_nr) and (i < (fc_nr + fi_nr))):
                # collect slice indices
                slice_nrs = list(fi_slices_dict.keys())
                # pick random slice
                z = slice_nrs[np.random.choice(len(slice_nrs))]
                # pick random coordinate
                (z,y,x) = self.getCoordinates(fi_slices_dict,z)
                # if patch does not exceed image continue else pick new coordinate
                while(not self.inBoundary((z,y,x),self.patch_size,img_array.shape)):
                    (z,y,x) = self.getCoordinates(fi_slices_dict,z)
                # extract patch with coordinate
                x_data[i] = self.patch_extractor.get_patch(img_array,(z, y, x))
                # store one hot encoding patch
                y_data[i,0,0,0,1] = 1
                # extract fissure mask patch
                fissure_data[i] = self.patch_extractor.get_patch(lbl_array,(z, y, x))
            # collect background patches
            else:
                # kept getting errors about out of boundary
                # so I decided on this ugly piece of code
                z_max, y_max, x_max = img_array.shape
                z_max = int(z_max - patch_size[0]/2)
                z_min = int(0 + patch_size[0]/2)
                y_max = int(y_max - patch_size[1]/2)
                y_min = int(0 + patch_size[1]/2)
                x_max = int(x_max - patch_size[2]/2)
                x_min = int(0 + patch_size[2]/2)
                # pick random coordinate of the image
                z = np.random.choice(list(range(z_min,z_max)))
                y = np.random.choice(list(range(y_min,y_max)))
                x = np.random.choice(list(range(x_min,x_max)))
                # if coordinate is background and does not exceed image continue else new coordinate
                while((not self.checkBackground((z,y,x),lbl_array,msk_array)) and (not self.inBoundary((z,y,x),self.patch_extractor.patch_size,img_array.shape))):
                    z = np.random.choice(list(range(z_min,z_max)))
                    y = np.random.choice(list(range(y_min,y_max)))
                    x = np.random.choice(list(range(x_min,x_max)))
                # extract patch with coordinate
                x_data[i] = self.patch_extractor.get_patch(img_array,(z, y, x))
                # store one hot encoding patch
                y_data[i,0,0,0,0] = 1
                # extracht fissure mask patch
                fissure_data[i] = self.patch_extractor.get_patch(lbl_array,(z, y, x))
        
        # store processed image
        self.examined_images.append(img_index)
        
        return x_data, y_data
    
    
    '''
    function to check if image is already used for patches
    '''
    def imageProcessed(self,img_index):
        if img_index in self.examined_images:
            return True
        else:
            return False
    
    '''
    function helps to check if specific coordinate has background label
    and does not exceed lung mask
    '''
    def checkBackground(self,location,lbl_array,msk_array):
        z, y, x = location
        if lbl_array[z,y,x] == 0 and msk_array[z,y,x] == 3:
            return True
        else:
            return False
    
    '''
    function helps to obtain random (y,x) coordinates.
    '''
    def getCoordinates(self,slices_dict,z):
        coordinates = slices_dict[z]
        random_coords = np.random.choice(len(coordinates))
        y, x = coordinates[random_coords]
        return (z,y,x)
    
    '''
    function helps to check if location does not exceed image.
    '''
    def inBoundary(self,location,patch_size,img_shape):
        z_size, y_size, x_size = img_shape
        z_patch, y_patch, x_patch = patch_size
    
        z_min = int(0+(z_patch/2))
        z_max = int(z_size-(z_patch/2))
    
        y_min = int(0+(y_patch/2))
        y_max = int(y_size-(y_patch/2))
    
        x_min = int(0+(x_patch/2))
        x_max = int(x_size-(x_patch/2))
    
        z, y, x = location
    
        if (z <= z_max and z >= z_min) and (y <= y_max and y >= y_min) and (x <= x_max and x >= x_min):
            return True
        else:
            return False
        
    def get_generator(self, batch_size):
        '''returns a generator that will yield batches infinitely'''
        while True:
            yield self.create_batch(batch_size)

# Example Setup

In [8]:
patch_size = (10,32,32)
patch_extractor = PatchExtractor(patch_size)
batch_size = 28
# this means 8 of fissure complete, 8 of fissure incomplete and remaining background
batch_division = (8,8)

In [12]:
batch_creator = BatchCreator(patch_extractor, data, patch_indices, batch_division)

In [13]:
# Here x is an array filled with patches and y is an array filled with one-hot encoding of corresponding patch
x, y = batch_creator.create_batch(28)

In [14]:
print(batch_creator.examined_images)

[6]
