<h1 align="center">SimpleITKv4 Nonrigid Registration</h1>

In our previous notebook we explored the SimpleITKv4 registration framework for rigid registration. Using other unbounded transformation models (e.g. affine) involves minimal changes to the code.

In this notebook we explore the use of bounded transformations, BSpline and DisplacementField.

We will work with a freely available 4D (3D+time) thoracic-abdominal CT, the Point-validated Pixel-based Breathing Thorax Model (POPI) model. This data consists of a set of temporal CT volumes, a set of masks segmenting each of the CTs to air/body/lung, and a set of corresponding points across the CT volumes.

The POPI model is provided by the Léon Bérard Cancer Center & CREATIS Laboratory, Lyon, France. The relevant publication is:

J. Vandemeulebroucke, D. Sarrut, P. Clarysse, "The POPI-model, a point-validated pixel-based breathing thorax model", Proc. XVth International Conference on the Use of Computers in Radiation Therapy (ICCR), Toronto, Canada, 2007.

The POPI data, and additional 4D CT data sets with reference points are available from the CREATIS Laboratory <a href="http://www.creatis.insa-lyon.fr/rio/popi-model?action=show&redirect=popi">here</a>. 

In [None]:
import SimpleITK as sitk

from __future__ import print_function

# Utility method that either downloads data from the MIDAS repository or
# if already downloaded returns the file name for reading from disk (cached data).
from downloaddata import fetch_data as fdata

# Always write output to a separate directory, we don't want to pollute the source directory. 
import os
OUTPUT_DIR = 'Output'

## Utility functions 

Callback functions for image display and for ploting the similarity metric during registration.

In [None]:
%matplotlib inline
%run registration_utilities.py

## Read the POPI images masks and reference data

In [None]:
fixed_image =  sitk.ReadImage(fdata('POPI/meta/00-P.mhd'), sitk.sitkFloat32)
fixed_mask = sitk.ReadImage(fdata('POPI/masks/00-air-body-lungs.mhd'))
fixed_points = read_POPI_points(fdata('POPI/landmarks/00-Landmarks.pts'))

moving_image =  sitk.ReadImage(fdata('POPI/meta/70-P.mhd'), sitk.sitkFloat32)
moving_mask = sitk.ReadImage(fdata('POPI/masks/70-air-body-lungs.mhd'))
moving_points = read_POPI_points(fdata('POPI/landmarks/70-Landmarks.pts'))

interact(display_coronal_with_overlay, temporal_slice=(0,1), 
         coronal_slice = (0, fixed_image.GetSize()[1]-1), 
         images = fixed([fixed_image,moving_image]), masks = fixed([fixed_mask,moving_mask]), 
         label=fixed(lung_label), window_min = fixed(-1024), window_max=fixed(976));

## Free Form Deformation

This registration approach uses a grid of control points to perform apply a BSpline transformation to the data and was popularized at the end of the last century. Note that the approach itself does admit unplausible transformations (yes folding can happen). The current ITK and hence SimpleITK implementation does not include regularization.

### Initial Alignment

As we are aligning two CTs capturing two respiratory phases of the same subject in the same position we will create a control grid that is physically overlaid on the image, representing the identity transform.

In [None]:
# Determine the number of Bspline control points using the physical spacing we want for the control grid. 
grid_physical_spacing = [50.0, 50.0, 50.0] # A control point every 50mm
image_physical_size = [size*spacing for size,spacing in zip(fixed_image.GetSize(), fixed_image.GetSpacing())]
mesh_size = [int(image_size/grid_spacing + 0.5) \
             for image_size,grid_spacing in zip(image_physical_size,grid_physical_spacing)]

transform = sitk.BSplineTransformInitializer(image1 = fixed_image, 
                                             transformDomainMeshSize = mesh_size, order=3)   

pre_errors_mean, pre_errors_std, _, pre_errors_max, pre_errors = registration_errors(transform, fixed_points, moving_points, display_errors=True)
print('Before registration, errors (TRE) in millimeters, mean(std): {:.2f}({:.2f}), max: {:.2f}'.format(pre_errors_mean, pre_errors_std, pre_errors_max))

### Final Alignment

Perform registration on a low resolution version of the images with our BSpline transform and a mean squares similarity metric (intra-modal registration). 

In [None]:
registration_method = sitk.ImageRegistrationMethod()
registration_method.SetInitialTransform(transform) # Transformation is modified in place.
        
registration_method.SetMetricAsMeanSquares()
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentage(0.1)

registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [8]) # Aggressively resampled data (speed).
registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[4])

registration_method.SetInterpolator(sitk.sitkLinear)
registration_method.SetOptimizerAsLBFGSB(gradientConvergenceTolerance=1e-5, maximumNumberOfIterations=50)

