# CNN-based Brain Tumour Segmentation Network
## Import packages
Please make sure you have all the required packages installed. If GPU is available, but you want to use CPU to train your model, make sure you add " os.environ['CUDA_VISIBLE_DEVICES'] = '-1'.
Package 'SimpleITK' is for loading the MR images, so you need to install it first.

In [None]:
import numpy as np
import random
import matplotlib.pyplot as plt
import os
import tensorflow as tf
from tensorflow import keras


from tensorflow.keras import layers
from tensorflow.keras import backend as K
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import load_model
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession


os.environ["CUDA_VISIBLE_DEVICES"] = '0'
config = ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.6
config.gpu_options.allow_growth = True
session = InteractiveSession(config = config)

## Visualise MRI Volume Slices and Segmentation Maps
Each MRI image contains information about a three-dimensional (3D) volume of space. An MRI image is composed of a number of voxels, which is like pixels in 2D images. Here we visualise the transverse plane (usually has a higher resolution) of some of the volumes and the corresponding segmentation maps.

In [None]:
def plot_samples(x,n=10):
    i = n
    j = 2
    plt.figure(figsize=(15,20))
    k = 1
    idx_nums = np.random.randint(len(x),size=n)
    for idx in idx_nums:
        plt.subplot(i,j,k)
        while k%2 != 0:
            plt.imshow(np.load(x[idx])[:,:,0], cmap='gray')
            plt.xlabel("Input")
            k += 1
        plt.subplot(i,j,k)
        plt.imshow(np.load(x[idx].split('_')[0]+'_seg.npy')[:,:], cmap='gray')
        plt.xlabel("Ground Truth")
        k += 1
    plt.tight_layout()
    plt.show()

img_path = 'Dataset/'
img_list = []
CLASS = 'Yes'
all_files = os.listdir(img_path + CLASS)
files = [item for item in all_files if "img" in item]
random.shuffle(files)
img_num = len(files)
for (n, file_name) in enumerate(files):
    img = os.path.join(img_path,CLASS,file_name)
    seg = os.path.join(img_path,CLASS,file_name.split('_')[0]+'_seg.npy')
    img_list.append(img)
plot_samples(img_list, n=5)

## Data preprocessing (Optional)

Images in the original dataset are usually in different sizes, so sometimes we need to resize and normalise (z-score is commonly used in preprocessing the MRI images) them to fit the CNN model. Depending on the images you choose to use for training your model, some other preprocessing methods. If preprocessing methods like cropping is applied, remember to convert the segmentation result back to its original size. 

In [None]:
!rmdir Train Val /s /q
!md Train Val Train\Yes Train\No Val\Yes Val\No


img_path = 'Dataset/'
train_list = []
val_list = []
for CLASS in os.listdir(img_path):
    if not CLASS.startswith('.'):
        all_files = os.listdir(img_path + CLASS)
        files = [item for item in all_files if "img" in item]
        random.shuffle(files)
        img_num = len(files)
        for (n, file_name) in enumerate(files):
            img = os.path.join(img_path,CLASS,file_name)
            seg = os.path.join(img_path,CLASS,file_name.split('_')[0]+'_seg.npy')
            # 80% of images will be used for training, change the number here 
            # to use different number of images for training your model.
            if n < 0.8*img_num:
                shutil.copy(img, os.path.join('Train/',CLASS,file_name))
                train_list.append(os.path.join('Train/',CLASS,file_name))
                shutil.copy(seg, os.path.join('Train/',CLASS,file_name.split('_')[0]+'_seg.npy'))
            else:
                shutil.copy(img, os.path.join('Val/',CLASS,file_name))
                val_list.append(os.path.join('Val/',CLASS,file_name))
                shutil.copy(seg, os.path.join('Val/',CLASS,file_name.split('_')[0]+'_seg.npy'))

## Train-time data augmentation
Generalizability is crucial to a deep learning model and it refers to the performance difference of a model when evaluated on the seen data (training data) versus the unseen data (testing data). Improving the generalizability of these models has always been a difficult challenge. 

**Data Augmentation** is an effective way of improving the generalizability, because the augmented data will represent a more comprehensive set of possible data samples and minimizing the distance between the training and validation/testing sets.

There are many data augmentation methods you can choose in this projects including rotation, shifting, flipping, etc.

You are encouraged to try different augmentation method to get the best segmentation result.


## Get the data generator ready

