# VTK - ITK Project : Etude longitudinale de l'évolution d'une tumeur

### Par Raphaël Duhen, Maël Conan et Nigel Andrews

In [1]:
# Imports
import vtk
import itk
import matplotlib.pyplot as plt
import numpy as np

import ipywidgets as widgets
from ipywidgets import interact
from IPython.display import display, clear_output

In [2]:
def run_window(path: str):
    reader = vtk.vtkNrrdReader()
    reader.SetFileName(path)
    
    window = vtk.vtkRenderWindow()
    renderer = vtk.vtkRenderer()
    
    window.AddRenderer(renderer)
    
    interactor = vtk.vtkRenderWindowInteractor()
    window.SetInteractor(interactor)
    
    contour = vtk.vtkContourFilter()
    contour.SetInputConnection(reader.GetOutputPort())
    contour.SetValue(0, 135)
    contour.SetFastOff()
    
    contourMapper = vtk.vtkPolyDataMapper()
    contourMapper.SetInputConnection(contour.GetOutputPort())
    contourMapper.ScalarVisibilityOff()
    
    contourActor = vtk.vtkActor()
    contourActor.SetMapper(contourMapper)
    
    renderer.AddActor(contourActor)
    
    window.Render()
    interactor.Start()

In [3]:
# Define types
PixelType = itk.D
Dimension = 3
ImageType = itk.Image[PixelType, Dimension]
TransformType = itk.TranslationTransform[PixelType, Dimension]

In [4]:
# Load images
image = itk.imread("Data/case6_gre1.nrrd", pixel_type=PixelType)
image2 = itk.imread("Data/case6_gre2.nrrd", pixel_type=PixelType) 

In [5]:
def registration(image, image2):
    # Définir le type de transformation
    dimension = image.GetImageDimension()
    TransformType = itk.TranslationTransform[itk.D, dimension]

    # Initialiser la transformation
    initial_transform = TransformType.New()
    initial_transform.SetIdentity()

    # Définir les paramètres de l'optimiseur
    optimizer = itk.RegularStepGradientDescentOptimizerv4.New(
        LearningRate=4,
        MinimumStepLength=0.001,
        RelaxationFactor=0.5,
        NumberOfIterations=200,
    )

    # Définir les paramètres de la métrique
    metric = itk.MeanSquaresImageToImageMetricv4[ImageType, ImageType].New()

    # Définir les paramètres de l'interpolateur
    # interpolator = itk.LinearInterpolateImageFunction[ImageType, itk.D].New()

    # Définir les paramètres de la méthode de recalage
    registration = itk.ImageRegistrationMethodv4[ImageType, ImageType].New(
        Metric=metric,
        Optimizer=optimizer,
        InitialTransform=initial_transform,
        FixedImage=image,
        MovingImage=image2,
    )

    # Lancer le recalage
    registration.Update()

    return registration, optimizer

In [6]:
def transform_image(image, image2, registration):
    # Transform image
    transformed_image = itk.resample_image_filter(
        image2,
        transform=registration.GetTransform(),
        use_reference_image=True,
        reference_image=image,
        default_pixel_value=100,
    )

    return transformed_image

In [8]:
try:
    transformed_image = itk.imread("Data/transformed_image.nrrd", pixel_type=PixelType)
except:
    registration, optimizer = registration(image, image2)
    transformed_image = transform_image(image, image2, registration)

In [10]:
clear_output()

array_view = itk.array_view_from_image(transformed_image)
original_fixed_view = itk.array_view_from_image(image)
original_moving_view = itk.array_view_from_image(image2)

# fixed_slider = widgets.IntSlider(value=135, min=0, max=original_fixed_view.shape[0])
# display(fixed_slider)

# @interact(fixed=widgets.IntSlider(value=135, min=0, max=original_fixed_view.shape[0]),
#           moving=widgets.IntSlider(value=135, min=0, max=original_moving_view.shape[0]),
#           transformed=widgets.IntSlider(value=135, min=0, max=array_view.shape[0]))
@interact(fixed=(0, original_fixed_view.shape[0]),
          moving=(0, original_moving_view.shape[0]),
          transformed=(0, array_view.shape[0]))
