# Registration Example

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

import SimpleITK as sitk

print(sitk.Version())
from myshow import myshow

# Download data to work on
%run update_path_to_download_script
from downloaddata import fetch_data as fdata

OUTPUT_DIR = "Output"

This section of the Visible Human Male is about 1.5GB. To expedite processing and registration we crop the region of interest, and reduce the resolution. Take note that the physical space is maintained through these operations. 

In [None]:
fixed_rgb = sitk.ReadImage(fdata("vm_head_rgb.mha"))
fixed_rgb = fixed_rgb[735:1330, 204:975, :]
fixed_rgb = sitk.BinShrink(fixed_rgb, [3, 3, 1])

In [None]:
moving = sitk.ReadImage(fdata("vm_head_mri.mha"))

In [None]:
myshow(moving)

In [None]:
# Segment blue ice
seeds = [[10, 10, 10]]
fixed_mask = sitk.VectorConfidenceConnected(
    fixed_rgb,
    seedList=seeds,
    initialNeighborhoodRadius=5,
    numberOfIterations=4,
    multiplier=8,
)

In [None]:
# Invert the segment and choose largest component
fixed_mask = sitk.RelabelComponent(sitk.ConnectedComponent(fixed_mask == 0)) == 1

In [None]:
myshow(sitk.Mask(fixed_rgb, fixed_mask));

In [None]:
# pick red channel
fixed = sitk.VectorIndexSelectionCast(fixed_rgb, 0)

fixed = sitk.Cast(fixed, sitk.sitkFloat32)
moving = sitk.Cast(moving, sitk.sitkFloat32)

In [None]:
initialTransform = sitk.Euler3DTransform()
initialTransform = sitk.CenteredTransformInitializer(
    sitk.Cast(fixed_mask, moving.GetPixelID()),
    moving,
    initialTransform,
    sitk.CenteredTransformInitializerFilter.MOMENTS,
)
print(initialTransform)

In [None]:
def command_iteration(method):
    print(
        f"{method.GetOptimizerIteration()} = {method.GetMetricValue()} : {method.GetOptimizerPosition()}",
        end="\n",
    )
    sys.stdout.flush();

In [None]:
tx = initialTransform
R = sitk.ImageRegistrationMethod()
R.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
R.SetOptimizerAsGradientDescentLineSearch(learningRate=1, numberOfIterations=100)
R.SetOptimizerScalesFromIndexShift()
R.SetShrinkFactorsPerLevel([4, 2, 1])
R.SetSmoothingSigmasPerLevel([8, 4, 2])
R.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()
R.SetMetricSamplingStrategy(R.RANDOM)
# specifying a seed eliminates registration variability associated with the
# random sampling
R.SetMetricSamplingPercentage(percentage=0.1, seed=42)
R.SetInitialTransform(tx)
R.SetInterpolator(sitk.sitkLinear)

In [None]:
import sys

R.RemoveAllCommands()
R.AddCommand(sitk.sitkIterationEvent, lambda: command_iteration(R))
outTx = R.Execute(
    sitk.Cast(fixed, sitk.sitkFloat32), sitk.Cast(moving, sitk.sitkFloat32)
)

print("-------")
print(tx)
print(f"Optimizer stop condition: {R.GetOptimizerStopConditionDescription()}")
print(f" Iteration: {R.GetOptimizerIteration()}")
print(f" Metric value: {R.GetMetricValue()}")

In [None]:
tx = sitk.CompositeTransform([initialTransform, sitk.AffineTransform(3)])

R.SetOptimizerAsGradientDescentLineSearch(learningRate=1, numberOfIterations=100)
R.SetOptimizerScalesFromIndexShift()
R.SetShrinkFactorsPerLevel([2, 1])
R.SetSmoothingSigmasPerLevel([4, 1])
R.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()
R.SetInitialTransform(tx)

In [None]:
outTx = R.Execute(
    sitk.Cast(fixed, sitk.sitkFloat32), sitk.Cast(moving, sitk.sitkFloat32)
)
R.GetOptimizerStopConditionDescription()

In [None]:
resample = sitk.ResampleImageFilter()
resample.SetReferenceImage(fixed_rgb)
resample.SetInterpolator(sitk.sitkBSpline3)
resample.SetTransform(outTx)
resample.AddCommand(
    sitk.sitkProgressEvent,
    lambda: print(f"\rProgress: {100*resample.GetProgress():03.1f}%...", end=""),
)
resample.AddCommand(sitk.sitkProgressEvent, lambda: sys.stdout.flush())
resample.AddCommand(sitk.sitkEndEvent, lambda: print("Done"))
out = resample.Execute(moving)

In [None]:
out_rgb = sitk.Cast(
    sitk.Compose([sitk.RescaleIntensity(out)] * 3), sitk.sitkVectorUInt8
)
vis_xy = sitk.CheckerBoard(fixed_rgb, out_rgb, checkerPattern=[8, 8, 1])
vis_xz = sitk.CheckerBoard(fixed_rgb, out_rgb, checkerPattern=[8, 1, 8])
vis_xz = sitk.PermuteAxes(vis_xz, [0, 2, 1])

In [None]:
myshow(vis_xz, dpi=30)

In [None]:
import os

sitk.WriteImage(out, os.path.join(OUTPUT_DIR, "example_registration.mha"))
sitk.WriteImage(vis_xy, os.path.join(OUTPUT_DIR, "example_registration_xy.mha"))
sitk.WriteImage(vis_xz, os.path.join(OUTPUT_DIR, "example_registration_xz.mha"))