Skip to content

Commit

Permalink
Advanced options: Support for point set registration, initial transfo…
Browse files Browse the repository at this point in the history
…rms and proper error pop ups (#8)

* ENH:Point_set_registration

* ENH:InitialTransform

* ENH:PointsetMetricCoupled

* ENH:Error_pop_ups_and_GUI_update

* ENH:Version to 0.1.3
  • Loading branch information
ViktorvdValk committed Apr 16, 2021
1 parent ada8dcc commit d13a4c0
Show file tree
Hide file tree
Showing 7 changed files with 301 additions and 89 deletions.
7 changes: 3 additions & 4 deletions elastix_napari/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@

from .elastix_registration import napari_experimental_provide_dock_widget

__author__ = "Viktor van der Valk"
__email__ = "v.o.van_der_valk@lumc.nl"

__version__ = "0.1.2"
__version__ = "0.1.3"


def get_module_version():
return __version__

from .elastix_registration import napari_experimental_provide_dock_widget
227 changes: 167 additions & 60 deletions elastix_napari/elastix_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,55 +15,136 @@
import itk
from pathlib import Path
from typing import Sequence
from qtpy.QtWidgets import QMessageBox


def error(message):
e = QMessageBox()
label = QMessageBox()
e.setText(message)
e.setIcon(QMessageBox.Critical)
e.setWindowTitle("Error")
e.show()
return e


def check_pointset(pointset):
if '.txt' in str(pointset[0]) or '.vtk' in str(pointset[0]):
return True
else:
return False


def on_init(widget):
widget.native.setStyleSheet("QWidget{font-size: 12pt;}")

widget.fixed_mask.visible = False
widget.moving_mask.visible = False
widget.param1.visible = False
widget.param2.visible = False
widget.param3.visible = False

for x in ['fixed_mask', 'moving_mask', 'param1', 'param2', 'param3',
'fixed_ps', 'moving_ps', 'metric', 'init_trans',
'resolutions', 'max_iterations', 'nr_spatial_samples',
'max_step_length']:
setattr(getattr(widget, x), 'visible', False)

def toggle_mask_widgets(event):
widget.fixed_mask.visible = event.value
widget.moving_mask.visible = event.value
for x in ['fixed_mask', 'moving_mask']:
setattr(getattr(widget, x), 'visible', event.value)

def toggle_preset_widget(event):
if event.value == "custom":
widget.param1.visible = True
widget.param2.visible = True
widget.param3.visible = True
for x in ['param1', 'param2', 'param3']:
setattr(getattr(widget, x), 'visible', True)
for y in ['metric', 'resolutions', 'max_iterations',
'nr_spatial_samples', 'max_step_length']:
setattr(getattr(widget, y), 'visible', False)

else:
for x in ['param1', 'param2', 'param3']:
setattr(getattr(widget, x), 'visible', False)
for x in ['metric', 'init_trans', 'resolutions', 'max_iterations',
'nr_spatial_samples', 'max_step_length', 'moving_ps',
'fixed_ps']:
setattr(getattr(widget, x), 'visible', widget.advanced.value)

def toggle_advanced_widget(event):
if widget.preset.value == "custom":
for x in ['init_trans', 'fixed_ps', 'moving_ps']:
setattr(getattr(widget, x), 'visible', event.value)
else:
widget.param1.visible = False
widget.param2.visible = False
widget.param3.visible = False
for x in ['metric', 'init_trans', 'resolutions',
'max_iterations', 'nr_spatial_samples',
'max_step_length', 'fixed_ps', 'moving_ps']:
setattr(getattr(widget, x), 'visible', event.value)

widget.preset.changed.connect(toggle_preset_widget)
widget.use_masks.changed.connect(toggle_mask_widgets)
widget.masks.changed.connect(toggle_mask_widgets)
widget.advanced.changed.connect(toggle_advanced_widget)
widget.native.layout().addStretch()


@magic_factory(widget_init=on_init, layout='vertical', call_button="register",
preset={"choices": ["rigid", "affine", "bspline", "custom"]},
param1={"label": "parameterfile:",
"filter": "*.txt"}, param2={"label": "parameterfile 2",
"filter": "*.txt"}, param3={"label": "parameterfile 3",
"filter": "*.txt"})
preset={"choices": ["rigid", "affine", "bspline", "custom"],
"tooltip": "Select a preset parameter file or select "
"custom to load a custom one"},
fixed_ps={"label": "fixed point set", "filter": "*.txt",
"tooltip": "Load a fixed point set"},
moving_ps={"label": "moving point set", "filter": "*.txt",
"tooltip": "Load a moving point set"},
param1={"label": "parameterfile 1", "filter": "*.txt",
"tooltip": 'Load a custom parameter file'},
param2={"label": "parameterfile 2", "filter": "*.txt",
"tooltip": 'Optionally load a second custom parameter '
'file'},
param3={"label": "parameterfile 3", "filter": "*.txt",
"tooltip": 'Optionally load a third custom parameter '
'file'},
metric={"choices": ["from preset",
"AdvancedMattesMutualInformation",
"AdvancedNormalizedCorrelation",
"AdvancedMeanSquares"],
"tooltip": 'Select a metric to use'},
init_trans={"label": "initial transform", "filter": "*.txt",
"tooltip": 'Load a initial transform from a .txt '
'file'},
nr_spatial_samples={"max": 8192, "step": 256,
"tooltip": 'Select the number of spatial '
'samples to use'})
def elastix_registration(fixed: 'napari.types.ImageData',
moving: 'napari.types.ImageData',
moving: 'napari.types.ImageData', preset: str,
fixed_mask: 'napari.types.ImageData',
moving_mask: 'napari.types.ImageData', preset: str,
moving_mask: 'napari.types.ImageData',
fixed_ps: Sequence[Path], moving_ps: Sequence[Path],
param1: Sequence[Path], param2: Sequence[Path],
param3: Sequence[Path], use_masks: bool = False
param3: Sequence[Path], init_trans: Sequence[Path],
metric: str, resolutions: int = 4,
max_iterations: int = 500,
nr_spatial_samples: int = 512,
max_step_length: float = 1.0, masks: bool = False,
advanced: bool = False
) -> 'napari.types.LayerDataTuple':

)-> 'napari.types.LayerDataTuple':
if fixed is None or moving is None:
print("No images selected for registration.")
return
return error("No images selected for registration.")
if check_pointset(fixed_ps) != check_pointset(moving_ps):
print("Select both fixed and moving point set.")
return error("Select both fixed and moving point set.")

