# Mesh to Mesh Registration Example

ITK natively supports image-to-image registration, which is a common operation for medical images with symmetry. Another common method of storing 3D volumetric data is to represent volume surfaces with meshes. In this example we seek to register two meshes using various metrics and optimization techniques built on top of ITK.

Registration classes are defined in the Python `hasi` submodule and built on top of the ITK Python wrapping. The `MeanSquaresRegistrar` and `DiffeoRegistrar` classes apply registration techniques to images derived from mesh inputs, while the `PointSetEntropyRegistrar` aims to register meshes via point set metrics. Mesh registration is carried out with each class in this notebook on sample bone femur mesh data in the `examples/Data` folder.

This notebook requires the following modules, which can be either acquired via `pip` or built alongside the ITK `master` branch:
- [ITK](https://github.com/InsightSoftwareConsortium/ITK/)
- [ITKBoneEnhancement](https://github.com/InsightSoftwareConsortium/ITKBoneEnhancement)
- [ITKMeshToPolyData](https://github.com/InsightSoftwareConsortium/ITKMeshToPolyData)

In [1]:
# Update sys.path to reference src/ modules
import os
import sys
import copy
import importlib
from urllib.request import urlretrieve

import itk
from itkwidgets import view, checkerboard, compare
from ipywidgets import FloatProgress, Label, HBox, VBox, FloatText, ColorPicker, Button
PATTERN_COUNT = 5

module_path = os.path.abspath(os.path.join('..'))

if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
os.makedirs('Input', exist_ok=True)
os.makedirs('Output', exist_ok=True)

In [3]:
meshes = ['901','902','906','907','908','915','916','917','918']
MESH_TO_USE = meshes[0]

TARGET_MESH_FILE = f'Input/{MESH_TO_USE}-R-mesh.vtk'
# Future update will replace with smaller template
TEMPLATE_MESH_FILE = f'Input/template-901-L.obj'

MEANSQUARES_OUTPUT_FILE = f'Output/{MESH_TO_USE}-meansquares-registered.obj'
DIFFEO_OUTPUT_FILE = f'Output/{MESH_TO_USE}-diffeo-registered.obj'
POINTSET_OUTPUT_FILE = f'Output/{MESH_TO_USE}-pointset-registered.obj'

In [4]:
# Download meshes
if not os.path.exists(TARGET_MESH_FILE):
    url = 'https://data.kitware.com/api/v1/file/5f9daaba50a41e3d1924dae9/download'
    urlretrieve(url, TARGET_MESH_FILE)
if not os.path.exists(TEMPLATE_MESH_FILE):
    url = 'https://data.kitware.com/api/v1/file/60762a4b2fa25629b9bbefef/download'
    urlretrieve(url, TEMPLATE_MESH_FILE)

In [5]:
template_mesh = itk.meshread(TEMPLATE_MESH_FILE, itk.F)
target_mesh = itk.meshread(TARGET_MESH_FILE, itk.F)

In [6]:
itk.meshwrite(target_mesh,'Input/901-R-mesh.obj')

### Compare images with ITKWidgets

We can use `view`, `compare`, and `checkerboard` to inspect mesh and image data.

In [48]:
#view(geometries=[template_mesh,target_mesh])

## Mesh-To-Image Conversion in the Base Class
Python registration classes inheriting from `MeshToMeshRegistrar` implement unique registration algorithms. The base class includes common definitions for 3D mesh-to-mesh registration, abstract methods, and a mesh-to-image conversion method.

The `mesh_to_image` function takes in a 3D mesh and converts it into an ITK 3D image object. The spacing, origin, and size of the 3D image may be calculated from the minimum bounding box of the mesh or set from a reference image.

In this case the target mesh occupies a slightly larger region than the template mesh so we create the target image first and then create the template image with the same image size, origin, and spacing.

In [7]:
from src.hasi.hasi.meshtomeshregistrar import MeshToMeshRegistrar
registrar = MeshToMeshRegistrar()

In [8]:
target_image = registrar.mesh_to_image(target_mesh)

In [9]:
template_image = registrar.mesh_to_image(template_mesh, target_image)

Comparing the meshes generated from the given meshes, we see that the two bone images are generally similar but do not exactly line up together. Registration will translate one mesh so that the two bones better coincide.

In [17]:
compare(template_image,target_image)

AppLayout(children=(HBox(children=(Label(value='Link:'), Checkbox(value=False, description='cmap'), Checkbox(v…

In [10]:
checkerboard(template_image, target_image, pattern=PATTERN_COUNT)

VBox(children=(Viewer(annotations=False, interpolation=False, rendered_image=<itk.itkImagePython.itkImageF3; p…

## Point set downsizing in the Base Class

The base `MeshToMeshRegistrar` class also provides functionality for uniform random point set sampling from a given mesh, primarily used to improve performance for point set based registration of dense meshes. Note that uniform random sampling may not be suitable for all shapes, in which case application-specific resampling should be carried out externally prior to registration.

In [13]:
target_points_reduced = \
    registrar.randomly_sample_mesh_points(mesh=target_mesh, sampling_rate=0.01)

In [None]:
#view(geometries=[target_mesh],point_sets=[target_points_reduced])

## Run Mean Squares Image Registration

The `MeanSquaresRegistrar` class converts meshes to images and runs [Broyden-Fletcher-Goldfarb-Shanno Optimization](https://itk.org/Doxygen/html/classitk_1_1LBFGSBOptimizerv4.html) on a [BSplineTransform](https://itk.org/Doxygen/html/classitk_1_1BSplineTransform.html) to iteratively reduce the mean square error. The resulting transform is then applied to resample the target mesh into the template mesh domain.

Progress is shown with an itkwidgets display via hooks into the ITK event-observer system. The resultant mesh is returned as an object in the Python environment and may be optionally written out to a file. Iteration updates may also be printed to the output window with the optional `verbose` flag.

In [18]:
from src.hasi.hasi.meansquaresregistrar import MeanSquaresRegistrar

In [19]:
# Must instantiate a registration object to initialize optimizers
registrar = MeanSquaresRegistrar()

In [20]:
progress = FloatProgress(
        min=0.0,
        max=150.0,
        step=1
    )
box = HBox([
    Label('Register images'),
    progress
])
box

HBox(children=(Label(value='Register images'), FloatProgress(value=0.0, max=150.0)))

In [21]:
def update_progress():
    progress.value = registrar.optimizer.GetCurrentIteration()
registrar.optimizer.AddObserver(itk.IterationEvent(), update_progress)

0

In [22]:
(transform, mesh_result) = registrar.register(template_mesh,
                                             target_mesh,
                                             filepath=MEANSQUARES_OUTPUT_FILE,
                                             verbose=True,
                                             num_iterations=200)

0 0.49874914874144727 0.0
0 0.35646806992164465 0.0
0 0.35646806992164465 0.0
1 0.10295348512391408 0.0026620884996106735
1 0.10295348512391408 0.0026620884996106735
2 0.08035788007325083 0.0009668692767732778
2 0.08035788007325083 0.0009668692767732778
3 0.07187943515785958 0.00031890859179923886
3 0.07187943515785958 0.00031890859179923886
4 0.038170146192536496 0.00037037996170112954
4 0.038170146192536496 0.00037037996170112954
5 0.028223556307958524 0.0006233894309459988
5 0.028223556307958524 0.0006233894309459988
6 0.02590328552918505 0.0002853482012446442
6 0.02590328552918505 0.0002853482012446442
7 0.025198775406897213 0.0001970344038145016
7 0.025198775406897213 0.0001970344038145016
8 0.023445122744411937 0.00014519944498974952
8 0.023445122744411937 0.00014519944498974952
9 0.022026775927756863 0.00013405114330135054
9 0.022026775927756863 0.00013405114330135054
10 0.020840563571749304 0.00012373405753563282
10 0.020840563571749304 0.00012373405753563282
11 0.0207193446228

131 0.014067941013986313 4.751221167143262e-06
131 0.014067964267936404 4.751221167143262e-06
131 0.014067920705364313 4.751221167143262e-06
131 0.014067920706291915 4.751221167143262e-06
131 0.014067920705364313 4.751221167143262e-06
131 0.014067920705364313 4.751221167143262e-06
Solution = [0.014559554532286315, 0.12014350981921557, 0.7987503672311127, 0.12030057710564017, 0.6788038918612957, 1.5143760317904724, 4.745896727046189, 0.8406573458978361, 0.6651776106824592, 1.0193041810049719, 2.7805916020355346, 0.6228623629908627, -0.016504320743012342, -0.2782048145186771, 0.06187687622661249, 0.05234694478796978, 0.6744924117051408, 0.5999101628805228, 2.177563761137019, 0.4782038222375929, 7.189322119112429, -4.102057176569233, 2.757327852881054, 3.0415623211282186, 6.587801209038333, -2.580599375582684, 5.1114076016347925, 3.57606572764389, 0.4978279852600288, -0.9940029033501432, 0.13186057262262493, 0.44372284163806985, 0.16809784969253344, -0.12290561481264638, 2.268174833600322

Comparison of the resulting mesh with the target shows successful registration.

In [38]:
#view(geometries=[meansquares_mesh_result,moving_mesh])
result_image = registrar.mesh_to_image(mesh_result, target_image)
compare(result_image, target_image)

AppLayout(children=(HBox(children=(Label(value='Link:'), Checkbox(value=False, description='cmap'), Checkbox(v…

## Diffeomorphic Registration

The `DiffeoRegistrar` class converts meshes to images and performs registration using the [Diffeomorphic Demons registration algorithm](https://itk.org/Doxygen/html/classitk_1_1DiffeomorphicDemonsRegistrationFilter.html). The resultant deformation field is then applied to resample the target mesh into the template mesh domain.

Custom observers may print out iteration data accessed via the `registrar.filter` object.

In [24]:
from src.hasi.hasi.diffeoregistrar import DiffeoRegistrar

In [25]:
registrar = DiffeoRegistrar()

In [26]:
diffeoProgress = FloatProgress(
        min=0.0,
        max=200.0,
        step=1
    )
diffeoBox = HBox([
    Label('Register images'),
    diffeoProgress
])
diffeoBox

HBox(children=(Label(value='Register images'), FloatProgress(value=0.0, max=200.0)))

In [27]:
def update_diff_progress():
    diffeoProgress.value = registrar.filter.GetElapsedIterations()

registrar.filter.AddObserver(itk.IterationEvent(),update_diff_progress)

(transform, mesh_result) = registrar.register(template_mesh,
                                        target_mesh,
                                        filepath=DIFFEO_OUTPUT_FILE,
                                        verbose=True)

1.7976931348623157e+308
0.4987491488238776
0.4957951476333167
0.4941950264434974
0.4930808966130274
0.49244047805993507
0.4914935131356182
0.49018818252226226
0.48846731263850734
0.4865865765164624
0.4845592439304183
0.4825032200711701
0.4805217333057509
0.4784276470217681
0.4763028060774365
0.47410199155061833
0.47183553547207796
0.46945634655126595
0.4670089080804245
0.46447236777377016
0.4618700059751327
0.4594918977465529
0.45712312522024456
0.45522632305652644
0.4533702021556196
0.4515470716093708
0.4497243459441331
0.4479271184542627
0.44616428966981375
0.4443887288887757
0.4426435844212059
0.4409485711023183
0.43921519289735944
0.4375637934800705
0.4357279227468467
0.43395033416663575
0.4320027254100205
0.42996833652422484
0.4278856395645356
0.42554150404513424
0.42268157967961434
0.4194370314646928
0.416190044045965
0.413215106680567
0.41022460807160394
0.40679013365077954
0.40136523982974737
0.39619486958540806
0.3918454867090814
0.3885642069086108
0.3859829635395982
0.3842073

Compare the translated image to the target image.

In [37]:
#view(geometries=[mesh_result, target_mesh])
result_image = registrar.mesh_to_image(mesh_result, target_image)
compare(result_image, target_image)

AppLayout(children=(HBox(children=(Label(value='Link:'), Checkbox(value=False, description='cmap'), Checkbox(v…

# Entropy-based Registration

The `PointSetEntropyRegistrar` class registers two 3D meshes by converting them to point clouds and computing the transform which minimizes entropy measures between the two clouds. In this example we substitute a [`EuclideanDistancePointSetToPointSetMetric`](https://itk.org/Doxygen/html/classitk_1_1EuclideanDistancePointSetToPointSetMetricv4.html) to compare the two clouds.

In [29]:
from src.hasi.hasi.pointsetentropyregistrar import PointSetEntropyRegistrar
registrar = PointSetEntropyRegistrar()

In [30]:
progress = FloatProgress(
        min=0.0,
        max=2000.0,
        step=1
    )
progressBox = HBox([
    Label('Register images'),
    progress
])
progressBox

HBox(children=(Label(value='Register images'), FloatProgress(value=0.0, max=2000.0)))

In [31]:
def update_progress():
    progressBox.value = registrar.optimizer.GetCurrentIteration()

registrar.optimizer.AddObserver(itk.IterationEvent(),update_progress)

0

In [32]:
metric = itk.EuclideanDistancePointSetToPointSetMetricv4[itk.PointSet[itk.F,3]].New()

(transform, mesh_result) = registrar.register(
                       template_mesh=template_mesh,
                       target_mesh=target_mesh,
                       metric=metric,
                       filepath=POINTSET_OUTPUT_FILE,
                       resample_rate=0.005,
                       verbose=True,
                       learning_rate=1.0,
                       max_iterations=300,
                       resample_from_target=False)

0 0.0
0 0.2931251218908875
1 0.1205193122124783
2 0.10441760162035982
3 0.09834692909340771
4 0.09362681805832306
5 0.08661904369141127
6 0.08230682202306248
7 0.07942040892548036
8 0.07747047345196857
9 0.0759834773217011
10 0.07431930257841414
11 0.07352427455528829
12 0.07292767471290339
13 0.07255832119738176
14 0.07227852401408703
15 0.07203357375770734
16 0.07162628370337433
17 0.07124927954132765
18 0.0705768147400119
19 0.0704106434338947
20 0.06978295403744568
21 0.06845888597949865
22 0.06800386975512246
23 0.06762741927421313
24 0.06650859049205682
25 0.06529362447371807
26 0.06446744560528696
27 0.06344765417959368
28 0.06251106364577003
29 0.06184677765942656
30 0.06141954778967147
31 0.06111198557453498
32 0.060795775352791286
33 0.05943190314826571
34 0.05757744627271415
35 0.05647908427828097
36 0.05568996446341291
37 0.05512690087778857
38 0.05469473397314897
39 0.054158985845869395
40 0.05380578359136856
41 0.05358108151555352
42 0.05340742678577612
43 0.0532742197051

In [33]:
compare(registrar.mesh_to_image(mesh_result),registrar.mesh_to_image(target_mesh))

AppLayout(children=(HBox(children=(Label(value='Link:'), Checkbox(value=False, description='cmap'), Checkbox(v…

In [35]:
template_resampled = \
    registrar.resample_template_from_target(mesh_result, target_mesh)

itk.meshwrite(template_resampled,'Output/pointset-resampled.obj')

In [36]:
#view(geometries=[target_mesh,mesh_result,template_resampled])
compare(registrar.mesh_to_image(template_resampled),registrar.mesh_to_image(target_mesh))

AppLayout(children=(HBox(children=(Label(value='Link:'), Checkbox(value=False, description='cmap'), Checkbox(v…

In [None]:
# Clean up file output
os.remove(MEANSQUARES_OUTPUT_FILE)
os.remove(DIFFEO_OUTPUT_FILE)
os.remove(POINTSET_OUTPUT_FILE)