# Registration Example

In [1]:
%matplotlib inline
import matplotlib.pyplot as plt

import SimpleITK as sitk

print(sitk.Version())
from myshow import myshow

# Download data to work on
%run update_path_to_download_script
from downloaddata import fetch_data as fdata

OUTPUT_DIR = "Output"

SimpleITK Version: 2.1.1 (ITK 5.2)
Compiled: Sep  9 2021 19:12:54



This section of the Visible Human Male is about 1.5GB. To expedite processing and registration we crop the region of interest, and reduce the resolution. Take note that the physical space is maintained through these operations. 

In [2]:
fixed_rgb = sitk.ReadImage(fdata("vm_head_rgb.mha"))
fixed_rgb = fixed_rgb[735:1330, 204:975, :]
fixed_rgb = sitk.BinShrink(fixed_rgb, [3, 3, 1])

Fetching vm_head_rgb.mha


In [3]:
moving = sitk.ReadImage(fdata("vm_head_mri.mha"))

Fetching vm_head_mri.mha


In [4]:
myshow(moving)

interactive(children=(IntSlider(value=16, description='z', max=32), Output()), _dom_classes=('widget-interact'…

In [5]:
# Segment blue ice
seeds = [[10, 10, 10]]
fixed_mask = sitk.VectorConfidenceConnected(
    fixed_rgb,
    seedList=seeds,
    initialNeighborhoodRadius=5,
    numberOfIterations=4,
    multiplier=8,
)

In [6]:
# Invert the segment and choose largest component
fixed_mask = sitk.RelabelComponent(sitk.ConnectedComponent(fixed_mask == 0)) == 1

In [7]:
myshow(sitk.Mask(fixed_rgb, fixed_mask));

interactive(children=(IntSlider(value=109, description='z', max=219), Output()), _dom_classes=('widget-interac…

In [8]:
# pick red channel
fixed = sitk.VectorIndexSelectionCast(fixed_rgb, 0)

fixed = sitk.Cast(fixed, sitk.sitkFloat32)
moving = sitk.Cast(moving, sitk.sitkFloat32)

In [9]:
initialTransform = sitk.Euler3DTransform()
initialTransform = sitk.CenteredTransformInitializer(
    sitk.Cast(fixed_mask, moving.GetPixelID()),
    moving,
    initialTransform,
    sitk.CenteredTransformInitializerFilter.MOMENTS,
)
print(initialTransform)

itk::simple::Transform
 Euler3DTransform (0x7ff26613f310)
   RTTI typeinfo:   itk::Euler3DTransform<double>
   Reference Count: 1
   Modified Time: 1659
   Debug: Off
   Object Name: 
   Observers: 
     none
   Matrix: 
     1 0 0 
     0 1 0 
     0 0 1 
   Offset: [-0.305777, 2.55523, 135.031]
   Center: [0.853078, -32.3853, -144.027]
   Translation: [-0.305777, 2.55523, 135.031]
   Inverse: 
     1 0 0 
     0 1 0 
     0 0 1 
   Singular: 0
   Euler's angles: AngleX=0 AngleY=0 AngleZ=0
   m_ComputeZYX = 0



In [10]:
def command_iteration(method):
    print(
        f"{method.GetOptimizerIteration()} = {method.GetMetricValue()} : {method.GetOptimizerPosition()}",
        end="\n",
    )
    sys.stdout.flush();

In [11]:
tx = initialTransform
R = sitk.ImageRegistrationMethod()
R.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
R.SetOptimizerAsGradientDescentLineSearch(learningRate=1, numberOfIterations=100)
R.SetOptimizerScalesFromIndexShift()
R.SetShrinkFactorsPerLevel([4, 2, 1])
R.SetSmoothingSigmasPerLevel([8, 4, 2])
R.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()
R.SetMetricSamplingStrategy(R.RANDOM)
R.SetMetricSamplingPercentage(0.1)
R.SetInitialTransform(tx)
R.SetInterpolator(sitk.sitkLinear)

In [12]:
import sys

R.RemoveAllCommands()
R.AddCommand(sitk.sitkIterationEvent, lambda: command_iteration(R))
outTx = R.Execute(
    sitk.Cast(fixed, sitk.sitkFloat32), sitk.Cast(moving, sitk.sitkFloat32)
)

print("-------")
print(tx)
print(f"Optimizer stop condition: {R.GetOptimizerStopConditionDescription()}")
print(f" Iteration: {R.GetOptimizerIteration()}")
print(f" Metric value: {R.GetMetricValue()}")

0 = -0.7783034043387447 : (-0.0006885196551277491, -0.0007332778688395892, -0.00021269100301044553, 0.056480819798521265, 2.7329889661154665, 109.83377112217079)
1 = -1.1826197821785007 : (-0.0009949760749885673, -0.0010399690008088237, -0.00033242995448211544, -0.16572634118754762, 2.880283610168554, 108.55338426165348)
2 = -1.1878192999069033 : (-0.002529719360354595, -0.0024548897956290557, -0.000844741780918456, -1.0315656340592119, 3.498520164268452, 107.83271247151318)
3 = -1.2029384293406742 : (-0.0040712802708309515, -0.003094330506255131, -0.0013878104018923481, -1.0991143254944442, 3.8334347913674613, 109.09343200328605)
4 = -1.2095974294124598 : (-0.004122191968468614, -0.0031157358434911945, -0.0014089194366907812, -1.0942264871291258, 3.837315380269295, 108.9175905069227)
5 = -1.2096294962896301 : (-0.00420471777257255, -0.0031503435765434006, -0.0014426759363930703, -1.0876801411965875, 3.8437761904234313, 108.68582918486115)
6 = -1.2091535125919437 : (-0.0043403814991171

In [13]:
tx = sitk.CompositeTransform([initialTransform, sitk.AffineTransform(3)])

R.SetOptimizerAsGradientDescentLineSearch(learningRate=1, numberOfIterations=100)
R.SetOptimizerScalesFromIndexShift()
R.SetShrinkFactorsPerLevel([2, 1])
R.SetSmoothingSigmasPerLevel([4, 1])
R.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()
R.SetInitialTransform(tx)

In [14]:
outTx = R.Execute(
    sitk.Cast(fixed, sitk.sitkFloat32), sitk.Cast(moving, sitk.sitkFloat32)
)
R.GetOptimizerStopConditionDescription()

0 = -0.8866994855511292 : (0.9931044010262785, 0.000998094254173154, -0.0004640862447812494, 0.0006808424565345121, 0.9965596185365826, 3.3463936853381874e-05, -0.005014089391081763, 0.023717847949674024, 0.9846092913127786, 0.12779131994826565, 0.3115370278689434, -1.6381147946307035)
1 = -0.897749326891908 : (0.991499319578755, 0.001400770694239334, -0.0005209383289389828, 0.0009547721770138657, 0.9957861516208987, 0.0002811282237854528, -0.007752520338168291, 0.021786094549257937, 0.9774482711506816, 0.12016184251311304, 0.2659411731044579, 0.09654919888530311)
2 = -0.9147524992601568 : (0.991175130643152, 0.0014821393397608108, -0.0005477662600994489, 0.001005409095812999, 0.9956021519886501, 0.00034632785536775515, -0.008119583811711132, 0.02293398506497117, 0.9770470451187708, 0.12559046267913535, 0.25563125884930665, -0.4483488785278744)
3 = -0.9155608022653271 : (0.9906709476496084, 0.0016129661322483043, -0.0005869543509221713, 0.0010867748377016249, 0.9953165187824713, 0.0004

'GradientDescentLineSearchOptimizerv4Template: Convergence checker passed at iteration 31.'

In [15]:
resample = sitk.ResampleImageFilter()
resample.SetReferenceImage(fixed_rgb)
resample.SetInterpolator(sitk.sitkBSpline)
resample.SetTransform(outTx)
resample.AddCommand(
    sitk.sitkProgressEvent,
    lambda: print(f"\rProgress: {100*resample.GetProgress():03.1f}%...", end=""),
)
resample.AddCommand(sitk.sitkProgressEvent, lambda: sys.stdout.flush())
resample.AddCommand(sitk.sitkEndEvent, lambda: print("Done"))
out = resample.Execute(moving)

Progress: 100.0%...Done


In [16]:
out_rgb = sitk.Cast(
    sitk.Compose([sitk.RescaleIntensity(out)] * 3), sitk.sitkVectorUInt8
)
vis_xy = sitk.CheckerBoard(fixed_rgb, out_rgb, checkerPattern=[8, 8, 1])
vis_xz = sitk.CheckerBoard(fixed_rgb, out_rgb, checkerPattern=[8, 1, 8])
vis_xz = sitk.PermuteAxes(vis_xz, [0, 2, 1])

In [17]:
myshow(vis_xz, dpi=30)

interactive(children=(IntSlider(value=128, description='z', max=256), Output()), _dom_classes=('widget-interac…

In [18]:
import os

sitk.WriteImage(out, os.path.join(OUTPUT_DIR, "example_registration.mha"))
sitk.WriteImage(vis_xy, os.path.join(OUTPUT_DIR, "example_registration_xy.mha"))
sitk.WriteImage(vis_xz, os.path.join(OUTPUT_DIR, "example_registration_xz.mha"))