In [11]:
# Import necessary libraries
import imageio
import SimpleITK as sitk
import os
import time
import matplotlib
import matplotlib.pyplot as plt
import glob
import natsort
from ipywidgets import interact, fixed
from IPython.display import clear_output 
from tqdm import tqdm
%matplotlib inline
import numpy as np
matplotlib.use('agg')
import matplotlib.image as mpimg


In [12]:
# Input zone hardcoded variables

path_of_mr_navigators           =               '/Users/lalith/Documents/GAN/MR-navigators' # change this
wild_card_string                =               'Transform_MRnav_*nii'
path_of_tfms                    =               '/Users/lalith/Documents/GAN/MR-navigators';
path_to_store_animation         =               '/Users/lalith/Documents/GAN/MR-navigators';

In [13]:
# Program start


# Utility functions borrowed from Simple ITK's notebook

from ipywidgets import interact, fixed
from IPython.display import clear_output

# Paste the two given images together. On the left will be image1 and on the right image2.
# image2 is also centered vertically in the combined image.
def write_combined_image(image1, image2, horizontal_space, file_name):
    combined_image = sitk.Image((image1.GetWidth() + image2.GetWidth() + horizontal_space,
                                max(image1.GetHeight(), image2.GetHeight())), 
                                image1.GetPixelID(), image1.GetNumberOfComponentsPerPixel())
    combined_image = sitk.Paste(combined_image, image1, image1.GetSize(), (0, 0), (0, 0))
    combined_image = sitk.Paste(combined_image, image2, image2.GetSize(), (0, 0), 
                                (image1.GetWidth()+horizontal_space, 
                                 round((combined_image.GetHeight()-image2.GetHeight())/2)))
    sitk.WriteImage(combined_image, file_name)


# Callback invoked when the StartEvent happens, sets up our new data.
def start_plot():
    global metric_values, multires_iterations
    
    metric_values = []
    multires_iterations = []


# Callback invoked when the EndEvent happens, do cleanup of data and figure.
def end_plot():
    global metric_values, multires_iterations
    
    del metric_values
    del multires_iterations
    # Close figure, we don't want to get a duplicate of the plot latter on.
    plt.close()


# Callback invoked when the IterationEvent happens, update our data and 
# save an image that includes a visualization of the registered images and
# the metric value plot.    

def save_plot(registration_method, fixed, moving, transform, file_name_prefix):

    #
    # Plotting the similarity metric values, resolution changes are marked with 
    # a blue star.
    #
    global metric_values, multires_iterations
    
    metric_values.append(registration_method.GetMetricValue())                                       
    # Plot the similarity metric values
    plt.plot(metric_values, 'r')
    plt.plot(multires_iterations, [metric_values[index] for index in multires_iterations], 'b*')
    plt.xlabel('Iteration Number',fontsize=12)
    plt.ylabel('Metric Value',fontsize=12)

    # Convert the plot to a SimpleITK image (works with the agg matplotlib backend, doesn't work
    # with the default - the relevant method is canvas_tostring_rgb())
    plt.gcf().canvas.draw()    
    plot_data = np.fromstring(plt.gcf().canvas.tostring_rgb(), dtype=np.uint8, sep='')
    plot_data = plot_data.reshape(plt.gcf().canvas.get_width_height()[::-1] + (3,))
    plot_image = sitk.GetImageFromArray(plot_data, isVector=True)

    
    #
    # Extract the central axial slice from the two volumes, compose it using the transformation
    # and alpha blend it.
    #
    alpha = 0.5
    
    central_index = round((fixed.GetSize())[2]/2)
    
    moving_transformed = sitk.Resample(moving, fixed, transform, 
                                       sitk.sitkLinear, 0.0, 
                                       moving_image.GetPixelIDValue())
    # Extract the central slice in xy and alpha blend them                                   
    combined = (1.0 - alpha)*fixed[:,:,central_index] + \
               alpha*moving_transformed[:,:,central_index]

    # Assume the alpha blended images are isotropic and rescale intensity
    # Values so that they are in [0,255], convert the grayscale image to
    # color (r,g,b).
    combined_slices_image = sitk.Cast(sitk.RescaleIntensity(combined), sitk.sitkUInt8)
    combined_slices_image = sitk.Compose(combined_slices_image,
                                         combined_slices_image,
                                         combined_slices_image)
    write_combined_image(combined_slices_image, plot_image, 0, 
                         file_name_prefix + str.format(str(len(metric_values)), '03d') + '.png')
    
# Callback invoked when the sitkMultiResolutionIterationEvent happens, update the index into the 
# metric_values list. 
def update_multires_iterations():
    global metric_values, multires_iterations
    multires_iterations.append(len(metric_values))




In [16]:
# Registration module

registration_method = sitk.ImageRegistrationMethod()

