MISA 3D Brain MRI Segmentation using 2D UNet

1. Preprocessing - Bias Correction

2. Method - Patch Based

3. Data Augmentation - Yes

##Importing the Libraries

In [None]:
import cv2
import glob
import warnings
import scipy.misc
import numpy as np
import nibabel as nib
!pip install simpleitk
import SimpleITK as sitk
from scipy import ndimage
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from keras.preprocessing.image import ImageDataGenerator

from google.colab import drive
drive._mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


##Defining the Parameters

In [None]:
# Image Parameters
IMAGE_SIZE = (256, 128, 256)

# Training, Testing and Validation Parameters
TRAINING_VOLUMES = [0, 1, 2, 3, 4, 5, 6, 7, 8]
VALIDATION_VOLUMES = [9]

# Hyperparameters
N_CLASSES = 4
N_INPUT_CHANNELS = 1
PATCH_SIZE = (32, 32)
PATCH_STRIDE = (32, 32)

# Data Preparation Parameters
CONTENT_THRESHOLD = 0.3 # To Get Rid of Useless Information in the Image

# Training Parameters
N_EPOCHS = 200
BATCH_SIZE = 64
PATIENCE = 20
MODEL_FNAME_PATTERN = 'model.h5'
OPTIMISER = 'Adam'
LOSS = 'categorical_crossentropy'
dropout_rate = 0.40

##Define UNet Architecture

In [None]:
"""def get_unet(img_size=PATCH_SIZE, n_classes=N_CLASSES, n_input_channels=N_INPUT_CHANNELS, scale=1):
    inputs = keras.Input(shape=img_size + (n_input_channels, ))

    # Encoding Path of the UNet
    conv1 = layers.Conv2D(32*scale, (3, 3), padding="same", activation='relu')(inputs)
    drop1 = layers.Dropout(rate=dropout_rate)(conv1, training=True)
    max1 = layers.MaxPooling2D((2, 2))(drop1)
    # max1 = layers.MaxPooling2D((2, 2))(conv1)

    conv2 = layers.Conv2D(64*scale, (3, 3), padding="same", activation='relu')(max1)
    drop2 = layers.Dropout(rate=dropout_rate)(conv2, training=True)
    # max2 = layers.MaxPooling2D((2, 2))(conv2)
    max2 = layers.MaxPooling2D((2, 2))(drop2)

    conv3 = layers.Conv2D(128*scale, (3, 3), padding="same", activation='relu')(max2)
    drop3 = layers.Dropout(rate=dropout_rate)(conv3, training=True)
    # max3 = layers.MaxPooling2D((2, 2))(conv3)
    max3 = layers.MaxPooling2D((2, 2))(drop3)

    lat = layers.Conv2D(256*scale, (3, 3), padding="same", activation='relu')(max3)
    drop4 = layers.Dropout(rate=dropout_rate)(lat, training=True)

    # Decoding Path of the UNet
    #up1 = layers.UpSampling2D((2, 2))(lat)
    up1 = layers.UpSampling2D((2, 2))(drop4)
    concat1 = layers.concatenate([conv3, up1], axis=-1)
    conv4 = layers.Conv2D(128*scale, (3, 3), padding="same", activation='relu')(concat1)
    drop5 = layers.Dropout(rate=dropout_rate)(conv4, training=True)
    
    #up2 = layers.UpSampling2D((2, 2))(conv4)
    up2 = layers.UpSampling2D((2, 2))(drop5)
    concat2 = layers.concatenate([conv2, up2], axis=-1)
    conv5 = layers.Conv2D(64*scale, (3, 3), padding="same", activation='relu')(concat2)
    drop6 = layers.Dropout(rate=dropout_rate)(conv5, training=True)
    
    #up3 = layers.UpSampling2D((2, 2))(conv5)
    up3 = layers.UpSampling2D((2, 2))(drop6)
    concat3 = layers.concatenate([conv1, up3], axis=-1)
    conv6 = layers.Conv2D(32*scale, (3, 3), padding="same", activation='relu')(concat3)
    drop7 = layers.Dropout(rate=dropout_rate)(conv6, training=True)
    
    #outputs = layers.Conv2D(n_classes, (1, 1), activation="softmax")(conv6)
    outputs = layers.Conv2D(n_classes, (1, 1), activation="softmax")(drop7)

    model = keras.Model(inputs, outputs)

    return model"""