registration_method.AddCommand(sitk.sitkStartEvent, metric_and_reference_start_plot)
registration_method.AddCommand(sitk.sitkEndEvent, metric_and_reference_end_plot)
registration_method.AddCommand(sitk.sitkIterationEvent, lambda: metric_and_reference_plot_values(registration_method, fixed_points, moving_points))

registration_method.Execute(fixed_image, moving_image)
print('Final metric value: {0}'.format(registration_method.GetMetricValue()))
print('Optimizer\'s stopping condition, {0}\n'.format(registration_method.GetOptimizerStopConditionDescription()))

In [None]:
final_errors_mean, final_errors_std, _, final_errors_max, final_errors = registration_errors(transform, fixed_points, moving_points, display_errors=True)
print('After final alignment, errors (TRE) in millimeters, mean(std): {:.2f}({:.2f}), max: {:.2f}'.format(final_errors_mean, final_errors_std, final_errors_max))

## Demons Registration

This registration approach is based on an analogy to Maxwell's Demons with differential equations related to optical flow. Similar to the FFD this framework was popularized at the end of the last century and begining of this one. The transformation is represented by a DeformationField. The basic approach, as implemented in ITK, supports regularization via smoothing. Variants of the approach that yield a diffeomorphic transformation, and symmetric formulations are also available.  

### Initial Alignment

As we are aligning two CTs capturing two respiratory phases of the same subject in the same position we will create a displacement field transform representing the identity transform which has the same dimensions as the fixed image and is aligned with it.

In [None]:
transform_to_displacment_field_filter = sitk.TransformToDisplacementFieldFilter()
transform_to_displacment_field_filter.SetReferenceImage(fixed_image)
# The image returned from the initial_transform_filter is transferred to the transform and cleared out.
transform = sitk.DisplacementFieldTransform(transform_to_displacment_field_filter.Execute(sitk.Transform()))
# Specify how to regularize the transform when updated during registration (update field - viscous, total field - elastic))
transform.SetSmoothingGaussianOnUpdate(varianceForUpdateField=0.0, varianceForTotalField=2.0); 

### Final Alignment

Perform registration on a low resolution version of the images with our DisplacementField transform and a Demons 
similarity metric (intra-modal registration). 

In [None]:
registration_method = sitk.ImageRegistrationMethod()
registration_method.SetInitialTransform(transform)

registration_method.SetMetricAsDemons(10)

registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [8]) # Aggressively resampled data (speed).
registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[4])

registration_method.SetInterpolator(sitk.sitkLinear)

registration_method.SetOptimizerAsConjugateGradientLineSearch(learningRate=1.0, numberOfIterations=20)

registration_method.AddCommand(sitk.sitkStartEvent, metric_and_reference_start_plot)
registration_method.AddCommand(sitk.sitkEndEvent, metric_and_reference_end_plot)
registration_method.AddCommand(sitk.sitkIterationEvent, lambda: metric_and_reference_plot_values(registration_method, fixed_points, moving_points))

registration_method.Execute(fixed_image, moving_image)
print('Final metric value: {0}'.format(registration_method.GetMetricValue()))
print('Optimizer\'s stopping condition, {0}\n'.format(registration_method.GetOptimizerStopConditionDescription()))

In [None]:
final_errors_mean, final_errors_std, _, final_errors_max, final_errors = registration_errors(transform, fixed_points, moving_points, display_errors=True)
print('After final alignment, errors (TRE) in millimeters, mean(std): {:.2f}({:.2f}), max: {:.2f}'.format(final_errors_mean, final_errors_std, final_errors_max))

## Demons - filters outside the registration framework

ITK and SimpleITK include extensions of the original Demons algorithm which are independent of the registration framework (ImageRegistrationMethod). These include: 
1. DemonsRegistrationFilter
2. DiffeomorphicDemonsRegistrationFilter
3. SymmetricForcesDemonsRegistrationFilter
4. FastSymmetricForcesDemonsRegistrationFilter

Note that unlike the registration framework, these filters return an image representing the displacement field and not a transform.

In [None]:
# Resample the input images on our own, as we don't have access to the registration frameworks multi-resolution

def smooth_and_resample(image, shrink_factor, smoothing_sigma):
    """
    Args:
        image: The image we want to resample.
        shrink_factor: A number greater than one, such that the new image's size is original_size/shrink_factor.
        smoothing_sigma: Sigma for Gaussian smoothing, this is in physical (image spacing) units, not pixels.
    Return:
        Image which is a result of smoothing the input and then resampling it using the given sigma and shrink factor.
    """
    smoothed_image = sitk.SmoothingRecursiveGaussian(image, smoothing_sigma)
    
    original_spacing = image.GetSpacing()
    original_size = image.GetSize()
    new_size = [int(sz/float(shrink_factor) + 0.5) for sz in original_size]
    new_spacing = [((original_sz-1)*original_spc)/(new_sz-1) 
                   for original_sz, original_spc, new_sz in zip(original_size, original_spacing, new_size)]
    return sitk.Resample(smoothed_image, new_size, sitk.Transform(), 
                         sitk.sitkLinear, image.GetOrigin(),
                         new_spacing, image.GetDirection(), 0.0, 
                         image.GetPixelIDValue())

