# 3D Abdominal CT Registration Code

This Python notebook includes the implementation of the two-stage 3D registration of CT image, as well as the generation process of segmentation mask from the registration result. 

## Importing Packages

In [None]:
import SimpleITK as sitk
import numpy as np
from matplotlib import pyplot as plt

from IPython.display import clear_output
from ipywidgets import interact, fixed

## Defining Functions

In [None]:
# Callback invoked by the interact IPython method for scrolling through the image stacks of
# the two images (moving and fixed).
def display_images(fixed_image_z, moving_image_z, fixed_npa, moving_npa):
    # Create a figure with two subplots and the specified size.
    plt.subplots(1,2,figsize=(10,8))
    
    # Draw the fixed image in the first subplot.
    plt.subplot(1,2,1)
    plt.imshow(fixed_npa[fixed_image_z,:,:],cmap=plt.cm.Greys_r);
    plt.title('fixed image')
    plt.axis('off')
    
    # Draw the moving image in the second subplot.
    plt.subplot(1,2,2)
    plt.imshow(moving_npa[moving_image_z,:,:],cmap=plt.cm.Greys_r);
    plt.title('moving image')
    plt.axis('off')
    
    plt.show()

# Callback invoked by the IPython interact method for scrolling and modifying the alpha blending
# of an image stack of two images that occupy the same physical space. 
def display_images_with_alpha(image_z, alpha, fixed, moving):
    img = (1.0 - alpha)*fixed[:,:,image_z] + alpha*moving[:,:,image_z] 
    plt.imshow(sitk.GetArrayViewFromImage(img),cmap=plt.cm.Greys_r);
    plt.axis('off')
    plt.show()
    
# Callback invoked when the StartEvent happens, sets up our new data.
def start_plot():
    global metric_values, multires_iterations
    
    metric_values = []
    multires_iterations = []

# Callback invoked when the EndEvent happens, do cleanup of data and figure.
def end_plot():
    global metric_values, multires_iterations
    
    del metric_values
    del multires_iterations
    # Close figure, we don't want to get a duplicate of the plot latter on.
    plt.close()

# Callback invoked when the IterationEvent happens, update our data and display new figure.
def plot_values(registration_method):
    global metric_values, multires_iterations
    
    metric_values.append(registration_method.GetMetricValue())                                       
    # Clear the output area (wait=True, to reduce flickering), and plot current data
    clear_output(wait=True)
    # Plot the similarity metric values
    plt.plot(metric_values, 'r')
    plt.plot(multires_iterations, [metric_values[index] for index in multires_iterations], 'b*')
    plt.xlabel('Iteration Number',fontsize=12)
    plt.ylabel('Metric Value',fontsize=12)
    plt.show()
    
# Callback invoked when the sitkMultiResolutionIterationEvent happens, update the index into the 
# metric_values list. 
def update_multires_iterations():
    global metric_values, multires_iterations
    multires_iterations.append(len(metric_values))

## Loading Image Data

In [None]:
fixed_img_path = 'fixed_img_path'
moving_img_path = 'moving_img_path'

In [None]:
# Running the following code if the image files are in dcm format
reader = sitk.ImageFileReader()
reader.SetImageIO("GDCMImageIO")
reader.SetFileName(fixed_img_path)
fixed_image_original = reader.Execute();
fixed_image_array = sitk.GetArrayFromImage(fixed_image_original)

# Running the following code if the image files are in DICOM series format
series_file_names = sitk.ImageSeriesReader.GetGDCMSeriesFileNames(fixed_img_path)
series_reader = sitk.ImageSeriesReader()
series_reader.SetFileNames(series_file_names)
fixed_image_original = series_reader.Execute()
fixed_image_array = sitk.GetArrayFromImage(fixed_image_original)

print('The size of image is: ', fixed_image_original.GetSize())
print('The range of intensity is from ', np.min(fixed_image_array), 'to ', np.max(fixed_image_array))

In [None]:
# Running the following code if the image files are in dcm format
reader = sitk.ImageFileReader()
reader.SetImageIO("GDCMImageIO")
reader.SetFileName(moving_img_path)
moving_image_original = reader.Execute();
moving_image_array = sitk.GetArrayFromImage(moving_image_original)

# Running the following code if the image files are in DICOM series format
series_file_names = sitk.ImageSeriesReader.GetGDCMSeriesFileNames(moving_img_path)
series_reader = sitk.ImageSeriesReader()
series_reader.SetFileNames(series_file_names)
moving_image_original = series_reader.Execute()
moving_image_array = sitk.GetArrayFromImage(moving_image_original)

print('The size of image is: ',moving_image.GetSize())
print('The range of intensity is from ', np.min(moving_image_array), 'to ', np.max(moving_image_array))

## Plotting Histogram

