In [None]:
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from skimage import io
from os.path import expanduser
from tqdm import tqdm
HOME = expanduser("~")
import os, sys
import SimpleITK as sitk
from ipywidgets import interact, fixed
from IPython.display import clear_output
%load_ext autoreload
%autoreload 2

In [None]:
def start_plot():
    global metric_values, multires_iterations
    metric_values = []
    multires_iterations = []
# Callback invoked when the EndEvent happens, do cleanup of data and figure.
def end_plot():
    global metric_values, multires_iterations
    del metric_values
    del multires_iterations
    # Close figure, we don't want to get a duplicate of the plot latter on.
    plt.close()
# Callback invoked when the sitkMultiResolutionIterationEvent happens, update the index into the 
# metric_values list. 
def update_multires_iterations():
    global metric_values, multires_iterations
    multires_iterations.append(len(metric_values))
# Callback invoked when the IterationEvent happens, update our data and display new figure.
def plot_values(registration_method):
    global metric_values, multires_iterations
    
    metric_values.append(registration_method.GetMetricValue())                                       
    # Clear the output area (wait=True, to reduce flickering), and plot current data
    clear_output(wait=True)
    # Plot the similarity metric values
    plt.plot(metric_values, 'r')
    plt.plot(multires_iterations, [metric_values[index] for index in multires_iterations], 'b*')
    plt.xlabel('Iteration Number',fontsize=12)
    plt.ylabel('Metric Value',fontsize=12)
    plt.show()

In [None]:
animal = 'DK39'
DIR = f'/net/birdstore/Active_Atlas_Data/data_root/pipeline_data/{animal}/preps'
INPUT = os.path.join(DIR, 'CH1', 'thumbnail_cleaned')
ELASTIX = os.path.join(DIR, 'elastix')

In [None]:
def register(fixed_image, moving_image):
    initial_transform = sitk.CenteredTransformInitializer(
    fixed_image,
    moving_image,
    sitk.Euler2DTransform(),
    sitk.CenteredTransformInitializerFilter.GEOMETRY)

    registration_method = sitk.ImageRegistrationMethod()

    # Similarity metric settings.
    registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
    registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
    registration_method.SetMetricSamplingPercentage(0.01)

    registration_method.SetInterpolator(sitk.sitkLinear)

    # Optimizer settings.
    registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, 
                                                      numberOfIterations=100, 
                                                      convergenceMinimumValue=1e-6, 
                                                      convergenceWindowSize=10)
    registration_method.SetOptimizerScalesFromPhysicalShift()

    # Setup for the multi-resolution framework.            
    registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [4,2,1])
    registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2,1,0])
    registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

    # Don't optimize in-place, we would possibly like to run this cell multiple times.
    registration_method.SetInitialTransform(initial_transform, inPlace=True)


    final_transform = registration_method.Execute(sitk.Cast(fixed_image, sitk.sitkFloat32), 
                                                   sitk.Cast(moving_image, sitk.sitkFloat32))
    return final_transform

In [None]:
image_name_list = sorted(os.listdir(INPUT))
for i in tqdm(range(1, len(image_name_list))):
    final_transform = None
    prev_img_name = os.path.splitext(image_name_list[i - 1])[0]
    curr_img_name = os.path.splitext(image_name_list[i])[0]
    moving_file = os.path.join(INPUT, image_name_list[i - 1])
    fixed_file = os.path.join(INPUT, image_name_list[i])
    outfile = f'{curr_img_name}-{prev_img_name}.tfm'
    outpath = os.path.join(ELASTIX, outfile)
    if os.path.exists(outpath):
        continue
    
    moving_image = sitk.ReadImage(moving_file, sitk.sitkUInt16)
    fixed_image =  sitk.ReadImage(fixed_file, sitk.sitkUInt16)
    
    try:
        final_transform = register(fixed_image, moving_image)
    except:
        print('Could not create transform for ', outfile)
    if final_transform is not None:
        sitk.WriteTransform(final_transform, outpath)    


In [None]:
outfile = os.path.join(ELASTIX, '225-224.tfm')
sitk.WriteTransform(final_transform, outfile)

In [None]:
fig = plt.figure(figsize=(20,10))

fig.add_subplot(1,3,1)
plt.imshow(sitk.GetArrayViewFromImage(fixed_image), cmap='gray')
plt.title('fixed image', fontsize=10)

fig.add_subplot(1,3,2)
plt.imshow(sitk.GetArrayViewFromImage(moving_image), cmap='gray')
plt.title('moving image')

fig.add_subplot(1,3,3)
plt.title('resampled moving image')
plt.imshow(sitk.GetArrayViewFromImage(moving_resampled), cmap='gray')
    