# Run K-Fold for Segmentation Tasks

Run a five-fold cross validation for each architecture on NFBS and BraTS data.

In [None]:
import numpy as np
import pandas as pd
import SimpleITK as sitk
from tqdm import trange
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
import subprocess
import os

##### Tensorflow #####
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, metrics
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.callbacks import ReduceLROnPlateau, ModelCheckpoint
import tensorflow.keras.backend as K
import os

# Set this environment variable to allow ModelCheckpoint to work
os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'

# Set this environment variable to only use the first available GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

# For tensorflow 2.x.x allow memory growth on GPU
###################################
gpus = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)
###################################

# Use this to allow memory growth on TensorFlow v1.x.x
# ###################################
# config = tf.ConfigProto()
 
# # Don't pre-allocate memory; allocate as-needed
# config.gpu_options.allow_growth = True
 
# # Only allow a specified percent of the GPU memory to be allocated
# config.gpu_options.per_process_gpu_memory_fraction = 0.75
 
# # Create a session with the above options specified.
# K.tensorflow_backend.set_session(tf.Session(config = config))
# ##################################

### L2 Dice Loss

In [None]:
# L2 Dice loss
def dice_loss_l2(y_true, y_pred):
    smooth = 0.0000001
    
    # (batch size, depth, height, width, channels)
    if len(y_true.shape) == 5:
        num = K.sum(K.square(y_true - y_pred), axis = (1,2,3))
        den = K.sum(K.square(y_true), axis = (1,2,3)) + K.sum(K.square(y_pred), axis = (1,2,3)) + smooth
        
    # (batch size, height, width, channels)
    elif len(y_true.shape) == 4:
        num = K.sum(K.square(y_true - y_pred), axis = (1,2))
        den = K.sum(K.square(y_true), axis = (1,2)) + K.sum(K.square(y_pred), axis = (1,2)) + smooth
        
    return K.mean(num/den, axis = -1)

### Architecture Implementation

In [None]:
def PocketNet(inputShape, 
              numClasses, 
              mode, 
              net, 
              pocket, 
              initFilters, 
              depth):
    
    '''
    PocketNet - Smaller CNN for medical image segmentation
    
    Inputs:
    inputShape   : Size of network input - (depth, height, width, channels) for 3D
                   (height, width, channels) for 2D
    numClasses   : Number of output classes
    mode         : 'seg' or 'class' for segmenation or classification network
    net          : 'unet', 'resnet', or 'densenet' for U-Net, ResNet or DenseNet blocks
    pocket       : True/False for pocket architectures
    initFilters  : Number of starting filters at input level
    depth        : Number of max-pooling layers
    
    Outputs:
    model        : Keras model for training/predicting
    
    Author: Adrian Celaya
    Last modified: 4.20.2021
    '''
    
    # 3D inputs are (depth, height, width, channels)
    if len(inputShape) == 4:
        dim = '3d'
    # 2D inputs are (height, width, channels)
    elif len(inputShape) == 3:
        dim = '2d'
    
    # Convolution block operator
    def Block(x, filters, params, net, dim):
        ### DenseNet block ###
        if net == 'densenet':
            for _ in range(2):
                if dim == '3d':
                    y = layers.Conv3D(filters, **params[0])(x)
                elif dim == '2d':
                    y = layers.Conv2D(filters, **params[0])(x)
                x = layers.concatenate([x, y])
                
            if dim == '3d':
                x = layers.Conv3D(filters, **params[1])(x)
            elif dim == '2d':
                x = layers.Conv2D(filters, **params[1])(x)
        
        ### ResNet block ###
        elif net == 'resnet':
            if dim == '3d':
                y = layers.Conv3D(filters, **params[0])(x)
                y = layers.Conv3D(filters, **params[0])(y)
            elif dim == '2d':
                y = layers.Conv2D(filters, **params[0])(x)
                y = layers.Conv2D(filters, **params[0])(y)
                
            x = layers.concatenate([x, y])
            
            if dim == '3d':
                x = layers.Conv3D(filters, **params[1])(x)
            elif dim == '2d':
                x = layers.Conv2D(filters, **params[1])(x)
        
        ### U-Net block ###
        elif net == 'unet':
            if dim == '3d':
                x = layers.Conv3D(filters, **params[0])(x)
                x = layers.Conv3D(filters, **params[0])(x)
            elif dim == '2d':
                x = layers.Conv2D(filters, **params[0])(x)
                x = layers.Conv2D(filters, **params[0])(x)
                
        return x

    # Downsampling block - Convolution + maxpooling
    def TransitionDown(x, filters, params, net, dim):
        skip = Block(x, filters, params, net, dim)
        
        if dim == '3d':
            x = layers.MaxPooling3D(pool_size = (1, 2, 2), strides = (1, 2, 2))(skip)
        elif dim == '2d':
            x = layers.MaxPooling2D(pool_size = (2, 2), strides = (2, 2))(skip)
            
        return skip, x

    # Upsampling block - Transposed convolution + concatenation + convolution
    def TransitionUp(x, skip, filters, params, net, dim):
        
        if dim == '3d':
            x = layers.Conv3DTranspose(filters, **params[2])(x)
        elif dim == '2d':
            x = layers.Conv2DTranspose(filters, **params[2])(x)
            
        x = layers.concatenate([x, skip])
        x = Block(x, filters, params, net, dim)
        return x
    
    # Parameters for each convolution operation
    params = list()
    if dim == '3d':
        params.append(dict(kernel_size = (3, 3, 3), activation = 'relu', padding = 'same'))
        params.append(dict(kernel_size = (1, 1, 1), activation = 'relu', padding = 'same'))
        params.append(dict(kernel_size = (1, 2, 2), strides = (1, 2, 2), padding = 'same'))
    elif dim == '2d':
        params.append(dict(kernel_size = (3, 3), activation = 'relu', padding = 'same'))
        params.append(dict(kernel_size = (1, 1), activation = 'relu', padding = 'same'))
        params.append(dict(kernel_size = (2, 2), strides = (2, 2), padding = 'same'))

        
    # Keep filters constant for PocketNet
    if pocket:
        filters = [initFilters for i in range(depth + 1)]
    else:
        filters = [initFilters * 2 ** (i) for i in range(depth + 1)]
    
    # Input to network
    inputs = layers.Input(inputShape)
 
    # Encoder path
    x = inputs
    skips = list()
    for i in range(depth):
        skip, x = TransitionDown(x, filters[i], params, net, dim)
        skips.append(skip)
        
    # Bottleneck
    x = Block(x, filters[-1], params, net, dim)

    # Apply global max-pooling to output of bottleneck if classification
    if mode == 'class':
        x = layers.GlobalMaxPooling2D()(x)
        output = layers.Dense(numClasses, activation = 'softmax')(x)

    
    # Continue with decoder path if segmentation
    elif mode == 'seg':
        
        for i in range(depth - 1, -1, -1):
            x = TransitionUp(x, skips[i], filters[i], params, net, dim)
            
        if dim == '3d':
            output = layers.Conv3D(numClasses, (1, 1, 1), activation = 'softmax')(x)
        elif dim == '2d':
            output = layers.Conv2D(numClasses, (1, 1), activation = 'softmax')(x)
            
    model = Model(inputs = [inputs], outputs = [output])
    return model