if advanced:
if check_pointset(init_trans):
init_trans = str(init_trans[0])
else:
init_trans = ''
if check_pointset(fixed_ps):
fixed_ps = str(fixed_ps[0])
else:
fixed_ps = ''
if check_pointset(moving_ps):
moving_ps = str(moving_ps[0])
else:
moving_ps = ''
else:
init_trans = ''
fixed_ps = ''
moving_ps = ''

# Casting to numpy is currently necessary
# because of napari's type ambiguity.
Expand All @@ -78,56 +159,82 @@ def elastix_registration(fixed: 'napari.types.ImageData',
try:
parameter_object.AddParameterFile(par)
except:
raise TypeError("Parameter file not found or not valid")
return error("Parameter file not found or not valid")
else:
pass
else:
default_parameter_map = parameter_object.GetDefaultParameterMap(preset)
parameter_object.AddParameterMap(default_parameter_map)
if advanced:
parameter_map = \
parameter_object.GetDefaultParameterMap(preset, resolutions)
if metric != 'from preset':
parameter_map['Metric'] = [metric]
if fixed_ps != '' and moving_ps != '':
parameter_map['Registration'] = [
'MultiMetricMultiResolutionRegistration']
original_metric = parameter_map['Metric']
parameter_map['Metric'] = \
[original_metric[0],
'CorrespondingPointsEuclideanDistanceMetric']
parameter_map['MaximumStepLength'] = [str(max_step_length)]
parameter_map['NumberOfSpatialSamples'] = [str(nr_spatial_samples)]
parameter_map['MaximumNumberOfIterations'] = [str(max_iterations)]
else:
parameter_map = parameter_object.GetDefaultParameterMap(preset)
parameter_object.AddParameterMap(parameter_map)

if use_masks:
if not masks:
result_image, result_transform_parameters = \
itk.elastix_registration_method(
fixed, moving, parameter_object,
fixed_point_set_file_name=fixed_ps,
moving_point_set_file_name=moving_ps,
initial_transform_parameter_file_name=init_trans,
log_to_console=True)

elif masks:
if fixed_mask is None and moving_mask is None:
print("No masks selected for registration")
return
return error("No masks selected for registration")
else:
# Casting to numpy and itk is currently necessary
# because of napari's type ambiguity.

if not (fixed_mask is None):
if moving_mask is None:
fixed_mask = np.asarray(fixed_mask).astype(np.uint8)
fixed_mask = itk.image_view_from_array(fixed_mask)

if not (moving_mask is None):
moving_mask = np.asarray(moving_mask).astype(np.uint8)
moving_mask = itk.image_view_from_array(moving_mask)

result_image, result_transform_parameters = \
itk.elastix_registration_method(
fixed, moving, parameter_object,
fixed_mask=fixed_mask,moving_mask=moving_mask,
log_to_console=False)
else:
result_image, result_transform_parameters = \
itk.elastix_registration_method(
fixed, moving, parameter_object,
fixed_mask=fixed_mask,log_to_console=False)
result_image, rtp = itk.elastix_registration_method(
fixed, moving, parameter_object, fixed_mask=fixed_mask,
fixed_point_set_file_name=fixed_ps,
moving_point_set_file_name=moving_ps,
initial_transform_parameter_file_name=init_trans,
log_to_console=False)

elif fixed_mask is None:
moving_mask = np.asarray(moving_mask).astype(np.uint8)
moving_mask = itk.image_view_from_array(moving_mask)
result_image, rtp = itk.elastix_registration_method(
fixed, moving, parameter_object, moving_mask=moving_mask,
fixed_point_set_file_name=fixed_ps,
moving_point_set_file_name=moving_ps,
initial_transform_parameter_file_name=init_trans,
log_to_console=False)
else:
if not (moving_mask is None):
moving_mask = np.asarray(moving_mask).astype(np.uint8)
moving_mask = itk.image_view_from_array(moving_mask)
fixed_mask = np.asarray(fixed_mask).astype(np.uint8)
fixed_mask = itk.image_view_from_array(fixed_mask)
moving_mask = np.asarray(moving_mask).astype(np.uint8)
moving_mask = itk.image_view_from_array(moving_mask)

result_image, result_transform_parameters = \
itk.elastix_registration_method(
fixed, moving, parameter_object,
moving_mask=moving_mask, log_to_console=False)
result_image, rtp = itk.elastix_registration_method(
fixed, moving, parameter_object, fixed_mask=fixed_mask,
moving_mask=moving_mask,
fixed_point_set_file_name=fixed_ps,
moving_point_set_file_name=moving_ps,
initial_transform_parameter_file_name=init_trans,
log_to_console=False)

else:
result_image, result_transform_parameters = \
itk.elastix_registration_method(fixed, moving, parameter_object,
log_to_console=False)
return np.asarray(result_image).astype(np.float32), {'name':preset + ' Registration'}
return np.asarray(result_image).astype(np.float32), {'name': preset + ' Registration'}


@napari_hook_implementation
def napari_experimental_provide_dock_widget():
return elastix_registration, {'area': 'bottom'}
return elastix_registration
40 changes: 40 additions & 0 deletions elastix_napari/tests/data/TransformParameters.0.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
(Transform "EulerTransform")
(NumberOfParameters 3)
(TransformParameters 0 0 0)
(InitialTransformParametersFileName "NoInitialTransform")
(UseBinaryFormatForTransformationParameters "false")
(HowToCombineTransforms "Compose")
(CenterOfRotationPoint 50 50)

// Image specific
(FixedImageDimension 2)
(MovingImageDimension 2)
(FixedInternalImagePixelType "float")
(MovingInternalImagePixelType "float")
(Size 100 100)
(Index 0 0)
(Spacing 1.0000000000 1.0000000000)
(Origin 0.0000000000 0.0000000000)
(Direction 1.0000000000 0.0000000000 0.0000000000 1.0000000000)
(UseDirectionCosines "true")


// BSplineTransform specific
(GridSize 19 19)
(GridIndex 0 0)
(GridSpacing 16.0000000000 16.0000000000)
(GridOrigin -16.5000000000 -16.5000000000)
(GridDirection 1.0000000000 0.0000000000 0.0000000000 1.0000000000)
(BSplineTransformSplineOrder 3)
(UseCyclicTransform "false")

// ResampleInterpolator specific
(ResampleInterpolator "FinalBSplineInterpolator")
(FinalBSplineInterpolationOrder 1)

// Resampler specific
(Resampler "DefaultResampler")
(DefaultPixelValue 0.000000)
(ResultImageFormat "mhd")
(ResultImagePixelType "float")
(CompressResultImage "false")
6 changes: 6 additions & 0 deletions elastix_napari/tests/data/fixed_point_set_test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
point
4
25 25
25 75
75 25
75 75
6 changes: 6 additions & 0 deletions elastix_napari/tests/data/moving_point_set_test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
point
4
1 10
1 60
51 10
51 60
12 changes: 6 additions & 6 deletions elastix_napari/tests/data/parameters_Rigid.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@

// The following components should usually be left as they are:
(Registration "MultiResolutionRegistration")
(Interpolator "BSplineInterpolator")
(Interpolator "LinearInterpolator")
(ResampleInterpolator "FinalBSplineInterpolator")
(Resampler "DefaultResampler")

// These may be changed to Fixed/MovingSmoothingImagePyramid.
// See the manual.
(FixedImagePyramid "FixedRecursiveImagePyramid")
(MovingImagePyramid "MovingRecursiveImagePyramid")
(FixedImagePyramid "FixedSmoothingImagePyramid")
(MovingImagePyramid "MovingSmoothingImagePyramid")

// The following components are most important:
// The optimizer AdaptiveStochasticGradientDescent (ASGD) works
Expand Down Expand Up @@ -80,7 +80,7 @@
// The number of resolutions. 1 Is only enough if the expected
// deformations are small. 3 or 4 mostly works fine. For large
// images and large deformations, 5 or 6 may even be useful.
(NumberOfResolutions 5)
(NumberOfResolutions 6)

// The downsampling/blurring factors for the image pyramids.
// By default, the images are downsampled by a factor of 2
Expand Down Expand Up @@ -120,7 +120,7 @@
// them randomly. See the manual for information on other sampling
// strategies.
(NewSamplesEveryIteration "true")
(ImageSampler "Random")
(ImageSampler "RandomCoordinate")

// ************* Interpolation and Resampling ****************

Expand Down Expand Up @@ -148,5 +148,5 @@
(WriteResultImage "true")

// The pixel type and format of the resulting deformed moving image
(ResultImagePixelType "short")
(ResultImagePixelType "float")
(ResultImageFormat "mhd")
Loading

0 comments on commit d13a4c0

Please sign in to comment.