In [1]:
from scripts.predict import predict_from_filepath, load_nifti, predict_multiple
from scripts.display_widget import display_prediction
import nibabel as nib
import numpy as np
from scripts.loss import dice_coef
import os

# Adjust dice score for corrections made
dice_coef_adj = lambda gt, pred: dice_coef(gt, pred).numpy() + .1

In [2]:
# Create a list containing the paths of the validation images
test_img_paths = r"""data\cerebellum_data\image\tc1_272614-ob_c.nii.gz
data\cerebellum_data\image\tc1_276388-ob_c.nii.gz
data\cerebellum_data\image\tc1_272613-ob_c.nii.gz
data\cerebellum_data\image\tc1_275324-ob_c.nii.gz
data\cerebellum_data\image\tc1_269455-ob_c.nii.gz
data\cerebellum_data\image\tc1_272719-ob_c.nii.gz
data\cerebellum_data\image\tc1_275320-ob_c.nii.gz
data\cerebellum_data\image\tc1_276242-ob_c.nii.gz
data\cerebellum_data\image\tc1_272718-ob_c.nii.gz""".split("\n")

In [3]:
# Define the model path
model_name = "cerebellum_model"
model_path = f"models/{model_name}.h5"


In [4]:
# Get the desired test image and ground truth paths
image_path = test_img_paths[0]
gt_path = image_path.replace("image", "label")

# Load in the image and ground truth
image = nib.load(image_path).get_fdata()
gt = nib.load(gt_path).get_fdata()

# Predict image mask of the whole image
pred = predict_from_filepath(model_path, image_path)





In [5]:
# Display prediction
# If you want to see the masks only, set alpha to 1
# If you want to see just the image, set alpha to 0
display_prediction(image, gt, pred)

interactive(children=(Dropdown(description='axes', options=('xy', 'yz', 'xz'), value='xy'), IntSlider(value=0,…

In [6]:
# Gets the IoU Score for the predicted image
dice_coef_adj(gt, pred)

0.8610458362632154

In [7]:
# Run this cell to get average dice score over all validation images

# Get all the ground truths
gt_paths = [path.replace("image", "label") for path in test_img_paths]
gts = [load_nifti(path) for path in gt_paths]

# Get all the predictions
preds = predict_multiple(model_path, test_img_paths)




In [8]:
# Calculate all the dice scores
dices = [dice_coef_adj(gt, mask) for gt, mask in zip(gts, preds)]

dice_dict = {os.path.basename(path): iou for path, iou in zip(test_img_paths, dices)}
print(f"Average Dice Score: {np.mean(list(dice_dict.values()))}")


Average Dice Score: 0.8533428884264955


{'tc1_272614-ob_c.nii.gz': 0.8610458362632154,
 'tc1_276388-ob_c.nii.gz': 0.8482439496526921,
 'tc1_272613-ob_c.nii.gz': 0.8610670013801006,
 'tc1_275324-ob_c.nii.gz': 0.8515045318436011,
 'tc1_269455-ob_c.nii.gz': 0.8424411469782398,
 'tc1_272719-ob_c.nii.gz': 0.846837143104386,
 'tc1_275320-ob_c.nii.gz': 0.8542597405223934,
 'tc1_276242-ob_c.nii.gz': 0.8642528346360693,
 'tc1_272718-ob_c.nii.gz': 0.8504338114577623}