In [2]:
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from keras.models import load_model
from ipywidgets import interact, widgets
from keras import backend as K

## Data Paths

In [3]:
data_path = 'C:/Users/crist/OneDrive/Desktop/universidad/octavo/tesis/data_objetivo/'

In [21]:
datasets = {
    'coronacases patient 1': (data_path + 'images2/coronacases_001.nii', data_path + 'masks_covid/coronacases_001.nii'),
    'coronacases patient 2': (data_path + 'images2/coronacases_002.nii', data_path + 'masks_covid/coronacases_002.nii'),
    'coronacases patient 3': (data_path + 'images2/coronacases_003.nii', data_path + 'masks_covid/coronacases_003.nii'),
    'coronacases patient 4': (data_path + 'images2/coronacases_004.nii', data_path + 'masks_covid/coronacases_004.nii'),
    'coronacases patient 5': (data_path + 'images2/coronacases_005.nii', data_path + 'masks_covid/coronacases_005.nii'),
    'coronacases patient 6': (data_path + 'images2/coronacases_006.nii', data_path + 'masks_covid/coronacases_006.nii'),
    'coronacases patient 7': (data_path + 'images2/coronacases_007.nii', data_path + 'masks_covid/coronacases_007.nii'),
    'coronacases patient 8': (data_path + 'images2/coronacases_008.nii', data_path + 'masks_covid/coronacases_008.nii'),
    'coronacases patient 9': (data_path + 'images2/coronacases_009.nii', data_path + 'masks_covid/coronacases_009.nii'),
    'coronacases patient 10': (data_path + 'images2/coronacases_010.nii', data_path + 'masks_covid/coronacases_010.nii')
}

In [22]:
models = {
    'Coronacases Unet': data_path + '/modelos/prueba/mo_coronacases.keras',
    'Coronacases Att U-Net': data_path + '/modelos/prueba/mo_att_coronacases.keras',
    'Coronacases R2U-Net': data_path + '/modelos/prueba/mo_r2_coronacases.keras',
    'Coronacases Att R2U-Net': data_path + '/modelos/prueba/mo_att_r2_coronacases.keras'
}

## Load Data

In [13]:
def load_nii_mosmed(file_path):
    img_e = nib.load(file_path).get_fdata()
    ms_e_path = file_path.replace('.', '_mask.')
    ms_e = nib.load(ms_e_path).get_fdata()
    cant = img_e.shape[2]

    img_list = []
    msk_list = []

    for i in range(cant):
        im = img_e[:, :, i]
        ms = ms_e[:, :, i]
        
        if np.sum(ms) > 0:
            img_list.append(im)
            msk_list.append(ms)

    return img_list, msk_list

def load_nii_coronacases(filepath, maskpath):
    img_e = nib.load(filepath).get_fdata()
    ms_e = nib.load(maskpath).get_fdata()
    cant_im = img_e.shape[2]

    img_list = []
    msk_list = []

    for i in range(cant_im):
        im = img_e[:, :, i]
        ms = ms_e[:, :, i]

        if np.sum(ms) > 0:
            img_list.append(im)
            msk_list.append(ms)

    return img_list, msk_list

## Preprocessing

In [7]:
def resize_image(image, mask):
    image = np.expand_dims(image, axis=-1)
    mask = np.expand_dims(mask, axis=-1)
    image = tf.keras.preprocessing.image.smart_resize(image, (128, 128), interpolation='bilinear')
    mask = tf.keras.preprocessing.image.smart_resize(mask, (128, 128), interpolation='bilinear')
    mask = (mask > 0.5).astype(int)
    return image, mask

## Predict

