In [None]:
# Install all dependencies for sgementation-models-3D library.
# We will use this library to call 3D unet.
# Alternative, you can define your own Unet, if you have skills!
!pip install classification-models-3D
!pip install efficientnet-3D
!pip install segmentation-models-3D

# Use patchify to break large volumes into smaller for training 
#and also to put patches back together after prediction.
!pip install patchify

In [None]:
import os
import random
import logging
import h5py
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from keras import backend as K
from keras.callbacks import ModelCheckpoint, LearningRateScheduler, EarlyStopping, ReduceLROnPlateau
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support

# funky librairies for segmentation
import segmentation_models_3D as sm
from patchify import patchify, unpatchify

print('All librairies sucessfully imported.')

In [None]:
# import data
PATH_COLAB = '/content/drive/MyDrive/6_aneurysm_segmentation/challenge_dataset.zip'
PATH_DEVICE = './challenge_dataset/'

try:
    from google.colab import drive
    logging.info('Working on Colab.')
    
    # connect your drive to the session
    drive.mount('/content/drive')

    %cd /content/drive/MyDrive/6_aneurysm_segmentation/

    # unzip data into the colab session
    ! unzip $PATH_COLAB -d /content
    logging.info('Data unziped in your Drive.')

except:
    logging.info('Working on your device.')
    
    data_exists = os.path.exists(PATH_DEVICE)
    
    if data_exists:
        logging.info(f"Dataset found on device at : '{PATH_DEVICE}.'") 
    else:
        raise FileNotFoundError(f"Data folder not found at '{PATH_DEVICE}'")

# Get data

In [None]:
PATH_DATASET='./challenge_dataset/'
TEST_SIZE = 0.2 # % of test samples from the full dataset
VAL_SPLIT = 0.2 # % of training samples kept for the validation metrics
CROP = 64

In [None]:
# get file names
file_names = os.listdir(PATH_DATASET)
N = len(file_names)
print(f'{N} samples in dataset.')

# open all .h5 files, split inputs and target masks, store all in np.arrays
raw_data = []
labels = []
names = []

for file_name in tqdm(file_names):
    f = h5py.File(f'{PATH_DATASET}/{file_name}', 'r')

    X, Y = np.array(f['raw']), np.array(f['label'])

    X = X[:,CROP:2*CROP,CROP:2*CROP]
    Y = Y[:,CROP:2*CROP,CROP:2*CROP]

    raw_data.append(X)
    labels.append(Y)
    names.append(file_name)

    # TO KEEP FOR LATER - USEFUL TO ADD FREE SAMPLES BY CROPPING
    # X_patches = patchify(X, (64, 64, 64), step=64)  # Step=64 for 64 patches means no overlap
    # X_patches_resh = np.reshape(X_patches, (-1, X_patches.shape[3], X_patches.shape[4], X_patches.shape[5]))
    # Y_patches = patchify(Y, (64, 64, 64), step=64)  # Step=64 for 64 patches means no overlap
    # Y_patches_resh = np.reshape(Y_patches, (-1, Y_patches.shape[3], Y_patches.shape[4], Y_patches.shape[5]))
    # raw_data.append(X_patches_resh)
    # labels.append(Y_patches_resh)
    # names.append(file_name)

# convert to arrays for patchify
raw_data = np.array(raw_data)
labels = np.array(labels)

# raw_data = np.reshape(raw_data, (-1, raw_data.shape[2], raw_data.shape[3], raw_data.shape[4]))
# labels = np.reshape(labels, (-1, labels.shape[2], labels.shape[3], labels.shape[4]))

raw_data = np.stack((raw_data,) * 3, axis=-1)
labels = np.expand_dims(labels, axis=4)

# check shapes
print(raw_data.shape)
print(labels.shape)

In [None]:
SCAN_ID = 45
DEPTH = 32

fig, ax = plt.subplots(1, 2)
ax[0].imshow(raw_data[SCAN_ID,:,:,DEPTH,0])
ax[1].imshow(labels[SCAN_ID,:,:,DEPTH,0]) # last 0 to get a 2D image
plt.show()

In [None]:
# split train and test data
X_train, X_test, y_train, y_test = train_test_split(raw_data, labels, test_size=TEST_SIZE)