### Data Generator 

Stream data from disk to model while training.

In [None]:
class data_generator(keras.utils.Sequence):
    def __init__(self, dataframe, batch_size, dim, n_channels, n_classes, shuffle):
        self.dim = dim
        self.dataframe = dataframe
        self.batch_size = batch_size
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        return int(np.floor(len(self.dataframe) / self.batch_size))

    def __getitem__(self, index):
        X, y = self.__data_generation(index)
        return X, y

    def on_epoch_end(self):
        if self.shuffle:
            self.dataframe = self.dataframe.sample(frac = 1).reset_index(drop = True)
        
    def __data_generation(self, index):
        X = np.empty((self.batch_size, *self.dim, self.n_channels))
        y = np.empty((self.batch_size, *self.dim, self.n_classes))

        for i in range(index, index + self.batch_size):
            X[i - index] = np.load(self.dataframe.iloc[i]['image'])
            y[i - index] = np.load(self.dataframe.iloc[i]['mask'])
        return X, y

### Inference

Create predictions on images after training each model.

In [None]:
def inference(model, df, num_classes, dest):
    
    dims = sitk.ReadImage(df.iloc[0]['mask'])
    dims = sitk.GetArrayFromImage(dims)
    dims = dims.shape
    
    def read_images(image_list, dims):
        def get_array(path):
            arr = sitk.ReadImage(path)
            arr = sitk.Normalize(arr)
            arr = sitk.GetArrayFromImage(arr)
            return arr

        image = np.empty((*dims, len(image_list)))

        for i in image_list
            image[..., i] = get_array(i)

        return image
    
    # Define parameters 
    patients = list(df['id'])
    slice_thickness = 5
    pred_img_depth = dims[0] + (2 * slice_thickness)
    
    for i in trange(len(patients)):
        
        patient = df.iloc[i].to_dict()
        
        image_list = list(patient.values())[2:len(patient)]
        
        original = sitk.ReadImage(image_list[0])
        
        # Load test patient image
        image = np.zeros((pred_img_depth, dims[1], dims[2], len(image_list)))
        image[slice_thickness:(dims[0] - slice_thickness), ...] = read_images(image_list, dims)

        # Predict on overlaping tiles of test image
        prediction = np.zeros((pred_img_depth, dims[1], dims[2], num_classes))
        for k in range(pred_img_depth - slice_thickness + 1):
            temp = image[k:(k + slice_thickness), ...]
            temp = temp.reshape((1, slice_thickness, dims[1], dims[2], len(image_list)))
            temp = model.predict(temp)
            temp = temp.reshape((slice_thickness, dims[1], dims[2], num_classes))
            prediction[k:(k + slice_thickness), ...] += temp

        # Take average prediction from overlap strategy and apply argmax to get final array
        prediction /= slice_thickness
        prediction = prediction[slice_thickness:(pred_img_depth - slice_thickness), ...]
        prediction = np.argmax(prediction, axis = -1)
        prediction = prediction.reshape((*dims))

        # Write prediction as SITK image
        pred_sitk = np.zeros((*dims))
        for j in range(dims[0]):
            pred_sitk[j, ...] = prediction[j, ...]

        # Copy header information from t1 image
        pred_sitk = sitk.GetImageFromArray(pred_sitk)
        pred_sitk.CopyInformation(original)

        # Write prediction as nifit file
        pred_file = dest + patient['id'] + '_prediction.nii.gz'
        sitk.WriteImage(pred_sitk, pred_file)

    ##### END OF FUNCTION #####