def plot_slices(fixed: int, moving: int, transformed: int):
    plt.figure(figsize=(15, 15))
    plt.subplot(1, 3, 1)
    plt.imshow(original_fixed_view[fixed, :, :], cmap="gray")
    plt.title("Fixed image")
    plt.subplot(1, 3, 2)
    plt.imshow(original_moving_view[moving, :, :], cmap="gray")
    plt.title("Moving image")
    plt.subplot(1, 3, 3)
    plt.imshow(array_view[transformed, :, :], cmap="gray")
    plt.title("Transformed image")
    plt.show()

interactive(children=(IntSlider(value=88, description='fixed', max=176), IntSlider(value=88, description='movi…

In [31]:
# Save image
itk.imwrite(transformed_image, "Data/transformed_image.nrrd")

In [37]:
# Get final parameters
final_parameters = registration.GetOutput().Get().GetParameters()

# # Get final metric value
# final_metric_value = registration.GetMetricValue()

# Get final number of iterations
final_number_of_iterations = optimizer.GetCurrentIteration()

# Print final parameters in detail
print("Final parameters = ", final_parameters.GetElement(0))
print("Final parameters = ", final_parameters.GetElement(1))
print("Final parameters = ", final_parameters.GetElement(2))
print("Final parameters = ", final_parameters.GetElement(3))
print("Final parameters = ", final_parameters.GetElement(4))
print("Final parameters = ", final_parameters.GetElement(5))

# print("Metric value = ", final_metric_value)
print("Number of iterations = ", final_number_of_iterations)

Final parameters =  -0.8339130025804753
Final parameters =  -3.54112087911544
Final parameters =  -59.455066840126456
Final parameters =  3.88e-321
Final parameters =  1.5e-323
Final parameters =  6.94223094131595e-310
Number of iterations =  20


In [14]:
# Segment tumour using ITK 
def segment_tumour(image):
    # Define types
    PixelType = itk.UC
    Dimension = 3
    ImageType = itk.Image[PixelType, Dimension]

    # Load image
    image = itk.imread(image, pixel_type=PixelType)

    # Define threshold filter
    threshold_filter = itk.BinaryThresholdImageFilter[ImageType, ImageType].New(
        Input=image,
        LowerThreshold=0,
        UpperThreshold=100,
        InsideValue=0,
        OutsideValue=1,
    )

    # Execute threshold filter
    threshold_filter.Update()

    # Get output of threshold filter
    thresholded_image = threshold_filter.GetOutput()

    # Define connected component filter
    connected_component_filter = itk.ConnectedComponentImageFilter[ImageType, ImageType].New(
        Input=thresholded_image
    )

    # Execute connected component filter
    connected_component_filter.Update()

    # Get output of connected component filter
    connected_image = connected_component_filter.GetOutput()

    # Define relabel component filter
    relabel_component_filter = itk.RelabelComponentImageFilter[ImageType, ImageType].New(
        Input=connected_image
    )

    # Execute relabel component filter
    relabel_component_filter.Update()

    # Get output of relabel component filter
    relabelled_image = relabel_component_filter.GetOutput()

    # Define label statistics filter
    label_statistics_filter = itk.LabelStatisticsImageFilter[ImageType, ImageType].New(
        Input=image,
        LabelInput=relabelled_image,
        UseHistograms=True,
        HistogramParameters=itk.HistogramMatchingFilter.GetDefaultHistogramParameters(),
    )

    # Execute label statistics filter
    label_statistics_filter.Update()

    # Get output of label statistics filter
    labelled_image = label_statistics_filter.GetOutput()

    # Define connected threshold filter
    connected_threshold_filter = itk.ConnectedThresholdImageFilter[ImageType, ImageType].New(
        Input=image,
        Lower=0,
        Upper=100,
        Seed=label_statistics_filter.GetMean(1),
        ReplaceValue=1,
    )

    # Execute connected threshold filter
    connected_threshold_filter.Update()

    # Get output of connected threshold filter
    segmented_image = connected_threshold_filter.GetOutput()

    return segmented_image, labelled_image