In [None]:
class DataGenerator(tf.keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, list_IDs, batch_size=4, dim=(240,240), n_channels=3, flips=(0.2,0.2), rotates=(0.1,0.2,0.1),
                 augmentation=False, shuffle=True):
        'Initialization'
        self.dim = dim
        self.batch_size = batch_size
        self.list_IDs = list_IDs
        self.n_channels = n_channels
        self.shuffle = shuffle
        self.flips = flips
        self.rotates = rotates
        self.augmentation = augmentation
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data
        if self.augmentation:
            X, y = self.__data_augmentation(list_IDs_temp)
        else:
            X, y = self.__data_generation(list_IDs_temp)

        return X, y
    
    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)
            
    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        X = np.empty((self.batch_size, *self.dim, self.n_channels))
        y = np.empty((self.batch_size, *self.dim, 1))

        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            # Store sample
            # Add data augmentation here
            X[i,] = np.load(ID)

            # Store class
            seg = np.zeros((*self.dim, 1))
            seg_0 = np.expand_dims(np.load(ID.split('_')[0]+'_seg.npy'), axis = 2)
            y[i] = np.maximum(seg, seg_0)

        return X, y

    def __data_augmentation(self, list_IDs_temp):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        X = np.empty((self.batch_size, *self.dim, self.n_channels))
        y = np.empty((self.batch_size, *self.dim, 1))

        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            # Store sample
            # Add data augmentation here

            X_temp = np.load(ID)

            # Store class
            seg = np.zeros((*self.dim, 1))
            seg_0 = np.expand_dims(np.load(ID.split('_')[0]+'_seg.npy'), axis = 2)
            y_temp = np.maximum(seg, seg_0)
            
            for a in np.arange(2):
                if np.random.binomial(1, self.flips[a]):
                    X_temp = np.flip(X_temp, axis = a)
                    y_temp = np.flip(y_temp, axis = a)
                    
            a = np.random.choice(np.arange(4), p=np.insert(self.rotates, 0, 1-np.sum(self.rotates)))
            X_temp = np.rot90(X_temp, k = a)
            y_temp = np.rot90(y_temp, k = a)

            X[i,] = X_temp
            y[i] = y_temp

        return X, y

In [None]:
class TestGenerator(tf.keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, list_IDs, dim=(240,240), n_channels=3, shuffle=True):
        'Initialization'
        self.dim = dim
        self.list_IDs = list_IDs
        self.n_channels = n_channels
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return len(self.list_IDs)

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index:index+1]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data
        X = self.__data_generation(list_IDs_temp)

        return X
    
    def get_item_id(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index:index+1]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data
        X = self.__data_generation(list_IDs_temp)

        return list_IDs_temp, X

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        X = np.empty((1, *self.dim, self.n_channels))

        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            # Store sample
            X[i,] = np.load(ID)

        return X

## Define a metric for the performance of the model
Dice score is used here to evaluate the performance of your model.
More details about the Dice score and other metrics can be found at 
https://towardsdatascience.com/metrics-to-evaluate-your-semantic-segmentation-model-6bcb99639aa2. Dice score can be also used as the loss function for training your model.

In [None]:
def dice_coef(y_true, y_pred, smooth=1e-2):
    y_true_cal = K.cast_to_floatx(y_true)
    y_pred_cal = K.cast_to_floatx(y_pred > 0.5)
    intersection = K.sum(y_true_cal * y_pred_cal, axis=[1,2,3])
    union = K.sum(y_true_cal, axis=[1,2,3]) + K.sum(y_pred_cal, axis=[1,2,3])
    dice = K.mean((2. * intersection + smooth)/(union + smooth), axis=0)
    return dice

