In [22]:
from __future__ import print_function, absolute_import

import ipywidgets as widgets
from IPython.display import display
import elastix
import matplotlib.pyplot as plt
import os
import SimpleITK as sitk
import warnings

%matplotlib notebook

warnings.filterwarnings("ignore", category=DeprecationWarning)

# Paths, Make sure the elastix folder etc. is included in your current working directory
paths = open(os.getcwd() + r"\\paths.txt").read().splitlines()
ELASTIX_PATH, TRANSFORMIX_PATH, DATA_PATH ,OUTPUT_DIR = paths

if not os.path.exists(ELASTIX_PATH):
    raise IOError('Elastix cannot be found, please set the correct ELASTIX_PATH.')
if not os.path.exists(TRANSFORMIX_PATH):
    raise IOError('Transformix cannot be found, please set the correct TRANSFORMIX_PATH.')

# Visualize the atlas images

In [None]:
# Get patient names and select atlas patients
patient_list = [patient for patient in os.listdir(DATA_PATH) if os.path.isdir(os.path.join(DATA_PATH, patient))]
atlas_patients = patient_list[:5]

# Load images
atlas_images = []
delineation_images = []

for patient in atlas_patients:
    fixed_image_path = os.path.join(DATA_PATH, patient, 'mr_bffe.mhd')
    delineation_image_path = os.path.join(DATA_PATH, patient, 'prostaat.mhd')

    fixed_image = sitk.ReadImage(fixed_image_path)
    delineation_image = sitk.ReadImage(delineation_image_path)

    atlas_images.append(sitk.GetArrayFromImage(fixed_image))
    delineation_images.append(sitk.GetArrayFromImage(delineation_image))

# Get image depth (number of slices)
num_slices = atlas_images[0].shape[0]  # Assumes all images have the same depth

# Create the figure and axes
fig, axes = plt.subplots(2, len(atlas_patients), figsize=(2.5 * len(atlas_patients), 5))

# Initial display with middle slice
initial_slice = num_slices // 2
image_plots = []

for patient_id in range(len(atlas_patients)):
    img1 = axes[0, patient_id].imshow(atlas_images[patient_id][initial_slice, :, :], cmap='gray')
    axes[0, patient_id].set_title(f'Fixed atlas Image ({atlas_patients[patient_id]})')
    axes[0, patient_id].axis('off')

    img2 = axes[1, patient_id].imshow(delineation_images[patient_id][initial_slice, :, :], cmap='gray')
    axes[1, patient_id].set_title(f'Prostate delineation ({atlas_patients[patient_id]})')
    axes[1, patient_id].axis('off')

    image_plots.append((img1, img2))

plt.tight_layout()

# Define slider
slice_slider = widgets.IntSlider(min=0, max=num_slices - 1, step=1, value=initial_slice, description="Slice")


def update(slice_idx):
    for patient_id in range(len(atlas_patients)):
        image_plots[patient_id][0].set_data(atlas_images[patient_id][slice_idx, :, :])
        image_plots[patient_id][1].set_data(delineation_images[patient_id][slice_idx, :, :])
    fig.canvas.draw_idle()
    plt.pause(0.1)  # Forces an update

# Create interactive widget
interactive_plot = widgets.interactive(update, slice_idx=slice_slider)

# Display the interactive widget (this ensures the function updates properly)
display(interactive_plot)

plt.show()

# Registering

In [None]:
from atlas_registration_functions import register_transform
register_patients = [patient for patient in patient_list if patient not in atlas_patients]

for patient_id, patient in enumerate(register_patients[:2]):
    transform_paths = []

    #Register to all atlases
    for atlas_id, atlas in enumerate(atlas_patients[:2]):
        transformed_delineation_path = register_transform(patient, atlas, DATA_PATH, ELASTIX_PATH, TRANSFORMIX_PATH)
        print(transformed_delineation_path)