'def get_unet(img_size=PATCH_SIZE, n_classes=N_CLASSES, n_input_channels=N_INPUT_CHANNELS, scale=1):\n    inputs = keras.Input(shape=img_size + (n_input_channels, ))\n\n    # Encoding Path of the UNet\n    conv1 = layers.Conv2D(32*scale, (3, 3), padding="same", activation=\'relu\')(inputs)\n    drop1 = layers.Dropout(rate=dropout_rate)(conv1, training=True)\n    max1 = layers.MaxPooling2D((2, 2))(drop1)\n    # max1 = layers.MaxPooling2D((2, 2))(conv1)\n\n    conv2 = layers.Conv2D(64*scale, (3, 3), padding="same", activation=\'relu\')(max1)\n    drop2 = layers.Dropout(rate=dropout_rate)(conv2, training=True)\n    # max2 = layers.MaxPooling2D((2, 2))(conv2)\n    max2 = layers.MaxPooling2D((2, 2))(drop2)\n\n    conv3 = layers.Conv2D(128*scale, (3, 3), padding="same", activation=\'relu\')(max2)\n    drop3 = layers.Dropout(rate=dropout_rate)(conv3, training=True)\n    # max3 = layers.MaxPooling2D((2, 2))(conv3)\n    max3 = layers.MaxPooling2D((2, 2))(drop3)\n\n    lat = layers.Conv2D(256*sc

In [None]:
from keras.layers import BatchNormalization, Activation

In [None]:
def get_unet(img_size=PATCH_SIZE, n_classes=N_CLASSES, n_input_channels=N_INPUT_CHANNELS, scale=1):
    inputs = keras.Input(shape=img_size + (n_input_channels, ))

    # Encoding Path of the UNet (32-64-128-256-512)
    conv1   = Conv2D(32*scale, (3, 3), padding="same", activation='relu')(inputs)
    drop1   = Dropout(rate=dropout_rate)(conv1, training=True)
    max1    = MaxPooling2D((2, 2))(drop1)

    conv2   = Conv2D(64*scale, (3, 3), padding="same", activation='relu')(max1)
    drop2   = Dropout(rate=dropout_rate)(conv2, training=True)
    max2    = MaxPooling2D((2, 2))(drop2)

    conv3   = Conv2D(128*scale, (3, 3), padding="same", activation='relu')(max2)
    drop3   = Dropout(rate=dropout_rate)(conv3, training=True)
    max3    = MaxPooling2D((2, 2))(drop3)

    conv4   = Conv2D(256*scale, (3, 3), padding="same", activation='relu')(max3)
    drop4   = Dropout(rate=dropout_rate)(conv4, training=True)
    max4    = MaxPooling2D((2, 2))(drop4)

    lat     = Conv2D(512*scale, (3, 3), padding="same", activation='relu')(max4)
    drop5   = Dropout(rate=dropout_rate)(lat, training=True)

    # Decoding Path of the UNet
    up1     = UpSampling2D((2, 2))(drop5)
    concat1 = concatenate([conv4, up1], axis=-1)
    conv5   = Conv2D(256*scale, (3, 3), padding="same", activation='relu')(concat1)
    drop6   = Dropout(rate=dropout_rate)(conv5, training=True)
    
    up2     = UpSampling2D((2, 2))(drop6)
    concat2 = concatenate([conv3, up2], axis=-1)
    conv6   = Conv2D(128*scale, (3, 3), padding="same", activation='relu')(concat2)
    drop7   = Dropout(rate=dropout_rate)(conv6, training=True)
    
    up3     = UpSampling2D((2, 2))(drop7)
    concat3 = concatenate([conv2, up3], axis=-1)
    conv7   = Conv2D(64*scale, (3, 3), padding="same", activation='relu')(concat3)
    drop8   = Dropout(rate=dropout_rate)(conv7, training=True)

    up4     = UpSampling2D((2, 2))(drop8)
    concat4 = concatenate([conv1, up4], axis=-1)
    conv8   = Conv2D(32*scale, (3, 3), padding="same", activation='relu')(concat4)
    drop9   = Dropout(rate=dropout_rate)(conv8, training=True)
    
    outputs = Conv2D(n_classes, (1, 1), activation="softmax")(drop9)

    model   = Model(inputs, outputs)

    return model