## Build your own model here
The U-Net (https://link.springer.com/chapter/10.1007/978-3-319-24574-4_28) structure is widely used for the medical image segmentation task. You can build your own model or modify the UNet by changing the hyperparameters for our task. If you choose to use Keras, more information about the Keras layers including Conv2D, MaxPooling and Dropout can be found at https://keras.io/api/layers/.

In [None]:
tf.random.set_seed(28)
# input image tile
def InputBlock(input, filters, kernel_size=3, padding='same'):
    convolution_1 = Conv2D(filters=filters, kernel_size=kernel_size, padding=padding, kernel_initializer = 'he_normal', 
                                    activation='relu')(input)
    convolution_1 = BatchNormalization()(convolution_1)
    convolution_2 = Conv2D(filters=filters, kernel_size=kernel_size, padding=padding, kernel_initializer = 'he_normal',
                                    activation='relu')(convolution_1)
    convolution_2 = BatchNormalization()(convolution_2)
    return convolution_2

# contracting path
def ContractingPathBlock(input, filters, kernel_size=3, padding='same'):
    down_sampling = MaxPool2D((2, 2))(input)
    convolution_1 = Conv2D(filters=filters, kernel_size=kernel_size, padding=padding, kernel_initializer = 'he_normal',
                                    activation='relu')(down_sampling)
    convolution_1 = BatchNormalization()(convolution_1)
    convolution_2 = Conv2D(filters=filters, kernel_size=kernel_size, padding=padding, kernel_initializer = 'he_normal',
                                    activation='relu')(convolution_1)
    convolution_2 = BatchNormalization()(convolution_2)
    return convolution_2

# expansive path
def ExpansivePathBlock(input, con_feature, filters, tran_filters, kernel_size=3, tran_kernel_size=2, strides=1,
                       tran_strides=2, padding='same', tran_padding='same'):
    upsampling = Conv2DTranspose(filters=tran_filters, kernel_size=tran_kernel_size,
                                                 strides=tran_strides, padding=tran_padding)(input)
    concat_feature = tf.image.resize(con_feature, ((upsampling.shape)[1], (upsampling.shape)[2]),
                                  method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    concatenation_feature = tf.concat([concat_feature, upsampling], axis=3)
    convolution_1 = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, kernel_initializer = 'he_normal',
                                    activation='relu')(concatenation_feature)
    convolution_1 = BatchNormalization()(convolution_1)
    convolution_2 = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, kernel_initializer = 'he_normal',
                                    activation='relu')(convolution_1)
    convolution_2 = BatchNormalization()(convolution_2)
    return convolution_2

# U-Net
def UNet(input_shape):
    inputs = Input(input_shape)
    
    # input block
    input_block = InputBlock(inputs, 64)

    # contracting path
    convolution_1 = ContractingPathBlock(input_block, 128)
    convolution_2 = ContractingPathBlock(convolution_1, 256)
    convolution_3 = ContractingPathBlock(convolution_2, 512)
    convolution_4 = ContractingPathBlock(convolution_3, 1024)
    convolution_4 = Dropout(rate = 0.5)(convolution_4)

    # expansive path
    expand_4 = ExpansivePathBlock(convolution_4, convolution_3, 512, 512)
    expand_3 = ExpansivePathBlock(expand_4, convolution_2, 256, 256)
    expand_2 = ExpansivePathBlock(expand_3, convolution_1, 128, 128)
    expand_1 = ExpansivePathBlock(expand_2, input_block, 64, 64)

    convolution_5 = Conv2D(2, 1, activation='relu', padding='same', kernel_initializer='he_normal')(expand_1)
    outputs = Conv2D(1, 1, activation = 'sigmoid')(convolution_5)

    return tf.keras.Model(inputs=[inputs], outputs=[outputs])

model = UNet(input_shape=(240, 240, 3))
model.summary()
model.compile(optimizer=Adam(learning_rate=1e-4), loss='binary_crossentropy', metrics=dice_coef)

## Train your model here
Once you defined the model and data generator, you can start training your model.

In [None]:
train_generator = DataGenerator(train_list, augmentation=True)
val_generator = DataGenerator(val_list)
earlystopping = EarlyStopping(monitor='val_dice_coef', mode='max', patience=5)
results = model.fit(train_generator, validation_data=val_generator, epochs=30, callbacks=[earlystopping])

## Save the model
Once your model is trained, remember to save it for testing.

In [None]:
model.save('trained_segmentation_model.h5')

## Run the model on the test set
After your last Q&A session, you will be given the test set. Run your model on the test set to get the segmentation results and submit your results in a .zip file. If the MRI image is named '100_img.npy', save your segmentation result as '100_seg.npy'. 

In [None]:
test_dir = 'Val/'
output_dir = 'Output/'
#load your model here
model_load = load_model('trained_segmentation_model.h5', custom_objects={'dice_coef':dice_coef})
test_list = []
for CLASS in os.listdir(test_dir):
    if not CLASS.startswith('.'):
        all_files = os.listdir(test_dir + CLASS)
        files = [item for item in all_files if "img" in item]
        for file_name in files:
            test_list.append(test_dir + CLASS + '/' + file_name)
test_generator = TestGenerator(test_list)

predictions = []
for i in range(test_generator.__len__()):
    ID, x_test = test_generator.get_item_id(i)
    prediction = model_load.predict(np.array(x_test))
    predictions.append(prediction[0])
    str_1 = ID[0].split('_')[0]
    str_2 = str_1.split('/')[2]
    result=np.reshape(np.uint8(prediction[0]>0.5), (240, 240))
    np.save(output_dir+str_2+'_seg.npy', result)
accuracy = dice_coef(np.array(y_test), np.array(predictions))
print('Test Accuracy = %.5f' % accuracy)