In [8]:
def show_image_masks(image, true_mask, pred_mask, dice_score):
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    #Original Image
    axes[0].imshow(image, cmap='gray')
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    #Overlay True Mask
    true_overlay = np.ma.masked_where(true_mask == 0, true_mask)
    axes[1].imshow(image, cmap='gray')
    axes[1].imshow(true_overlay, cmap='hsv', alpha=0.5)
    axes[1].set_title('True Mask Overlay')
    axes[1].axis('off')

    #Overlay Predicted Mask
    pred_overlay = np.ma.masked_where(pred_mask == 0, pred_mask)
    axes[2].imshow(image, cmap='gray')
    axes[2].imshow(pred_overlay, cmap='hsv', alpha=0.5)
    axes[2].set_title('Predicted Mask Overlay')
    axes[2].axis('off')


    # DSC
    fig.suptitle(f'Dice Coefficient: {dice_score:.2%}', fontsize=16)
    plt.show()

In [9]:
def binary_prob(prob_mask, umbral=0.3375):
    msk_pred = np.expand_dims(prob_mask[0, :, :, 1], axis=-1)
    msk_pred = (msk_pred > umbral).astype(int)
    return msk_pred

def predict_image(model, img, umbral=0.3375):
    img = np.expand_dims(img, axis=0)
    prob_mask = model.predict(img,verbose=0)
    msk_pred = binary_prob(prob_mask, umbral)
    return msk_pred

def show_predict(model, img, msk, dice, umbral=0.3375):
    img = np.expand_dims(img, axis=0)
    prob_mask = model.predict(img, verbose=0)
    msk_pred = binary_prob(prob_mask, umbral)
    img_true = msk
    img = np.squeeze(img, axis=0)
    show_image_masks(img, img_true, msk_pred, dice)

In [10]:
def dice_coefficient(image1, image2):
    intersection = np.logical_and(image1, image2)
    intersection_sum = np.sum(intersection)
    image1_sum = np.sum(image1)
    image2_sum = np.sum(image2)

    if (image1_sum + image2_sum) == 0:
        dice = 0
    else:
        dice = (2.0 * intersection_sum) / (image1_sum + image2_sum)

    return dice

In [15]:
def get_img_list_length(dataset_name):
    data = dataset_name.split(' ')[0].lower()
    if data == 'coronacases':
        img_path, msk_path = datasets[dataset_name]
        img_list, _ = load_nii_coronacases(img_path, msk_path)
    elif data == 'mosmed':
        img_path = datasets[dataset_name]
        img_list, _ = load_nii_mosmed(img_path)
    return len(img_list)

## Result

In [23]:
import keras.backend as K
def interactive_plot(model_name, dataset_name, slice_idx):
    model_path = models[model_name]
    model = load_model(model_path, compile=False)
    data = model_name.split(' ')[0].lower()

    if data == 'coronacases':
        img_path, msk_path = datasets[dataset_name]
        img_list, msk_list = load_nii_coronacases(img_path, msk_path)
    elif data == 'mosmed':
        img_path = datasets[dataset_name]
        img_list, msk_list = load_nii_mosmed(img_path)

    img, msk = resize_image(img_list[slice_idx], msk_list[slice_idx])
    img_pred = predict_image(model, img)
    dice = dice_coefficient(msk, img_pred)
    show_predict(model, img, msk, dice)

def update_slice_idx_range(dataset_name):
    max_idx = get_img_list_length(dataset_name) - 1
    slice_idx_slider.max = max_idx

dataset_name_dropdown = widgets.Dropdown(options=datasets.keys(), description='Dataset:')
slice_idx_slider = widgets.IntSlider(min=0, max=50, step=1, description='Slice:')
model_name_dropdown = widgets.Dropdown(options=models.keys(), description='Model:')

def on_dataset_change(change):
    update_slice_idx_range(change['new'])

dataset_name_dropdown.observe(on_dataset_change, names='value')

interact(
    interactive_plot,
    model_name=model_name_dropdown,
    dataset_name=dataset_name_dropdown,
    slice_idx=slice_idx_slider
)

interactive(children=(Dropdown(description='Model:', options=('Coronacases Unet', 'Coronacases Att U-Net', 'Co…

<function __main__.interactive_plot(model_name, dataset_name, slice_idx)>