##Generate Bias Corrected Images

In [None]:
"""
def N4(inputImagePath, maskImagePath, outputPath):
  # inputImagePath = input('Enter the path of the image : ')
  inputImage = sitk.ReadImage(inputImagePath)

  print("N4 bias correction runs.")

  # maskImage = sitk.ReadImage("06-t1c_mask.nii.gz")
  maskImage = sitk.OtsuThreshold(inputImage,0,1,200)
  maskImagePath = input('Enter the name of the mask image to be saved : ')
  sitk.WriteImage(maskImage, maskImagePath)
  print("Mask image is saved.")

  inputImage = sitk.Cast(inputImage,sitk.sitkFloat32)

  corrector = sitk.N4BiasFieldCorrectionImageFilter();

  output = corrector.Execute(inputImage,maskImage)

  outputPath = input("Enter the name of the Bias Field Corrected Image :")
  sitk.WriteImage(output,outputPath)
  print("Finished N4 Bias Field Correction.....")
"""

'\ndef N4(inputImagePath, maskImagePath, outputPath):\n  # inputImagePath = input(\'Enter the path of the image : \')\n  inputImage = sitk.ReadImage(inputImagePath)\n\n  print("N4 bias correction runs.")\n\n  # maskImage = sitk.ReadImage("06-t1c_mask.nii.gz")\n  maskImage = sitk.OtsuThreshold(inputImage,0,1,200)\n  maskImagePath = input(\'Enter the name of the mask image to be saved : \')\n  sitk.WriteImage(maskImage, maskImagePath)\n  print("Mask image is saved.")\n\n  inputImage = sitk.Cast(inputImage,sitk.sitkFloat32)\n\n  corrector = sitk.N4BiasFieldCorrectionImageFilter();\n\n  output = corrector.Execute(inputImage,maskImage)\n\n  outputPath = input("Enter the name of the Bias Field Corrected Image :")\n  sitk.WriteImage(output,outputPath)\n  print("Finished N4 Bias Field Correction.....")\n'

In [None]:
def data_bias_correction(setName) :
  
  data_file = '/content/drive/My Drive/MISA/Normal Segmentations/data/{}/*'.format(setName)

  for filename in glob.glob(data_file):
    
    #print(filename)
    name = filename[-7:]
    #print(name)
    print("Working on image {0}".format(name))

    img_path = '/content/drive/My Drive/MISA/Normal Segmentations/data/{}/{}/{}.nii.gz'.format(setName, name, name)
    inputImage = sitk.ReadImage(img_path)
    
    mask_path = '/content/drive/My Drive/MISA/Normal Segmentations/data/{}/{}/{}_mask.nii.gz'.format(setName, name, name)
    maskImage = sitk.OtsuThreshold(inputImage,0,1,200)
    sitk.WriteImage(maskImage, mask_path)
    print("Mask image is saved.")

    inputImage = sitk.Cast(inputImage,sitk.sitkFloat32)
    corrector = sitk.N4BiasFieldCorrectionImageFilter();
    output = corrector.Execute(inputImage,maskImage)

    bias_corrected_path = '/content/drive/My Drive/MISA/Normal Segmentations/data/{}/{}/{}_bias_corrected.nii.gz'.format(setName, name, name)
    sitk.WriteImage(output,bias_corrected_path)
    print("Finished N4 Bias Field Correction.....")

    #bias_corrected_path = '/content/drive/My Drive/MISA/Normal Segmentations/data/{}/{}/{}_bias_corrected.nii.gz'.format(setName, name, name)
    #N4(img_path, mask_path, bias_corrected_path)
    #N4('/content/drive/MyDrive/MISA/Normal Segmentations/data/Training_Set/IBSR_01/IBSR_01.nii.gz', '/content/hello/', '/content/hello/')

In [None]:
"""# Calling the Bias Removal Function (N4)
data_bias_correction('Training_Set')
data_bias_correction('Validation_Set')
data_bias_correction('Test_Set')"""

"# Calling the Bias Removal Function (N4)\ndata_bias_correction('Training_Set')\ndata_bias_correction('Validation_Set')\ndata_bias_correction('Test_Set')"

