In [28]:
import SimpleITK as sitk
import numpy as np
import os
import matplotlib.pyplot as plt

In [None]:
def start_plot():
    global metric_values, multires_iterations
    
    metric_values = []
    multires_iterations = []

# Callback invoked when the EndEvent happens, do cleanup of data and figure.
def end_plot(patientNumber):
    global metric_values
    workingDir = os.getcwd()
    textPath = os.path.join(workingDir, f"input_files/pacient_{patientNumber}/metricValues{patientNumber}.txt")
    with open(textPath, 'w') as textFile:
        [textFile.write(f"{value}\n") for value in metric_values]

    # Close figure, we don't want to get a duplicate of the plot latter on.
    plt.close()

def plot_values(registration_method):
    global metric_values, multires_iterations
    
    metric_values.append(registration_method.GetMetricValue())                                       
    # Clear the output area (wait=True, to reduce flickering), and plot current data
    """
    clear_output(wait=True)
    # Plot the similarity metric values
    plt.plot(metric_values, 'r')
    plt.plot(multires_iterations, 
             [metric_values[index] for index in multires_iterations],
             "b*",)
    plt.xlabel('Iteration Number',fontsize=12)
    plt.ylabel('Metric Value',fontsize=12)
    plt.show()
    """

def update_multires_iterations():
    global metric_values, multires_iterations
    multires_iterations.append(len(metric_values))

In [6]:
baseDir = os.getcwd()
patientNumber = "01"

preopDrrPath = os.path.join(baseDir, f"input_files/pacient_{patientNumber}/pacient{patientNumber}Preop.mha")
intraopDrrPath =  os.path.join(baseDir, f"input_files/pacient_{patientNumber}/pacient{patientNumber}Intraop.mha")

preopImage = sitk.ReadImage(preopDrrPath, sitk.sitkFloat32)
intraopImage = sitk.ReadImage(intraopDrrPath, sitk.sitkFloat32)

intraopImageInverted = sitk.InvertIntensity(intraopImage, maximum=1)
print(f"moving: {preopImage.GetSize()}, fixed: {intraopImageInverted.GetSize()}")

fixedImage = intraopImageInverted
movingImage = preopImage[:, :, 0]
movingImage.SetOrigin(fixedImage.GetOrigin())
movingImage.SetSpacing(fixedImage.GetSpacing())
print(f"Fixed image, spacing: {fixedImage.GetSpacing()}, size: {fixedImage.GetSize()}, direction: {fixedImage.GetDirection()}, origin: {fixedImage.GetOrigin()}")
print(f"Moving image, spacing: {movingImage.GetSpacing()}, size: {movingImage.GetSize()}, direction: {movingImage.GetDirection()}, origin: {movingImage.GetOrigin()}")
plt.imshow(sitk.GetArrayViewFromImage(fixedImage), cmap="gray")
plt.imshow(sitk.GetArrayViewFromImage(movingImage), cmap="gray", alpha=0.5)
plt.axis("off")
plt.show()

('c:\\Users\\vojte\\Desktop\\skola\\DP\\input_files/pacient_01/pacient01Preop.mha',
 'c:\\Users\\vojte\\Desktop\\skola\\DP\\input_files/pacient_01/pacient01Intraop.mha')

In [None]:
initialTransform = sitk.CenteredTransformInitializer(fixedImage, movingImage,
                                                     sitk.AffineTransform(2))
registration = sitk.ImageRegistrationMethod()
registration.SetMetricAsMattesMutualInformation()
registration.SetOptimizerScalesFromPhysicalShift()
registration.SetOptimizerAsGradientDescent(learningRate=1.0,
                                           numberOfIterations=200,
                                           convergenceMinimumValue=1e-5,
                                           convergenceWindowSize=5)
registration.SetInitialTransform(initialTransform)
registration.SetInterpolator(sitk.sitkLinear)

#registration.SetShrinkFactorsPerLevel(shrinkFactors=[4, 2, 1])
#registration.SetSmoothingSigmasPerLevel(smoothingSigmas=[2, 1, 0])
#registration.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

registration.AddCommand(sitk.sitkStartEvent, start_plot)
registration.AddCommand(sitk.sitkEndEvent, end_plot(patientNumber))
#registration.AddCommand(sitk.sitkMultiResolutionIterationEvent, update_multires_iterations)
registration.AddCommand(sitk.sitkIterationEvent, lambda: plot_values(registration))
outTransform = registration.Execute(fixedImage, movingImage)
print(f"Optimizer stop condition: {registration.GetOptimizerStopConditionDescription()}")
print(f" Iteration: {registration.GetOptimizerIteration()}")
print(f" Metric value: {registration.GetMetricValue()}")

plt.plot(range(len(metric_values)), metric_values)
plt.xlabel("Iterace")
plt.ylabel("Mutual information")
plt.imshow()

In [None]:
movingImageResampled = sitk.Resample(movingImage, fixedImage, outTransform, sitk.sitkLinear, 0.0, movingImage.GetPixelID())

plt.imshow(sitk.GetArrayViewFromImage(fixedImage), cmap="gray")
plt.imshow(sitk.GetArrayViewFromImage(movingImageResampled), cmap="gray", alpha=0.5)
plt.axis("off")
plt.show()

In [None]:
movingImageResampled255 = sitk.Cast(sitk.IntensityWindowing(movingImageResampled, 
                                                            windowMinimum=0., windowMaximum=32767., 
                                                            outputMinimum=0., outputMaximum=255.), 
                                                            sitk.sitkUInt8)
fixedImage255 = sitk.Cast(sitk.IntensityWindowing(sitk.RescaleIntensity(fixedImage, 0, 32767), 
                                                            windowMinimum=0., windowMaximum=32767., 
                                                            outputMinimum=0., outputMaximum=255.), 
                                                            sitk.sitkUInt8)
plt.imshow(sitk.GetArrayViewFromImage(sitk.CheckerBoard(movingImageResampled255, fixedImage255, [6, 6])), cmap="gray")
plt.axis("off")
plt.show()