# train_test_split returns lists, we want arrays for easier calls
X_train = np.array(X_train)
X_test = np.array(X_test)
y_train = np.array(y_train, dtype='float32')
y_test = np.array(y_test, dtype='float32')

# check shapes
print(X_train.shape)
print(X_test.shape)

# Models

In [None]:
# Loss Function and coefficients to be used during training:
def custom_iou(smooth=1e-6):
    """
    Returns a IoU function, with a custom smoothing parameter.
    Such a double function is needed because loss function in Keras are expected
    to take only two parameters. Therefore, smooth couldn't be a parameter.
    May be removed in future commits beacuse seems finally useless.
    """
    def IoULoss(targets, inputs):
        """
        Returns the intersection over union (IoU) of the two inputs masks.
        """
        inputs = tf.cast(inputs, tf.float32)
        targets = tf.cast(targets, tf.float32)
        # flatten label and prediction tensors
        inputs = K.flatten(inputs)
        targets = K.flatten(targets)

        intersection = K.sum(targets * inputs)
        total = K.sum(targets) + K.sum(inputs)
        union = total - intersection
        
        IoU = (intersection + smooth) / (union + smooth)
        return 1 - IoU

    return IoULoss

Backbones: ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'seresnet18', 'seresnet34', 'seresnet50', 'seresnet101', 'seresnet152', 'seresnext50', 'seresnext101', 'senet154', 'resnext50', 'resnext101', 'vgg16', 'vgg19', 'densenet121', 'densenet169', 'densenet201', 'inceptionresnetv2', 'inceptionv3', 'mobilenet', 'mobilenetv2', 'efficientnetb0', 'efficientnetb1', 'efficientnetb2', 'efficientnetb3', 'efficientnetb4', 'efficientnetb5', 'efficientnetb6', 'efficientnetb7']

So far, what worked best for our dataset (IOU_test = 43%) is:
```
BACKBONE = 'resnet50'
LOSS_TYPE = 'jaccard'
BATCH_SIZE = 8
LR = 1e-4
```

In [None]:
# MODEL PARAMETERS
encoder_weights = 'imagenet' # Try 'imagenet' or None (random initialization)
BACKBONE = 'resnet50'  # Try vgg16, efficientnetb7, inceptionv3, resnet50
activation = 'sigmoid' # final layer activation function, sigmoid for binary
patch_size = 64 # cube side length
n_classes = 1 # num channels output, here binary segmentation so 1
channels = 3 # num channels input, need 3 because backbones are trained on RGB
LOSS_TYPE = 'jaccard_focal' # check dict 'losses' down below

# TRAINING PARAMETERS
LR = 1e-4 # starting learning rate
EPOCHS = 100
BATCH_SIZE = 12

MODEL_NAME = f'./3D_model_{BACKBONE}_{encoder_weights}weights_{EPOCHS}epochs_{LOSS_TYPE}'
print(MODEL_NAME)

In [None]:
# Define optimizer
optim = tf.keras.optimizers.Adam(LR)

# Segmentation models losses can be combined together by '+' and scaled by integer or float factor
# set class weights for dice_loss (car: 1.; pedestrian: 2.; background: 0.5;)
dice_loss = sm.losses.DiceLoss() 
jaccard_loss = sm.losses.JaccardLoss()
focal_loss = sm.losses.BinaryFocalLoss()

losses = {'dice': dice_loss,
          'jaccard': jaccard_loss,
          'custom_jaccard': custom_iou(),
          'focal_loss': focal_loss,
          'dice_focal': dice_loss + (1 * focal_loss), 
          'jaccard_focal': jaccard_loss + (1 * focal_loss), 
          }

total_loss = losses.get(LOSS_TYPE)
assert total_loss is not None, ('Loss not defined. Check your spelling of LOSS_TYPE or the dict losses.')

# actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses
# total_loss = sm.losses.binary_focal_dice_loss # or sm.losses.categorical_focal_dice_loss 
metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]

In [None]:
# add callbacks to monitor training 
weight_path = "{}_weights.best.hdf5".format(MODEL_NAME)

checkpoint = ModelCheckpoint(weight_path, 
                             monitor='val_loss', 
                             verbose=1, 
                             save_best_only=True, 
                             mode='min', 
                             save_weights_only=True)