##Loading the training and validation data

In [None]:
def load_data_bias(image_size, setName) :
  
  data_file = '/content/drive/My Drive/MISA/Normal Segmentations/data/{}/*'.format(setName)

  folders = glob.glob(data_file)
  n_volumes = len(folders)
  
  volumes = np.zeros((n_volumes, *image_size, 1))
  labels = np.zeros((n_volumes, *image_size, 1))

  i = 0

  for filename in glob.glob(data_file):
    
    #print(filename)
    name = filename[-7:]
    #print(name)

    img_data = nib.load('/content/drive/My Drive/MISA/Normal Segmentations/data/{}/{}/{}_bias_corrected.nii.gz'.format(setName, name, name))
    img_data_temp = img_data.get_fdata()
    img_data_temp = img_data_temp.reshape((*image_size, 1))
    #print(img_data_temp.shape)
    volumes[i] = img_data_temp

    seg_data = nib.load('/content/drive/My Drive/MISA/Normal Segmentations/data/{}/{}/{}_seg.nii.gz'.format(setName, name, name))
    labels[i] = seg_data.get_fdata()
    
    print("Working on image {0}".format(name))
    i = i+1

  return (volumes, labels)

In [None]:
(t_volumes, t_labels) = load_data_bias(IMAGE_SIZE, 'Training_Set')
(v_volumes, v_labels) = load_data_bias(IMAGE_SIZE, 'Validation_Set')

Working on image IBSR_05
Working on image IBSR_03
Working on image IBSR_08
Working on image IBSR_04
Working on image IBSR_01
Working on image IBSR_16
Working on image IBSR_18
Working on image IBSR_07
Working on image IBSR_09
Working on image IBSR_06
Working on image IBSR_14
Working on image IBSR_17
Working on image IBSR_12
Working on image IBSR_13


Visualising the Training Images

In [None]:
"""check_vol = t_volumes[1,:,:,:,:]
check_vol = check_vol.reshape((256, 128, 256))
rotated_vol = ndimage.rotate(check_vol, 90)
plt.axis('off')
plt.imshow(rotated_vol[:, :, 150], cmap='gray')
plt.show()"""

"check_vol = t_volumes[1,:,:,:,:]\ncheck_vol = check_vol.reshape((256, 128, 256))\nrotated_vol = ndimage.rotate(check_vol, 90)\nplt.axis('off')\nplt.imshow(rotated_vol[:, :, 150], cmap='gray')\nplt.show()"

## Denoising the Volumes

In [None]:
"""def anisodiff3(stack,niter=1,kappa=50,gamma=0.1,step=(1.,1.,1.),option=1,ploton=False):

    # ...you could always diffuse each color channel independently if you
    # really want
    if stack.ndim == 4:
        warnings.warn("Only grayscale stacks allowed, converting to 3D matrix")
        stack = stack.mean(3)

    # initialize output array
    stack = stack.astype('float32')
    stackout = stack.copy()

    # initialize some internal variables
    deltaS = np.zeros_like(stackout)
    deltaE = deltaS.copy()
    deltaD = deltaS.copy()
    NS = deltaS.copy()
    EW = deltaS.copy()
    UD = deltaS.copy()
    gS = np.ones_like(stackout)
    gE = gS.copy()
    gD = gS.copy()

    # create the plot figure, if requested
    if ploton:
        import pylab as pl
        from time import sleep

        showplane = stack.shape[0]//2

        fig = pl.figure(figsize=(20,5.5),num="Anisotropic diffusion")
        ax1,ax2 = fig.add_subplot(1,2,1),fig.add_subplot(1,2,2)

        ax1.imshow(stack[showplane,...].squeeze(),interpolation='nearest')
        ih = ax2.imshow(stackout[showplane,...].squeeze(),interpolation='nearest',animated=True)
        ax1.set_title("Original stack (Z = %i)" %showplane)
        ax2.set_title("Iteration 0")

        fig.canvas.draw()

    for ii in range(niter):

        # calculate the diffs
        deltaD[:-1,: ,:  ] = np.diff(stackout,axis=0)
        deltaS[:  ,:-1,: ] = np.diff(stackout,axis=1)
        deltaE[:  ,: ,:-1] = np.diff(stackout,axis=2)

        # conduction gradients (only need to compute one per dim!)
        if option == 1:
            gD = np.exp(-(deltaD/kappa)**2.)/step[0]
            gS = np.exp(-(deltaS/kappa)**2.)/step[1]
            gE = np.exp(-(deltaE/kappa)**2.)/step[2]
        elif option == 2:
            gD = 1./(1.+(deltaD/kappa)**2.)/step[0]
            gS = 1./(1.+(deltaS/kappa)**2.)/step[1]
            gE = 1./(1.+(deltaE/kappa)**2.)/step[2]

        # update matrices
        D = gD*deltaD
        E = gE*deltaE
        S = gS*deltaS

        # subtract a copy that has been shifted 'Up/North/West' by one
        # pixel. don't as questions. just do it. trust me.
        UD[:] = D
        NS[:] = S
        EW[:] = E
        UD[1:,: ,: ] -= D[:-1,:  ,:  ]
        NS[: ,1:,: ] -= S[:  ,:-1,:  ]
        EW[: ,: ,1:] -= E[:  ,:  ,:-1]

        # update the image
        stackout += gamma*(UD+NS+EW)

        if ploton:
            iterstring = "Iteration %i" %(ii+1)
            ih.set_data(stackout[showplane,...].squeeze())
            ax2.set_title(iterstring)
            fig.canvas.draw()
            # sleep(0.01)

    return stackout"""



