## TODO: For now we only segment the liver
- Added shuffle for data, there might have been some bias in the order.
- Should we use all the training data? (doesn't fit in memory at the moment)
- Should we validate with a "real" dice (also done on in the competiton evaluation)? With patches or full images?
- Experiment with bigger patches or more patches in batch. 

In [1]:
# imports
import matplotlib
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt
import os
from scipy.ndimage import zoom
import json
import warnings
from random import randint
import random
import SimpleITK as sitk
from multi_slice_viewer import multi_slice_viewer
import tensorflow as tf
from keras.models import Model, load_model
from keras.layers import Input, Conv3D, MaxPooling3D, Dropout, Conv3DTranspose, UpSampling3D, concatenate, Cropping3D, Reshape, BatchNormalization
import keras.callbacks
from keras import backend as K
from keras import optimizers
from keras import regularizers
from keras.optimizers import SGD, Adam
from keras.utils.np_utils import to_categorical
from IPython.display import clear_output
import pickle 
from tensorflow.python.client import device_lib

#this part is needed if you run the notebook on Cartesius with multiple cores
n_cores = 16
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.Session(config=config)
K.set_session(session)
os.environ["OMP_NUM_THREADS"] = str(n_cores-1)
os.environ["KMP_BLOCKTIME"] = "1"
os.environ["KMP_SETTINGS"] = "1"
os.environ["KMP_AFFINITY"]= "granularity=fine,verbose,compact,1,0"

Using TensorFlow backend.


In [2]:
# check if we use gpu or cpu
print(device_lib.list_local_devices())
print(tf.test.is_gpu_available())

[name: "/device:CPU:0"
device_type: "CPU"
memory_limit: 268435456
locality {
}
incarnation: 8270264475881289365
, name: "/device:GPU:0"
device_type: "GPU"
memory_limit: 11321648743
locality {
  bus_id: 1
}
incarnation: 13206431668856760341
physical_device_desc: "device: 0, name: Tesla K40m, pci bus id: 0000:02:00.0, compute capability: 3.5"
, name: "/device:GPU:1"
device_type: "GPU"
memory_limit: 11321648743
locality {
  bus_id: 2
}
incarnation: 10523617580257072059
physical_device_desc: "device: 1, name: Tesla K40m, pci bus id: 0000:82:00.0, compute capability: 3.5"
]
True


In [3]:
# confirm TensorFlow sees the GPU
from tensorflow.python.client import device_lib
assert 'GPU' in str(device_lib.list_local_devices())

# confirm Keras sees the GPU
assert len(K.tensorflow_backend._get_available_gpus()) > 0

In [None]:
# Task03_liver dir in same directory as notebook
data_path = './Task03_Liver/'

In [None]:
# info about dataset in json file
with open(data_path + 'dataset.json') as f:
    d = json.load(f)   
    
    # paths to training set images with label
    train_paths = d['training']
    
    # paths to testset images with label
    test_paths = d['test'] 

In [None]:
# change to data dir 
os.chdir(data_path)
print(os.getcwd())

# Load the train set as SITK images

In [None]:
# load images and labels, loading all takes some time, take 50 for now
train_imgs = [sitk.ReadImage(train_instance['image']) for train_instance in train_paths[50:100]]
train_lbls = [sitk.ReadImage(train_instance['label']) for train_instance in train_paths[50:100]]

In [None]:
# train images as numpy
np_train_imgs = [sitk.GetArrayFromImage(i) for i in train_imgs]
np_train_lbls = [sitk.GetArrayFromImage(i) for i in train_lbls]

# Spacing
Images do not have the same spacings. We will first resample. For this we need the spacings in the SITK images. Note that when converting sitk to numpy the z axis is placed at the front. Spacings in order: (x, y, z), numpy image: (z, x, y)

In [None]:
for image in train_imgs:
    print(image.GetSpacing())

## Resampling
to 1mm x 1mm x 1mm resolution => images should have different sizes (not all 512 x 512 x N anymore). 
For example, when the image has a shape of (512, 512, 74) and a spacing of (0.75, 0.75, 2),
you can calculate how wide the image is along the x-axis: 512 * 0.75 mm = 384 mm. As a tip, look for “scipy zoom”.

In [None]:
def resample(np_imgs, spacings, order):
    """
    Resample to 1mm x 1mm x 1mm. 
    np_imgs: list of images or labels to be resampled as numpy
    spacings: spacings to resample with, order: (z, x, y)
    """ 
    resampled = []
    
    for i in range(len(np_imgs)): 
        # apply zoom with spacing, different order for labels and imgs
        resampled.append(zoom(np_imgs[i], spacings[i], order=order))

    return resampled

In [None]:
# store shapes before to check
shapes_before = [img.shape for img in np_train_imgs]

# spacings from sitk images
spacings = [img.GetSpacing() for img in train_imgs]

# change order
spacings = [(z, x, y) for x, y, z in spacings]          # order: (z, x, y)

In [None]:
# resample train images and labels, order 3 for imgs, 1 (neigherest neighbour) for labels
np_train_imgs = resample(np_train_imgs, spacings, order=3)
np_train_lbls = resample(np_train_lbls, spacings, order=1)

In [None]:
# lets print what happened.
print("Before\t\t\tSpacings\t\tAfter\t\t\tLabels")
for i in range(len(np_train_imgs)):    
    # round for printing
    spacing_round = [(round(a, 1), round(b, 1), round(c, 1)) for a, b, c in spacings]    
    print("{}\t\t{}\t\t{}\t\t{}".format(shapes_before[i] , spacing_round[i], 
                                        np_train_imgs[i].shape, np_train_lbls[i].shape))

## Save resampled data as pickle and load
I put this in the Task03_liver folder.

In [None]:
with open('./data_liver.pickle', 'wb') as handle:
    pickle.dump({'images': np_train_imgs, 'labels': np_train_lbls}, handle, protocol=pickle.HIGHEST_PROTOCOL)

## Loading from pickle, start here if you saved the pickle

In [4]:
data_path = './Task03_Liver/'
os.chdir(data_path)
print(os.getcwd())

with open('./data_liver.pickle', 'rb') as handle:
    data = pickle.load(handle)
np_train_imgs = data['images']
np_train_lbls = data['labels']

/nfs/home4/mbotros/ISMI_project/Task03_Liver


## Make the labels binary
For now we will focus on only on segmentation of the liver. Set the cancer labels to liver labels. Remove this line if you want to segment cancer aswell.

In [5]:
np_train_lbls = [np.where(lbl != 2, lbl, 1) for lbl in np_train_lbls]

## Do we have imbalances in our data? 

In [6]:
# count the labels of train images
sums = np.zeros(3)
for lbs in np_train_lbls:
    labels, counts = np.unique(lbs, return_counts=True)
    
    # if there are only 2 labels
    if len(counts) == 2:
        sums[:2]+=counts
    else:
        sums+=counts

In [7]:
# print percentages of voxels.
total = sum(sums)
print("{:.2f}% background, {:.2f}% liver, {:.2f}% cancer.".format(sums[0]/total*100, sums[1]/total*100, sums[2]/total*100))

98.02% background, 1.98% liver, 0.00% cancer.


In [8]:
class DataSet:
    
    def __init__(self, imgs, lbls=None):
        self.imgs = imgs
        self.lbls = lbls
    
    def get_lenght(self):
        return len(self.imgs)
    
    def show_image(self, i):
        if self.lbls != None: 
            plt.rcParams['figure.figsize'] = [8, 8]
            multi_slice_viewer(self.imgs[i], view='axial', overlay_1=self.lbls[i], overlay_1_thres=1, 
                   overlay_2=self.lbls[i], overlay_2_thres=2, overlay_2_cmap='coolwarm', overlay_2_alpha=0.75)
        else:
            plt.rcParams['figure.figsize'] = [8, 8]
            multi_slice_viewer(self.imgs[i], view='axial')  

## Shuffle the np images and np labels
In case there might be some bias in the order in which the images are stored. The images seem already shuffled so lets skip this for now.

In [9]:
# indexes = list(range(len(np_train_imgs)))
# random.shuffle(indexes)

# np_train_imgs = list(np.asarray(np_train_imgs)[indexes])
# np_train_lbls = list(np.asarray(np_train_lbls)[indexes])

# Split for training and validation

In [10]:
# make a small data set of training images, as numpy
validation_percent = 0.2 # coefficient to define validation dataset (value between 0 and 1)
n_validation_imgs = int(validation_percent * len(np_train_imgs))

train_set = DataSet(np_train_imgs[:n_validation_imgs], np_train_lbls[:n_validation_imgs])
val_set   = DataSet(np_train_imgs[n_validation_imgs:], np_train_lbls[n_validation_imgs:])

# Patch extractor
We re-use the patch extractor from assignment 7, but modify it to get 3D patches from a 3D image.
We can add augmentations later in the patch extractor. Note the extra dimension in the shape of patch_out and target_out. This doesn't work if the patch size doesn't fit in the image.

In [11]:
class PatchExtractor:

    def __init__(self, patch_size, fromLiver):
        self.patch_size = patch_size 
        self.fromLiver = fromLiver
    
    def get_patch(self, image, label):
        ''' 
        Get a 3D patch of patch_size from 3D input image, along with corresponding 3D label map.
        Pick random location of the patch inside the image. The point is at the center of the patch.
        We first pad the image to not go out of bounds when extracting the patch.
        image: a numpy array representing the input image
        label: a numpy array representing the labels corresponding to input image
        '''
        
        # size of patch in each dimension
        pz, px, py = self.patch_size
        
#         print('Patch_size: {}'.format(patch_size))
#         print('Image_size: {}'.format(image.shape))

        # pad with the min value in the image
        min_val = np.min(image)
        
        # pad with half the patch size, I assume even patch size
        padded_img = np.pad(image, ((pz//2, pz//2), (px//2, px//2), (py//2, py//2)), 'constant', constant_values=min_val)
        padded_lbl = np.pad(label, ((pz//2, pz//2), (px//2, px//2), (py//2, py//2)), 'constant')
        
#         print('Padded_size: {}'.format(padded_img.shape))

        # centre of the patch: a random point from the liver in the non padded image
        if self.fromLiver:
            # getting the liver labeled points
            liver_ind = np.argwhere(label == 1)  
            
            # get a random point from the liver labeled points
            r = randint(0, len(liver_ind))
            z = liver_ind[r][0]
            x = liver_ind[r][1]
            y = liver_ind[r][2]
            
        # centre of the patch: a random location in the non padded image    
        else:
            dims = image.shape
            z = randint(0, dims[0]) 
            x = randint(0, dims[1]) 
            y = randint(0, dims[2])   
            
        # z, x, y is the left bottom corner of the patch in the padded image (index shift with pad size)     
        # take a patch, with the random point at the center in the padded img
        patch  = padded_img[z:z+pz, x:x+px, y:y+py].reshape(pz, px, py, 1)
        target = padded_lbl[z:z+pz, x:x+px, y:y+py].reshape(pz, px, py, 1)

        return patch, target

In [12]:
# get an image and a label from our train set
image = train_set.imgs[0]
label = train_set.lbls[0]

# test PatchExtractor
patch_size = (188, 188, 188)
patch_extractor = PatchExtractor(patch_size=patch_size, fromLiver=True)

# lets check some patches
patch, target = patch_extractor.get_patch(image, label)

print(patch.shape)
print(target.shape)

# show patch
plt.rcParams['figure.figsize'] = [8, 8]            
multi_slice_viewer(patch.reshape(patch_size), view='axial', overlay_1=target.reshape(patch_size), overlay_1_thres=1, 
                   overlay_2=target.reshape(patch_size), overlay_2_thres=2, overlay_2_cmap='coolwarm', overlay_2_alpha=0.75)

(188, 188, 188, 1)
(188, 188, 188, 1)


<IPython.core.display.Javascript object>

# Batch creator
Lets also reuse the batch creator from assignment 7. We are going to use valid convolutions, which means the output of our network will be smaller than the input. The purpose of this batchcreator is the make batches consisting of patches with their corresponding labels (for the network to train on). Since a UNet with valid convolutions has a smaller output than input, we need to crop the label based on the target size aswell. And labels should be in onehot.

In [13]:
class BatchCreator:
    
    def __init__(self, patch_extractor, dataset, target_size):
        self.patch_extractor = patch_extractor
        self.target_size = target_size # size of the output, can be useful when valid convolutions are used        
        self.imgs = dataset.imgs
        self.lbls = dataset.lbls                
        self.n = len(self.imgs)
        self.patch_size = self.patch_extractor.patch_size
    
    def create_image_batch(self, batch_size):
        '''
        returns a single (batch of?) patches (x) with corresponding labels (y) in one-hot structure
        '''
        x_data = np.zeros((batch_size, *self.patch_extractor.patch_size, 1))  # 1 channel
        y_data = np.zeros((batch_size, *self.target_size, 2)) # one-hot encoding with 2 classes
        
        for i in range(0, batch_size):
        
            random_index = np.random.choice(len(self.imgs))                   # pick random image
            img, lbl = self.imgs[random_index], self.lbls[random_index]       # get image and segmentation map
            
            # clip values outside [-1000, 3000] and normalize image intensity to range [0., 1.]      
            img = np.clip(img, -1000, 3000)
            img = (img - np.min(img)) / np.ptp(img)     
            
            # get a patch with corresponding labels from the patch extractor
            patch_img, patch_lbl = self.patch_extractor.get_patch(img, lbl)   
            
            # crop labels based on target_size           
            ph = (self.patch_extractor.patch_size[0] - self.target_size[0]) // 2    
            pw = (self.patch_extractor.patch_size[1] - self.target_size[1]) // 2
            pd = (self.patch_extractor.patch_size[2] - self.target_size[2]) // 2
            
            # take the cropped patch, it contains labels with values 0,1,2
            cropped_patch = patch_lbl[ph:ph+self.target_size[0], pw:pw+self.target_size[1], pd:pd+self.target_size[2]].squeeze() 
            
            # instead of 0,1,2 label values we want categorical/onehot => 0: [1, 0, 0], 1: [0, 1, 0], 2: [0, 0, 1]
            onehot = to_categorical(cropped_patch, num_classes=2)
            
            x_data[i, :, :, :, :] = patch_img
            y_data[i, :, :, :, :] = onehot
        
        return (x_data.astype(np.float32), y_data.astype(np.float32))
    
    def get_image_generator(self, batch_size):
        '''returns a generator that will yield image-batches infinitely'''
        while True:
            yield self.create_image_batch(batch_size)

# 3D UNet Model
Start with this model, we can adapt this later if needed. Build like the net from: 
3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. Ozgun Cicek et al, 2016.

In [14]:
# make block of two convolve3D's
def unet_block(inputs, n_filters, padding, up_conv=False, batchnorm=False):
    # 3d convolve, 32 3x3x3 filters 
    c1 = Conv3D(n_filters, (3,3,3), activation='relu', padding=padding)(inputs)
    if batchnorm:
        c1 = BatchNormalization()(c1)
    
    # up conv (normal conv in the expanding path) has same number of filters twice
    if up_conv:
        c2 = Conv3D(n_filters, (3, 3, 3), activation='relu', padding=padding)(c1)
    else:          # normal convs have twice the filters in the second conv
        c2 = Conv3D(n_filters*2, (3, 3, 3), activation='relu', padding=padding)(c1)
        
    if batchnorm:
        c2 = BatchNormalization()(c2)
    
    return c2

In [15]:
# 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. Ozgun Cicek et al, 2016.
def build_unet_3d(initial_filters, padding, batchnorm=True):
    
    ## CONTRACTING PATH

    # (spac_dim_1, space_dim_2, space_dim_3, channels)
    inputs = Input(shape=(188, 188, 188, 1))

    # First conv pool, 32 filters and 64 filters    
    block_1    = unet_block(inputs, initial_filters, padding=padding, batchnorm=batchnorm) 
    max_pool_1 = MaxPooling3D(pool_size=(2, 2, 2), strides=2)(block_1)  # 2×2×2 max pooling with strides two
                                                                        # needs even spacial_dimensions as input
    # second conv pool, 64 filters, 128 filters    
    block_2    = unet_block(max_pool_1, initial_filters*2, padding=padding, batchnorm=batchnorm)
    max_pool_2 = MaxPooling3D(pool_size=(2, 2, 2), strides=2)(block_2)
    
    # third conv pool, 128 filters, 256 filters    
    block_3    = unet_block(max_pool_2, initial_filters*4, padding=padding, batchnorm=batchnorm)
    max_pool_3 = MaxPooling3D(pool_size=(2, 2, 2), strides=2)(block_3)
    
    # just a conv block without maxpooling, 256 filters and 512 filters
    conv_4     = unet_block(max_pool_3, initial_filters*8, padding=padding, batchnorm=batchnorm)
    
    ## EXPANDING PATH   
    
    #TODO: check Conv3DTranspose correctly applied
    
    # round 1
    up_conv_3  = Conv3DTranspose(16*initial_filters, (2, 2, 2), strides=(2, 2, 2), padding=padding)(conv_4)
    crop_3     = Cropping3D(cropping=4)(block_3) 
    concat_3   = concatenate([crop_3, up_conv_3])  
    up_block_3 = unet_block(concat_3, 8*initial_filters, padding, up_conv=True, batchnorm=batchnorm)
    
    # round 2
    up_conv_2  = Conv3DTranspose(8*initial_filters, (2, 2, 2), strides=(2, 2, 2), padding=padding)(up_block_3) 
    crop_2     = Cropping3D(cropping=16)(block_2) 
    concat_2   = concatenate([crop_2, up_conv_2])  
    up_block_2 = unet_block(concat_2, 4*initial_filters, padding, up_conv=True, batchnorm=batchnorm)
    
    # round 3
    up_conv_1  = Conv3DTranspose(4*initial_filters, (2, 2, 2), strides=(2, 2, 2), padding=padding)(up_block_2) 
    crop_1     = Cropping3D(cropping=40)(block_1) 
    concat_1   = concatenate([crop_1, up_conv_1])  
    up_block_1 = unet_block(concat_1, 2*initial_filters, padding, up_conv=True, batchnorm=batchnorm)
    
    # finish with 1x1x1 conv, 3 filters, # labels, softmax or ReLU?
    finish = Conv3D(2, (1,1,1), activation='softmax', padding=padding)(up_block_1)
    
    model = Model(inputs, finish) 
    print(model.summary(line_length=150))
    
    return model

In [16]:
unet_3d = build_unet_3d(initial_filters=32, padding='valid')

______________________________________________________________________________________________________________________________________________________
Layer (type)                                     Output Shape                     Param #           Connected to                                      
input_1 (InputLayer)                             (None, 188, 188, 188, 1)         0                                                                   
______________________________________________________________________________________________________________________________________________________
conv3d_1 (Conv3D)                                (None, 186, 186, 186, 32)        896               input_1[0][0]                                     
______________________________________________________________________________________________________________________________________________________
batch_normalization_1 (BatchNormalization)       (None, 186, 186, 186, 32)        128         

## Testing the batch generator

In [17]:
# define parameters for the batch creator
patch_size  = (188, 188, 188)  # isotropic patch size
target_size = (100, 100, 100)  # output size, smaller since valid convolutions are used
batch_size  = 1                # number of patches in a mini-batch, for segmentation 1 is fine, since the 
                               # output of the net is many thousands of values per patch, which all contribute to the loss

# initialize patch generator and batch creator
patch_generator       = PatchExtractor(patch_size, fromLiver=True)
batch_generator_train = BatchCreator(patch_generator, train_set, target_size=target_size)
batch_generator_val   = BatchCreator(patch_generator, val_set, target_size=target_size)

# get one minibatch
x_data, y_data = batch_generator_train.create_image_batch(batch_size)

print("(batch, d, h, w, channels)")
print('xdata has shape: {}'.format(x_data.shape))
print('ydata has shape: {}'.format(y_data.shape))
print('Occuring values in true labels: {}'.format(np.unique(y_data)))
print('Min of input: {}'.format(np.min(x_data)))
print('Max of input: {}'.format(np.max(x_data)))

(batch, d, h, w, channels)
xdata has shape: (1, 188, 188, 188, 1)
ydata has shape: (1, 100, 100, 100, 2)
Occuring values in true labels: [0. 1.]
Min of input: 0.0
Max of input: 0.7025898694992065


## Define a logger which saves the losses and saves the best model

In [18]:
class Logger(keras.callbacks.Callback):

    # logg losses and accs, add dice later
    def __init__(self, data_dir, model_name):  
        self.model_filename = os.path.join(data_dir, model_name + '.h5')        
        self.tr_losses = []  
#         self.tr_accs = []
        self.val_losses = []  
#         self.val_accs = []     
        self.best_val_loss = float("inf") 
        self.best_model = None     
       
    def on_epoch_end(self, batch, logs={}):
        # add validation info
        self.val_losses.append(logs.get('val_loss'))
#         self.val_accs.append(logs.get('val_acc')) 
        self.tr_losses.append(logs.get('loss'))
#         self.tr_accs.append(logs.get('acc')) 
        self.plot()

        # safe best model after epoch end
        if self.val_losses[-1] < self.best_val_loss:
            self.best_val_loss = self.val_losses[-1]
            self.model.save(self.model_filename) # save best model to disk
            print('Best model saved as {}'.format(self.model_filename))
         
    def plot(self): 
        clear_output()
        plt.figure(figsize=(8, 4))
        plt.ylim([0, 1])
        n = len(self.val_losses) + 1         
        plt.plot(range(1, n), self.tr_losses, label='train loss')         
#         plt.plot(range(1, n), self.tr_accs, label='train accuracy')
        plt.plot(range(1, n), self.val_losses, label='val loss')        
#         plt.plot(range(1, n), self.val_accs, label='val accuracy') 
        plt.legend(loc='lower left')
        plt.show()

In [19]:
# make a data dir to store best model
print(os.getcwd())
data_dir = '../data'
if not os.path.exists(data_dir):
    os.makedirs(data_dir)

/nfs/home4/mbotros/ISMI_project/Task03_Liver


In [20]:
# back to inline, I don't know how to plot in notebook mode
%matplotlib inline

## Dice Loss
Dice loss seems to be a good pick for 3D segmentation with class inbalances. We have to look if this works like this.
In our case y_pred is output of the softmax.(see https://arxiv.org/pdf/1707.03237.pdf)

$$ \textbf{Dice loss} =  1 - \frac{2 \hspace{0.3em}|X \cap Y|}{|X|+ |Y|} $$


**For binary volumes of N voxels (Milletari et al., 2016, VNet):**

$$ \textbf{Dice loss} = 1 -\frac{2 \sum_{i}^{N} p_{i} g_{i}}{\sum_{i}^{N} p_{i}^{2}+\sum_{i}^{N} g_{i}^{2}} $$

In [21]:
# dice loss as above
def dice_loss_bv(y_true, y_pred, epsilon=1e-6):
    ''' 
    Dice loss calculation.
    Assumes the channels_last format.
    y_true: One hot encoding of ground truth
    y_pred: Network output, must sum to 1 over c channel (such as after softmax) 
    '''
    # for every voxel of the prediction the probabililty of being foreground (liver
    P = K.sum(y_pred * [0., 1.], axis=-1)
    
    # for every voxel of the groundtruth the label (0: background, 1: foreground)
    G = K.sum(y_true * [0., 1.], axis=-1)
    
    return 1. - (2. * K.sum(P * G) + epsilon) / (K.sum(P**2.) + K.sum(G**2.) + epsilon)

**Proposed in Milletari et al. [8] as a loss function, the 2-class variant of the Dice loss, denoted DL2, can be expressed as**

$$
\mathrm{DL}_{2}=1-\frac{\sum_{n=1}^{N} p_{n} r_{n}+\epsilon}{\sum_{n=1}^{N} p_{n}+r_{n}+\epsilon}-\frac{\sum_{n=1}^{N}\left(1-p_{n}\right)\left(1-r_{n}\right)+\epsilon}{\sum_{n=1}^{N} 2-p_{n}-r_{n}+\epsilon}
$$

Sudre, Carole H., et al. "Generalised dice overlap as a deep learning loss function for highly unbalanced segmentations." https://arxiv.org/pdf/1707.03237.pdf


**Let R be the reference foreground segmentation
(gold standard) with voxel values $r_n$, and P the predicted probabilistic map for the foreground label over N image elements $p_n$, with the background class probability being 1 − P. ***

In [22]:
def dice_loss(y_true, y_pred, epsilon=1e-6):
    ''' 
    Dice loss calculation in a binary classification (foreground vs. background) formulation.
    Assumes the channels_last format.
    y_true: One hot encoding of ground truth
    y_pred: Network output, must sum to 1 over c channel (such as after softmax) 
    '''
    # for every voxel of the prediction the probabililty of being foreground (liver
    P = K.sum(y_pred * [0., 1.], axis=-1)
    
    # for every voxel of the groundtruth the label (0: background, 1: foreground)
    R = K.sum(y_true * [0., 1.], axis=-1)
    
    a = K.sum(P * R) + epsilon
    b = K.sum(P + R) + epsilon
    c = K.sum((1 - P) * (1 - R)) + epsilon 
    d = K.sum((2 - P - R)) + epsilon
    
    return 1 - a/b - c/d

## Now we define parameters, compile the model and train the network 

In [23]:
learning_rate   = 10**-4
optimizer       = Adam(lr=learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
metrics         = ['accuracy']  # probably not useful (?)
steps_per_epoch = 5
epochs          = 10
logger          = Logger(data_dir, '3D-UNet-21-05')

image_generator_train = batch_generator_train.get_image_generator(batch_size)
image_generator_val   = batch_generator_val.get_image_generator(batch_size)

# compile model
unet_3d.compile(optimizer=optimizer, loss=dice_loss, metrics=metrics)

In [24]:
# train the model
unet_3d.fit_generator(generator=image_generator_train, 
                    steps_per_epoch=steps_per_epoch, 
                    epochs=epochs, 
                    validation_data=image_generator_val,
                    verbose=1,
                    validation_steps=1,
                    callbacks=[logger])

Epoch 1/10


ResourceExhaustedError: OOM when allocating tensor with shape[1,64,184,184,184]
	 [[Node: max_pooling3d_1/MaxPool3D = MaxPool3D[T=DT_FLOAT, data_format="NDHWC", ksize=[1, 2, 2, 2, 1], padding="VALID", strides=[1, 2, 2, 2, 1], _device="/job:localhost/replica:0/task:0/device:GPU:0"](batch_normalization_2/cond/Merge)]]
	 [[Node: loss/mul/_983 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_8852_loss/mul", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

Caused by op 'max_pooling3d_1/MaxPool3D', defined at:
  File "/hpc/sw/python-3.5.2/lib/python3.5/runpy.py", line 184, in _run_module_as_main
    "__main__", mod_spec)
  File "/hpc/sw/python-3.5.2/lib/python3.5/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/hpc/sw/python-3.5.2/lib/python3.5/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/hpc/sw/python-3.5.2/lib/python3.5/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/hpc/sw/python-3.5.2/lib/python3.5/site-packages/ipykernel/kernelapp.py", line 478, in start
    self.io_loop.start()
  File "/hpc/sw/python-3.5.2/lib/python3.5/site-packages/zmq/eventloop/ioloop.py", line 177, in start
    super(ZMQIOLoop, self).start()
  File "/hpc/sw/python-3.5.2/lib/python3.5/site-packages/tornado/ioloop.py", line 888, in start
    handler_func(fd_obj, events)
  File "/hpc/sw/python-3.5.2/lib/python3.5/site-packages/tornado/stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "/hpc/sw/python-3.5.2/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 440, in _handle_events
    self._handle_recv()
  File "/hpc/sw/python-3.5.2/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 472, in _handle_recv
    self._run_callback(callback, msg)
  File "/hpc/sw/python-3.5.2/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 414, in _run_callback
    callback(*args, **kwargs)
  File "/hpc/sw/python-3.5.2/lib/python3.5/site-packages/tornado/stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "/hpc/sw/python-3.5.2/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 281, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/hpc/sw/python-3.5.2/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 232, in dispatch_shell
    handler(stream, idents, msg)
  File "/hpc/sw/python-3.5.2/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 397, in execute_request
    user_expressions, allow_stdin)
  File "/hpc/sw/python-3.5.2/lib/python3.5/site-packages/ipykernel/ipkernel.py", line 208, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/hpc/sw/python-3.5.2/lib/python3.5/site-packages/ipykernel/zmqshell.py", line 533, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/hpc/sw/python-3.5.2/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2728, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/hpc/sw/python-3.5.2/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2850, in run_ast_nodes
    if self.run_code(code, result):
  File "/hpc/sw/python-3.5.2/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2910, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-16-26afc44b8a32>", line 1, in <module>
    unet_3d = build_unet_3d(initial_filters=32, padding='valid')
  File "<ipython-input-15-2f92b20ad753>", line 11, in build_unet_3d
    max_pool_1 = MaxPooling3D(pool_size=(2, 2, 2), strides=2)(block_1)  # 2×2×2 max pooling with strides two
  File "/home/mbotros/.local/lib/python3.5/site-packages/keras/engine/base_layer.py", line 457, in __call__
    output = self.call(inputs, **kwargs)
  File "/home/mbotros/.local/lib/python3.5/site-packages/keras/layers/pooling.py", line 374, in call
    data_format=self.data_format)
  File "/home/mbotros/.local/lib/python3.5/site-packages/keras/layers/pooling.py", line 432, in _pooling_function
    padding, data_format, pool_mode='max')
  File "/home/mbotros/.local/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py", line 4024, in pool3d
    data_format=tf_data_format)
  File "/home/mbotros/.local/lib/python3.5/site-packages/tensorflow/python/ops/gen_nn_ops.py", line 2870, in max_pool3d
    padding=padding, data_format=data_format, name=name)
  File "/home/mbotros/.local/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/home/mbotros/.local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2956, in create_op
    op_def=op_def)
  File "/home/mbotros/.local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1470, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

ResourceExhaustedError (see above for traceback): OOM when allocating tensor with shape[1,64,184,184,184]
	 [[Node: max_pooling3d_1/MaxPool3D = MaxPool3D[T=DT_FLOAT, data_format="NDHWC", ksize=[1, 2, 2, 2, 1], padding="VALID", strides=[1, 2, 2, 2, 1], _device="/job:localhost/replica:0/task:0/device:GPU:0"](batch_normalization_2/cond/Merge)]]
	 [[Node: loss/mul/_983 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_8852_loss/mul", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]


## Getting the full segmentation map
Like this but then 3D:

![seg_diagram.png](seg_diagram.png)


In [None]:
def padding(image, patch_size, target_size):
    """
    Adding the red border (see example above) to the image. Which is needed for when we don't have full context. 
    Pad with lowest occuring values.
    image       : the input image (as numpy)
    patch_size  : patch_size of the input for the UNet
    target_size : output size of the model, needed to calculate how much to padd in each dimension. 
    """
    z, y, x = patch_size
    
    # pad with min value from image, always safe
    min_val = np.min(image)
    
    # size of padding for each dimension
    pad_z = (z - target_size[0]) // 2
    pad_x = (x - target_size[1]) // 2
    pad_y = (y - target_size[2]) // 2
    
    # pad with a tuple for how much on each side for every dimension
    padded_input = np.pad(image, ((pad_z, pad_z), (pad_x, pad_x), (pad_y, pad_y)), 'constant', constant_values=min_val)
    
    return padded_input

In [None]:
def predict_image_segmentation(model, image, target_size, patch_size):
    """
    Give a full segmentation map (same size as input_image) using the model. 
    model       : the model to do the prediction
    image       : the input image (as numpy)
    target_size : output size of the model (since we use valid convutions the output gets smaller)
    patch_size: : the size of the patch that is put into the model
    """
    
    # clip values outside [-1000, 3000] and normalize image intensity to range [0., 1.]      
    image = np.clip(image, -1000, 3000)
    image = (image - np.min(image)) / np.ptp(image)    
    
    # pad the input image:
    pad_img = padding(image, patch_size, target_size)  

    print("Image size: {}".format(image.shape))
    print("Padded image size: {}".format(pad_img.shape))
    
    dims = image.shape
    # how many times target size fits in a dimension 
    pz = dims[0] // target_size[0] 
    px = dims[1] // target_size[1] 
    py = dims[2] // target_size[2] 
    
    # segmentation map, same size as input image
    segmentation = np.zeros(image.shape)   
    
    for z in range(pz, -1, -1):         
        for x in range(px, -1, -1):
            for y in range(py, -1, -1):  
                
                # shift starting point with target_size
                start_z = z * target_size[0]
                start_x = x * target_size[1]
                start_y = y * target_size[2]
                
                # if the patch does not fit:
                if start_z + patch_size[0] > pad_img.shape[0]:
                    start_z = pad_img.shape[0] - patch_size[0]
                if start_x + patch_size[1] > pad_img.shape[1]:
                    start_x = pad_img.shape[1] - patch_size[1]
                if start_y + patch_size[2] > pad_img.shape[2]:
                    start_y = pad_img.shape[2] - patch_size[2]
                
                # Get patch: shift with target_size, take patch_size                
                patch = pad_img[start_z:start_z + patch_size[0], 
                                start_x:start_x + patch_size[1], 
                                start_y:start_y + patch_size[2]]     

                # Reshape for u-net and make prediction:
                patch = np.reshape(patch, (1, patch_size[0], patch_size[1], patch_size[2], 1))
                prediction = model.predict(patch)

                # Put the prediction in segmentation map, shift with target_size, take target_size
                segmentation[start_z:start_z + target_size[0], 
                             start_x:start_x + target_size[1], 
                             start_y:start_y + target_size[2]] = np.argmax(np.squeeze(prediction), axis=3)
    
    return segmentation 

## Inspecting the prediction

In [None]:
# take an image and a label from the validation set
image = val_set.imgs[2][200:400, :, -200:]
label = val_set.lbls[2][200:400, :, -200:]

In [None]:
# load best model, note that we shuffle now and you shouldnt use the old models trained with unshuffeled data
unet_3d = load_model(os.path.join(data_dir, '3D_UNet_DL2' + '.h5'), custom_objects={'dice_loss': dice_loss})

In [None]:
# predict the segmentation map
segmentation = predict_image_segmentation(unet_3d, image, target_size, patch_size)

In [None]:
print(np.unique(segmentation, return_counts=True))
print(np.unique(label, return_counts=True))

In [None]:
# plot slices
%matplotlib notebook
s = 100
slice_img = image[s, :, :]
slice_lbl = label[s, :, :]
slice_seg = segmentation[s, :, :]

masked_lbl = np.ma.masked_where(slice_lbl < 1, slice_lbl)
masked_seg = np.ma.masked_where(slice_seg < 1, slice_seg)

plt.figure()
plt.subplot(1,2,1).set_title('Prediction')
plt.imshow(slice_img, cmap='gray')
plt.imshow(masked_seg, cmap='coolwarm', alpha = 0.75)

plt.subplot(1,2,2).set_title('Ground truth')
plt.imshow(slice_img, cmap='gray')
plt.imshow(masked_lbl, cmap='coolwarm', alpha = 0.75)
plt.show()

In [None]:
# show segmentation
plt.rcParams['figure.figsize'] = [8, 8]            
multi_slice_viewer(image, view='axial', overlay_1=segmentation, overlay_1_thres=1, 
                   overlay_2=segmentation, overlay_2_thres=2, overlay_2_cmap='coolwarm', overlay_2_alpha=0.75)

In [None]:
# show ground truth
plt.rcParams['figure.figsize'] = [8, 8]            
multi_slice_viewer(image, view='axial', overlay_1=label, overlay_1_thres=1, 
                   overlay_2=label, overlay_2_thres=2, overlay_2_cmap='coolwarm', overlay_2_alpha=0.75)