In [1]:
import os
import sys
import numpy as np
from skimage import io
import itk
from scipy import ndimage


In [2]:
def compare_two_images(img1, img2):
    plt.figure(figsize=[20,20])
    plt.imshow(img1,cmap='Blues')
    plt.imshow(img2,alpha=0.5,cmap='Reds')
    
def compute_dice_coefficient(source_image: itk.Image, target_image: itk.Image) -> float:
    """Compute the dice coefficient to compare volume overlap between two label regions"""
    dice_filter = itk.LabelOverlapMeasuresImageFilter[type(source_image)].New()
    dice_filter.SetInput(source_image)
    dice_filter.SetTargetImage(target_image)
    dice_filter.Update()
    return dice_filter.GetDiceCoefficient()

In [3]:
DATA = '/net/birdstore/Active_Atlas_Data/data_root/pipeline_data/DK52/preps/CH1'
REGDATA = '/net/birdstore/Active_Atlas_Data/data_root/brains_info/registration'

In [4]:
filename = 'allen_50um_sagittal.tif'
fixedFilepath = os.path.join(REGDATA, filename)
fixed_volume = io.imread(fixedFilepath)
fz, fy, fx = fixed_volume.shape
print(f'Fixed volume shape={fixed_volume.shape} dtype={fixed_volume.dtype}')

Fixed volume shape=(228, 160, 264) dtype=uint16


In [None]:
colpad = np.zeros((fixed_volume.shape[0], fixed_volume.shape[1], 100), np.uint16)
print(f'colspad shape={colpad.shape}')
fixed_volume = np.concatenate((fixed_volume, colpad), axis=2)
rowpad = np.zeros((fixed_volume.shape[0], 50, fixed_volume.shape[2]), np.uint16)
print(f'rowpad shape={rowpad.shape}')
fixed_volume = np.concatenate((fixed_volume, rowpad), axis=1)
print(f'Fixed volume after padding shape={fixed_volume.shape} dtype={fixed_volume.dtype}')

In [None]:
#outpath = os.path.join(REGDATA, 'allen_25um_sagittal_padded.tif')
#imwrite(outpath, fixed_volume)

In [5]:
movingFilepath = os.path.join(DATA, 'aligned_volume.128.tif')
moving_volume = io.imread(movingFilepath)
mz, my, mx = moving_volume.shape
print(f'Moving volume shape={moving_volume.shape} dtype={moving_volume.dtype}')

Moving volume shape=(486, 139, 256) dtype=uint16


In [None]:
fixed_midpoint = fixed_volume.shape[0] // 2
fixed_image = fixed_volume[fixed_midpoint,:,:]
print(f'Fixed image shape={fixed_image.shape} dtype={fixed_image.dtype}')

In [None]:
# moving_volume around x=1200, y = 750 for 10um
# moving volume around x=600, y=350 for 25um
files = os.listdir(os.path.join(DATA, 'thumbnail_aligned'))
midpoint = len(files) // 2
filename = f'{midpoint}.tif'
movingFilepath = os.path.join(DATA, 'thumbnail_aligned', filename)
moving_image = io.imread(movingFilepath)
print(f'Moving image shape={moving_image.shape} dtype={moving_image.dtype}')
#moving_image = moving_image[200:-200,200:-200]
#print(f'Shape after cropping of {filename}: {moving_image.shape}')

In [None]:
#movingImage = itk.image_from_array(moving_image.astype(np.float32))
#fixedImage = itk.image_from_array(fixed_image.astype(np.float32))
movingImage = itk.GetImageFromArray(np.ascontiguousarray(moving_image))
fixedImage = itk.GetImageFromArray(np.ascontiguousarray(fixed_image))

In [None]:
parameter_object = itk.ParameterObject.New()
trans_parameter_map = parameter_object.GetDefaultParameterMap('translation')
rigid_parameter_map = parameter_object.GetDefaultParameterMap('rigid')
affine_parameter_map = parameter_object.GetDefaultParameterMap('affine')
bspline_parameter_map = parameter_object.GetDefaultParameterMap("bspline", 3, 20.0)

parameter_object.AddParameterMap(trans_parameter_map)
parameter_object.AddParameterMap(rigid_parameter_map)
parameter_object.AddParameterMap(affine_parameter_map)
parameter_object.AddParameterMap(bspline_parameter_map)

In [None]:
# Load Elastix Image Filter Object
elastix_object = itk.ElastixRegistrationMethod.New(fixedImage, movingImage)
# elastix_object.SetFixedImage(fixed_image)
# elastix_object.SetMovingImage(moving_image)
elastix_object.SetParameterObject(parameter_object)
# Set additional options
elastix_object.SetLogToConsole(False)
#elastix_object.SetNumberOfThreads(2)
# Update filter object (required)
elastix_object.UpdateLargestPossibleRegion()
# Results of Registration
resultImage = elastix_object.GetOutput()
result_transform_parameters = elastix_object.GetTransformParameterObject()

In [None]:
registered_image = np.asarray(resultImage).astype(np.uint16)
print(f'image dtype={registered_image.dtype} shape={registered_image.shape}')

In [None]:
type(resultImage)