# Similarity metric settings.
registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=100)
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentage(0.01)


registration_method.SetInterpolator(sitk.sitkLinear)

# Optimizer settings.
registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=200, convergenceMinimumValue=1e-6, convergenceWindowSize=10)
registration_method.SetOptimizerScalesFromPhysicalShift()

# Setup for the multi-resolution framework.            
registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [4,2,1])
registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2,1,0])
registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()


# Connect all of the observers so that we can perform plotting during registration.
registration_method.AddCommand(sitk.sitkStartEvent, start_plot)
registration_method.AddCommand(sitk.sitkEndEvent, end_plot)
registration_method.AddCommand(sitk.sitkMultiResolutionIterationEvent, update_multires_iterations) 


0

In [17]:
# Registration cycle 

os.chdir(path_of_mr_navigators) # go to the path of the mr navigators 
mr_nav_files                    =               (glob.glob(wild_card_string))
mr_nav_files                    =               (natsort.natsorted(mr_nav_files))
path_fixed_image                =               os.path.join(path_of_mr_navigators,mr_nav_files[0])
for i in tqdm(range(1,(len(mr_nav_files)))):
    time.sleep(0.5)
    print('Processing '+mr_nav_files[i]+'...')
    path_moving_image           =               os.path.join(path_of_mr_navigators,mr_nav_files[i])
    fixed_image                 =               sitk.ReadImage(path_fixed_image,sitk.sitkFloat32)
    moving_image                =               sitk.ReadImage(path_moving_image,sitk.sitkFloat32)
    initial_transform           =               sitk.CenteredTransformInitializer(fixed_image, 
                                                                                  moving_image, 
                                                                                  sitk.Euler3DTransform())
                                                                                  
    registration_method.SetInitialTransform(initial_transform)
    registration_method.AddCommand(sitk.sitkIterationEvent, lambda: save_plot(registration_method, fixed_image, moving_image, initial_transform, os.path.join(path_to_store_animation+'/iteration_plot')))
    final_transform             =               registration_method.Execute(fixed_image,moving_image)
    tfm_file_name               =               str(i+1)+'->'+str(1)+'.tfm'
    sitk.WriteTransform(final_transform, os.path.join(path_of_tfms, tfm_file_name))
    wild_card_png               =               '*png'
    png_files                   =               (glob.glob(wild_card_png))
    png_files                   =               (natsort.natsorted(png_files))
    images                      =               []
    #for x in range(0,(len(png_files))):
    #    images.append(imageio.imread(os.path.join(path_to_store_animation,png_files[x])))
    #    os.remove((os.path.join(path_to_store_animation,png_files[x])))
    #    gif_name=tfm_file_name+'_reg_animation.gif'
    #    imageio.mimsave(os.path.join(path_to_store_animation,gif_name),images,loop=1)




  0%|          | 0/13 [00:00<?, ?it/s][A[A[A

Processing Transform_MRnav_2_To_PET.nii...





  8%|▊         | 1/13 [00:10<02:10, 10.91s/it][A[A[A

Processing Transform_MRnav_3_To_PET.nii...





 15%|█▌        | 2/13 [00:27<02:17, 12.52s/it][A[A[A

Processing Transform_MRnav_4_To_PET.nii...





 23%|██▎       | 3/13 [00:46<02:24, 14.50s/it][A[A[A

Processing Transform_MRnav_5_To_PET.nii...





 31%|███       | 4/13 [01:10<02:36, 17.39s/it][A[A[A

Processing Transform_MRnav_6_To_PET.nii...





 38%|███▊      | 5/13 [01:35<02:37, 19.70s/it][A[A[A

Processing Transform_MRnav_7_To_PET.nii...





 46%|████▌     | 6/13 [02:12<02:53, 24.84s/it][A[A[A

Processing Transform_MRnav_8_To_PET.nii...





 54%|█████▍    | 7/13 [03:14<03:36, 36.04s/it][A[A[A

Processing Transform_MRnav_9_To_PET.nii...





 62%|██████▏   | 8/13 [03:59<03:14, 38.81s/it][A[A[A

Processing Transform_MRnav_10_To_PET.nii...





 69%|██████▉   | 9/13 [05:02<03:04, 46.12s/it][A[A[A

Processing Transform_MRnav_11_To_PET.nii...





 77%|███████▋  | 10/13 [06:03<02:31, 50.37s/it][A[A[A

Processing Transform_MRnav_12_To_PET.nii...





 85%|████████▍ | 11/13 [07:04<01:47, 53.62s/it][A[A[A

Processing Transform_MRnav_13_To_PET.nii...





 92%|█████████▏| 12/13 [08:02<00:55, 55.02s/it][A[A[A

Processing Transform_MRnav_14_To_PET.nii...





100%|██████████| 13/13 [09:12<00:00, 42.52s/it][A[A[A