### K-Folds

Create folds and train model on each fold.

In [None]:
def run_kfold(input_csv, original_csv, net, pocket, pred_dest, model_dir):
    
    # Read in csv with paths to images and masks
    df = pd.read_csv(input_csv)
    df = df.sample(frac = 1).reset_index(drop = True)
    
    dims_mask = np.load(df.iloc[0]['mask'])
    dims_mask = dims_mask.shape
    n_classes = dims_mask[-1]
    
    dims_image = np.load(df.iloc[0]['image'])
    dims_image = dims_image.shape
    n_channels = dims_image[-1]
    
    original_data = pd.read_csv(original_csv)
    
    # Get unique patient IDs and split them into training and validation with Kfold split
    patients = np.unique(df['id'])
    kfold = KFold(n_splits = 5)
    splits = kFolds.split(patients)
    
    # For each split, train a model and predict on validation data. Write validation predictions as .nii.gz files.
    split_id = 1
    for split in splits:

        train_pats = list(patients[split[0]])
        val_pats = list(patients[split[1]])

        print('Starting split ' + str(split_id) + ' of ' + str(5))
        # Create DataFrames with only training and validation patients for this split
        train_df = df[df['id'].isin(train_pats)]
        val_df = df[df['id'].isin(val_pats)]
        val_images = original_data[original_data['id'].isin(val_pats)]
        
        num_train = len(train_df)
        num_val = len(val_df)

        # Create training and validation data generators
        batch_size = 4
        train_gen = data_generator(train_df, batch_size, dims_image[0:-1], n_channels, n_classes, True)
        val_gen = data_generator(val_df, batch_size, dims_image[0:-1], n_channels, n_classes, True)

        # Create model, compile it, and set up callbacks
        model = PocketNet(inputShape = dims_image, 
                          numClasses = n_classes, 
                          mode = 'seg', 
                          net = net, 
                          pocket = pocket, 
                          initFilters = 16, 
                          depth = 4)
        model.compile(optimizer = 'adam', loss = [dice_loss_l2])

        # Reduce learning rate by 0.5 if validation dice coefficient does not improve after 5 epochs
        reduce_lr = ReduceLROnPlateau(monitor = 'val_loss', 
                                      mode = 'min',
                                      factor = 0.5, 
                                      patience = 5, 
                                      min_lr = 0.000001, 
                                      verbose = 1)

        if pocket:
            model_name = model_dir + 'pocket_' + net + '_split_' + str(split_id) + '.h5'
        else:
            model_name = model_dir + 'full_' + net + '_split_' + str(split_id) + '.h5'
            
        best_model = ModelCheckpoint(filepath = model_name, 
                                     monitor = 'val_loss', 
                                     verbose = 1, 
                                     save_best_only = True)

        # Train model
        model.fit_generator(train_gen, 
                            epochs = 50, 
                            steps_per_epoch = (num_train // (4 * batch_size)), 
                            validation_data = val_gen, 
                            validation_steps = (num_val // (4 * batch_size)), 
                            callbacks = [reduce_lr, best_model], 
                            verbose = 1)

        # Use model to get 3D predictions
        model = load_model(model_name, custom_objects = {'dice_loss_l2': dice_loss_l2})
        inference(model, val_images, pred_dest)
        
        split_id += 1

Run K-fold code on the BraTS and NFBS datasets.

In [None]:
nets = ['unet', 'resnet', 'densenet']
pockets = [True, False]

for net in nets:
    for pocket in pockets:
        run_kfold('brats_slices_paths.csv', 'brats_paths.csv', net, pocket, '/path/to/brats/predictions/', '/path/to/brats/models/')
        run_kfold('nfbs_slices_paths.csv', 'nfbs_paths.csv', net, pocket, '/path/to/nfbs/predictions/', '/path/to/nfbs/models/')