In [None]:
"""def denoise_volumes(in_volumes) :

  n_loop = in_volumes.shape[0]

  out_volumes = np.zeros(in_volumes.shape)
  #print(out_volumes.shape)

  for i in range(0,n_loop,1):
    temp = in_volumes[i,:,:,:,:]
    temp = anisodiff3(temp,niter=10)
    temp = temp.reshape((*temp.shape, 1))
    out_volumes[i] = temp

    #print(temp.shape)

  return out_volumes"""

'def denoise_volumes(in_volumes) :\n\n  n_loop = in_volumes.shape[0]\n\n  out_volumes = np.zeros(in_volumes.shape)\n  #print(out_volumes.shape)\n\n  for i in range(0,n_loop,1):\n    temp = in_volumes[i,:,:,:,:]\n    temp = anisodiff3(temp,niter=10)\n    temp = temp.reshape((*temp.shape, 1))\n    out_volumes[i] = temp\n\n    #print(temp.shape)\n\n  return out_volumes'

In [None]:
"""t_volumes = denoise_volumes(t_volumes)
v_volumes = denoise_volumes(v_volumes)

t_volumes = t_volumes_clean
v_volumes = v_volumes_clean

print(t_volumes.shape)
print(v_volumes.shape)"""

't_volumes = denoise_volumes(t_volumes)\nv_volumes = denoise_volumes(v_volumes)\n\nt_volumes = t_volumes_clean\nv_volumes = v_volumes_clean\n\nprint(t_volumes.shape)\nprint(v_volumes.shape)'

In [None]:
# check_vol_clean = anisodiff3(check_vol)

In [None]:
"""rotated_vol_clean = ndimage.rotate(check_vol_clean, 90)
plt.axis('off')
plt.imshow(rotated_vol_clean[:, :, 150], cmap='gray')
plt.show()"""

"rotated_vol_clean = ndimage.rotate(check_vol_clean, 90)\nplt.axis('off')\nplt.imshow(rotated_vol_clean[:, :, 150], cmap='gray')\nplt.show()"

Splitting the Dataset

In [None]:
# Split the training data into training and validation
training_volumes = t_volumes[TRAINING_VOLUMES]
training_labels = t_labels[TRAINING_VOLUMES]

validation_volumes = t_volumes[VALIDATION_VOLUMES]
validation_labels = t_labels[VALIDATION_VOLUMES]

print(training_volumes.shape)
#print(training_labels.shape)

print(validation_volumes.shape)
#print(validation_labels.shape)

(9, 256, 128, 256, 1)
(1, 256, 128, 256, 1)


##Extracting Patches

In [None]:
# def z_score_standardisation(x, avg, std):
#   return (x-avg)/std

**Extract *useful* patches**

This step is fundamental, we want to provide the network with useful information