In [None]:
#rImage = itk.GetImageFromArray(np.ascontiguousarray(registered_image))
#print(type(rImage))
#itk.LabelOverlapMeasuresImageFilter.GetTypes()
dice_score = compute_dice_coefficient(resultImage, fixedImage)
print(f'Evaluated dice value: {dice_score}')

In [None]:
fixedImage = sitk.ReadImage(fixedFilepath)
movingImage = sitk.ReadImage(movingFilepath)

initial_transform = sitk.CenteredTransformInitializer(fixedImage, 
                                                    movingImage, 
                                                    sitk.Euler3DTransform(), 
                                                    sitk.CenteredTransformInitializerFilter.MOMENTS)

moving_resampled = sitk.Resample(movingImage, fixedImage, initial_transform, sitk.sitkLinear, 0.0, movingImage.GetPixelID())
moving_volume = sitk.GetArrayFromImage(moving_resampled)

In [None]:
%%time
elastixImageFilter = sitk.ElastixImageFilter()
elastixImageFilter.SetFixedImage(fixedImage)
elastixImageFilter.SetMovingImage(movingImage)
translateParameterMap = sitk.GetDefaultParameterMap('translation')
translateParameterMap["ResultImagePixelType"] = ["unsigned char"]
translateParameterMap["MaximumNumberOfIterations"] = ["15"] # 250 works ok        
rigidParameterMap = sitk.GetDefaultParameterMap('rigid')
rigidParameterMap["ResultImagePixelType"] = ["unsigned char"]
rigidParameterMap["MaximumNumberOfIterations"] = ["1500"] # 250 works ok        
rigidParameterMap["MaximumNumberOfSamplingAttempts"] = ["10"]
rigidParameterMap["UseDirectionCosines"] = ["true"]
rigidParameterMap["NumberOfResolutions"]= ["6"]
rigidParameterMap["NumberOfSpatialSamples"] = ["4000"]
rigidParameterMap["WriteResultImage"] = ["false"]


affineParameterMap = sitk.GetDefaultParameterMap('affine')
affineParameterMap["UseDirectionCosines"] = ["true"]
affineParameterMap["MaximumNumberOfIterations"] = ["1000"] # 250 works ok
affineParameterMap["MaximumNumberOfSamplingAttempts"] = ["10"]
affineParameterMap["NumberOfResolutions"]= ["6"]
affineParameterMap["NumberOfSpatialSamples"] = ["4000"]
affineParameterMap["WriteResultImage"] = ["false"]

bsplineParameterMap = sitk.GetDefaultParameterMap('bspline')
bsplineParameterMap["MaximumNumberOfIterations"] = ["1500"] # 150 works ok
bsplineParameterMap["WriteResultImage"] = ["false"]
bsplineParameterMap["UseDirectionCosines"] = ["true"]
bsplineParameterMap["FinalGridSpacingInVoxels"] = ["10"]
bsplineParameterMap["MaximumNumberOfSamplingAttempts"] = ["10"]
bsplineParameterMap["NumberOfResolutions"]= ["6"]
bsplineParameterMap["GridSpacingSchedule"] = ["6.219", "4.1", "2.8", "1.9", "1.4", "1.0"]
bsplineParameterMap["NumberOfSpatialSamples"] = ["4000"]
del bsplineParameterMap["FinalGridSpacingInPhysicalUnits"]

elastixImageFilter.SetParameterMap(translateParameterMap)
elastixImageFilter.SetParameterMap(rigidParameterMap)
elastixImageFilter.AddParameterMap(affineParameterMap)
#elastixImageFilter.AddParameterMap(bsplineParameterMap)
resultImage = elastixImageFilter.Execute()

In [None]:
sitk.PrintParameterMap(sitk.GetDefaultParameterMap("rigid"))

In [None]:
#arr = fixed_volume
#r = sitk.Cast(resultImage, sitk.sitkUInt16)
#registered_image = sitk.GetArrayFromImage(resultImage)
plt.title('all')
plt.imshow(registered_image, cmap="gray")
plt.show()

In [None]:
type(registered_image)

In [None]:
# 1- blue is fixed (Allen), 2 red is moving (DKXX)
compare_two_images(fixed_image, registered_image)

In [None]:
plt.title('fixed image')
plt.imshow(fixed_image, cmap="gray")
plt.show()

In [None]:
plt.title('reg image')
plt.imshow(resultImage, cmap="gray")
plt.show()

In [None]:
#ri = sitk.GetArrayFromImage(registered_image)
#plt.title('registered image')
#plt.imshow(ri, cmap="gray")
#plt.show()

In [None]:
#scaled = zoom(moving_volume, zoom=(1, scaler, scaler))
#print(scaled.shape)
outpath = os.path.join(DATA, 'registered_243.allparams.tif')
imwrite(outpath, registered_image)
#vm2 = np.swapaxes(volume, 0,2)
#print(vm2.shape)
#outpath = os.path.join(DATA, 'allen_50um_sagittal.tif')
#imwrite(outpath, vm2)

In [None]:
result_image_affine, result_transform_parameters = itk.elastix_registration_method(fixed, moving, parameter_object = parameter_object, log_to_console=True)
