### tools

In [7]:
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from keras.models import load_model
from ipywidgets import interact, widgets
from keras import backend as K

## Data Paths

In [11]:
# paths
corona_img_path = './data/coronacases/images/'
corona_mask_path = './data/coronacases/masks/'
mosmed_img_path = './data/mosmed/images/'
mosmed_mask_path = './data/mosmed/masks/'

In [12]:
# .nii file path for each patient slices
datasets = {
    # Coronacases patients slices
    'coronacases patient 1': (corona_img_path + 'coronacases_001.nii', corona_mask_path + 'coronacases_001.nii'),
    'coronacases patient 2': (corona_img_path + 'coronacases_002.nii', corona_mask_path + 'coronacases_002.nii'),
    'coronacases patient 3': (corona_img_path + 'coronacases_003.nii', corona_mask_path + 'coronacases_003.nii'),
    'coronacases patient 4': (corona_img_path + 'coronacases_004.nii', corona_mask_path + 'coronacases_004.nii'),
    'coronacases patient 5': (corona_img_path + 'coronacases_005.nii', corona_mask_path + 'coronacases_005.nii'),

    # Mosmed patients slices
    'mosmed patient 1': (mosmed_img_path + 'study_0255.nii', mosmed_mask_path + 'study_0255.nii'),
    'mosmed patient 2': (mosmed_img_path + 'study_0256.nii', mosmed_mask_path + 'study_0256.nii'),
    'mosmed patient 3': (mosmed_img_path + 'study_0257.nii', mosmed_mask_path + 'study_0257.nii'),
    'mosmed patient 4': (mosmed_img_path + 'study_0258.nii', mosmed_mask_path + 'study_0258.nii'),
    'mosmed patient 5': (mosmed_img_path + 'study_0259.nii', mosmed_mask_path + 'study_0259.nii'),
}

In [13]:
# Model paths for the different models
models = {
    # Coronacases models
    'Coronacases Unet': './model/coronacases/mo_coronacases.keras',
    'Coronacases Att U-Net': './model/coronacases/mo_att_coronacases.keras',
    'Coronacases R2U-Net': './model/coronacases/mo_r2_coronacases.keras',
    'Coronacases Att R2U-Net': './model/coronacases/mo_att_r2_coronacases.keras',

    # Mosmed models
    'Mosmed Unet': './model/mosmed/mo_unet.keras',
    'Mosmed Att U-Net': './model/mosmed/mo_att.keras',
    'Mosmed R2U-Net': './model/mosmed/mo_r2.keras',
    'Mosmed Att R2U-Net': './model/mosmed/mo_att_r2.keras',
}

## Load Data

In [15]:
# Load .nii files and return list of images and masks
def load_nii(imgpath, maskpath):
    img_e = nib.load(imgpath).get_fdata()
    ms_e = nib.load(maskpath).get_fdata()
    cant_im = img_e.shape[2]

    img_list = []
    msk_list = []

    for i in range(cant_im):
        im = img_e[:, :, i]
        ms = ms_e[:, :, i]

        if np.sum(ms) > 0:
            img_list.append(im)
            msk_list.append(ms)

    return img_list, msk_list

## Preprocessing

In [16]:
# Resize the image and mask to 128x128x1, dimensions trained in the models
def resize_image(image, mask):
    image = np.expand_dims(image, axis=-1)
    mask = np.expand_dims(mask, axis=-1)
    image = tf.keras.preprocessing.image.smart_resize(image, (128, 128), interpolation='bilinear')
    mask = tf.keras.preprocessing.image.smart_resize(mask, (128, 128), interpolation='bilinear')
    mask = (mask > 0.5).astype(int)
    return image, mask

In [17]:
# Normalize fucntion for mosmed images
def normalize(image, min_value=-1000, max_value=200):
    img_clipped = np.clip(image, min_value, max_value)
    img_normalized = (img_clipped - min_value) / (max_value - min_value)
    return img_normalized

## Predict