In [None]:
def extract_patches(x, patch_size, patch_stride) :
  return tf.image.extract_patches(
    x,
    sizes=[1, *patch_size, 1],
    strides=[1, *patch_stride, 1],
    rates=[1, 1, 1, 1],
    padding='SAME', name=None)

In [None]:
def extract_useful_patches(
    volumes, labels,
    image_size=IMAGE_SIZE,
    patch_size=PATCH_SIZE,
    stride=PATCH_STRIDE,
    threshold=CONTENT_THRESHOLD,
    num_classes=N_CLASSES) :

  volumes = volumes.reshape([-1, image_size[1], image_size[2], 1])
  labels = labels.reshape([-1, image_size[1], image_size[2], 1])

  vol_patches = extract_patches(volumes, patch_size, stride).numpy()
  seg_patches = extract_patches(labels, patch_size, stride).numpy()

  vol_patches = vol_patches.reshape([-1, *patch_size, 1])
  seg_patches = seg_patches.reshape([-1, *patch_size, ])

  # this will get rid of the background and only take foreground
  foreground_mask = seg_patches != 0 

  # we only keep the useful forground patches
  # threshold too small - takes even the useless patches
  # threshold too high - might leave out useful patches
  useful_patches = foreground_mask.sum(axis=(1, 2)) > threshold * np.prod(patch_size)

  vol_patches = vol_patches[useful_patches]
  seg_patches = seg_patches[useful_patches]

  seg_patches = tf.keras.utils.to_categorical(
    seg_patches, num_classes=N_CLASSES, dtype='float32')
  
  return (vol_patches, seg_patches)

In [None]:
# extract patches from training set
(training_patches, training_patches_seg) = extract_useful_patches(training_volumes, training_labels)

# extract patches from validation set
(validation_patches, validation_patches_seg) = extract_useful_patches(validation_volumes, validation_labels)

print(training_patches.shape)

(11546, 32, 32, 1)


##Data Augmentation

In [None]:
# Degree of Augmentation
deg     = 0.2

datagen = ImageDataGenerator(
        rotation_range=40, #40
        width_shift_range=deg,
        height_shift_range=deg,
        # rescale=1./255,
        shear_range=deg,
        zoom_range=deg,
        horizontal_flip=True,
        vertical_flip=True,
        fill_mode='nearest') #reflect, wrap, constant(black)

In [None]:
train_generator = datagen.flow(training_patches, batch_size=int(training_patches.shape[0]/BATCH_SIZE), seed=1)
train_label_generator = datagen.flow(training_patches_seg, batch_size=int(training_patches.shape[0]/BATCH_SIZE), seed=1)

val_generator = datagen.flow(validation_patches, batch_size=int(validation_patches.shape[0]/BATCH_SIZE), seed=1)
val_label_generator = datagen.flow(validation_patches_seg, batch_size=int(validation_patches.shape[0]/BATCH_SIZE), seed=1)

In [None]:
X_train = train_generator.next()
y_train = train_label_generator.next()

X_val = val_generator.next()
y_val = val_label_generator.next()

In [None]:
print(training_patches.shape)
print(training_patches_seg.shape)
print("----------------")
print(validation_patches.shape)
print(validation_patches_seg.shape)

(11546, 32, 32, 1)
(11546, 32, 32, 4)
----------------
(1201, 32, 32, 1)
(1201, 32, 32, 4)


In [None]:
full_train = np.concatenate((training_patches, X_train))
print(full_train.shape)
full_train_label = np.concatenate((training_patches_seg, y_train))
print(full_train_label.shape)

full_val = np.concatenate((validation_patches, X_val))
print(full_val.shape)
full_val_label = np.concatenate((validation_patches_seg, y_val))
print(full_val_label.shape)

(11726, 32, 32, 1)
(11726, 32, 32, 4)
(1219, 32, 32, 1)
(1219, 32, 32, 4)


##Train the Model



----

---




**Instantiate UNet model and train it**


Using callbacks to stop training and avoid overfitting


*   Early stopping with a certain patience
*   Save (and load!) best model



In [None]:
unet = get_unet()
# unet.summary()

In [None]:
my_callbacks = [
    tf.keras.callbacks.EarlyStopping(patience=PATIENCE), # early stopping
    tf.keras.callbacks.ModelCheckpoint(filepath=MODEL_FNAME_PATTERN, save_best_only=True) # save the best based on validation
]

