In [None]:
import os
import sys
import numpy as np
from skimage import io
from matplotlib import pyplot as plt
import cv2
import SimpleITK as sitk
from tifffile import imwrite

In [None]:
def compare_two_images(img1, img2):
    plt.figure(figsize=[20,20])
    plt.imshow(img1,cmap='Blues')
    plt.imshow(img2,alpha=0.5,cmap='Reds')

In [None]:
DATA = '/net/birdstore/Active_Atlas_Data/data_root/pipeline_data/DK52/preps/CH1'
REGDATA = '/net/birdstore/Active_Atlas_Data/data_root/brains_info/registration'

In [None]:
filename = 'allen_10um_sagittal.tif'
fixedFilepath = os.path.join(REGDATA, filename)
fixed_volume = io.imread(fixedFilepath)
print(f'Fixed volume shape={fixed_volume.shape} dtype={fixed_volume.dtype}')

In [None]:
fixed_midpoint = fixed_volume.shape[0] // 2
fixed_image = fixed_volume[fixed_midpoint,:,:]
del fixed_volume
print(f'Fixed image shape={fixed_image.shape} dtype={fixed_image.dtype}')

In [None]:
# moving_volume around x=1200, y = 750 for 10um
# moving volume around x=600, y=350 for 25um
files = os.listdir(os.path.join(DATA, 'thumbnail_aligned'))
midpoint = len(files) // 2
filename = f'{midpoint}.tif'
movingFilepath = os.path.join(DATA, 'thumbnail_aligned', filename)
moving_image = io.imread(movingFilepath)
print(f'Shape of {filename}: {moving_image.shape}')
#moving_image = moving_image[200:,200:]
#print(f'Shape of {filename}: {moving_image.shape}')

In [None]:
fx = 65500
fy = 35500
print(fx/2252)
print(fy/1220)
1/29.85*100

In [None]:
fixedImage = sitk.ReadImage(fixedFilepath)
movingImage = sitk.ReadImage(movingFilepath)

initial_transform = sitk.CenteredTransformInitializer(fixedImage, 
                                                    movingImage, 
                                                    sitk.Euler3DTransform(), 
                                                    sitk.CenteredTransformInitializerFilter.MOMENTS)

moving_resampled = sitk.Resample(movingImage, fixedImage, initial_transform, sitk.sitkLinear, 0.0, movingImage.GetPixelID())
moving_volume = sitk.GetArrayFromImage(moving_resampled)

In [None]:
fixedImage = sitk.GetImageFromArray(fixed_image)
movingImage = sitk.GetImageFromArray(moving_image)
print(type(fixedImage))

In [None]:
%%time
elastixImageFilter = sitk.ElastixImageFilter()
elastixImageFilter.SetFixedImage(fixedImage)
elastixImageFilter.SetMovingImage(movingImage)
translateParameterMap = sitk.GetDefaultParameterMap('translation')
translateParameterMap["ResultImagePixelType"] = ["unsigned char"]
translateParameterMap["MaximumNumberOfIterations"] = ["15"] # 250 works ok        
rigidParameterMap = sitk.GetDefaultParameterMap('rigid')
rigidParameterMap["ResultImagePixelType"] = ["unsigned char"]
rigidParameterMap["MaximumNumberOfIterations"] = ["1500"] # 250 works ok        
rigidParameterMap["MaximumNumberOfSamplingAttempts"] = ["10"]
rigidParameterMap["UseDirectionCosines"] = ["true"]
rigidParameterMap["NumberOfResolutions"]= ["6"]
rigidParameterMap["NumberOfSpatialSamples"] = ["4000"]
rigidParameterMap["WriteResultImage"] = ["false"]


affineParameterMap = sitk.GetDefaultParameterMap('affine')
affineParameterMap["UseDirectionCosines"] = ["true"]
affineParameterMap["MaximumNumberOfIterations"] = ["1000"] # 250 works ok
affineParameterMap["MaximumNumberOfSamplingAttempts"] = ["10"]
affineParameterMap["NumberOfResolutions"]= ["6"]
affineParameterMap["NumberOfSpatialSamples"] = ["4000"]
affineParameterMap["WriteResultImage"] = ["false"]

bsplineParameterMap = sitk.GetDefaultParameterMap('bspline')

bsplineParameterMap["MaximumNumberOfIterations"] = ["1500"] # 150 works ok

bsplineParameterMap["WriteResultImage"] = ["true"]
"""
bsplineParameterMap["UseDirectionCosines"] = ["true"]
bsplineParameterMap["FinalGridSpacingInVoxels"] = ["10"]
bsplineParameterMap["MaximumNumberOfSamplingAttempts"] = ["10"]
bsplineParameterMap["NumberOfResolutions"]= ["6"]
bsplineParameterMap["GridSpacingSchedule"] = ["6.219", "4.1", "2.8", "1.9", "1.4", "1.0"]
bsplineParameterMap["NumberOfSpatialSamples"] = ["4000"]
del bsplineParameterMap["FinalGridSpacingInPhysicalUnits"]
"""
#elastixImageFilter.SetParameterMap(translateParameterMap)
elastixImageFilter.SetParameterMap(rigidParameterMap)
#elastixImageFilter.AddParameterMap(affineParameterMap)
#elastixImageFilter.AddParameterMap(bsplineParameterMap)
resultImage = elastixImageFilter.Execute()

In [None]:
sitk.PrintParameterMap(sitk.GetDefaultParameterMap("rigid"))

In [None]:
#arr = fixed_volume
r = sitk.Cast(resultImage, sitk.sitkUInt16)
registered_image = sitk.GetArrayFromImage(r)
plt.title('all')
plt.imshow(registered_image, cmap="gray")
plt.show()

In [None]:
registered_image.dtype

In [None]:
# 1- blue is fixed, 2 red is moving
compare_two_images(fixed_image, registered_image)

In [None]:
plt.title('fixed image')
plt.imshow(fixed_image, cmap="gray")
plt.show()

In [None]:
plt.title('moving image')
plt.imshow(moving_image, cmap="gray")
plt.show()

In [None]:
plt.title('registered image')
plt.imshow(registered_image, cmap="gray")
plt.show()

In [None]:
#scaled = zoom(moving_volume, zoom=(1, scaler, scaler))
#print(scaled.shape)
outpath = os.path.join(DATA, 'registered_243.allparams.tif')
imwrite(outpath, registered_image)
#vm2 = np.swapaxes(volume, 0,2)
#print(vm2.shape)
#outpath = os.path.join(DATA, 'allen_50um_sagittal.tif')
#imwrite(outpath, vm2)

In [None]:
result_image_affine, result_transform_parameters = itk.elastix_registration_method(fixed, moving, parameter_object = parameter_object, log_to_console=True)