In [None]:
# Plotting the histogram of fixed image
plt.figure('historgram')
fixed_hist = fixed_image_array.flatten()
n, bins, patches = plt.hist(fixed_hist, bins=256, range= (fixed_hist.min(),fixed_hist.max()), histtype = 'step')
plt.axis([-1100, 900, 0, 4000000])
plt.title('Intensity Histogram of Abdominal CT Image')
plt.xlabel('Intensity Value (HU)')
plt.ylabel('Number of Pixels')
plt.show()

In [None]:
# Plotting the histogram of moving image
plt.figure('historgram')
moving_hist = moving_image_array.flatten()
n, bins, patches = plt.hist(moving_hist, bins=256, range= (fixed_hist.min(),fixed_hist.max()), histtype = 'step')
plt.axis([-1100, 900, 0, 4000000])
plt.title('Intensity Histogram of Abdominal CT Image')
plt.xlabel('Intensity Value (HU)')
plt.ylabel('Number of Pixels')
plt.show()

## Windowing

In [None]:
# Performing windowing with a width of 700 and a level of 50, followed by normlising the output intensity range to 0-1
fixed_image = sitk.Cast(sitk.IntensityWindowing(fixed_image_original, windowMinimum=-300, windowMaximum=400, 
                                             outputMinimum=0.0, outputMaximum=1.0), sitk.sitkFloat32)
moving_image = sitk.Cast(sitk.IntensityWindowing(moving_image_original, windowMinimum=-300, windowMaximum=400, 
                                             outputMinimum=0.0, outputMaximum=1.0), sitk.sitkFloat32)

## Displaying 3D Images

In [None]:
interact(display_images, fixed_image_z=(0,fixed_image.GetSize()[2]-1), moving_image_z=(0,moving_image.GetSize()[2]-1), fixed_npa = fixed(sitk.GetArrayViewFromImage(fixed_image)), moving_npa=fixed(sitk.GetArrayViewFromImage(moving_image)));

## Registration Step 1 (Affine Transformation)

In [None]:
dimension = 3

initial_transform_1 = sitk.CenteredTransformInitializer(fixed_image, 
                                                      moving_image, 
                                                      sitk.AffineTransform(dimension), 
                                                      sitk.CenteredTransformInitializerFilter.GEOMETRY)

moving_resampled = sitk.Resample(moving_image, fixed_image, initial_transform_1, sitk.sitkLinear, 0.0, moving_image.GetPixelID())

In [None]:
registration_method_1 = sitk.ImageRegistrationMethod()

# Similarity metric settings.
registration_method_1.SetMetricAsMeanSquares()
registration_method_1.SetMetricSamplingStrategy(registration_method_1.RANDOM)
registration_method_1.SetMetricSamplingPercentage(0.01)

registration_method_1.SetInterpolator(sitk.sitkLinear)

# Optimizer settings.
registration_method_1.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=500, convergenceMinimumValue=1e-6, convergenceWindowSize=5)
registration_method_1.SetOptimizerScalesFromPhysicalShift()

# Setup for the multi-resolution framework.            
registration_method_1.SetShrinkFactorsPerLevel(shrinkFactors = [4,2,1])
registration_method_1.SetSmoothingSigmasPerLevel(smoothingSigmas=[2,1,0])
registration_method_1.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

# Don't optimize in-place, we would possibly like to run this cell multiple times.
registration_method_1.SetInitialTransform(initial_transform_1, inPlace=False)

# Connect all of the observers so that we can perform plotting during registration.
registration_method_1.AddCommand(sitk.sitkStartEvent, start_plot)
registration_method_1.AddCommand(sitk.sitkEndEvent, end_plot)
registration_method_1.AddCommand(sitk.sitkMultiResolutionIterationEvent, update_multires_iterations) 
registration_method_1.AddCommand(sitk.sitkIterationEvent, lambda: plot_values(registration_method_1))

final_transform_1 = registration_method_1.Execute(sitk.Cast(fixed_image, sitk.sitkFloat32), 
                                               sitk.Cast(moving_image, sitk.sitkFloat32))

In [None]:
print('Final metric value: {0}'.format(registration_method_1.GetMetricValue()))
print('Optimizer\'s stopping condition, {0}'.format(registration_method_1.GetOptimizerStopConditionDescription()))

In [None]:
moving_resampled_1 = sitk.Resample(moving_image, fixed_image, final_transform_1, sitk.sitkLinear, 0.0, moving_image.GetPixelID())

interact(display_images, fixed_image_z=(0,fixed_image.GetSize()[2]-1), moving_image_z=(0,moving_resampled_1.GetSize()[2]-1), fixed_npa = fixed(sitk.GetArrayViewFromImage(fixed_image)), moving_npa=fixed(sitk.GetArrayViewFromImage(moving_resampled_1)));

## Registration Step 2 (Bspline Transformation)

In [None]:
transformDomainMeshSize = [3]*fixed_image.GetDimension()

initial_transform_2 = sitk.BSplineTransformInitializer(fixed_image, 
                                                      transformDomainMeshSize)

