# Import Libraries

In [52]:
import os, random
import ntpath
import SimpleITK
from matplotlib import pyplot as plt
import matplotlib
import pandas as pd
import numpy as np
import pickle
import math

# Load Dataset

In [2]:
# Loading data from pickle:
data = pd.read_pickle("pickle files/train-data-filelist.pkl")

# Load Label Indices

In [3]:
patch_indices = pickle.load(open("pickle files/patch_indices.p","rb"))

In [4]:
# fc_indices are the fissure complete label indices
# fi_indices are the fissure incomplete label indices
fc_indices, fi_indices = patch_indices

# Help Functions

In [5]:
def readImg(img_path):
    img = SimpleITK.ReadImage(img_path)
    img_array = SimpleITK.GetArrayFromImage(img)
    return img_array

# Patch Generator

In [6]:
'''
PatchExtractor: class used to extract and possibly augment patches from images.
'''

class PatchExtractor:
    
    def __init__(self, patch_size):
        self.patch_size = patch_size
        
    def get_patch(self, image, location):
        '''
        image: a numpy array representing the input image
        location: a tuple with an z, y, and x coordinate
        
        return a 3D patch from the image at 'location', representing the center of the patch
        '''
        
        z, y, x = location
        c, h, w = self.patch_size
        try:
            patch = image[int(z-(c/2)):int(z+(c/2)),int(y-(h/2)):int(y+(h/2)),int(x-(w/2)):int(x+(w/2))]
        except:
            print("Patch out of boundary, please make sure that the patch location is not out of boundary.")
        return patch

# Batch Creator