shrink = 8
smooth = 4

resampled_fixed_image = smooth_and_resample(image=fixed_image, shrink_factor=shrink, smoothing_sigma=smooth)
resampled_moving_image = smooth_and_resample(image=moving_image, shrink_factor=shrink, smoothing_sigma=smooth)

In [None]:
# Define a simple callback which allows us to monitor the Demons filter's progress.
def iteration_callback(filter):
    print('\r{0}: {1:.2f}'.format(filter.GetElapsedIterations(), filter.GetMetric()), end='')

# Our demons filter of choice    
demons_filter =  sitk.FastSymmetricForcesDemonsRegistrationFilter()
demons_filter.SetNumberOfIterations(20)
# Regularization (update field - viscous, total field - elastic).
demons_filter.SetSmoothDisplacementField(True)
demons_filter.SetStandardDeviations(2.0)

# Add our simple callback to the registration filter.
demons_filter.AddCommand(sitk.sitkIterationEvent, lambda: iteration_callback(demons_filter))
displacement_field_image = demons_filter.Execute(resampled_fixed_image, resampled_moving_image)
    
transform = sitk.DisplacementFieldTransform(displacement_field_image)

In [None]:
final_errors_mean, final_errors_std, _, final_errors_max, final_errors = registration_errors(transform, fixed_points, moving_points, display_errors=True)
print('After final alignment, errors (TRE) in millimeters, mean(std): {:.2f}({:.2f}), max: {:.2f}'.format(final_errors_mean, final_errors_std, final_errors_max))

## Coarse evaluation of registration

Another option for evaluating registration is to use segmentation as a reference, transferring the segmentation from one image to the other and comparing the two segmentations.

In [None]:
def evaluate_registration_via_segmentation(transform, fixed_image, fixed_mask, moving_mask, seg_label):
    # Transfer the segmentation via the estimated transformation. Use Nearest Neighbor interpolation to retain the labels.
    transformed_labels = sitk.Resample(moving_mask,
                                       fixed_image,
                                       transform, 
                                       sitk.sitkNearestNeighbor,
                                       0.0, 
                                       moving_mask.GetPixelIDValue())

    segmentations_before_and_after = [moving_mask, transformed_labels]
    interact(display_coronal_with_label_maps_overlay, coronal_slice = (0, fixed_image.GetSize()[1]-1),
             mask_index=(0,len(segmentations_before_and_after)-1),
             image = fixed(fixed_image), masks = fixed(segmentations_before_and_after), 
             label=fixed(seg_label), window_min = fixed(-1024), window_max=fixed(976));

    # Compute the Dice coefficient and Hausdorff distance between the segmentations before, and after registration.
    ground_truth = fixed_mask == seg_label
    before_registration = moving_mask == seg_label
    after_registration = transformed_labels == seg_label

    label_overlap_measures_filter = sitk.LabelOverlapMeasuresImageFilter()
    label_overlap_measures_filter.Execute(ground_truth, before_registration)
    print("Dice coefficient before registration: {:.2f}".format(label_overlap_measures_filter.GetDiceCoefficient()))
    label_overlap_measures_filter.Execute(ground_truth, after_registration)
    print("Dice coefficient after registration: {:.2f}".format(label_overlap_measures_filter.GetDiceCoefficient()))

    hausdorff_distance_image_filter = sitk.HausdorffDistanceImageFilter()
    hausdorff_distance_image_filter.Execute(ground_truth, before_registration)
    print("Hausdorff distance before registration: {:.2f}".format(hausdorff_distance_image_filter.GetHausdorffDistance()))
    hausdorff_distance_image_filter.Execute(ground_truth, after_registration)
    print("Hausdorff distance after registration: {:.2f}".format(hausdorff_distance_image_filter.GetHausdorffDistance()))

In [None]:
evaluate_registration_via_segmentation(transform, fixed_image, fixed_mask, moving_mask, lung_label)

## Exercises

In this section you will explore the effects of various settings on the registration:
<ol>
<li>
Modify the spacing of the FFD control grid: how do the accuracy and runtime of the registration change when using a sparser grid?
</li>
<li>
Modify the resolutions used in the FFD framework.
</li>
<li>
In the Demons filter based registration:
<ol>
    <li>
    Modify the shrink factor and resolution, how do they effect registration accuracy and runtime?
    </li>
    <li>
    Try the other filters from the Demons family of filters and identify the one that yields the most accurate results.
    </li>
</ol>
</li>   
</ol>   