In [1]:
import itk

In [2]:
import os
import matplotlib
import matplotlib.pyplot as plt

In [3]:
PixelType = itk.ctype('signed short')
DiffusionImageType = itk.Image[PixelType, 3]
FileReaderType = itk.ImageFileReader[DiffusionImageType]

In [22]:
PixelType = itk.ctype('float')
fixed_image = itk.imread('Data/case6_gre1.nrrd', PixelType)
moving_image = itk.imread('Data/case6_gre2.nrrd', PixelType)
dimension = fixed_image.GetImageDimension()
FixedImageType = itk.Image[PixelType,dimension]
MovingImageType = itk.Image[PixelType,dimension]

In [17]:
# # Reading the images
# reader = FileReaderType.New()
# reader.SetFileName('Data/case6_gre1.nrrd')
# reader.Update()
# image1 = reader.GetOutput()
# reader.SetFileName('Data/case6_gre2.nrrd')
# reader.Update()
# image2 = reader.GetOutput()

In [27]:
TransformType = itk.TranslationTransform[itk.D, dimension]
initial_transform = TransformType.New()

In [28]:
optimizer = itk.RegularStepGradientDescentOptimizerv4.New()
optimizer.SetLearningRate(4)
optimizer.SetMinimumStepLength(0.001)
optimizer.SetNumberOfIterations(200)

In [29]:
metric = itk.MeanSquaresImageToImageMetricv4[FixedImageType, MovingImageType].New()
fixed_image_interpolation = itk.LinearInterpolateImageFunction[FixedImageType, itk.D].New()
metric.SetFixedInterpolator(fixed_image_interpolation)

In [30]:
registration = itk.ImageRegistrationMethodv4[FixedImageType,MovingImageType].New(
    FixedImage=fixed_image,
    MovingImage=moving_image,
    Metric=metric,
    Optimizer=optimizer,
    InitialTransform=initial_transform
)

In [31]:
moving_initial_transform = TransformType.New()
initial_parameters = moving_initial_transform.GetParameters()
initial_parameters[0] = 0
initial_parameters[1] = 0
moving_initial_transform.SetParameters(initial_parameters)
registration.SetMovingInitialTransform(moving_initial_transform)

In [32]:
identity_transform = TransformType.New()
identity_transform.SetIdentity()

registration.SetFixedInitialTransform(identity_transform)
registration.SetNumberOfLevels(1)
registration.SetSmoothingSigmasPerLevel([0])
registration.SetShrinkFactorsPerLevel([1])
registration.Update()

In [33]:
transform = registration.GetTransform()
finalParameters = transform.GetParameters()
translationAlongX = finalParameters.GetElement(0)
translationAlongY = finalParameters.GetElement(1)

In [34]:
number_of_iterations = optimizer.GetCurrentIteration()

In [35]:
best_value = optimizer.GetValue()

In [37]:
print("Result = ")
print(" Translation X = " + str(translationAlongX))
print(" Translation Y = " + str(translationAlongY))
print(" Iterations    = " + str(number_of_iterations))
print(" Metric value  = " + str(best_value))

Result = 
 Translation X = -0.835883044801508
 Translation Y = -3.5406643556563693
 Iterations    = 31
 Metric value  = 11177.92375512222


In [41]:
CompositeTransformType = itk.CompositeTransform[itk.D, dimension]
outputCompositeTransform = CompositeTransformType.New()
outputCompositeTransform.AddTransform(initial_transform)
outputCompositeTransform.AddTransform(registration.GetModifiableTransform())

In [42]:
resampler = itk.ResampleImageFilter.New(
    Input=moving_image,
    Transform=transform,
    UseReferenceImage=True,
    ReferenceImage=fixed_image
)

In [43]:
resampler.SetDefaultPixelValue(100)

In [44]:
subtraction = itk.SubtractImageFilter(Input1=fixed_image, Input2=resampler)

It might be interesting to test things with an intensity rescaler as well

In [46]:
OutputPixelType = itk.UC
OutputImageType = itk.Image[OutputPixelType, dimension]
caster = itk.CastImageFilter[FixedImageType, OutputImageType].New(resampler)

In [47]:
itk.imwrite(caster, "res.nrrd")