Skip to content

Commit

Permalink
World coordinates support (#14)
Browse files Browse the repository at this point in the history
* ENH:Use_itk_napari_conversion

* STYLE:Adhere_to_pep8

* ENH: World coordinate support for images

* ENH: World coordinate support for masks

* ENH:Version_bump_0.1.5
  • Loading branch information
ViktorvdValk committed May 21, 2021
1 parent 83afebc commit 1565b4f
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 33 deletions.
2 changes: 1 addition & 1 deletion elastix_napari/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
__author__ = "Viktor van der Valk"
__email__ = "v.o.van_der_valk@lumc.nl"

__version__ = "0.1.4"
__version__ = "0.1.5"


def get_module_version():
Expand Down
52 changes: 31 additions & 21 deletions elastix_napari/elastix_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import itk
from pathlib import Path
from typing import Sequence
from itk_napari_conversion import image_from_image_layer
from itk_napari_conversion import image_layer_from_image


def on_init(widget):
Expand Down Expand Up @@ -84,10 +86,10 @@ def toggle_advanced_widget(event):
nr_spatial_samples={"max": 8192, "step": 256,
"tooltip": 'Select the number of spatial '
'samples to use'})
def elastix_registration(fixed: 'napari.types.ImageData',
moving: 'napari.types.ImageData', preset: str,
fixed_mask: 'napari.types.ImageData',
moving_mask: 'napari.types.ImageData',
def elastix_registration(fixed: 'napari.layers.Image',
moving: 'napari.layers.Image', preset: str,
fixed_mask: 'napari.layers.Image',
moving_mask: 'napari.layers.Image',
fixed_ps: Sequence[Path], moving_ps: Sequence[Path],
param1: Sequence[Path], param2: Sequence[Path],
param3: Sequence[Path], init_trans: Sequence[Path],
Expand All @@ -96,7 +98,7 @@ def elastix_registration(fixed: 'napari.types.ImageData',
nr_spatial_samples: int = 512,
max_step_length: float = 1.0, masks: bool = False,
advanced: bool = False
) -> 'napari.types.LayerDataTuple':
) -> 'napari.layers.Image':
"""
Takes user input and calls elastix' registration function in itkelastix.
"""
Expand Down Expand Up @@ -125,10 +127,11 @@ def elastix_registration(fixed: 'napari.types.ImageData',
fixed_ps = ''
moving_ps = ''

# Casting to numpy is currently necessary
# because of napari's type ambiguity.
fixed = np.asarray(fixed).astype(np.float32)
moving = np.asarray(moving).astype(np.float32)
# Convert image layer to itk_image
fixed = image_from_image_layer(fixed)
moving = image_from_image_layer(moving)
fixed = fixed.astype(itk.F)
moving = moving.astype(itk.F)

parameter_object = itk.ParameterObject.New()
if preset == "custom":
Expand Down Expand Up @@ -174,12 +177,12 @@ def elastix_registration(fixed: 'napari.types.ImageData',
print("No masks selected for registration")
return utils.error("No masks selected for registration")
else:
# Casting to numpy and itk is currently necessary
# because of napari's type ambiguity.

if moving_mask is None:
fixed_mask = np.asarray(fixed_mask).astype(np.uint8)
fixed_mask = itk.image_view_from_array(fixed_mask)
# Convert mask layer to itk_image
fixed_mask = image_from_image_layer(fixed_mask)
fixed_mask = fixed_mask.astype(itk.UC)

# Call elastix
result_image, rtp = itk.elastix_registration_method(
fixed, moving, parameter_object, fixed_mask=fixed_mask,
fixed_point_set_file_name=fixed_ps,
Expand All @@ -188,20 +191,25 @@ def elastix_registration(fixed: 'napari.types.ImageData',
log_to_console=False)

elif fixed_mask is None:
moving_mask = np.asarray(moving_mask).astype(np.uint8)
moving_mask = itk.image_view_from_array(moving_mask)
# Convert mask layer to itk_image
moving_mask = image_from_image_layer(moving_mask)
moving_mask = moving_mask.astype(itk.UC)

# Call elastix
result_image, rtp = itk.elastix_registration_method(
fixed, moving, parameter_object, moving_mask=moving_mask,
fixed_point_set_file_name=fixed_ps,
moving_point_set_file_name=moving_ps,
initial_transform_parameter_file_name=init_trans,
log_to_console=False)
else:
fixed_mask = np.asarray(fixed_mask).astype(np.uint8)
fixed_mask = itk.image_view_from_array(fixed_mask)
moving_mask = np.asarray(moving_mask).astype(np.uint8)
moving_mask = itk.image_view_from_array(moving_mask)
# Convert mask layer to itk_image
fixed_mask = image_from_image_layer(fixed_mask)
fixed_mask = fixed_mask.astype(itk.UC)
moving_mask = image_from_image_layer(moving_mask)
moving_mask = moving_mask.astype(itk.UC)

# Call elastix
result_image, rtp = itk.elastix_registration_method(
fixed, moving, parameter_object, fixed_mask=fixed_mask,
moving_mask=moving_mask,
Expand All @@ -210,7 +218,9 @@ def elastix_registration(fixed: 'napari.types.ImageData',
initial_transform_parameter_file_name=init_trans,
log_to_console=False)

return np.asarray(result_image).astype(np.float32), {'name': preset + ' Registration'}
layer = image_layer_from_image(result_image)
layer.name = preset + " Registration"
return layer


@napari_hook_implementation
Expand Down
35 changes: 24 additions & 11 deletions elastix_napari/tests/test_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from elastix_napari import elastix_registration
import numpy as np
from qtpy.QtWidgets import QMessageBox
from itk_napari_conversion import image_layer_from_image
from itk_napari_conversion import image_from_image_layer


# Test widget function
Expand Down Expand Up @@ -33,6 +35,7 @@ def image_generator(x1, x2, y1, y2, mask=False, artefact=False,
if artefact:
image[-10:, :] = 1
image = itk.image_view_from_array(image)
image = image_layer_from_image(image)
return image


Expand All @@ -45,8 +48,10 @@ def get_er(*args, **kwargs):
def test_registration():
fixed_image = image_generator(25, 75, 25, 75)
moving_image = image_generator(1, 51, 10, 60)
result_image = get_er(fixed_image, moving_image, preset='rigid')[0]
mean_diff = np.absolute(np.subtract(result_image, fixed_image)).mean()
result_image = get_er(fixed_image, moving_image, preset='rigid')
mean_diff = np.absolute(np.subtract(
np.asarray(image_from_image_layer(result_image)),
np.asarray(image_from_image_layer(fixed_image)))).mean()
assert mean_diff < 0.001


Expand All @@ -61,11 +66,13 @@ def test_masked_registration():

result_image = get_er(fixed=fixed_image, moving=moving_image,
fixed_mask=fixed_mask, moving_mask=moving_mask,
preset='rigid', masks=True)[0]
preset='rigid', masks=True)

# Filter artifacts out of the images.
masked_fixed_image = np.asarray(fixed_image)[0:90, 0:90]
masked_result_image = result_image[0:90, 0:90]
masked_fixed_image = np.asarray(
image_from_image_layer(fixed_image))[0:90, 0:90]
masked_result_image = np.asarray(
image_from_image_layer(result_image))[0:90, 0:90]

mean_diff = np.absolute(np.subtract(masked_fixed_image,
masked_result_image)).mean()
Expand All @@ -85,9 +92,11 @@ def test_pointset_registration(data_dir):

result_image = get_er(fixed_image, moving_image, fixed_ps=fixed_ps,
moving_ps=moving_ps, preset='rigid',
advanced=True)[0]
advanced=True)

mean_diff = np.absolute(np.subtract(result_image, fixed_image)).mean()
mean_diff = np.absolute(np.subtract(
np.asarray(image_from_image_layer(result_image)),
np.asarray(image_from_image_layer(fixed_image)))).mean()
assert mean_diff < 0.001


Expand All @@ -99,9 +108,11 @@ def test_custom_registration(data_dir):
filename = "parameters_Rigid.txt"
result_image = get_er(fixed_image, moving_image, preset='custom',
param1=(str(data_dir / filename), 'x'),
param2=(str(data_dir / filename), 'x'))[0]
param2=(str(data_dir / filename), 'x'))

mean_diff = np.absolute(np.subtract(result_image, fixed_image)).mean()
mean_diff = np.absolute(np.subtract(
np.asarray(image_from_image_layer(result_image)),
np.asarray(image_from_image_layer(fixed_image)))).mean()
assert mean_diff < 0.01


Expand All @@ -113,8 +124,10 @@ def test_initial_transform(data_dir):
result_image = get_er(
fixed_image, moving_image, preset='rigid',
init_trans=(str(data_dir / init_trans_filename), 'x'), resolutions=6,
max_iterations=500, advanced=True)[0]
mean_diff = np.absolute(np.subtract(result_image, fixed_image)).mean()
max_iterations=500, advanced=True)
mean_diff = np.absolute(np.subtract(
np.asarray(image_from_image_layer(result_image)),
np.asarray(image_from_image_layer(fixed_image)))).mean()
assert mean_diff < 0.01


Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ numpy>=1.19.0
napari>=0.4.6
napari-plugin-engine>=0.1.4
magicgui>=0.2.6
itk_napari_conversion>=0.3.1
napari-itk-io>=0.1.0

0 comments on commit 1565b4f

Please sign in to comment.