# Muti-task Rodent TBI segmentation
## Skull-stripping and ROI Segmentation
Author: Marcello De Salvo <br>

# Table of contents<font><a class='anchor' id='top'></a>
1. [Importing Libraries](#lib)
2. [Problem Definition](#problem)
3. [Data Description](#data)
4. [Configuration](#conf)
5. [Data Visualization](#visual)
6. [Evaluation metrics](#metrics)  
7. [Data Loader](#load)
8. [Model](#model)
9. [Results](#results)
10. [Evaluation](#eval)

# 1. Libraries <a class='anchor' id='lib'></a> [↑](#top)

In [None]:
import numpy as np
import os
import importlib
import pandas as pd
import time

# neural imaging
import nilearn as nl
import nibabel as nib
import nilearn.plotting as nlplt

# tensorflow
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import CSVLogger, ModelCheckpoint, ReduceLROnPlateau, EarlyStopping, TensorBoard
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *
from tensorflow.keras.layers.experimental import preprocessing
from tensorflow.keras.utils import plot_model

# sklearn
from sklearn.model_selection import train_test_split

# Make numpy printouts easier to read.
np.set_printoptions(precision=3, suppress=True)
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
print(tf.__version__)


# fix random seed for reproducibility
seed = 5
np.random.seed(seed)
tf.random.set_seed(seed)

# 2. Problem definiton<a class='anchor' id='problem'></a> [↑](#top)
Problem: Skull-stripping and ROI semantic segmentation<br>
Each pixel in the image has to be assigned one of the following labels: <br>
- Background (label 0)
- Lesion (red, label 1)
- Ventricle contra (blue, label 3)
- Ventricle ipsi (light green, label 13)
- Third Ventricle (purple, label 21)

# 3. Image data descriptions <a class='anchor' id='data'></a> [↑](#top)

All multimodal scans are available as  NIfTI files (.nii.gz), a commonly used medical imaging format to store brain imagin data obtained using MRI and describe different MRI settings 
1. **T1w (Flash)**
2. **T2w (Rare)**

Data were acquired directly in this institute.
All the imaging datasets have been segmented manually.

# 4. Configuration <a class='anchor' id='conf'></a> [↑](#top)

In [None]:
# PRE PROCESSING
import time

config ={
    'dataset_path': '..\dataset_roi',
    'input_shape': (80,80,80),
    'target_resolution': (0.1,0.1,0.1),
    'labels': [0,1,3,13,21], # 0: background, 1: lesion, 3: contra-ventricle, 13: ipsi-ventricle, 21: third ventricle
    'mapping': {0:0, 1:1, 3:2, 13:2, 21:3}, # 0: background, 1: ventricles, 2: cortex, 3: hippocampus -> 4 classes
    'num_classes': 4,
    'in_channels': 1,
    'batch_size': 8,
    'epochs': 300,
    'lr': 1e-3,
    'model_name': "mice_roi_unet" + "_ep" + str(500) + time.strftime("_%d-%m-%Y_%H-%M"),
    'validation_split': 0.2,
    'test_split': 0.1,
    'mice_sampling_rate': 3,
    'rats_sampling_rate': 0,
}

# 5. Data Visualization <a class='anchor' id='visual'></a> [↑](#top)

In [None]:
from utils.visualization import *
from utils.loader import load_data

# Modalities
modalities = ['N4', 'brain_mask', 'Labels']

# Patient
scan_type = 't2w-C52-RARE'
scan_id = 'TBI_fm_19_50'
scan_folder = os.path.join(config['dataset_path'], scan_type, scan_id, 'Anat')
print(scan_folder)

# Load data
img, data, file_paths = load_data(scan_folder, scan_id, modalities)

# Plot data
plot_data(data)

In [None]:
# Check unique values in the lesion mask
print(f"Unique values in the lesion mask: {np.unique(data['brain_mask'])}")

In [None]:
# Load data
niimg = nl.image.load_img(file_paths['N4'])
nimask = nl.image.load_img(file_paths['Labels'])
fig, axes = plt.subplots(nrows=2, figsize=(30, 40))

# Plot data
nlplt.plot_img(niimg, title='N4', axes=axes[0])
nlplt.plot_roi(nimask, title='Labels', bg_img=niimg, axes=axes[1], cmap='jet')
plt.show()

In [None]:
# Print matrix shape
print(f"Matrix shape: {data['N4'].shape}")

# 6. Evaluation metrics <a class='anchor' id='loss'></a> [↑](#metrics)

In [None]:
from evaluation.metrics import *
from evaluation.losses import *

metrics = [accuracy_coefficient(), precision_coefficient(), sensitivity_coefficient(), specificity_coefficient(), dice_coefficient(), iou_coefficient(), volume_similarity_coefficient()]

# append to metric a class_dice_coef for each class
for i in range(config['num_classes']):
    metrics.append(dice_coefficient(class_index=i, exclude_background=False))

# 7. Data Loader <a class='anchor' id='load'></a> [↑](#top)
Loading all data into memory is not a good idea since the data are too big to fit in.<br>
So we will create a DataGenerators class to load data on the fly as explained [here](https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly)

In [None]:
from utils.loader import *

# Main Dataset
rodent_dataset = RodentDatasets(labels=config['labels'])

# Add all datasets
rodent_dataset.add_dataset(config['dataset_path'], 'T1w-C52-FLASH', sub_folder='Anat')
rodent_dataset.add_dataset(config['dataset_path'], 'T1w-CD1-FLASH', sub_folder='Anat')
rodent_dataset.add_dataset(config['dataset_path'], 'T2w-C52-RARE', sub_folder='Anat')
rodent_dataset.add_dataset(config['dataset_path'], 'T2w-Caen\\3 weeks', sub_folder='')

# Split data
train_and_test_ids = rodent_dataset.get_subjects_list()
print('Size of the dataset: ', len(train_and_test_ids))

# Splitting
train_test_ids, val_ids = train_test_split(train_and_test_ids,test_size=config['validation_split']) 
train_ids, test_ids = train_test_split(train_test_ids,test_size=config['test_split'])

In [None]:
# Loading manually the rats of different time points
rats_dataset = RodentDatasets(labels=config['labels'])
rats_dataset.add_dataset(config['dataset_path'], 'T2w-RATS', sub_folder='Anat')

# Extract ids   
rats_ids = rats_dataset.get_subjects_list()

# Extract unique subjects by checking the first 4 characters of the second element of the tuple (e.g. RAT1_5w -> RAT1)
id_char_length = 4
rats_unique_ids = list(set([x[1][:id_char_length] for x in rats_ids]))

# Print
print('---- ')
print('Unique rats ids: ', rats_unique_ids)

# Select one rat at random for validation and use the rest for training
val_rat = np.random.choice(rats_unique_ids)
train_rats = [x for x in rats_unique_ids if x != val_rat]

# Extract thet time points of the selected rats
val_rat_id = [x for x in rats_ids if x[1][:id_char_length] == val_rat]
train_rat_ids = [x for x in rats_ids if x[1][:id_char_length] in train_rats]

# Print
print('---- ')
print('Validation rat id: ', val_rat, ' - time points: ', val_rat_id)
print('Training rat ids: ', train_rats, ' - time points: ', train_rat_ids)

# Merge the tuples in rats ids belonging to the the selected rats with the train and validation ids
train_ids = train_ids * config['mice_sampling_rate'] + train_rat_ids * config['rats_sampling_rate']
val_ids = val_ids * config['mice_sampling_rate'] + val_rat_id * config['rats_sampling_rate']
test_ids = test_ids * config['mice_sampling_rate']

# Print
print('---- ')
print('Training ids: ', train_ids)
print('Validation ids: ', val_ids)
print('Test ids: ', test_ids)

In [None]:
from preprocessing.preprocessor import Preprocessor, Resample, Reorient, Normalize, CorrectX10, MapLabels, RandomCropping, RandomAffine, GaussianBlur, Noise, Flip, Padder

# ref image for reorientation
ref_img = nib.load(os.path.join('../example', 'RARE', 'TBI_fm_19_49', 'Anat', 'TBI_fm_19_49_N4.nii.gz'))

augmented = Preprocessor([
    MapLabels(config['labels'], mapping=config['mapping']),
    CorrectX10(),
    Reorient(ref_img),
    RandomAffine(rotation_range=[-5,5], scale_range=[0.95,1.05], probability=0.3),
    GaussianBlur([0,0.6], probability=0.3),
    Noise([0,0.05], probability=0.3),
    Resample(target_resolution=config['target_resolution'], interpolation=0),
    Normalize(),
    Padder(config['input_shape'], 'constant'),
    RandomCropping(config['input_shape'], mode='center', std=None),
    Flip(axis_list=[0], probability=0.5),
])

preprocessor =  Preprocessor([
    MapLabels(config['labels'], mapping=config['mapping']),
    CorrectX10(),
    Reorient(ref_img),
    Resample(target_resolution=config['target_resolution'], interpolation=0),
    Normalize(),
    Padder(config['input_shape'], 'constant'),
    RandomCropping(config['input_shape'], mode='center', std=None),
    Flip(axis_list=[0], probability=0.5),
])

In [None]:
from preprocessing.generator import MultiTaskGenerator

# Datasets Initialization
training_generator = MultiTaskGenerator(train_ids, rodent_dataset, batch_size=config['batch_size'], preprocessor=augmented, config=config)
valid_generator = MultiTaskGenerator(val_ids, rodent_dataset, batch_size=config['batch_size'], preprocessor=preprocessor, config=config)
test_generator = MultiTaskGenerator(test_ids, rodent_dataset, batch_size=config['batch_size'], preprocessor=preprocessor, config=config)

print("Val IDs: ", val_ids)
print("Train IDs: ", train_ids)
print("Test IDs: ", test_ids)

## Sanity Check

In [None]:
X,[Y,M]= training_generator.__getitem__(index=1) # Fetching the first batch  (X, [Seg, Mask])

print('Img shape: ', X.shape) # Should be equal to (BATCH_SIZE, IMG_SIZE, IMG_SIZE, NUM_SLICES, IN_CHANNELS) 
print('Labels shape: ', Y.shape) # Should be equal to (BATCH_SIZE, IMG_SIZE, IMG_SIZE, NUM_SLICES, NUM_CLASSES)
print('Brain mask shape:', M.shape) # Should be equal to (BATCH_SIZE, IMG_SIZE, IMG_SIZE, NUM_SLICES, 1)

# print max and min values in X
print("Max value in X: ", np.max(X))
print("Min value in X: ", np.min(X))

# Check if Y is one-hot encoded and has 4 different channels with label 1
print("Unique values in Y: ", np.unique(Y))

# print max and min values in Y
print("Max value in Y: ", np.max(Y))
print("Min value in Y: ", np.min(Y))

# Plot preview
sample=7
layer=config['input_shape'][2]//2

yhat=Y[sample]
yhat[yhat==-1]=0 # Convert all -1 to 0
yhat = np.argmax(yhat, axis=-1)

print("Unique values in Y: ", np.unique(yhat))
plt.figure(figsize=(15, 5))
plt.imshow(np.rot90(X[sample,:,:,layer,0], k=-1),cmap='gray')
plt.imshow(np.rot90(yhat[:,:,layer], k=-1), cmap='jet', alpha=0.6)
plt.imshow(np.rot90(M[sample,:,:,layer,0], k=-1), cmap='gray', alpha=0.2)
plt.title("Processed")
plt.show()

In [None]:
explore_3D_array_comparison(np.rot90(X[sample,:,:,:,0],k=-1),np.rot90(Y[sample,:,:,:,0],k=-1), axis=-1)

## Show Data Split Distribution

In [None]:
plt.bar(["Train","Validation", "Test"],
[len(train_ids), len(val_ids), len(test_ids)], align='center',color=[ 'green','red','blue'])

plt.ylabel('Number of images')
plt.title('Data distribution')

plt.show()

# 8. Model | 3D U-Net <a class='anchor' id='model'></a> [↑](#top)

In [None]:
from models.networks import *

# Example usage:
filters = [16, 32, 64, 128, 256]
model = mt_r_net_3d((None,None,None), config['in_channels'], config['num_classes'], filters, attention=True, residual=True)

In [None]:
# To load a pre-trained model for fine-tuning
'''
model = load_model('../results/mice_rnet_pretrained/mice_rnet_pretrained.h5', 
                    custom_objects={'mt_r_net_3d': mt_r_net_3d, 
                                    'diceCELoss': diceCELoss, 'diceBCELoss': diceBCELoss, 'loss':losses,
                                    'mean_accuracy': accuracy_coefficient, 'mean_precision': precision_coefficient, 'mean_sensitivity': sensitivity_coefficient, 
                                    'mean_specificity': specificity_coefficient, 'mean_dice': dice_coefficient, 'mean_iou': iou_coefficient, 
                                    'mean_volume_similarity': volume_similarity_coefficient, 'class_0_dice': dice_coefficient(class_index=0, exclude_background=False),
                                    'class_1_dice': dice_coefficient(class_index=1, exclude_background=False), 'class_2_dice': dice_coefficient(class_index=2, exclude_background=False),
                                    'class_3_dice': dice_coefficient(class_index=3, exclude_background=False)})

model.summary()
    
# compile the model
model.compile(optimizer=Nadam(learning_rate=config['lr']), loss=losses, loss_weights=loss_weights, metrics=metrics)
'''

## Overview

In [None]:
# Print summary of the model giving the input shape to function
model.summary()

## Callbacks

In [None]:
import datetime

csv_logger = CSVLogger('../results/'+config['model_name']+'/training.log', separator=',', append=False)
log_dir = "../results/"+config['model_name']+"/logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
checkpoint_path = "../results/"+config['model_name']+"/checkpoint/"

callbacks = [
        tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=10, min_lr=1e-7, verbose=1),
        tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1),
        tf.keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=20, restore_best_weights=True),
        tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, monitor='val_loss', save_best_only=True, mode='min'),
        csv_logger
    ]

## Training

In [None]:
steps = len(train_ids) // config['batch_size']
val_steps = len(val_ids) // config['batch_size']

roi_loss = diceCELoss(smooth=1e-5, batch_wise=True, gamma=0.8)
skullstrip_loss = diceBCELoss(alpha=0.6, smooth=1e-5, batch_wise=True)

losses = {
    'regions': roi_loss,
    'brain_mask': skullstrip_loss,
}
loss_weights = {
    'regions': 1,
    'brain_mask': 1,
}

tasks_metrics = {
    'regions': metrics,
    'brain_mask': dice_coefficient(class_index=0, exclude_background=False),
}

model.compile(loss=losses, optimizer=tf.keras.optimizers.Nadam(learning_rate=1e-3), metrics=tasks_metrics, loss_weights=loss_weights)
history = model.fit(training_generator,epochs=config['epochs'], steps_per_epoch=steps, callbacks= callbacks, validation_data=valid_generator, validation_steps=val_steps)

In [None]:
# Load the best model
model.load_weights(checkpoint_path)

# Save the model
model.save("../results/"+config['model_name']+"/save_" + config['model_name'] + ".h5")

In [None]:
model_history = pd.read_csv('../results/'+config['model_name']+'/training.log')

# 9. Results <a class='anchor' id='results'></a> [↑](#top)
## History

In [None]:
import utils.visualization
importlib.reload(utils.visualization)
from utils.visualization import plot_history, plot_loss
plot_loss(model_history, path='../results/'+config['model_name']+'/' + 'loss_history.png', log=False)
plot_history(model_history, path='../results/'+config['model_name']+'/' + 'metrics_history.png', figsize=(50,20))

## Predictions

In [None]:
from utils.visualization import plot_multitask
# To plot the pre-processed data we can use our custom test generator with batch_size equal to 1
test_plot_generator = MultiTaskGenerator(ids=test_ids, loader=rodent_dataset, batch_size=1, preprocessor=preprocessor, config=config, shuffle=True)
for index in range(len(test_plot_generator))[:15]:
    plot_multitask(val_ids, index, test_plot_generator, model)

In [None]:
import numpy as np
from evaluation.inference import RandomCroppingPrediction
from evaluation.postprocessing import ipsi_contra_division_callback

# Take a random id from the validation set tuple
id = test_ids[1]
print("ID: ", id)

# Preprocessor
inferenceProcessor = Preprocessor([
    CorrectX10(),
    MapLabels(labels=config['labels']),
    Resample(target_resolution=config['target_resolution'], interpolation=0),
    Reorient(ref_img),
    Normalize(),
    Flip(axis_list=[0], probability=1),
])

# Get mouse object from loader
mouse = rodent_dataset.get_subject(id)

# Extract images and masks
image, roi_ground_truth, brain_mask_ground_truth = mouse.get_images()

# Preprocess and augment the images
prep_image = inferenceProcessor.preprocess(image)

# Predict regions and brain mask
predictor = RandomCroppingPrediction(model, patch_size=config['input_shape'], stride=32, threshold=0.5, num_classes=config['num_classes'])
results = predictor.random_cropping_inference(prep_image, with_brain_mask=True)
Y_pred = results['roi']
y_mask = results['brain_mask']

# Post processing
Y_pred = ipsi_contra_division_callback(visualize_pca=True, use_centroids=False)(Y_pred)

# Explore the final roi mask
explore_3D_overlay(arr_before=np.rot90(prep_image.get_fdata(), k=-1), mask=np.rot90(Y_pred, k=-1), axis=-1)


# 10. Evaluation <a class='anchor' id='eval'></a> [↑](#top)

In [None]:
print("Evaluate on test data")
results = model.evaluate(test_generator, batch_size=100, callbacks= callbacks)
print("test loss, test acc:", results)

In [None]:
from tabulate import tabulate
from utils.utils import save_metrics, save_model_info

# Print test
print(test_ids)
# Save metrics, model info, and augmentation info
print("Model name: ", config['model_name'])
save_metrics(results, model, path=f'../results/{config["model_name"]}/metrics.txt')
save_model_info(model, config, filters, test_ids, path=f'../results/{config["model_name"]}/model_info.txt')
augmented.save_configuration(path=f'../results/{config["model_name"]}/augmentation_config.txt')
preprocessor.save_configuration(path=f'../results/{config["model_name"]}/preprocessor_config.txt')

# 11. Full volume evaluation <a class='anchor' id='eval'></a> [↑](#top)

In [None]:
from evaluation.inference import FullVolumeEvaluation
from evaluation.postprocessing import ipsi_contra_division_callback, morphology_refinement_callback

inferenceProcessor = Preprocessor([
    CorrectX10(),
    MapLabels(labels=config['labels']),
    Resample(target_resolution=config['target_resolution'], interpolation=0),
    Reorient(ref_img),
    Normalize(),
])

# ROIS
full_volume_eval = FullVolumeEvaluation(model, test_ids, config, rodent_dataset, roi_postprocess_callback=ipsi_contra_division_callback(),
                                        strides=[20,40], preprocessor=inferenceProcessor, path='../results/'+config['model_name']+'/', verbose=True)
resu= full_volume_eval.evaluate()

In [None]:
# BRAIN MASK
full_volume_eval = FullVolumeEvaluation(model, test_ids, config, rodent_dataset, strides=[20,40], preprocessor=inferenceProcessor, 
                                        brain_mask_postprocess_callback=morphology_refinement_callback(), path='../results/'+config['model_name']+'/', verbose=True)
resu= full_volume_eval.evaluate(evaluate_brain_mask=True)