moving_resampled_2 = sitk.Resample(moving_resampled_1, fixed_image, initial_transform_2, sitk.sitkLinear, 0.0, moving_resampled_1.GetPixelID())

In [None]:
registration_method_2 = sitk.ImageRegistrationMethod()

# Similarity metric settings.
registration_method_2.SetMetricAsMeanSquares()
registration_method_2.SetMetricSamplingStrategy(registration_method_2.RANDOM)
registration_method_2.SetMetricSamplingPercentage(0.01)

registration_method_2.SetInterpolator(sitk.sitkLinear)

# Optimizer settings.
registration_method_2.SetOptimizerAsGradientDescent(learningRate=5.0, numberOfIterations=1000, convergenceMinimumValue=1e-6, convergenceWindowSize=5)
registration_method_2.SetOptimizerScalesFromPhysicalShift()

# Setup for the multi-resolution framework.            
registration_method_2.SetShrinkFactorsPerLevel(shrinkFactors = [4,2,1])
registration_method_2.SetSmoothingSigmasPerLevel(smoothingSigmas=[4,2,1])

registration_method_2.SetInitialTransformAsBSpline(initial_transform_2, inPlace=False, scaleFactors=[1,2,5])

# Connect all of the observers so that we can perform plotting during registration.
registration_method_2.AddCommand(sitk.sitkStartEvent, start_plot)
registration_method_2.AddCommand(sitk.sitkEndEvent, end_plot)
registration_method_2.AddCommand(sitk.sitkMultiResolutionIterationEvent, update_multires_iterations) 
registration_method_2.AddCommand(sitk.sitkIterationEvent, lambda: plot_values(registration_method_2))

final_transform_2 = registration_method_2.Execute(sitk.Cast(fixed_image, sitk.sitkFloat32), 
                                               sitk.Cast(moving_resampled_1, sitk.sitkFloat32))

In [None]:
print('Final metric value: {0}'.format(registration_method_2.GetMetricValue()))
print('Optimizer\'s stopping condition, {0}'.format(registration_method_2.GetOptimizerStopConditionDescription()))

In [None]:
moving_resampled_2 = sitk.Resample(moving_resampled_1, fixed_image, final_transform_2, sitk.sitkLinear, 0.0, moving_resampled_1.GetPixelID())

interact(display_images, fixed_image_z=(0,fixed_image.GetSize()[2]-1), moving_image_z=(0,moving_resampled_2.GetSize()[2]-1), fixed_npa = fixed(sitk.GetArrayViewFromImage(fixed_image)), moving_npa=fixed(sitk.GetArrayViewFromImage(moving_resampled_2)));

## Loading Segmentation

In [None]:
fixed_seg_path = 'fixed_seg_path'

fixed_seg = sitk.ReadImage(fixed_seg_path)

print('The size of image is: ', fixed_seg.GetSize())
print('The range of intensity is from ', np.min(sitk.GetArrayFromImage(fixed_seg)), 'to ', np.max(sitk.GetArrayFromImage(fixed_seg)))

In [None]:
interact(display_images, fixed_image_z=(0,fixed_seg.GetSize()[2]-1), moving_image_z=(0,fixed_image.GetSize()[2]-1), fixed_npa = fixed(fixed_seg_array), moving_npa=fixed(sitk.GetArrayFromImage(fixed_image)));

## Computing Inverse Bspline Transformation 

In [None]:
# Computing the deformation field of B-spline transformation
disp_field = sitk.TransformToDisplacementField(final_transform_2, 
                                               sitk.sitkVectorFloat32,
                                               fixed_image.GetSize(),
                                               fixed_image.GetOrigin(),
                                               fixed_image.GetSpacing(),
                                               fixed_image.GetDirection())

# Computing the inverse deformation field and converting to the transformation
disp_field_inv = sitk.InvertDisplacementField(disp_field,10)
inverse_transform_2 = sitk.DisplacementFieldTransform(sitk.Cast(disp_field_inv, sitk.sitkVectorFloat64))

In [None]:
moving_seg_generated_1 = sitk.Resample(fixed_seg, moving_resampled_1, inverse_transform_2, sitk.sitkNearestNeighbor, 0.0, fixed_seg.GetPixelID())

## Computing Inverse Affine Transformation

In [None]:
inverse_transform_1 = final_transform_1.GetInverse()

In [None]:
moving_seg_generated_2 = sitk.Resample(moving_seg_generated_1, moving_image, inverse_transform_1, sitk.sitkNearestNeighbor, 0.0, moving_seg_generated_1.GetPixelID())

## Writing and Saving the Generated Segmentation Mask

In [None]:
generated_seg = sitk.GetImageFromArray(sitk.GetArrayViewFromImage(moving_seg_generated_2)

In [None]:
sitk.WriteImage(moving_seg_generated_2, 'generated_seg_path')