In [175]:
class BatchCreator:
    
    def __init__(self,patch_extractor,dataset,patch_indices,batch_division):
        self.patch_extractor = patch_extractor
        self.patch_size = self.patch_extractor.patch_size
        
        self.img_list = dataset['image'].values
        self.lbl_list = dataset['fissuremask'].values
        self.msk_list = dataset['lungmask'].values
        
        self.a_indices = dataset.index[dataset['label'] == "a"].tolist()
        self.b_indices = dataset.index[dataset['label'] == "b"].tolist()
        self.c_indices = dataset.index[dataset['label'] == "c"].tolist()
        
        self.img_indices = self.a_indices + self.b_indices + self.c_indices
        
        self.fc_indices = patch_indices[0]
        self.fi_indices = patch_indices[1]
        
        self.batch_division = batch_division
        
        self.examined_images = []
        
    def create_batch(self, batch_size):
        
        if len(self.examined_images) == len(self.a_indices + self.b_indices + self.c_indices):
            self.examined_images = []
            self.img_indices = self.a_indices + self.b_indices + self.c_indices
            
        img_index = self.pickImage()
        
        x_data, y_data, fissure_data = self.initializeOutputArrays(batch_size)
        
        fc_slices_dict = self.fc_indices[img_index]
        fi_slices_dict = self.fi_indices[img_index]
        
        img_array, lbl_array, msk_array = self.img2array(img_index)
        
        (fc_nr,fi_nr) = self.batch_division
        b_nr = batch_size-(fc_nr+fi_nr)
        
        if len(list(fi_slices_dict.keys())) == 0:
            fc_nr = fc_nr + fi_nr
            fi_nr = 0
            
        fc_grid, fc_grid_size = self.fissureGrid(fc_slices_dict)
        fi_grid, fi_grid_size = self.fissureGrid(fi_slices_dict)
        b_grid, b_grid_dict, b_grid_size = self.backgroundGrid(img_array.shape,int(b_nr/4))
            
        background_counter = 0
        background_index = 0
        
        for i in range(batch_size):
            if i < fc_nr:
                z = fc_grid[i%fc_grid_size]
                (z,y,x) = self.getCoordinates(fc_slices_dict,z,img_array)
                x_data[i] = self.patch_extractor.get_patch(img_array,(z,y,x))
                y_data[i,0,0,0,2] = 1
                fissure_data[i] = self.patch_extractor.get_patch(lbl_array,(z,y,x))
            elif ((i >= fc_nr) and (i < (fc_nr + fi_nr))):
                z = fi_grid[i%fi_grid_size]
                (z,y,x) = self.getCoordinates(fi_slices_dict,z,img_array)
                x_data[i] = self.patch_extractor.get_patch(img_array,(z,y,x))
                y_data[i,0,0,0,1] = 1
                fissure_data[i] = self.patch_extractor.get_patch(lbl_array,(z,y,x))
            else:
                if background_counter == 4:
                    background_index += 1
                    background_counter = 0
                z = b_grid[background_index]
                grid = b_grid_dict[z][background_counter]
                (z,y,x) = self.getBackground(grid,msk_array,z)
                x_data[i] = self.patch_extractor.get_patch(img_array,(z,y,x))
                y_data[i,0,0,0,0] = 1
                fissure_data[i] = self.patch_extractor.get_patch(lbl_array,(z,y,x))
                background_counter += 1
        
        self.examined_images.append(img_index)
        
        return x_data, y_data
    
    def pickImage(self):
        index = np.random.randint(0,len(self.img_indices)-1)
        img_index = self.img_indices[index]
        self.examined_images.append(img_index)
        self.img_indices = np.delete(self.img_indices,index)
        return img_index
    
    def initializeOutputArrays(self, batch_size):
        # patch array
        x_data = np.zeros((batch_size, *self.patch_extractor.patch_size))
        # label array (one-hot structure)
        y_data = np.zeros((batch_size, 1, 1, 1, 3))
        # fissure mask patch array
        fissure_data = np.zeros((batch_size, *self.patch_extractor.patch_size))
        
        return x_data, y_data, fissure_data
    
    def img2array(self, img_index):
        # compute numpy array from image
        img_path = self.img_list[img_index]
        img_array = readImg(img_path)
        
        # compute numpy array from fissure mask
        lbl_path = self.lbl_list[img_index]
        lbl_array = readImg(lbl_path)
        
        # compute numpy array from lung mask
        msk_path = self.msk_list[img_index]
        msk_array = readImg(msk_path)
        return img_array, lbl_array, msk_array
    
    def fissureGrid(self,slicesDict):
        z_size, _, _ = self.patch_extractor.patch_size
        slices = sorted(list(slicesDict.keys()))
        z_grid = list(chunks(slices,int(z_size*1.5)))
        z_medians = [int(np.median(chunk)) for chunk in z_grid]
        grid_size = len(z_medians)
        return z_medians, grid_size
    
    def backgroundGrid(self,img_shape,b_nr):
        z_max, y_max, x_max = img_shape
        z_size, y_size, x_size = self.patch_extractor.patch_size
        slices = list(range(z_max))
        z_grid = list(chunks(slices,int(len(slices)/b_nr)))
        z_medians = [int(np.median(chunk)) for chunk in z_grid]
        grid_size = len(z_medians)
        z_grid_dict = {}
        for z_median in z_medians:
            grid1 = (0+math.ceil(y_size/2),int(y_max/2)-1,0+math.ceil(x_size/2),int(x_max/2)-1)
            grid2 = (0+math.ceil(y_size/2),int(y_max/2)-1,int(x_max/2),(x_max-1)-(math.floor(y_size/2)))
            grid3 = (int(y_max/2),(y_max-1)-(math.floor(y_size/2)),0+math.ceil(x_size/2),int(x_max/2)-1)
            grid4 = (int(y_max/2),(y_max-1)-(math.floor(y_size/2)),int(x_max/2),(x_max-1)-(math.floor(y_size/2)))
            z_grid_dict[z_median] = (grid1,grid2,grid3,grid4)
        return z_medians, z_grid_dict, grid_size
        
    def chunks(self,l,n):
        return [l[i:i + n] for i in range(0, len(l), n)]
    
    def getCoordinates(self,slices_dict,z,img_array):
        coordinates = slices_dict[z]
        filtered_coordinates = [(y,x) for (y,x) in coordinates if self.inBoundary((z,y,x),img_array.shape)]
        if len(filtered_coordinates) == 0:
            print("Error: no x,y coordinates found that are in boundary of the image taking the patch size in mind")
        random_coords_index = np.random.choice(len(filtered_coordinates))
        y, x = filtered_coordinates[random_coords_index]
        return (z,y,x)
    
    def getBackground(self,grid,msk_array,z):
        (y_min,y_max,x_min,x_max) = grid
        y_indices, x_indices = np.where(msk_array[z,:,:] == 3)
        coords = self.getCoords(z,y_indices,x_indices,msk_array)
        i = np.random.randint(len(coords)-1)
        (y,x) = coords[i]
        return (z,y,x)
    
    def getCoords(self,z,y_indices,x_indices,msk_array):
        coords = []
        for i, y in enumerate(y_indices):
            x = x_indices[i]
            coord = (y,x)
            if self.inBoundary((z,y,x),msk_array.shape):
                coords.append(coord)
        return coords
        
    
    def inBoundary(self,location, img_shape):
        _, y_size, x_size = img_shape
        _, y_patch, x_patch = self.patch_extractor.patch_size
        
        y_min = math.ceil(0+(y_patch/2))
        y_max = math.floor(y_size-(y_patch/2))
        
        x_min = math.ceil(0+(x_patch/2))
        x_max = math.ceil(x_size-(x_patch/2))
        
        _, y, x = location
        
        if (y <= y_max and y >= y_min) and (x <= x_max and x >= x_min):
            return True
        else:
            return False
        
    def checkBackground(self,location,lbl_array,msk_array):
        z, y, x = location
        if (lbl_array[z,y,x] == 0) and (msk_array[z,y,x] == 3):
            return True
        else:
            return False
        
    def get_generator(self, batch_size):
        '''returns a generator that will yield batches infinitely'''
        while True:
            yield self.create_batch(batch_size)
        

# Example Setup

In [176]:
patch_size = (10,32,32)
patch_extractor = PatchExtractor(patch_size)
batch_size = 32
batch_division = (8,8)

In [177]:
batch_creator = BatchCreator(patch_extractor, data, patch_indices, batch_division)

In [178]:
x, y = batch_creator.create_batch(batch_size)