unet = get_unet()
unet.compile(optimizer=OPTIMISER, loss=LOSS)
unet.fit(
    x=full_train, 
    y=full_train_label,
    validation_data=(full_val, full_val_label),
    batch_size=BATCH_SIZE,
    epochs=N_EPOCHS,
    callbacks=my_callbacks,
    verbose=1)

Epoch 1/200
Epoch 2/200
Epoch 3/200
Epoch 4/200
Epoch 5/200
Epoch 6/200
Epoch 7/200
Epoch 8/200
Epoch 9/200
Epoch 10/200
Epoch 11/200
Epoch 12/200
Epoch 13/200
Epoch 14/200
Epoch 15/200
Epoch 16/200
Epoch 17/200
Epoch 18/200
Epoch 19/200
Epoch 20/200
Epoch 21/200
Epoch 22/200
Epoch 23/200
Epoch 24/200
Epoch 25/200
Epoch 26/200
Epoch 27/200
Epoch 28/200
Epoch 29/200
Epoch 30/200
Epoch 31/200


<keras.callbacks.History at 0x7f8f2d36d450>

##Load the best model

In [None]:
unet = get_unet(
    img_size=(IMAGE_SIZE[1], IMAGE_SIZE[2]),
    n_classes=N_CLASSES,
    n_input_channels=N_INPUT_CHANNELS)
unet.compile(optimizer=OPTIMISER, loss=LOSS)
unet.load_weights('model.h5')

##Prepare test data using the validation volumes

In [None]:
def prepare_val_data(the_volumes, the_labels):
  testing_volumes_processed = the_volumes.reshape([-1, IMAGE_SIZE[1], IMAGE_SIZE[2], 1])
  testing_labels_processed = the_labels.reshape([-1, IMAGE_SIZE[1], IMAGE_SIZE[2], 1])

  testing_labels_processed = tf.keras.utils.to_categorical(testing_labels_processed, num_classes=4, dtype='float32')

  #print(testing_volumes_processed.shape)
  #print(testing_labels_processed.shape)

  return (testing_volumes_processed, testing_labels_processed)

###Predict labels for test data

In [None]:
def pred_val_data(testing_volumes_processed)  :
  # creates probability map of each label for all volumes
  prediction = unet.predict(x=testing_volumes_processed)

  prediction = np.argmax(prediction, axis=3)

  #plt.axis('off')
  #plt.imshow(prediction[:, :, 150])

  return prediction

In [None]:
"""
print(prediction.shape)
print(testing_labels_processed.shape)
print(testing_volumes_T1_processed.shape)
"""

'\nprint(prediction.shape)\nprint(testing_labels_processed.shape)\nprint(testing_volumes_T1_processed.shape)\n'

##Computing Dice, AVD and HD (Final)



In [None]:
def compute_hausdorff_distance(in1, in2, label = 'all'):
    in1=in1.squeeze()
    in2=in2.squeeze()
    hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter()
    if label == 'all':
        # Hausdorff distance
        hausdorff_distance_filter.Execute(in1, in2)
    else:
        in1_array  = in1 #sitk.GetArrayFromImage(in1)
        in1_array = (in1_array == label) *1 
        in1_array = in1_array.astype('uint16')  
        img1 = sitk.GetImageFromArray(in1_array)
        
        in2_array  = in2 #sitk.GetArrayFromImage(in2)
        in2_array = (in2_array == label) *1 
        in2_array = in2_array.astype('uint16')  
        img2 = sitk.GetImageFromArray(in2_array)
        # Hausdorff distance
        hausdorff_distance_filter.Execute(img1, img2)
    return hausdorff_distance_filter.GetHausdorffDistance()

def compute_dice_coefficient(in1, in2, label  = 'all'):
    in1=in1.squeeze()
    in2=in2.squeeze()
    if label=='all': 
        return 2 * np.sum( (in1>0) &  (in2>0) & (in1 == in2)) / (np.sum(in1 > 0) + np.sum(in2 > 0))
    else:
        return 2 * np.sum((in1 == label) & (in2 == label)) / (np.sum(in1 == label) + np.sum(in2 == label))

