# 3D Image Registration

## Import required packages

In [2]:
import itk
from itkwidgets import view 
from itkwidgets import compare, checkerboard

## Load fixed and moving volume

In [3]:

fixed_image_filename="VF-MRT1-1014-1174.vtk"
moving_image_filename="VF-MRT2-1014-1174.vtk"
fixed_image = itk.imread(fixed_image_filename, itk.F)
moving_image = itk.imread(moving_image_filename, itk.F)

## Visualize input images 

In [156]:
view(fixed_image)

Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageF3; pro…

In [157]:
view(moving_image)

Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageF3; pro…

In [158]:
checkerboard(fixed_image, moving_image)

VBox(children=(Viewer(annotations=False, interpolation=False, rendered_image=<itk.itkImagePython.itkImageF3; p…

In [159]:
compare(fixed_image, moving_image)

AppLayout(children=(HBox(children=(Label(value='Link:'), Checkbox(value=False, description='cmap'), Checkbox(v…

In [4]:
%matplotlib inline
import matplotlib.pyplot as plt

from ipywidgets import interact, fixed
from IPython.display import clear_output

# Callback invoked by the interact IPython method for scrolling through the image stacks of
# the two images (moving and fixed).
def display_images(fixed_image_z, moving_image_z, fixed_npa, moving_npa):
    # Create a figure with two subplots and the specified size.
    plt.subplots(1, 2, figsize=(10, 8))

    # Draw the fixed image in the first subplot.
    plt.subplot(1, 2, 1)
    plt.imshow(fixed_npa[fixed_image_z, :, :], cmap=plt.cm.Greys_r)
    plt.title("fixed image")
    plt.axis("off")

    # Draw the moving image in the second subplot.
    plt.subplot(1, 2, 2)
    plt.imshow(moving_npa[moving_image_z, :, :], cmap=plt.cm.Greys_r)
    plt.title("moving image resampled")
    plt.axis("off")

    plt.show()


# Callback invoked by the IPython interact method for scrolling and modifying the alpha blending
# of an image stack of two images that occupy the same physical space.
def display_images_with_alpha(image_z, alpha, fixed, moving):
    img = (1.0 - alpha) * fixed[:, :, image_z] + alpha * moving[:, :, image_z]
    
    plt.imshow(itk.GetImageFromArray(img), cmap=plt.cm.Greys_r)
    plt.axis("off")
    plt.show()


# Callback invoked when the StartEvent happens, sets up our new data.
def start_plot():
    global metric_values, multires_iterations

    metric_values = []
    multires_iterations = []


# Callback invoked when the EndEvent happens, do cleanup of data and figure.
def end_plot():
    global metric_values, multires_iterations

    del metric_values
    del multires_iterations
    # Close figure, we don't want to get a duplicate of the plot latter on.
    plt.close()


# Callback invoked when the IterationEvent happens, update our data and display new figure.
def plot_values(registration_method):
    global metric_values, multires_iterations

    metric_values.append(registration_method.GetMetricValue())
    # Clear the output area (wait=True, to reduce flickering), and plot current data
    clear_output(wait=True)
    # Plot the similarity metric values
    plt.plot(metric_values, "r")
    plt.plot(
        multires_iterations,
        [metric_values[index] for index in multires_iterations],
        "b*",
    )
    plt.xlabel("Iteration Number", fontsize=12)
    plt.ylabel("Metric Value", fontsize=12)
    plt.show()


# Callback invoked when the sitkMultiResolutionIterationEvent happens, update the index into the
# metric_values list.
def update_multires_iterations():
    global metric_values, multires_iterations
    multires_iterations.append(len(metric_values))

In [5]:

interact(
    display_images,
    fixed_image_z=(0, itk.size(fixed_image)[2] - 1),
    moving_image_z=(0, itk.size(moving_image)[2] - 1),
    fixed_npa=fixed(itk.GetArrayFromImage(fixed_image)),
    moving_npa=fixed(itk.GetArrayFromImage(moving_image)),
);

interactive(children=(IntSlider(value=16, description='fixed_image_z', max=32), IntSlider(value=31, descriptio…

## Initialize parameters and perform image registration

In [6]:

Dimension = 3
PixelType = itk.F

ImageType = itk.Image[PixelType, Dimension]
TransformType = itk.TranslationTransform[itk.D, Dimension]
OptimizerType = itk.RegularStepGradientDescentOptimizer
#MetricType = itk.MeanSquaresImageToImageMetric[ImageType, ImageType]
MetricType = itk.MattesMutualInformationImageToImageMetric[ImageType, ImageType]
InterpolatorType = itk.NearestNeighborInterpolateImageFunction[ImageType, itk.D]
RegistrationType = itk.ImageRegistrationMethod[ImageType, ImageType]

FixedImageType = ImageType
MovingImageType = ImageType

In [7]:
metric = MetricType.New()
transform = TransformType.New()
optimizer = OptimizerType.New()
interpolator = InterpolatorType.New()
registration = RegistrationType.New()

In [8]:
registration.SetFixedImage(fixed_image)
registration.SetMovingImage(moving_image)
registration.SetMetric(metric)
registration.SetOptimizer(optimizer)
registration.SetTransform(transform)
registration.SetInterpolator(interpolator)
registration.SetFixedImageRegion(fixed_image.GetLargestPossibleRegion())
registration.SetInitialTransformParameters(transform.GetParameters())

In [9]:
optimizer.SetMaximumStepLength(4.00)
optimizer.SetMinimumStepLength(0.001)
optimizer.SetNumberOfIterations(200)

optimizer.SetRelaxationFactor(0.5)

registration.Update()

finalParameters = registration.GetLastTransformParameters()
TranslationAlongX = finalParameters[0]
TranslationAlongY = finalParameters[1]
TranslationAlongZ = finalParameters[2]
numberOfIterations = optimizer.GetCurrentIteration()
bestValue = optimizer.GetValue()

print("Result = ") 
print(" Translation X = " , TranslationAlongX)
print(" Translation Y = " , TranslationAlongY)
print(" Translation Z = " , TranslationAlongZ)
print(" Iterations    = " ,numberOfIterations)
print(" Metric value  = " ,bestValue)


Result = 
 Translation X =  -0.296669490292377
 Translation Y =  -8.196658153847325
 Translation Z =  -2.4992790081625036
 Iterations    =  22
 Metric value  =  -0.7161121058731182


## Resample moving image 

In [10]:
ResampleFilterType = itk.ResampleImageFilter[ImageType, ImageType]
resampler = ResampleFilterType.New()
resampler.SetInput(moving_image)
resampler.SetTransform(registration.GetOutput().Get())
resampler.SetSize(fixed_image.GetLargestPossibleRegion().GetSize())
resampler.SetOutputOrigin(fixed_image.GetOrigin())
resampler.SetOutputSpacing(fixed_image.GetSpacing())
resampler.SetOutputDirection(fixed_image.GetDirection())
resampler.SetDefaultPixelValue(100)
resampler.SetInterpolator(interpolator)
resampler.Update()

## Image registration results 

In [11]:
interact(
    display_images,
    fixed_image_z=(0, itk.size(fixed_image)[2] - 1),
    moving_image_z=(0, itk.size(resampler.GetOutput())[2] - 1),
    fixed_npa=fixed(itk.GetArrayFromImage(fixed_image)),
    moving_npa=fixed(itk.GetArrayFromImage(resampler.GetOutput())),
);

interactive(children=(IntSlider(value=16, description='fixed_image_z', max=32), IntSlider(value=16, descriptio…

In [12]:
final_transform=registration.GetTransform()
print(final_transform)

TranslationTransform (000001D5EB3D0B70)
  RTTI typeinfo:   class itk::TranslationTransform<double,3>
  Reference Count: 6
  Modified Time: 594
  Debug: Off
  Object Name: 
  Observers: 
    none
  Offset: [-0.296669, -8.19666, -2.49928]



## Save results

In [13]:
itk.imwrite(resampler.GetOutput(),'moving_resampled_itk_MMI.mha')

In [14]:
writer = itk.TransformFileWriterTemplate[itk.F].New()
transform_filename = 'Transform_between_T1_T2_MRI.tfm'
writer.SetInput(final_transform)
writer.SetFileName(transform_filename)
writer.Update()