reduceLROnPlat = ReduceLROnPlateau(monitor='val_loss', 
                                   factor=0.5, 
                                   patience=4, 
                                   verbose=1, 
                                   mode='auto', 
                                   min_delta=0.0001, 
                                   cooldown=5, 
                                   min_lr=1e-6)

early = EarlyStopping(monitor="val_loss", 
                      mode="min", 
                      patience=12) 

callbacks_list = [checkpoint, 
                  early, 
                  reduceLROnPlat]

In [None]:
# Preprocess input data - otherwise you end up with garbage resutls 
# and potentially model that does not converge.
preprocess_input = sm.get_preprocessing(BACKBONE)

X_train_prep = preprocess_input(X_train)
X_test_prep = preprocess_input(X_test)

In [None]:
# Define the model. Here we use Unet but we can also use other model architectures from the library.
model = sm.Unet(BACKBONE, classes=n_classes, 
                input_shape=(patch_size, patch_size, patch_size, channels), 
                encoder_weights=encoder_weights,
                activation=activation)

model.compile(optimizer = optim, loss=total_loss, metrics=metrics)
# print(model.summary())

In [None]:
# Fit the model
history = model.fit(X_train_prep, 
                    y_train,
                    batch_size=BATCH_SIZE, 
                    epochs=EPOCHS,
                    verbose=1,
                    validation_split=VAL_SPLIT,
                    callbacks=callbacks_list)

In [None]:
# plot the training and validation IoU and loss at each epoch
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(loss) + 1)
plt.plot(epochs, loss, 'y', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

acc = history.history['iou_score']
val_acc = history.history['val_iou_score']

plt.plot(epochs, acc, 'y', label='Training IOU')
plt.plot(epochs, val_acc, 'r', label='Validation IOU')
plt.title('Training and validation IOU')
plt.xlabel('Epochs')
plt.ylabel('IOU')
plt.legend()
plt.show()

# Predict

In [None]:
# Load the pretrained model for testing and predictions. 
# you need a model instance before loading weights
print(f"Reload weights from : {MODEL_NAME}_weights.best.hdf5")
model.load_weights(f"{MODEL_NAME}_weights.best.hdf5")

In [None]:
# Predict on the test data
y_pred = model.predict(X_test_prep)

In [None]:
THRESHOLD = 0.5

y_pred01 = (y_pred > THRESHOLD).astype(int) # float => boolean => binary (0/1)

print(f'------ AFTER THRESHOLDING AT {THRESHOLD} ------')
print('> sm.metrics.IOUScore :', sm.metrics.IOUScore()(y_test, y_pred01).numpy())
print('> Custom IoU :', 1 - custom_iou()(y_test, y_pred01).numpy()) # to check my custom function, it was a loss so 1 - loss

# precision_recall_fscore_support report
precision, recall, fscore, support = precision_recall_fscore_support(y_test.flatten(), 
                                                                  y_pred01.flatten()) 
print('> Precision :', precision[1])
print('> Recall :', recall[1])
print('> Fscore :', fscore[1])

# Confusion matrix
cm = confusion_matrix(y_test.flatten(), 
                      y_pred01.flatten())
print('\nConfusion matrix :\n', cm)

In [None]:
# Test some random images

# pick a random test scan and its ground truth mask
test_img_number = random.randint(0, len(X_test)-1)
print(f'I choose test image n° {test_img_number} ...')

test_img = X_test[test_img_number]
ground_truth = y_test[test_img_number]

# process input image before prediction
test_img_input = np.expand_dims(test_img, 0)
test_img_input1 = preprocess_input(test_img_input)

# prediction
test_pred = model.predict(test_img_input1)
test_pred = test_pred.squeeze()

# thresholding + reshaping
print(test_pred.shape)

In [None]:
# Plot individual slices from test predictions for verification
SLICE_MIN, SLICE_MAX = 25, 45

for slice in range(SLICE_MIN, SLICE_MAX+1):
    plt.figure(figsize=(12, 8))
    plt.subplot(231)
    plt.title(f'Testing Image {slice}')
    plt.imshow(test_img[slice,:,:,0], cmap='gray')
    plt.subplot(232)
    plt.title(f'Testing Label {slice}')
    plt.imshow(ground_truth[slice,:,:,0])
    plt.subplot(233)
    plt.title(f'Prediction on test image {slice}')
    plt.imshow(test_pred[slice,:,:])
    plt.show()