def compute_volumentric_difference(in1, in2, label  = 'all'):
    in1=in1.squeeze()
    in2=in2.squeeze()
    if label  == 'all':
      #  vol_dif  = np.sum((in1 != in2) & (in1 !=0) & (in2 !=0))
        return np.sum((in1 != in2)) / ((np.sum(in1 > 0) + np.sum(in2 > 0)))
    else:
        in1  = (in1 == label) * 1
        in2  = (in2 == label) * 1
        return np.sum((in1 != in2)) / ((np.sum(in1 > 0) + np.sum(in2 > 0)))

In [None]:
for cl in range(0,4,1): 
  overallDSC = np.zeros(N_CLASSES)
  overall_Hausdorff = np.zeros(N_CLASSES)
  overall_vol = np.zeros(N_CLASSES)

  for i in range(0,validation_volumes.shape[0], 1):
      
      testing_volumes_processed, testing_labels_processed = prepare_val_data(v_volumes[i], v_labels[i])
      prediction = pred_val_data(testing_volumes_processed)
      
      #cl = 3

      cur_DSC = compute_dice_coefficient(prediction, v_labels[i], label=cl)
      overallDSC = overallDSC + cur_DSC

      cur_Hausdorff = compute_hausdorff_distance(prediction, v_labels[i], label=cl) 
      overall_Hausdorff = overall_Hausdorff + cur_Hausdorff

      cur_vol = compute_volumentric_difference(prediction, v_labels[i], label=cl)
      overall_vol = overall_vol + cur_vol
      
      #print(prediction.shape)
      #print(v_labels[i].shape)
      
  #print(overall_Hausdorff)
  overallDSC = overallDSC/validation_volumes.shape[0]
  overall_Hausdorff = overall_Hausdorff/validation_volumes.shape[0]
  overall_vol = overall_vol/validation_volumes.shape[0]

  # for i in range(0,cl,1):
  #print("Class {} - Dice Coefficient = {:.4f}".format(cl, overallDSC[i]))
  #print("Class {} - HD = {:.4f}".format(cl, overall_Hausdorff[i]))
  #print("Class {} - AVD = {:.4f}".format(cl, overall_vol[i]))
  print("Class {}".format(cl))
  print("\tDice Coefficient = {:.4f}".format(overallDSC[i]))
  print("\tHD = {:.4f}".format(overall_Hausdorff[i]))
  print("\tAVD = {:.4f}".format(overall_vol[i]))

Class 0
	Dice Coefficient = 0.9972
	HD = 20.3224
	AVD = 0.0028
Class 1
	Dice Coefficient = 0.7940
	HD = 145.8869
	AVD = 0.2060
Class 2
	Dice Coefficient = 0.9106
	HD = 112.4189
	AVD = 0.0894
Class 3
	Dice Coefficient = 0.8646
	HD = 113.6574
	AVD = 0.1354


In [None]:
# batch size = 32, patient = 5, dropout=0.15, epoch = 50
"""
Class 0 - Dice Coefficient 0.9976
Class 1 - Dice Coefficient 0.8288
Class 2 - Dice Coefficient 0.9186
Class 3 - Dice Coefficient 0.8765
"""

# batch size = 40, patient = 5, dropout=0.15, epoch = 50
"""
Class 0 - Dice Coefficient 0.9977
Class 1 - Dice Coefficient 0.8261
Class 2 - Dice Coefficient 0.9202
Class 3 - Dice Coefficient 0.8790
"""

# batch size = 50, patient = 20, dropout=0.15, epoch = 200
"""
Class 0 - Dice Coefficient 0.9977
Class 1 - Dice Coefficient 0.8261
Class 2 - Dice Coefficient 0.9202
Class 3 - Dice Coefficient 0.8790
"""

# batch size = 64, patient = 20, dropout=0.40, epoch = 200
"""
Class 0 - Dice Coefficient 0.9975
Class 1 - Dice Coefficient 0.8342
Class 2 - Dice Coefficient 0.9209
Class 3 - Dice Coefficient 0.8825
"""

'\nClass 0 - Dice Coefficient 0.9975\nClass 1 - Dice Coefficient 0.8342\nClass 2 - Dice Coefficient 0.9209\nClass 3 - Dice Coefficient 0.8825\n'