In [None]:
import SimpleITK as sitk
import registration_utilities as ru
import registration_callbacks as rc

%matplotlib inline
import matplotlib.pyplot as plt

from ipywidgets import interact, fixed

# utility method that either downloads data from the Girder repository or
# if already downloaded returns the file name for reading from disk (cached data)
%run update_path_to_download_script
from downloaddata import fetch_data as fdata
import numpy as np

In [None]:
%run popi_utilities_setup.py

In [None]:
images = []
masks = []
points = []
for i in range(0, 10):
    image_file_name = f"POPI/meta/{i}0-P.mhd"
    mask_file_name = f"POPI/masks/{i}0-air-body-lungs.mhd"
    points_file_name = f"POPI/landmarks/{i}0-Landmarks.pts"
    images.append(
        sitk.ReadImage(fdata(image_file_name), sitk.sitkFloat32)
    )  # read and cast to format required for registration
    masks.append(sitk.ReadImage(fdata(mask_file_name)))
    points.append(read_POPI_points(fdata(points_file_name)))

interact(
    display_coronal_with_overlay,
    temporal_slice=(0, len(images) - 1),
    coronal_slice=(0, images[0].GetSize()[1] - 1),
    images=fixed(images),
    masks=fixed(masks),
    label=fixed(lung_label),
    window_min=fixed(-1024),
    window_max=fixed(976),
);

In [None]:
def bspline_intra_modal_registration(
    fixed_image,
    moving_image,
    fixed_image_mask=None,
    fixed_points=None,
    moving_points=None,
):

    registration_method = sitk.ImageRegistrationMethod()

    # Determine the number of BSpline control points using the physical spacing we want for the control grid.
    grid_physical_spacing = [50.0, 50.0, 50.0]  # A control point every 50mm
    image_physical_size = [
        size * spacing
        for size, spacing in zip(fixed_image.GetSize(), fixed_image.GetSpacing())
    ]
    mesh_size = [
        int(image_size / grid_spacing + 0.5)
        for image_size, grid_spacing in zip(image_physical_size, grid_physical_spacing)
    ]
    
    mesh_size = [int(sz / 4 + 0.5) for sz in mesh_size]

    initial_transform = sitk.BSplineTransformInitializer(
        image1=fixed_image, transformDomainMeshSize=mesh_size, order=3
    )
    registration_method.SetInitialTransform(initial_transform)   

    
    #registration_method.SetMetricAsMeanSquares()
    #registration_method.SetMetricAsJointHistogramMutualInformation()
    registration_method.SetMetricAsMutualInformationEfficientEntropy()
    #registration_method.SetMetricAsMattesMutualInformation()

    # Settings for metric sampling, usage of a mask is optional. When given a mask the sample points will be
    # generated inside that region. Also, this implicitly speeds things up as the mask is smaller than the
    # whole image.
    registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
    registration_method.SetMetricSamplingPercentage(0.01)
    if fixed_image_mask:
        registration_method.SetMetricFixedMask(fixed_image_mask)

    # Multi-resolution framework.
    registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[4, 2, 1])
    registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2, 1, 0])
    registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

    registration_method.SetInterpolator(sitk.sitkLinear)

    registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100, convergenceMinimumValue=1e-6, convergenceWindowSize=10)
    #registration_method.SetOptimizerAsConjugateGradientLineSearch(learningRate=1.0, numberOfIterations=100, convergenceMinimumValue=1e-6, convergenceWindowSize=10)
    #registration_method.SetOptimizerAsOnePlusOneEvolutionary(numberOfIterations=1000)
    #registration_method.SetOptimizerAsPowell(numberOfIterations=5)    
    #registration_method.SetOptimizerAsLBFGS2(solutionAccuracy=1e-2, numberOfIterations=100, deltaConvergenceTolerance=0.01)

    # If corresponding points in the fixed and moving image are given then we display the similarity metric
    # and the TRE during the registration.
    if fixed_points and moving_points:
        registration_method.AddCommand(
            sitk.sitkStartEvent, rc.metric_and_reference_start_plot
        )
        registration_method.AddCommand(
            sitk.sitkEndEvent, rc.metric_and_reference_end_plot
        )
        registration_method.AddCommand(
            sitk.sitkIterationEvent,
            lambda: rc.metric_and_reference_plot_values(
                registration_method, fixed_points, moving_points
            ),
        )

    return registration_method.Execute(fixed_image, moving_image)

In [None]:
%%timeit -r1 -n1

global tx

# Select the fixed and moving images, valid entries are in [0,9].

fixed_image_index = 0
moving_image_index = 7

# fixed_image_index = 1
# moving_image_index = 8

# fixed_image_index = 2
# moving_image_index = 9

# fixed_image_index = 0
# moving_image_index = 9

# fixed_image_index = 0
# moving_image_index = 8

# fixed_image_index = 3
# moving_image_index = 6

# fixed_image_index = 1
# moving_image_index = 6

