In [None]:
"""
The following lines need to be in the notebook before importing this file
import matplotlib.pyplot as plt
%matplotlib inline

from ipywidgets import interact, fixed
from IPython.display import clear_output
"""


# Callback invoked by the interact IPython method for scrolling through the image stacks of
# the two images (moving and fixed).
def display_images(fixed_image_z, moving_image_z, fixed_npa, moving_npa):
    # Create a figure with two subplots and the specified size.
    plt.subplots(1,2,figsize=(10,8))
    
    # Draw the fixed image in the first subplot.
    plt.subplot(1,2,1)
    plt.imshow(fixed_npa[fixed_image_z,:,:],cmap=plt.cm.Greys_r);
    plt.title('fixed image')
    plt.axis('off')
    
    # Draw the moving image in the second subplot.
    plt.subplot(1,2,2)
    plt.imshow(moving_npa[moving_image_z,:,:],cmap=plt.cm.Greys_r);
    plt.title('moving image')
    plt.axis('off')
    
    plt.show()

# Callback invoked by the IPython interact method for scrolling and modifying the alpha blending
# of an image stack of two images that occupy the same physical space. 
def display_images_with_alpha(image_z, alpha, fixed, moving):
    img = (1.0 - alpha)*fixed[:,:,image_z] + alpha*moving[:,:,image_z] 
    plt.imshow(sitk.GetArrayViewFromImage(img),cmap=plt.cm.Greys_r);
    plt.axis('off')
    plt.show()
    
# Callback invoked when the StartEvent happens, sets up our new data.
def start_plot():
    global metric_values, multires_iterations
    
    metric_values = []
    multires_iterations = []

# Callback invoked when the EndEvent happens, do cleanup of data and figure.
def end_plot():
    global metric_values, multires_iterations
    
    del metric_values
    del multires_iterations
    # Close figure, we don't want to get a duplicate of the plot latter on.
    plt.close()

# Callback invoked when the IterationEvent happens, update our data and display new figure.    
def plot_values(registration_method):
    global metric_values, multires_iterations
    
    metric_values.append(registration_method.GetMetricValue())                                       
    # Clear the output area (wait=True, to reduce flickering), and plot current data
    clear_output(wait=True)
    # Plot the similarity metric values
    plt.plot(metric_values, 'r')
    plt.plot(multires_iterations, [metric_values[index] for index in multires_iterations], 'b*')
    plt.xlabel('Iteration Number',fontsize=12)
    plt.ylabel('Metric Value',fontsize=12)
    plt.show()
    
# Callback invoked when the sitkMultiResolutionIterationEvent happens, update the index into the 
# metric_values list. 
def update_multires_iterations():
    global metric_values, multires_iterations
    multires_iterations.append(len(metric_values))   
    
    

In [None]:
registration_method = sitk.ImageRegistrationMethod()

# Similarity metric settings.
#registration_method.SetMetricAsMeanSquares()
#registration_method.SetMetricAsANTSNeighborhoodCorrelation(radius=3)
registration_method.SetMetricAsMeanSquares()
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentage(0.05)

registration_method.SetInterpolator(sitk.sitkLinear)

# Optimizer settings.
#registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100,
#                                                  convergenceMinimumValue=1e-6, convergenceWindowSize=10)

registration_method.SetOptimizerAsLBFGSB(gradientConvergenceTolerance=1e-5, numberOfIterations=100)
registration_method.SetOptimizerScalesFromPhysicalShift()

# Registration Method
registration_method.SetMetricFixedMask(idimg)
registration_method.SetMetricMovingMask(idimg)

# Setup for the multi-resolution framework.            
registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [4,2])
registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[4,2])
registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

# Don't optimize in-place, we would possibly like to run this cell multiple times.
#idtransform=sitk.Transform()
#registration_method.SetInitialTransform(idtransform, inPlace=False)

# Determine the number of BSpline control points using the physical spacing we want for the control grid. 
grid_physical_spacing = [96.0, 96.0, 96.0] # A control point every 50mm
image_physical_size = [size*spacing for size,spacing in zip(idimg.GetSize(), idimg.GetSpacing())]
mesh_size = [int(image_size/grid_spacing + 0.5) \
             for image_size,grid_spacing in zip(image_physical_size,grid_physical_spacing)]

initial_transform = sitk.BSplineTransformInitializer(image1 = idimg, 
                                                     transformDomainMeshSize = mesh_size, order=3)    
registration_method.SetInitialTransform(initial_transform)

# Connect all of the observers so that we can perform plotting during registration.
registration_method.AddCommand(sitk.sitkStartEvent, start_plot)
registration_method.AddCommand(sitk.sitkEndEvent, end_plot)
registration_method.AddCommand(sitk.sitkMultiResolutionIterationEvent, update_multires_iterations) 
registration_method.AddCommand(sitk.sitkIterationEvent, lambda: plot_values(registration_method))

final_transform = registration_method.Execute(sitk.Cast(out_img, sitk.sitkFloat32), 
                                              sitk.Cast(idimg, sitk.sitkFloat32))

In [None]:
    """
    previous_mask = sitk.Image(in_mask)
    build_up_img = sitk.Image(idimg)*0
    for rg in range(1,12):
        next_mask=sitk.DilateObjectMorphology(previous_mask,4)
        RingMask=sitk.Cast(next_mask-previous_mask > 0, sitk.sitkInt32)
        smoothed_img = sitk.Cast(sitk.SmoothingRecursiveGaussian(idimg,rg),sitk.sitkInt32)
        build_up_img = smoothed_img * RingMask
        print(rg)
    RingMask=sitk.Cast(1-next_mask, sitk.sitkInt32)
    smoothed_img = sitk.Cast(sitk.SmoothingRecursiveGaussian(idimg,12),sitk.sitkInt32)
    build_up_img = smoothed_img * RingMask

    build_up_img= sitk.Cast(( in_mask * idimg + not_in_mask * build_up_img),sitk.sitkInt32)
    """
