In [1]:
import glob
import numpy as np
import nibabel as nib

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [2]:
"""3D UNET model for Tensorflow-Keras."""

def conv_block(x, filters, maxpool=False):
    x = layers.Conv3D(filters=filters, kernel_size=(3, 3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    y = layers.Conv3D(filters=filters, kernel_size=(3, 3, 3), padding='same')(x)
    y = layers.BatchNormalization()(y)
    y = layers.Activation('relu')(y)
    if maxpool:
        y = layers.MaxPooling3D(pool_size=(2, 2, 2))(y)
    return x, y
    

def compress_block(x):
    x = layers.Conv3D(filters=128, kernel_size=(3, 3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv3D(filters=32, kernel_size=(1, 1, 1), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    return x


def decode_block(x, filters, down_conn):
    deconv = layers.Conv3DTranspose(filters=filters, kernel_size=(2, 2, 2), strides=(2, 2, 2))(x)
    deconv = layers.BatchNormalization()(deconv)
    deconv = x = layers.Activation('relu')(deconv)
    _, conv_1 = conv_block(deconv, filters=filters)
    concat = layers.Concatenate(axis=-1)([down_conn, conv_1])
    conv_2 = layers.Conv3D(filters=filters, kernel_size=(1, 1, 1), padding='same')(concat)
    conv_2 = layers.BatchNormalization()(conv_2)
    conv_2 = layers.Activation('relu')(conv_2)
    return conv_2


def unet_3d():
    input_layer = keras.Input(shape=(128, 128, 96, 1)) 
    
    conn1, down1 = conv_block(input_layer, filters=16)
    conn2, down2 = conv_block(down1, filters=16, maxpool=True)
    conn3, down3 = conv_block(down2, filters=32, maxpool=True)
    conn4, down4 = conv_block(down3, filters=64, maxpool=True)
    conn5, down5 = conv_block(down4, filters=128, maxpool=True)
    
    compress1 = compress_block(down5)
    compress2 = compress_block(compress1)
    compress3 = compress_block(compress2)
    
    decode1 = decode_block(compress3, filters=128, down_conn=conn5)
    decode2 = decode_block(decode1, filters=64, down_conn=conn4)
    decode3 = decode_block(decode2, filters=32, down_conn=conn3)
    decode4 = decode_block(decode3, filters=16, down_conn=conn2)
    
    output = layers.Conv3D(filters=16, kernel_size=(1, 1, 1), padding='same')(decode4)
    output = layers.BatchNormalization()(output)
    output = layers.Activation('relu')(output)
    output = layers.Conv3DTranspose(filters=16, kernel_size=(2, 2, 2), padding='same')(output)
    output = layers.BatchNormalization()(output)
    output = layers.Activation('relu')(output)
    output = layers.concatenate([conn1, output], axis=-1)
    output = layers.Conv3D(filters=2, kernel_size=(1, 1, 1), padding='same', activation='softmax')(output)
    
    model = keras.Model(inputs=input_layer, outputs=output)
    return model

In [3]:
model = unet_3d()
# model.output
# model.summary()

In [4]:
checkpoint_path = "skull_stripping_unet_3d_soft_dice_loss.h5"
initial_learning_rate = 0.003

In [5]:
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
       initial_learning_rate,
       decay_steps=1000,
       decay_rate=0.96,
       staircase=True,
   )

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_path,
        monitor='val_loss',
        verbose=1,
        save_best_only=True,
        mode='min',
    )

early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',  
    patience=5, 
    verbose=1)

In [6]:
def dice_coefficient(y_true, 
                     y_pred, 
                     axis=(1, 2, 3), 
                     epsilon=0.00001):
    """
    Compute mean dice coefficient over all abnormality classes.

    Args:
        y_true (Tensorflow tensor): tensor of ground truth values for all classes.
                                    shape: (num_classes, x_dim, y_dim, z_dim)
        y_pred (Tensorflow tensor): tensor of predictions for all classes.
                                    shape: (num_classes, x_dim, y_dim, z_dim)
        axis (tuple): spatial axes to sum over when computing numerator and
                      denominator of dice coefficient.
                      Hint: pass this as the 'axis' argument to the K.sum
                            and K.mean functions.
        epsilon (float): small constant add to numerator and denominator to
                        avoid divide by 0 errors.
    Returns:
        dice_coefficient (float): computed value of dice coefficient.     
    """
    
    dice_numerator = 2. * tf.reduce_sum(y_true * y_pred, axis=axis) + epsilon
    dice_denominator = tf.reduce_sum(y_true, axis=axis) + tf.reduce_sum(y_pred, axis=axis) + epsilon
    dice_coefficient = tf.reduce_mean((dice_numerator)/(dice_denominator))
    
    return dice_coefficient

In [7]:
def soft_dice_loss(y_true, 
                   y_pred, 
                   axis=(1, 2, 3), 
                   epsilon=0.00001):
    """
    Compute mean soft dice loss over all abnormality classes.

    Args:
        y_true (Tensorflow tensor): tensor of ground truth values for all classes.
                                    shape: (num_classes, x_dim, y_dim, z_dim)
        y_pred (Tensorflow tensor): tensor of soft predictions for all classes.
                                    shape: (num_classes, x_dim, y_dim, z_dim)
        axis (tuple): spatial axes to sum over when computing numerator and
                      denominator in formula for dice loss.
                      Hint: pass this as the 'axis' argument to the K.sum
                            and K.mean functions.
        epsilon (float): small constant added to numerator and denominator to
                        avoid divide by 0 errors.
    Returns:
        dice_loss (float): computed value of dice loss.     
    """

    dice_numerator = 2. * tf.reduce_sum(y_true * y_pred, axis=axis) + epsilon
    dice_denominator = tf.reduce_sum(y_true**2, axis=axis) +tf.reduce_sum(y_pred**2, axis=axis) + epsilon
    dice_loss = 1 - tf.reduce_mean((dice_numerator)/(dice_denominator))

    return dice_loss

In [14]:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule),
              loss=soft_dice_loss,#tf.keras.losses.CategoricalCrossentropy(),
              metrics=['accuracy', dice_coefficient])

