
# 3D CNN | Human Brain Skullstripping with 3D UNet

Notebook example for "Advanced Deep Learning Models and Methods for 3D Spatial Data" course at Politecnico di Milano (Polimi) during the academic year 2023/2024, instructed by Professors Boracchi Giacomo, Magri Luca, Matteucci Matteo, and Melzi Simone.

Author: Marcello De Salvo <br>
Repo: https://github.com/MarcelloDeSalvo/LearningWith3dCnn


# Hackathon Tasks

*   Run the initial part of the script and get aquainted with visualization of 3D data. Open other images / change the sampling frequency of point clouds
*   Implement the Unet3D as illustrated below
*   Implement a modified variant of UNet3D which leverages residual connections in each block.
*   Assess testing performance and compare the above models. Visualize the estimated masks



# Libraries

In [None]:
# Utils
from tqdm import tqdm
import os
import shutil
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Nibabel visualization and processing
import nibabel as nib
!pip install nilearn
import nilearn
from scipy.ndimage import zoom

# Tensorflow and Keras
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *
from tensorflow.keras.layers.experimental import preprocessing
from tensorflow.keras.utils import plot_model
from tensorflow.keras.losses import binary_crossentropy

# sklearn
from sklearn.model_selection import train_test_split

import warnings
warnings.simplefilter("ignore")

