In [None]:
import os
import itertools
import itk
import itkwidgets
from skimage import io
import numpy as np
assert "ElastixRegistrationMethod" in dir(itk)  # Ensure itk-elastix is installed

In [None]:
DATA = '/net/birdstore/Active_Atlas_Data/data_root/pipeline_data/DK52/preps/CH1'
REGDATA = '/net/birdstore/Active_Atlas_Data/data_root/brains_info/registration'

In [None]:
filename = 'allen_100um_sagittal.tif'
fixedFilepath = os.path.join(REGDATA, filename)
fixed_volume = io.imread(fixedFilepath)
fz, fy, fx = fixed_volume.shape
print(f'Fixed volume shape={fixed_volume.shape} dtype={fixed_volume.dtype}')

In [None]:
movingFilepath = os.path.join(DATA, 'aligned_volume.256.tif')
moving_volume = io.imread(movingFilepath)
mz, my, mx = moving_volume.shape
print(f'Moving volume shape={moving_volume.shape} dtype={moving_volume.dtype}')

In [None]:
source_image = itk.image_from_array(moving_volume.astype(np.float32))
target_image = itk.image_from_array(fixed_volume.astype(np.float32))
#source_image.SetSpacing([axis_spacings[axis_name] for axis_name in itk_spatial_axes])
#print(source_image)

In [None]:
# Image orientation is derived from accompanying "acquisition.json" file

# ITK is in "right-to-left", "anterior-to-posterior", "inferior-to-superior" (LPS) space.
# "acquisition.json" for "SmartSPIM_631680_2022-09-09_13-52-33" shows voxel data uses same
# axes order but inverted, i.e. "left-to-right", "posterior-to-anterior", "superior-to-inferior" (RAI).
# moving is LPI
INPUT_COORDINATE_ORIENTATION = (
    itk.SpatialOrientationEnums.ValidCoordinateOrientations_ITK_COORDINATE_ORIENTATION_RAI
)

In [None]:
def get_bounds(image, transform=None):
    """Get the physical boundaries of the space sampled by the ITK image.
       Each voxel in an ITK image is considered to be a sample of the spatial
       volume occupied by that voxel taken at the spatial center of the volume.
       The physical point returned at each discrete voxel coordinate is
       considered to be the physical location of the sample point. We adjust by
       half a voxel in each direction to get the bounds of the space sampled
       by the image.
    """
    HALF_VOXEL_STEP = 0.5
    dimension = image.GetImageDimension()
    lower_index = itk.ContinuousIndex[itk.D,dimension]()
    lower_index.Fill(-1 * HALF_VOXEL_STEP)
    upper_index = itk.ContinuousIndex[itk.D,dimension]()
    for dim in range(dimension):
        upper_index.SetElement(dim, itk.size(image)[dim] + HALF_VOXEL_STEP)
    
    image_bounds = [
        image.TransformContinuousIndexToPhysicalPoint(lower_index),
        image.TransformContinuousIndexToPhysicalPoint(upper_index)
    ]
    return (
        [transform.TransformPoint(pt) for pt in image_bounds]
        if transform
        else image_bounds
    )

def get_physical_size(image, transform=None):
    """Get the distance along each size of the physical space sampled by the image"""
    bounds = get_bounds(image, transform)
    return np.absolute(np.array(bounds[1]) - np.array(bounds[0]))

In [None]:
print(f"CCF physical bounds: {get_bounds(target_image)}")
print(f"Moving image physical bounds: {get_bounds(source_image)}")

In [None]:
print(f"CCF physical size: {get_physical_size(target_image)}")
print(f"Moving image physical size: {get_physical_size(source_image)}")

In [None]:
itk.auto_progress(1)
itk.CenteredTransformInitializer
itk.auto_progress(0)

In [None]:
# Translate to roughly position sample data on top of CCF data
init_transform = itk.VersorRigid3DTransform[
    itk.D
].New()  # Represents 3D rigid transformation with unit quaternion
init_transform.SetIdentity()

transform_initializer = itk.CenteredVersorTransformInitializer[
    type(target_image), type(source_image)
].New()
transform_initializer.SetFixedImage(target_image)
transform_initializer.SetMovingImage(source_image)
transform_initializer.SetTransform(init_transform)
transform_initializer.GeometryOn()  # We compute translation between the center of each image
transform_initializer.ComputeRotationOff()  # We have previously verified that spatial orientation aligns

transform_initializer.InitializeTransform()

# initializer maps from the fixed image to the moving image,
# whereas we want to map from the moving image to the fixed image.
init_transform = init_transform.GetInverseTransform()

print(init_transform)

In [None]:
# Apply translation without resampling the image by updating the image origin directly
change_information_filter = itk.ChangeInformationImageFilter[type(source_image)].New()
change_information_filter.SetInput(source_image)
change_information_filter.SetOutputOrigin(
    init_transform.TransformPoint(itk.origin(source_image))
)
change_information_filter.ChangeOriginOn()
change_information_filter.UpdateOutputInformation()

source_image_init = change_information_filter.GetOutput()
print(source_image_init)

In [None]:
# Verify that the initialized source image bounds overlap with the target image

print(
    f"Original input source image bounds: {get_bounds(source_image)[0]}, {get_bounds(source_image)[1]}"
)
print(
    f"Translated source image bounds: {get_bounds(source_image_init)[0]}, {get_bounds(source_image_init)[1]}"
)
print(
    f"Target image bounds: {get_bounds(target_image)[0]}, {get_bounds(target_image)[1]}"
)

In [None]:
itk.auto_progress(1)
itk.ElastixRegistrationMethod
itk.auto_progress(0)

In [None]:
parameter_object = itk.ParameterObject.New()
parameter_object.AddParameterMap(
    parameter_object.GetDefaultParameterMap("rigid")
)
parameter_object.AddParameterMap(
    parameter_object.GetDefaultParameterMap("affine")
)

bspline_map = parameter_object.GetDefaultParameterMap("bspline")
bspline_map["FinalGridSpacingInPhysicalUnits"] = ("0.5000",)
parameter_object.AddParameterMap(bspline_map)

print(parameter_object)

In [None]:
registration_method = itk.ElastixRegistrationMethod[
    type(target_image), type(source_image)
].New(
    fixed_image=target_image,
    moving_image=source_image_init,
    parameter_object=parameter_object,
    log_to_console=False,
)

In [None]:
# Run registration with `itk-elastix`, may take a few minutes
registration_method.Update()