In [1]:
from __future__ import print_function, absolute_import

import os
import warnings

import SimpleITK as sitk
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

%matplotlib notebook

# from atlas_registration_functions import registrate_atlas_patient

warnings.filterwarnings("ignore", category=DeprecationWarning)

# Paths
paths = open(os.getcwd() + r"\\paths.txt").read().splitlines()
ELASTIX_PATH, TRANSFORMIX_PATH, DATA_PATH ,OUTPUT_DIR = paths


if not os.path.exists(ELASTIX_PATH):
    raise IOError('Elastix cannot be found, please set the correct ELASTIX_PATH.')
if not os.path.exists(TRANSFORMIX_PATH):
    raise IOError('Transformix cannot be found, please set the correct TRANSFORMIX_PATH.')


def calc_dice(true_del, est_del):
    # Ensure the arrays are binary (0s and 1s)
    true_del = (true_del > 0).astype(np.uint8)
    est_del = (est_del > 0).astype(np.uint8)

    intersection = np.sum(true_del * est_del)
    size1 = np.sum(true_del)
    size2 = np.sum(est_del)

    if size1 + size2 == 0:
        return 1.0  # If both are empty, define DICE as 1.0 (perfect match)

    return 2.0 * intersection / (size1 + size2)

In [2]:
# Get patient names and select atlas patients
patient_list = [patient for patient in os.listdir(DATA_PATH) if os.path.isdir(os.path.join(DATA_PATH, patient))]
print(os.listdir(OUTPUT_DIR))
# atlas_patients = patient_list[:5]
atlas_patients = ["p102", "p108", "p109"]
register_patients = [patient[8:-4] for patient in os.listdir(OUTPUT_DIR) if
                     patient.startswith("reg_maj_") and patient.endswith(".mhd")]
# print(register_patients)
true_delineations = []
est_delineations = []

for patient in register_patients:
    true_delineation = os.path.join(DATA_PATH, patient, 'prostaat.mhd')
    est_delineation = os.path.join(OUTPUT_DIR, f'reg_maj_{patient}.mhd')

    true_delineation_image = sitk.ReadImage(true_delineation)
    est_delineation_image = sitk.ReadImage(est_delineation)

    true_delineations.append(sitk.GetArrayFromImage(true_delineation_image))
    est_delineations.append(sitk.GetArrayFromImage(est_delineation_image))

dice_scores = []
for i in range(len(true_delineations)):
    dice_score = calc_dice(true_delineations[i], est_delineations[i])
    print(f"Patient {register_patients[i]} registration reached a dice score of: {dice_score:.3f}")
    dice_scores.append(dice_score)

[]


# Show results

In [3]:
# Get patient names and select registered patients
patient_list = [patient for patient in os.listdir(DATA_PATH) if os.path.isdir(os.path.join(DATA_PATH, patient))]
register_patients = [patient[8:-4] for patient in os.listdir(OUTPUT_DIR) if
                     patient.startswith("reg_maj_") and patient.endswith(".mhd")]

# Load images
true_delineations = []
est_delineations = []

for patient in register_patients[:5]:  # Limit to 5 patients
    true_delineation_path = os.path.join(DATA_PATH, patient, 'prostaat.mhd')
    est_delineation_path = os.path.join(OUTPUT_DIR, f'reg_maj_{patient}.mhd')

    true_delineation_image = sitk.ReadImage(true_delineation_path)
    est_delineation_image = sitk.ReadImage(est_delineation_path)

    true_delineations.append(sitk.GetArrayFromImage(true_delineation_image))
    est_delineations.append(sitk.GetArrayFromImage(est_delineation_image))

# Get image depth (number of slices)
num_slices = true_delineations[0].shape[0]  # Assumes all images have the same depth

# Create the figure and axes
fig, axes = plt.subplots(2, len(true_delineations), figsize=(2.5 * len(true_delineations), 5))

# Initial display with the middle slice
initial_slice = num_slices // 2
image_plots = []

for patient_id in range(len(true_delineations)):
    img1 = axes[0, patient_id].imshow(true_delineations[patient_id][initial_slice, :, :], cmap='gray')
    axes[0, patient_id].set_title(f'True ({register_patients[patient_id]})', fontsize=10)
    axes[0, patient_id].axis('off')

    img2 = axes[1, patient_id].imshow(est_delineations[patient_id][initial_slice, :, :], cmap='gray')
    axes[1, patient_id].set_title(f'Estimated ({register_patients[patient_id]})', fontsize=10)
    axes[1, patient_id].axis('off')

    image_plots.append((img1, img2))

plt.tight_layout()

# Define slider
slice_slider = widgets.IntSlider(min=0, max=num_slices - 1, step=1, value=initial_slice, description="Slice")

def update(slice_idx):
    """Update function for interactive slider."""
    for patient_id in range(len(true_delineations)):
        image_plots[patient_id][0].set_data(true_delineations[patient_id][slice_idx, :, :])
        image_plots[patient_id][1].set_data(est_delineations[patient_id][slice_idx, :, :])
    fig.canvas.draw_idle()
    plt.pause(0.1)  # Forces an update

# Create interactive widget
interactive_plot = widgets.interactive(update, slice_idx=slice_slider)

# Display the interactive widget
display(interactive_plot)

plt.show()

IndexError: list index out of range