In [18]:
# Given a image, mask and predicted mask show the overlay of the mask and the predicted mask in original image with a dice score given
def show_image_masks(image, true_mask, pred_mask, dice_score):
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    #Original Image
    axes[0].imshow(image, cmap='gray')
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    #Overlay True Mask
    true_overlay = np.ma.masked_where(true_mask == 0, true_mask)
    axes[1].imshow(image, cmap='gray')
    axes[1].imshow(true_overlay, cmap='hsv', alpha=0.5)
    axes[1].set_title('True Mask Overlay')
    axes[1].axis('off')

    #Overlay Predicted Mask
    pred_overlay = np.ma.masked_where(pred_mask == 0, pred_mask)
    axes[2].imshow(image, cmap='gray')
    axes[2].imshow(pred_overlay, cmap='hsv', alpha=0.5)
    axes[2].set_title('Predicted Mask Overlay')
    axes[2].axis('off')

    # DSC
    fig.suptitle(f'Dice Coefficient: {dice_score:.2%}', fontsize=16)
    plt.show()

In [19]:
# Binary probability mask from the model output with a given threshold (0.3375)
def binary_prob(prob_mask, umbral=0.3375):
    msk_pred = np.expand_dims(prob_mask[0, :, :, 1], axis=-1)
    msk_pred = (msk_pred > umbral).astype(int)
    return msk_pred

# Predict the mask of a given image with a given model
def predict_image(model, img, umbral=0.3375):
    img = np.expand_dims(img, axis=0)
    prob_mask = model.predict(img,verbose=0)
    msk_pred = binary_prob(prob_mask, umbral)
    return msk_pred

# Show the prediction of a given model with a given image and mask with a given dice score  
def show_predict(model, img, msk, dice, umbral=0.3375):
    img = np.expand_dims(img, axis=0)
    prob_mask = model.predict(img, verbose=0)
    msk_pred = binary_prob(prob_mask, umbral)
    img_true = msk
    img = np.squeeze(img, axis=0)
    show_image_masks(img, img_true, msk_pred, dice)

In [20]:
# Calculate the dice coefficient between two images (mask and predicted mask)
def dice_coefficient(image1, image2):
    intersection = np.logical_and(image1, image2)
    intersection_sum = np.sum(intersection)
    image1_sum = np.sum(image1)
    image2_sum = np.sum(image2)

    if (image1_sum + image2_sum) == 0:
        dice = 0
    else:
        dice = (2.0 * intersection_sum) / (image1_sum + image2_sum)

    return dice

In [21]:
# Get the size of the dataset to control the slider in the interactive plot
def get_img_list_length(dataset_name):

    img_path, msk_path = datasets[dataset_name]
    img_list, _ = load_nii(img_path, msk_path)
    
    return len(img_list)

## Result

In [23]:
import keras.backend as K

# Interactive plot to show the prediction of a given model with a given dataset and slice
# The plot shows the original image, the mask overlay and the predicted mask overlay with the dice score
# The plot is interactive with the model, dataset and slice
def interactive_plot(model_name, dataset_name, slice_idx):
    model_path = models[model_name]
    model = load_model(model_path, compile=False)
    data = model_name.split(' ')[0].lower()

    img_path, mask_path = datasets[dataset_name]
    img_list, msk_list = load_nii(img_path, mask_path)

    if data == 'mosmed':
        img_list = [normalize(img) for img in img_list]

    img, msk = resize_image(img_list[slice_idx], msk_list[slice_idx])
    img_pred = predict_image(model, img)
    dice = dice_coefficient(msk, img_pred)
    show_predict(model, img, msk, dice)

# Update the slice slider range when the dataset is changed
def update_slice_idx_range(dataset_name):
    max_idx = get_img_list_length(dataset_name) - 1
    slice_idx_slider.max = max_idx

# Create the interactive plot with the model, dataset and slice widgets
dataset_name_dropdown = widgets.Dropdown(options=datasets.keys(), description='Dataset:')
slice_idx_slider = widgets.IntSlider(min=0, max=50, step=1, description='Slice:')
model_name_dropdown = widgets.Dropdown(options=models.keys(), description='Model:')

# Update the slice slider range when the dataset is changed
def on_dataset_change(change):
    update_slice_idx_range(change['new'])

# Observe the dataset change to update the slice slider range
dataset_name_dropdown.observe(on_dataset_change, names='value')

# Show the interactive plot
interact(
    interactive_plot,
    model_name=model_name_dropdown,
    dataset_name=dataset_name_dropdown,
    slice_idx=slice_idx_slider
)

interactive(children=(Dropdown(description='Model:', options=('Coronacases Unet', 'Coronacases Att U-Net', 'Co…

<function __main__.interactive_plot(model_name, dataset_name, slice_idx)>