epochs=30
history = model.fit(train_data, 
                    epochs=epochs,
                    validation_data=valid_data,
                    callbacks=[checkpoint_callback, early_stopping])

Epoch 1/30

Epoch 00001: val_loss improved from inf to 0.10563, saving model to skull_stripping_unet_3d_soft_dice_loss.h5
Epoch 2/30

Epoch 00002: val_loss improved from 0.10563 to 0.07731, saving model to skull_stripping_unet_3d_soft_dice_loss.h5
Epoch 3/30

Epoch 00003: val_loss did not improve from 0.07731
Epoch 4/30

Epoch 00004: val_loss did not improve from 0.07731
Epoch 5/30

Epoch 00005: val_loss improved from 0.07731 to 0.07171, saving model to skull_stripping_unet_3d_soft_dice_loss.h5
Epoch 6/30

Epoch 00006: val_loss improved from 0.07171 to 0.06264, saving model to skull_stripping_unet_3d_soft_dice_loss.h5
Epoch 7/30

KeyboardInterrupt: 

In [8]:
import math
# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.

class BrainMRIDataset(tf.keras.utils.Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return math.ceil(len(self.x) / self.batch_size)

    def read_nifti_file(self, filepath):
      volume = nib.load(filepath).get_fdata()
      volume = np.array(volume)
      return volume


    def __getitem__(self, idx):

        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]

        image = [self.read_nifti_file(image_file) for image_file in batch_x]
        # image = np.rollaxis(image, 0, 3)
        image = tf.expand_dims(image, axis=-1)
        label = [self.read_nifti_file(mask_file) for mask_file in batch_y]
        label = tf.keras.utils.to_categorical(label, 2)
        # label = np.rollaxis(label, 0, 3)
        # label = tf.expand_dims(label, axis=0)

        return image, label

In [9]:
train_image_paths = sorted(glob.glob('../input/neuroscience/train/images/*'))
train_mask_paths = sorted(glob.glob('../input/neuroscience/train/masks/*'))

valid_image_paths = sorted(glob.glob('../input/neuroscience/valid/images/*'))
valid_mask_paths = sorted(glob.glob('../input/neuroscience/valid/masks/*'))

In [10]:
# !cp -r /content/drive/MyDrive/Patchs .

In [11]:
# train_image_paths = sorted(glob.glob('/content/Patchs/train/images/*'))
# train_mask_paths = sorted(glob.glob('/content/Patchs/train/masks/*'))

# valid_image_paths = sorted(glob.glob('/content/Patchs/valid/images/*'))
# valid_mask_paths = sorted(glob.glob('/content/Patchs/valid/masks/*'))

In [12]:
train_data = BrainMRIDataset(train_image_paths, train_mask_paths, 1)
valid_data = BrainMRIDataset(valid_image_paths, valid_mask_paths, 1)