Collecting nilearn
  Downloading nilearn-0.10.3-py3-none-any.whl (10.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.4/10.4 MB[0m [31m24.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: nilearn
Successfully installed nilearn-0.10.3


In [None]:
# Set random seed
seed = 42
tf.keras.utils.set_random_seed(seed)

# Dataset
We'll use this public dataset: http://preprocessed-connectomes-project.org/NFB_skullstripped/


The dataset comprises information from 125 individuals aged between 21 and 45 years, encompassing a diverse range of clinical and subclinical psychiatric symptoms. Each participant's data includes:

- Anonymized (de-faced for privacy) Structural T1-weighted image
- Skull-stripped image
- Brain mask

The images are in NiFTI format (.nii.gz) with a resolution of 1 mm³.

**What is a T1-weighted image?**

A T1-weighted image is a type of magnetic resonance imaging (MRI) that highlights differences in tissue density, providing detailed anatomical information. It is particularly useful for visualizing brain structures since it provides crisp images, and shows fluids as dark.

**What is a Nifti format?**

NIfTI (Neuroimaging Informatics Technology Initiative) is a data format for the storage of Functional Magnetic Resonance Imaging (fMRI) and other medical images.

### Download and extract the dataset

In [None]:
dataset_folder = './NFBS_Dataset'

# Check if the folder already exists
if not os.path.exists(dataset_folder):
    # Step 1: Download the dataset
    !curl -O https://fcp-indi.s3.amazonaws.com/data/Projects/RocklandSample/NFBS_Dataset.tar.gz

    # Step 2: Extract the dataset
    !tar -xf NFBS_Dataset.tar.gz

    # Optionally, you can remove the tar file after extracting to save space
    !rm NFBS_Dataset.tar.gz
else:
    print(f"The folder '{dataset_folder}' already exists. Skipping download and extraction.")


  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1670M  100 1670M    0     0  11.8M      0  0:02:21  0:02:21 --:--:-- 13.4M


In [None]:
items = os.listdir(dataset_folder)

patients = [item for item in items if os.path.isdir(os.path.join(dataset_folder, item))]
num_patients = len(patients)

print("Number of patients (folders):", num_patients)

Number of patients (folders): 125


## Data exploration

We can load our NIfTI files using the **nibabel** library and utilize the **nilearn** library to process the MRIs and visualize them

In [None]:
sample_filename = dataset_folder + '/A00028185/sub-A00028185_ses-NFB3_T1w.nii.gz'
sample_filename_mask = dataset_folder + '/A00028185/sub-A00028185_ses-NFB3_T1w_brainmask.nii.gz'

sample_img = nib.load(sample_filename)
sample_img = sample_img.get_fdata()
sample_mask = nib.load(sample_filename_mask)
sample_mask = sample_mask.get_fdata()

input_shape = sample_img.shape
mask_shape = sample_mask.shape
print("img shape ->", input_shape)
print("mask shape ->", mask_shape)

img shape -> (256, 256, 192)
mask shape -> (256, 256, 192)


### 2D slices view

In [None]:
from nilearn.plotting import view_img, plot_glass_brain, plot_anat, plot_epi
from nilearn.image import load_img
nilearn_img = load_img(sample_filename)
nilearn_mask = load_img(sample_filename_mask)

In [None]:
plot_anat(nilearn_img)

In [None]:
plot_anat(nilearn_img, draw_cross=False, display_mode='z')

In [None]:
plot_anat(nilearn_mask, draw_cross=False, display_mode='z')

In [None]:
view_img(nilearn_mask, nilearn_img)

### 3D Pointcloud view

In [None]:
import plotly.graph_objects as go
def plot_3d_overlap(sample_img, sample_mask, density, alpha=0.2):

    # adjust the rotation (just for visual purposes)
    sample_img = np.swapaxes(sample_img, 1, 2)
    sample_mask = np.swapaxes(sample_mask, 1, 2)

    # Invert Z-axis (just for visual purposes)
    sample_img = sample_img[::-1, :, ::-1]
    sample_mask = sample_mask[::-1, :, ::-1]

    # Brain and Gt
    _sample_mask = np.where(sample_mask > 0.5, 1, 0)
    gt_indices = np.argwhere(_sample_mask > 0)
    brain_indices =  np.argwhere(sample_img > 200) # remove noise

    # Randomly sample points based on the specified density
    brain_indices = brain_indices[np.random.choice(len(brain_indices), density, replace=False)]
    gt_indices = gt_indices[np.random.choice(len(gt_indices), density, replace=False)]

    # Get matrix values for color mapping
    brain_colors = sample_img[brain_indices[:, 0], brain_indices[:, 1], brain_indices[:, 2]]
    gt_colors = sample_mask[gt_indices[:, 0], gt_indices[:, 1], gt_indices[:, 2]]


    # Create 3D scatter plot for the brain mask
    brain_scatter = go.Scatter3d(
        x=brain_indices[:, 0],
        y=brain_indices[:, 1],
        z=brain_indices[:, 2],
        mode='markers',
        marker=dict(
            size=1,
            opacity=0.3,
            color=brain_colors,
            colorscale='Viridis',
            cmin=np.min(brain_colors),
            cmax=np.max(brain_colors),
        ),
        name='Head'
    )

    # Create 3D scatter plot for the gt
    gt_scatter = go.Scatter3d(
        x=gt_indices[:, 0],
        y=gt_indices[:, 1],
        z=gt_indices[:, 2],
        mode='markers',
        marker=dict(
            size=1.5,
            color=gt_colors,
            colorscale='YlOrRd',
            cmin=np.min(gt_colors),
            cmax=np.max(gt_colors),
        ),
        name='Ground truth'
    )

    # Create figure
    fig = go.Figure(data=[brain_scatter, gt_scatter])

    # Set layout properties
    fig.update_layout(
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z',
        ),
        title='3D Plot - Brain Mask Point Cloud'
    )

    # Show figure
    fig.show()

plot_3d_overlap(sample_img, sample_mask, density=50000)

# Problem definiton

The objective of this task is to create a 3D model for binary semantic segmentation, where each voxel in the volumetric data is assigned one of the following labels:
0. Background (other anatomical tissues such as bones and flesh)
1. Brain tissue

This procedure (called "skull stripping") that recognizes and separates the brain and surrounding tissues from the MRI data, is essential for many applications, such as medical picture analysis and neuroscience research.

In [None]:
config = {
    'input_shape': input_shape,
    'dataset_folder': dataset_folder,
    'batch_size': 3,
    'num_classes': 1,  # Binary segmentation
    'num_channels': 1, # Just one input channel for the single t1w modality
    'validation_split': 0.2,
    'test_split': 0.2,
    'target_resolution': (4,4,4) # We'll downsample the data from 1 mm³ to 4 mm³ (You can also try with 2 mm³ but it takes longer to train!)
}

config['input_shape'] = tuple(int(dim / resolution) for dim, resolution in zip(config['input_shape'], config['target_resolution']))
print('Resampled input shape: ', config['input_shape'])

Resampled input shape:  (64, 64, 48)


# Loss function
## Dice coefficient <br>
The Dice coefficient (also known as the Sørensen-Dice coefficient) is used to measure the similarity or overlap between two sets.

In this context, it's used to measure the similarity between the ground truth binary mask (y_true) and the predicted binary mask (y_pred) for a specific class. This measure ranges from 0 to 1 where a Dice coefficient of 1 denotes perfect and complete overlap. <br>
The Dice coefficient can be calculated as: <br>
$$
\Large
\text{Dice Coefficient} = \frac{2 \times |X \cap Y|}{|X| + |Y|}
$$
In our case this translates into:
- Intersection: This is the sum of the element-wise multiplication of the ground truth mask and the predicted mask.<br>
- X or Area of y_true: This is the sum of all non-zero values in the ground truth mask, representing the total number of pixels that belong to the class in the ground truth.
- Y or Area of y_pred: This is the sum of all non-zero values in the predicted mask, representing the total number of pixels that the model predicted as belonging to the class.

<img src="https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcSrnfpehrVZMJLjRDVUWxEZ9_pW0RYUlkdhlw&usqp=CAU" alt="Image 1" style="width:400px; height:200px"/>

In [None]:
def dice_loss(smooth=1e-5):
    '''
    This loss function is known as the Soft Dice loss because we directly use the predicted probabilities
    instead of thresholding and converting them into a binary mask.
    '''
    def loss(y_true, y_pred):
        return 1 - dice(y_true, y_pred, smooth)
    return loss

def dice(y_true, y_pred, smooth=1e-5):
    '''
    Soft dice
    '''
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)

    intersection = K.sum(y_true_f * y_pred_f)
    union = K.sum(y_pred_f) + K.sum(y_true_f)
    return (2. * intersection + smooth) / (union + smooth)

def dice_bce_loss(alpha=0.5, smooth=1e-5):
    '''
    Combined loss: Weighted sum of Dice loss and Binary Crossentropy loss
    Combining a Dice loss with Binary Cross-Entropy (BCE) is another frequently employed strategy
    '''
    def loss(y_true, y_pred):
        dice_loss_value = dice_loss(smooth)(y_true, y_pred)
        bce_loss_value = binary_crossentropy(y_true, y_pred)

        combined_loss = alpha * dice_loss_value + (1 - alpha) * bce_loss_value
        return combined_loss
    return loss

## Other useful metrics for 3D data

In [None]:
def precision(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    return precision

def sensitivity(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    return true_positives / (possible_positives + K.epsilon())

def specificity(y_true, y_pred):
    true_negatives = K.sum(K.round(K.clip((1-y_true) * (1-y_pred), 0, 1)))
    possible_negatives = K.sum(K.round(K.clip(1-y_true, 0, 1)))
    return true_negatives / (possible_negatives + K.epsilon())

def iou_coeff(y_true, y_pred, threshold=0.5):
    '''
    IoU (Intersection over Union) coefficient
    - threshold: threshold for prediction binaryzation, set low probabilities to 0
    '''
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    y_pred_f = K.greater(y_pred_f, threshold)
    y_pred_f = K.cast(y_pred_f, dtype='float32') # Becomes binary

    # Intersection and Union
    intersection = K.sum(y_true_f * y_pred_f)
    union = K.sum(K.maximum(y_true_f, y_pred_f))

    # Calculate IoU (Intersection over Union)
    iou = (intersection + K.epsilon()) / (union + K.epsilon())

    return iou

In [None]:
_metrics = ['accuracy', precision, sensitivity, specificity, dice, iou_coeff]

# Data Generator
## Data Loader
Loading all data into memory is not a good idea since the data are too big to fit in.<br>
So we will create a DataGenerators class to load data on the fly as explained [here](https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly)

## Preprocessing
1. First, we'll reduce the resolution of each image by resampling it. (Ex: from 0.1 mm$^3$ to 0.2 mm$^3$)

2. Second, we will z-score normalization to normalize the MRI data's histogram frequencies.

In [None]:
dataset_ids = []

for patient in os.listdir(dataset_folder):
    dataset_ids.append(patient)

dataset_ids.sort()
print('Size of the dataset: ', len(dataset_ids))

# Splitting
train_test_ids, val_ids = train_test_split(dataset_ids,test_size=config['validation_split'], random_state=seed)
train_ids, test_ids = train_test_split(train_test_ids,test_size=config['test_split'], random_state=seed)

print('Size of the training set: ', len(train_ids))
print('Size of the validation set: ', len(val_ids))
print('Size of the test set: ', len(test_ids))
print(test_ids)

In [None]:
class DataGenerator(tf.keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, list_IDs, dim = config['input_shape'], dataset_folder = config['dataset_folder'], batch_size = config['batch_size'],
                 n_channels = config['num_channels'], n_class = config['num_classes'], target_resolution=config['target_resolution'],  shuffle=True):
        'Initialization'
        self.dim = dim
        self.dataset_folder = dataset_folder
        self.batch_size = batch_size
        self.list_IDs = list_IDs
        self.n_channels = n_channels
        self.shuffle = shuffle
        self.n_class = n_class
        self.target_resolution = target_resolution
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]

        # Find list of IDs
        batch_ids = [self.list_IDs[k] for k in indexes]

        # Generate data
        X, Y = self.__data_generation(batch_ids)
        return X, Y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def res_nifti(self, image):
        'Resamples the nifti image to a target resolution'
        voxel_dims = image.header.get_zooms() # Automatically retrieves the original resolution from the nifti file
        scale_factors =  [current_dim / target_dim for current_dim, target_dim in zip(voxel_dims, self.target_resolution)]
        resampled_data = zoom(image.get_fdata(), scale_factors, order=0, mode='nearest')

        return resampled_data

    def normalize(self, arr, mode='zscore'):
        # min max normalization
        if mode == 'minmax':
            return (arr - arr.min()) / (arr.max() - arr.min())
        # zscore normalization
        elif mode == 'zscore':
            return (arr - arr.mean()) / arr.std()
        else:
            print("ERROR: Unknown normalization mode")
            return arr

    def __data_generation(self, batch_ids):
        'Generates data containing batch_size samples'

        # Initialization
        X = np.zeros((self.batch_size,*(self.dim),self.n_channels))
        Y = np.zeros((self.batch_size,*(self.dim),self.n_class))

        # Generate data
        for i, id in enumerate(batch_ids):
            case_path = os.path.join(self.dataset_folder, id)

            # Original Image Preprocessing
            image_path = os.path.join(case_path, f'sub-{id}_ses-NFB3_T1w.nii.gz')
            image = nib.load(image_path)
            image_data = self.res_nifti(image)
            image_data = self.normalize(image_data)

            # Mask preprocessing
            mask_path = os.path.join(case_path, f'sub-{id}_ses-NFB3_T1w_brainmask.nii.gz')
            mask = nib.load(mask_path)
            mask_data = self.res_nifti(mask)
            mask_data = np.where(mask_data > 0.5, 1, 0)

            if (self.n_class == 1): mask_data = np.expand_dims(mask_data, axis=-1)
            else: mask_data = tf.keras.utils.to_categorical(mask_data,self.n_class)

            # Stacking
            X[i] = np.stack((image_data,)*self.n_channels, axis=-1)
            Y[i] = mask_data

        return X, Y

In [None]:
# Datasets Initialization
training_generator = DataGenerator(train_ids)
valid_generator = DataGenerator(val_ids)
test_generator = DataGenerator(test_ids)

## Sanity check
Here we verify whether the image has been resampled correctly, ensuring that there are no errors in the preprocessing of the DataGenerator by checking the output shapes and by visualizing a preprocesed sample image.

In [None]:
X,Y = training_generator.__getitem__(index=0) # Fetching the first batch
# Check if shapes match expectations
assert X.shape == (config['batch_size'], *(config['input_shape']), config['num_channels'])
assert Y.shape == (config['batch_size'], *(config['input_shape']), config['num_classes'])

# Check if Y is one-hot encoded
print("Unique values in Y: ", np.unique(Y))

# Take the slice in the middle
slice_idx = config['input_shape'][-1] // 2

# Plotting
plt.figure(figsize=(20, 20))
plt.subplot(1,5,1)
plt.imshow(np.rot90(X[0,:,:,slice_idx,0], k=-1), cmap='gray')
plt.title('T1 image')
plt.subplot(1,5,2)
plt.imshow(np.rot90(Y[0,:,:,slice_idx,0], k=-1), cmap='gray')
plt.title('Ground truth')
plt.show()

# Model | 3D U-Net
The success of U-Net has led to the development of derivative architectures by other researchers, including those employing 3D convolutions: [Çiçek et al. in 2016](https://arxiv.org/abs/1606.06650) showed that 3D architectures can achieve comprehensive 3D segmentation with minimal annotated slices from the same volume. Building on this progress, on the same year, [Milletari et al.](https://arxiv.org/abs/1606.04797) proposed a 3D version of the U-Net, trained using the Dice Coefficient.

## 3D convolutions
3D convolutions expand the convolutional operation into an additional dimension, making them perfect a perfect fit for volumetric and temporal data (e.g. [with videos](https://www.tensorflow.org/tutorials/video/video_classification)).

Within the field of medical imaging, specifically in applications such as this one, we can use **voxels** to represent three-dimensional datasets like those from CT and MRI scans. With each voxel holding the signal value, these provide an intricate depiction of the underlying anatomy. For this reasons, 3D U-Net topologies have demonstrated excellent performance in a variety of biomedical applications.
![](https://i.imgur.com/TYMETaw.gif)

## 3D Unet architecture
To accomplish this task we could design something like this:

![](https://i.imgur.com/gbzLYgG.png)

Keras provides already some useful layers that can handle 3D images such as [Conv3D](https://keras.io/api/layers/convolution_layers/convolution3d/), [MaxPooling3D](https://keras.io/api/layers/pooling_layers/max_pooling3d/) and [UpSampling3D](https://keras.io/api/layers/reshaping_layers/up_sampling3d/) layers, and you can also use [BatchNormalization](https://keras.io/api/layers/normalization_layers/batch_normalization/). <br>
*Note that with 'UpConv3D' we are referring to an operation involving an UpSampling3D layer followed by a Conv3D layer.*

In [None]:
# TODO: implement these blocks according to the architecture shown above

# This is the C block shown in the above architecture
def conv_block(x, filters, kernel_size=3, padding='same', activation='relu', kernel_initializer='he_normal'):
    # TODO
    return x

# This is the CDC block shown above
def double_conv_block(x, filters, kernel_size=3, padding='same', activation='relu', kernel_initializer='he_normal'):
    # TODO
    return x

# This is the encoder block (CDC + MaxPooling3D) that returns the output and a skip connection
def encoder_block(input, num_filter):
    # TODO
    return out, cdc

# This is the decoder block (UpConv3D + Skip connection concatenation + CDC)
# Simply concatenate the skip connections with the UpConv3D layer output
def decoder_block(input, skip, num_filter):
    # TODO
    return cdc

In [None]:
# Assemble the encoder, decoder and all the skip connections in a UNet3D
def UNet3D(in_shape, in_channels, num_classes):
    inputs = Input(shape=(in_shape[0], in_shape[1], in_shape[2], in_channels))

    ## -------------Encoder--------------
    # TODO

    ## -----------Bottleneck-------------
    # TODO

    ## -------------Decoder--------------
    # TODO

    outputs = Conv3D(1, kernel_size=1, activation='sigmoid')(d3) # Final sigmoid activation
    return Model(inputs=inputs, outputs=outputs)

model = UNet3D(config['input_shape'], config['num_channels'], config['num_classes'])

In [None]:
model.summary()

In [None]:
callbacks = [
        tf.keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=15),
    ]

In [None]:
steps = len(train_ids) // config['batch_size']
val_steps = len(val_ids) // config['batch_size']

# Training will take a while! (In the solution: 15min for 4 mm³, 30 for 2 mm³)
model.compile(loss=dice_loss(), optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), metrics = _metrics)
history = model.fit(training_generator,epochs=15, steps_per_epoch=steps, callbacks= callbacks, validation_data=valid_generator, validation_steps=val_steps)

In [None]:
model.save("skull_stripping_net.h5")

## Results

In [None]:
_history = model.history.history
epoch = range(len(_history['loss']))

In [None]:
def plot_metrics(ax, metric, val_metric, label, val_label):
    ax.plot(epoch, metric, 'b', label=label)
    ax.plot(epoch, val_metric, 'r', label=val_label)
    ax.legend()

metrics = ['accuracy', 'loss', 'precision', 'dice', 'iou_coeff']
labels = ['Training Accuracy', 'Training Loss', 'Precision', 'Dice', 'IoU']

f, ax = plt.subplots(1, 5, figsize=(16, 5))

for i in range(5):
    plot_metrics(ax[i], _history[metrics[i]], _history[f'val_{metrics[i]}'], labels[i], f'Validation {labels[i]}')

plt.show()

## Evaluation

In [None]:
print("Evaluate on test data")
results = model.evaluate(test_generator, callbacks=callbacks)
print("test evaluation metrics:", results)

In [None]:
from tabulate import tabulate

headers = ['Metric', 'Value']
data = [
    ['Loss', results[0]],
    ['Accuracy', results[1]],
    ['Precision', results[2]],
    ['Sensitivity', results[3]],
    ['Specificity', results[4]],
    ['Dice', results[5]],
    ['IoU', results[6]]
]

# Print the table
table = tabulate(data, headers, tablefmt='pretty')
print(table)

In [None]:
def showPredictsById(generator, model, image_id, alpha=0.4, axis=1):
    # Get data for the specified image_id from the generator
    processed_image, ground_truth = generator.__getitem__(image_id)

    # Make predictions using the model
    prediction = model.predict(processed_image)

    if (axis == 2): rotation = -1
    else: rotation = 0

    # Display the images
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 4, 1)
    plt.imshow(np.rot90(np.take(processed_image[0,:,:,:,0], slice_idx, axis=axis), k=rotation), cmap='gray')
    plt.title("Original Image")

    plt.subplot(1, 4, 2)
    plt.imshow(np.rot90(np.take(ground_truth[0,:,:,:,0], slice_idx, axis=axis), k=rotation), cmap='gray')
    plt.title("Ground Truth")

    plt.subplot(1, 4, 3)
    plt.imshow(np.rot90(np.take(prediction[0,:,:,:,0], slice_idx, axis=axis), k=rotation), cmap='gray')
    plt.title("Prediction")

    plt.subplot(1, 4, 4)
    plt.imshow(np.rot90(np.take(processed_image[0,:,:,:,0], slice_idx, axis=axis), k=rotation), cmap='gray')
    plt.imshow(np.rot90(np.take(prediction[0,:,:,:,0], slice_idx, axis=axis), k=rotation), cmap='jet', alpha=alpha)
    plt.title("Overlap")

    plt.show()


# To plot the pre-processed data we can use our custom test generator with batch_size equal to 1
test_generator = DataGenerator(list_IDs=test_ids, batch_size=1)

for index in range(len(test_generator)):
    random_axis = np.random.randint(0, 3)
    showPredictsById(test_generator, model, index, axis=random_axis)

In [None]:
import plotly.graph_objects as go
def plot_3d_overlap(generator, model, image_id, alpha=0.1, density=10000, threshold=200):
    processed_image, ground_truth = generator.__getitem__(image_id)

    # Brain and Gt
    gt_volume = ground_truth[0, :, :, :, 0]
    brain_volume = processed_image[0, :, :, :, 0]

    # Prediction
    prediction = model.predict(processed_image)
    prediction_volume = prediction[0, :, :, :, 0]
    prediction_volume = np.where(prediction_volume > 0.5, 1, 0)

    # Fix rotation and swap axis (just for visual purposes)
    brain_volume = np.swapaxes(brain_volume, 1, 2)
    gt_volume = np.swapaxes(gt_volume, 1, 2)
    prediction_volume = np.swapaxes(prediction_volume, 1, 2)

    # Invert axis (just for visual purposes)
    brain_volume = brain_volume[::-1, :, ::-1]
    gt_volume = gt_volume[::-1, :, ::-1]
    prediction_volume = prediction_volume[::-1, :, ::-1]

    # Get Indices
    gt_indices = np.argwhere(gt_volume > 0)
    brain_indices =  np.argwhere(brain_volume > threshold) # remove noise
    prediction_indices = np.argwhere(prediction_volume > 0)

    # Randomly sample points based on the specified density
    brain_indices = brain_indices[np.random.choice(len(brain_indices), density, replace=False)]

    # Get matrix values for color mapping
    brain_colors = brain_volume[brain_indices[:, 0], brain_indices[:, 1], brain_indices[:, 2]]

    # Create 3D scatter plot for the brain mask
    brain_scatter = go.Scatter3d(
        x=brain_indices[:, 0],
        y=brain_indices[:, 1],
        z=brain_indices[:, 2],
        mode='markers',
        marker=dict(
            size=1,
            opacity=0.2,
            color=brain_colors,
            colorscale='Viridis',
            cmin=np.min(brain_colors),
            cmax=np.max(brain_colors),
        ),
        name='Brain Mask Point Cloud'
    )

    # Create 3D scatter plot for the gt
    gt_scatter = go.Scatter3d(
        x=gt_indices[:, 0],
        y=gt_indices[:, 1],
        z=gt_indices[:, 2],
        mode='markers',
        marker=dict(
            size=2,
            color='yellow',
            opacity=alpha
        ),
        name='Ground truth'
    )

    # Create 3D scatter plot for the prediction
    prediction_scatter = go.Scatter3d(
        x=prediction_indices[:, 0],
        y=prediction_indices[:, 1],
        z=prediction_indices[:, 2],
        mode='markers',
        marker=dict(
            size=1,
            color='red',
            opacity=alpha
        ),
        name='Prediction'
    )

    # Create figure
    fig = go.Figure(data=[brain_scatter, gt_scatter, prediction_scatter])

    # Set layout properties
    fig.update_layout(
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z',
        ),
        title='3D Plot - Brain Mask Point Cloud',
        annotations=[
        dict(
            text='You can switch on and off the masks by clicking',
            x=1.05,
            y=1,
            align='right',
            font=dict(family='Arial', size=12, color='black'),
        ),
    ],
    )

    # Show figure
    fig.show()

random_index = np.random.randint(0, len(test_generator) - 1)
plot_3d_overlap(test_generator, model, random_index, density=30000, threshold=0.2)

# Model Variant | 3D U-Net with Residual Skip connections

For the residual block, you can substitute every CDC block with something like this:

![](https://i.imgur.com/6n0QNCI.png)

where a shortcut connection (identity mapping) is processed in parallel with a 1x1x1 3d convolutional block that adjusts shapes and is then added ([Add layer](https://keras.io/api/layers/merging_layers/add/)) before applying the final ReLU.

In [None]:
# TODO: Implement a modified variant of UNet3D which leverages residual connections in each block.

# This is the residual block showed in the diagram
def residual_block(res, filters, kernel_size=3, padding='same', activation='relu', kernel_initializer='he_normal'):
    # TODO
    return x

# This is the encoder block (CDC + MaxPooling3D) that returns the output and a skip connection
def res_encoder_block(input, num_filter):
    # TODO
    return out, cdc

# This is the decoder block (UpConv3D + Skip connection concatenation + CDC)
# Simply concatenate the skip connections with the UpConv3D layer output
def res_decoder_block(input, skip, num_filter):
    # TODO
    return cdc

In [None]:
# Assemble the encoder, decoder and all the skip connections in a ResUNet3D
def ResUNet3D(in_shape, in_channels, num_classes):
    inputs = Input(shape=(in_shape[0], in_shape[1], in_shape[2], in_channels))

    ## -------------Encoder--------------
    # TODO

    ## -----------Bottleneck-------------
    # TODO

    ## -------------Decoder--------------
    # TODO

    outputs = Conv3D(1, kernel_size=1, activation='sigmoid')(d3) # Final sigmoid activation
    return Model(inputs=inputs, outputs=outputs)

res_model = ResUNet3D(config['input_shape'], config['num_channels'], config['num_classes'])

In [None]:
res_model.summary()

In [None]:
res_model.compile(loss=dice_loss(), optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), metrics = _metrics)
history = res_model.fit(training_generator,epochs=15, steps_per_epoch=steps, callbacks= callbacks, validation_data=valid_generator, validation_steps=val_steps)

In [None]:
res_model.save("residual_skull_stripping_net.h5")

## Results

In [None]:
_res_history = res_model.history.history
epoch = range(len(_res_history['loss']))

In [None]:
f, ax = plt.subplots(1, 5, figsize=(16, 5))

for i in range(5):
    plot_metrics(ax[i], _res_history[metrics[i]], _res_history[f'val_{metrics[i]}'], labels[i], f'Validation {labels[i]}')

plt.show()

## Evaluation

In [None]:
print("Evaluate on test data")
results = res_model.evaluate(test_generator, callbacks=callbacks)
print("test evaluation metrics:", results)

In [None]:
from tabulate import tabulate

headers = ['Metric', 'Value']
data = [
    ['Loss', results[0]],
    ['Accuracy', results[1]],
    ['Precision', results[2]],
    ['Sensitivity', results[3]],
    ['Specificity', results[4]],
    ['Dice', results[5]],
    ['IoU', results[6]]
]

# Print the table
table = tabulate(data, headers, tablefmt='pretty')
print(table)

In [None]:
# To plot the pre-processed data we can use our custom test generator with batch_size equal to 1
test_generator = DataGenerator(list_IDs=test_ids, batch_size=1)

for index in range(len(test_generator)):
    random_axis = np.random.randint(0, 3)
    showPredictsById(test_generator, res_model, index, axis=random_axis)

In [None]:
random_index = np.random.randint(0, len(test_generator) - 1)
plot_3d_overlap(test_generator, res_model, random_index, density=30000, threshold=0.2)