# fixed_image_index = 1
# moving_image_index = 7

# fixed_image_index = 1
# moving_image_index = 5

# fixed_image_index = 5
# moving_image_index = 9



tx = bspline_intra_modal_registration(
    fixed_image=images[fixed_image_index],
    moving_image=images[moving_image_index],
    fixed_image_mask=(masks[fixed_image_index] == lung_label),
    fixed_points=points[fixed_image_index],
    moving_points=points[moving_image_index],
)


In [None]:
fixed_image_index = 0
moving_image_index = 7

# fixed_image_index = 1
# moving_image_index = 8

# fixed_image_index = 2
# moving_image_index = 9

# fixed_image_index = 0
# moving_image_index = 9

# fixed_image_index = 0
# moving_image_index = 8

# fixed_image_index = 3
# moving_image_index = 6

# fixed_image_index = 1
# moving_image_index = 6

# fixed_image_index = 1
# moving_image_index = 7

# fixed_image_index = 1
# moving_image_index = 5

# fixed_image_index = 5
# moving_image_index = 9

(
    initial_errors_mean,
    initial_errors_std,
    _,
    initial_errors_max,
    initial_errors,
) = ru.registration_errors(
    sitk.Euler3DTransform(), points[fixed_image_index], points[moving_image_index]
)
(
    final_errors_mean,
    final_errors_std,
    _,
    final_errors_max,
    final_errors,
) = ru.registration_errors(tx, points[fixed_image_index], points[moving_image_index])

print(
    f"Initial alignment errors in millimeters, mean(std): {initial_errors_mean:.2f}({initial_errors_std:.2f}), max: {initial_errors_max:.2f}, median: {np.median(initial_errors):.2f}"
)
print(
    f"Final alignment errors in millimeters, mean(std): {final_errors_mean:.2f}({final_errors_std:.2f}), max: {final_errors_max:.2f}, median: {np.median(final_errors):.2f}"
)


In [None]:
fixed_image_index = 0
moving_image_index = 7

# fixed_image_index = 1
# moving_image_index = 8

# fixed_image_index = 2
# moving_image_index = 9

# fixed_image_index = 0
# moving_image_index = 9

# fixed_image_index = 0
# moving_image_index = 8

# fixed_image_index = 3
# moving_image_index = 6

# fixed_image_index = 1
# moving_image_index = 6

# fixed_image_index = 1
# moving_image_index = 7

# fixed_image_index = 1
# moving_image_index = 5

# fixed_image_index = 5
# moving_image_index = 9


# Transfer the segmentation via the estimated transformation. Use Nearest Neighbor interpolation to retain the labels.
transformed_labels = sitk.Resample(
    masks[moving_image_index],
    images[fixed_image_index],
    tx,
    sitk.sitkNearestNeighbor,
    0.0,
    masks[moving_image_index].GetPixelID(),
)

segmentations_before_and_after = [masks[moving_image_index], transformed_labels]
interact(
    display_coronal_with_label_maps_overlay,
    coronal_slice=(0, images[0].GetSize()[1] - 1),
    mask_index=(0, len(segmentations_before_and_after) - 1),
    image=fixed(images[fixed_image_index]),
    masks=fixed(segmentations_before_and_after),
    label=fixed(lung_label),
    window_min=fixed(-1024),
    window_max=fixed(976),
)

# Compute the Dice coefficient and Hausdorff distance between the segmentations before, and after registration.
ground_truth = masks[fixed_image_index] == lung_label
before_registration = masks[moving_image_index] == lung_label
after_registration = transformed_labels == lung_label

label_overlap_measures_filter = sitk.LabelOverlapMeasuresImageFilter()
label_overlap_measures_filter.Execute(ground_truth, before_registration)
print(
    f"Dice coefficient before registration: {label_overlap_measures_filter.GetDiceCoefficient():.2f}"
)
label_overlap_measures_filter.Execute(ground_truth, after_registration)
print(
    f"Dice coefficient after registration: {label_overlap_measures_filter.GetDiceCoefficient():.2f}"
)

label_overlap_measures_filter.Execute(ground_truth, before_registration)
print(
    f"Jaccard coefficient before registration: {label_overlap_measures_filter.GetJaccardCoefficient():.2f}"
)
label_overlap_measures_filter.Execute(ground_truth, after_registration)
print(
    f"Jaccard coefficient after registration: {label_overlap_measures_filter.GetJaccardCoefficient():.2f}"
)

hausdorff_distance_image_filter = sitk.HausdorffDistanceImageFilter()
hausdorff_distance_image_filter.Execute(ground_truth, before_registration)
print(
    f"Hausdorff distance before registration: {hausdorff_distance_image_filter.GetHausdorffDistance():.2f}"
)
hausdorff_distance_image_filter.Execute(ground_truth, after_registration)
print(
    f"Hausdorff distance after registration: {hausdorff_distance_image_filter.GetHausdorffDistance():.2f}"
)