# Mesh to Mesh Registration Example

ITK natively supports image-to-image registration, which is a common operation for medical images with symmetry. Another common method of storing 3D volumetric data is to represent volume surfaces with meshes. In this example we seek to register two meshes using various ITK metrics and optimization techniques.

Registration classes are defined in the Python `hasi` submodule and built on top of the ITK Python wrapping. The `MeanSquaresRegistrar` and `DiffeoRegistrar` classes apply registration techniques to images derived from mesh inputs, while the `PointSetEntropyRegistrar` aims to register meshes via point set entropy metrics. Mesh registration is carried out with each class in this notebook on sample bone femur mesh data downloaded to the `examples/Data` folder.

This notebook requires the following modules, which can be either acquired via `pip` or built alongside the ITK `master` branch:
- [ITK](https://github.com/InsightSoftwareConsortium/ITK/)
- [ITKBoneEnhancement](https://github.com/InsightSoftwareConsortium/ITKBoneEnhancement)
- [ITKMeshToPolyData](https://github.com/InsightSoftwareConsortium/ITKMeshToPolyData)
- [ITKWidgets](https://github.com/InsightSoftwareConsortium/itkwidgets)

In [1]:
# Update sys.path to reference src/ modules
import os
import sys
import copy
import importlib
from urllib.request import urlretrieve

import itk
from itkwidgets import view, checkerboard, compare
from ipywidgets import FloatProgress, Label, HBox, VBox, FloatText, ColorPicker, Button
PATTERN_COUNT = 5

module_path = os.path.abspath(os.path.join('..'))

if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
os.makedirs('Input', exist_ok=True)
os.makedirs('Output', exist_ok=True)

In [3]:
MESH_TO_USE = '901-R'
TARGET_MESH_FILE = f'Input/{MESH_TO_USE}-mesh.vtk'
TEMPLATE_MESH_FILE = f'Input/906-R-atlas.obj'

MEANSQUARES_OUTPUT_FILE = f'Output/{MESH_TO_USE}-meansquares-registered.obj'
DIFFEO_OUTPUT_FILE = f'Output/{MESH_TO_USE}-diffeo-registered.obj'
POINTSET_OUTPUT_FILE = f'Output/{MESH_TO_USE}-pointset-registered.obj'
POINTSET_RESAMPLED_OUTPUT_FILE = f'Output/{MESH_TO_USE}-pointset-resampled.obj'

In [4]:
# Download meshes
if not os.path.exists(TARGET_MESH_FILE):
    url = 'https://data.kitware.com/api/v1/file/5f9daaba50a41e3d1924dae9/download'
    urlretrieve(url, TARGET_MESH_FILE)
if not os.path.exists(TEMPLATE_MESH_FILE):
    url = 'https://data.kitware.com/api/v1/file/608b006d2fa25629b970f139/download'
    urlretrieve(url, TEMPLATE_MESH_FILE)

In [5]:
template_mesh = itk.meshread(TEMPLATE_MESH_FILE, itk.F)
target_mesh = itk.meshread(TARGET_MESH_FILE, itk.F)

## Compare images with ITKWidgets

We can use `view`, `compare`, and `checkerboard` to inspect mesh and image data.

In [6]:
#view(geometries=[template_mesh,target_mesh])

## Mesh-To-Image Conversion in the Base Class
Python registration classes inheriting from `MeshToMeshRegistrar` implement unique registration algorithms. The base class includes common definitions for 3D mesh-to-mesh registration, abstract methods, and a mesh-to-image conversion method.

The `mesh_to_image` function takes in a 3D mesh and converts it into an ITK 3D image object. The spacing, origin, and size of the 3D image may be calculated from the minimum bounding box of the mesh or set from a reference image.

If a list of meshes is passed to `mesh_to_image` then the output images will each be set with the same size, spacing, and origin so that every mesh is fully contained within the image region.

In [7]:
from src.hasi.hasi.meshtomeshregistrar import MeshToMeshRegistrar
registrar = MeshToMeshRegistrar()

In [8]:
target_image, template_image = registrar.mesh_to_image([target_mesh, template_mesh])

Comparing the meshes generated from the given meshes, we see that the two bone images are generally similar but do not exactly line up together. Registration will translate one mesh so that the two bones better coincide.

In [9]:
#compare(template_image,target_image)

In [10]:
#checkerboard(template_image, target_image, pattern=PATTERN_COUNT)

## Point Set Downsizing in the Base Class

The base `MeshToMeshRegistrar` class also provides decimation functionality for uniform random point set sampling from a given mesh, primarily used to improve performance for point set based registration of dense meshes. Note that uniform random sampling may not be suitable for all shapes, in which case application-specific resampling should be carried out externally prior to registration.

For subsequent registration in this notebook we will rely on a template atlas with approximately 4,000 points which does not require decimation.

In [11]:
target_points_reduced = \
    registrar.randomly_sample_mesh_points(mesh=target_mesh, sampling_rate=0.01)

In [12]:
#view(geometries=[target_mesh],point_sets=[target_points_reduced])

## Run Mean Squares Image Registration

The `MeanSquaresRegistrar` class converts meshes to images and runs [Broyden-Fletcher-Goldfarb-Shanno Optimization](https://itk.org/Doxygen/html/classitk_1_1LBFGSBOptimizerv4.html) on a [BSplineTransform](https://itk.org/Doxygen/html/classitk_1_1BSplineTransform.html) to iteratively reduce the mean square error. The resulting transform is then applied to resample the target mesh into the template mesh domain.

Progress is shown with an itkwidgets display via hooks into the ITK event-observer system. The resultant mesh is returned as an object in the Python environment and may be optionally written out to a file. Iteration updates may also be printed to the output window with the optional `verbose` flag.

In [13]:
from src.hasi.hasi.meansquaresregistrar import MeanSquaresRegistrar

In [14]:
# Must instantiate a registration object to initialize optimizers
registrar = MeanSquaresRegistrar()

In [15]:
progress = FloatProgress(
        min=0.0,
        max=150.0,
        step=1
    )
box = HBox([
    Label('Register images'),
    progress
])
box

HBox(children=(Label(value='Register images'), FloatProgress(value=0.0, max=150.0)))

In [16]:
def update_progress():
    progress.value = registrar.optimizer.GetCurrentIteration()
registrar.optimizer.AddObserver(itk.IterationEvent(), update_progress)

0

In [17]:
(transform_result, mesh_result) = registrar.register(template_mesh,
                                             target_mesh,
                                             filepath=MEANSQUARES_OUTPUT_FILE,
                                             verbose=True,
                                             num_iterations=200)

0 0.4953344442828443 0.0
0 0.3679928041625301 0.0
0 0.3679928041625301 0.0
1 0.1382597111142386 0.002047170989457994
1 0.1382597111142386 0.002047170989457994
2 0.09620259423214085 0.001261876484024073
2 0.09620259423214085 0.001261876484024073
3 0.0810644365222312 0.0011402423999786182
3 0.0810644365222312 0.0011402423999786182
4 0.05506265821062781 0.0008026746053761812
4 0.05506265821062781 0.0008026746053761812
5 0.036229100368238225 0.0005999657307756083
5 0.036229100368238225 0.0005999657307756083
6 0.02201122948495042 0.0006569240263049952
6 0.02201122948495042 0.0006569240263049952
7 0.019790912974060545 0.000227595005185355
7 0.019790912974060545 0.000227595005185355
8 0.017138213201441897 0.00012848767601378967
8 0.017138213201441897 0.00012848767601378967
9 0.01643926158089219 0.00011380109594120822
9 0.01643926158089219 0.00011380109594120822
10 0.015742424102559768 0.000229305332514829
10 0.015742424102559768 0.000229305332514829
11 0.015332139820951112 9.86042590450415e-0

100 0.007166497278198204 3.225428203746481e-05
100 0.007166497278198204 3.225428203746481e-05
101 0.007149650939008197 1.2455706971964613e-05
101 0.007149650939008197 1.2455706971964613e-05
102 0.007121663744128236 1.2982402927792889e-05
102 0.007121663744128236 1.2982402927792889e-05
103 0.007107483587565755 1.609346523245832e-05
103 0.007107483587565755 1.609346523245832e-05
104 0.007092042896733735 2.3113347269329504e-05
104 0.007092042896733735 2.3113347269329504e-05
105 0.0070876421136930245 3.544249173464179e-05
105 0.0070876421136930245 3.544249173464179e-05
106 0.0070856698373387254 1.011908203672434e-05
106 0.0070856698373387254 1.011908203672434e-05
107 0.007076003884998143 9.5015908970922e-06
107 0.007076003884998143 9.5015908970922e-06
108 0.0070623753300841675 1.2760766981931491e-05
108 0.0070623753300841675 1.2760766981931491e-05
109 0.007066499924291111 1.4079264563422209e-05
109 0.007056548175037092 1.4079264563422209e-05
109 0.007056548175037092 1.4079264563422209e-05


Comparison of the resulting mesh with the target shows successful registration.

In [18]:
#view(geometries=[mesh_result,target_mesh])

## Diffeomorphic Registration

The `DiffeoRegistrar` class converts meshes to images and performs registration using the [Diffeomorphic Demons registration algorithm](https://itk.org/Doxygen/html/classitk_1_1DiffeomorphicDemonsRegistrationFilter.html). The resultant deformation field is then applied to resample the target mesh into the template mesh domain.

Custom observers may print out iteration data accessed via the `registrar.filter` object.

In [19]:
from src.hasi.hasi.diffeoregistrar import DiffeoRegistrar

In [20]:
registrar = DiffeoRegistrar()

In [21]:
diffeoProgress = FloatProgress(
        min=0.0,
        max=200.0,
        step=1
    )
diffeoBox = HBox([
    Label('Register images'),
    diffeoProgress
])
diffeoBox

HBox(children=(Label(value='Register images'), FloatProgress(value=0.0, max=200.0)))

In [22]:
def update_diff_progress():
    diffeoProgress.value = registrar.filter.GetElapsedIterations()

registrar.filter.AddObserver(itk.IterationEvent(),update_diff_progress)

(transform_result, mesh_result) = registrar.register(template_mesh,
                                        target_mesh,
                                        filepath=DIFFEO_OUTPUT_FILE,
                                        verbose=True)

1.7976931348623157e+308
0.4953344445027531
0.4919978146151957
0.49014250375875607
0.4883152615673534
0.4876264321046578
0.4863758875349746
0.4848329850901018
0.4832380374425963
0.4816613820019586
0.47992311701764906
0.47819615858812126
0.4764180158719083
0.4746375325575763
0.4725753735317665
0.47054456107198983
0.4683531446736118
0.46619317072217226
0.4638398448774739
0.4617263143098048
0.45963088709668165
0.45765992990626214
0.45564904932109535
0.45405493940994296
0.4526567041931353
0.45126435618474847
0.4499785970781415
0.4486827830549383
0.4475157724666353
0.44628145441325595
0.4450578744943134
0.4438195343732895
0.4425877546195735
0.441339593969622
0.4401098214527867
0.4388818595160368
0.43749092474381784
0.43564875311473567
0.43370194230716685
0.4311092442088367
0.4285241607433405
0.42582081376755554
0.4231977163312322
0.4205323448942155
0.4176938430560469
0.41429447935831476
0.4105150579147669
0.4078069432083754
0.4054533839896938
0.4031936651465812
0.4008872988176601
0.398616092

Compare the translated image to the target image.

In [23]:
#view(geometries=[mesh_result, target_mesh])

# Entropy-based Registration

The `PointSetEntropyRegistrar` class registers two 3D meshes by computing the transform which minimizes entropy measures between the two point clouds. In this example we substitute a [`EuclideanDistancePointSetToPointSetMetric`](https://itk.org/Doxygen/html/classitk_1_1EuclideanDistancePointSetToPointSetMetricv4.html) to compare the two clouds.

In [24]:
from src.hasi.hasi.pointsetentropyregistrar import PointSetEntropyRegistrar
registrar = PointSetEntropyRegistrar()

In [25]:
progress = FloatProgress(
        min=0.0,
        max=2000.0,
        step=1
    )
progressBox = HBox([
    Label('Register images'),
    progress
])
progressBox

HBox(children=(Label(value='Register images'), FloatProgress(value=0.0, max=2000.0)))

In [26]:
def update_progress():
    progressBox.value = registrar.optimizer.GetCurrentIteration()

registrar.optimizer.AddObserver(itk.IterationEvent(),update_progress)

0

In [27]:
metric = itk.EuclideanDistancePointSetToPointSetMetricv4[itk.PointSet[itk.F,3]].New()

(transform_result, mesh_result) = registrar.register(
                       template_mesh=template_mesh,
                       target_mesh=target_mesh,
                       metric=metric,
                       filepath=POINTSET_OUTPUT_FILE,
                       verbose=True,
                       learning_rate=1.0,
                       max_iterations=300,
                       resample_from_target=False)

0 0.0
0 0.18402767981233376
1 0.13636549597252293
2 0.1136153451996774
3 0.10070102504901322
4 0.0921356833845075
5 0.08639007757362922
6 0.08227116489324207
7 0.07907374804042967
8 0.07637873779025549
9 0.07405272544249199
10 0.0719671439265937
11 0.07005118273844133
12 0.06828105853186942
13 0.06660776118233713
14 0.06498893467311129
15 0.06343569110339238
16 0.06194662338114907
17 0.06055190889404482
18 0.059165763253996674
19 0.05782261337640995
20 0.05652837680168472
21 0.05530507143716108
22 0.054134602834920165
23 0.053020636495240706
24 0.051944194750049855
25 0.05091371680519884
26 0.04990854226515862
27 0.048950925455557744
28 0.048024307153977146
29 0.04712948446720358
30 0.0462846441709203
31 0.04545128218920508
32 0.044643358857768024
33 0.043888957619121746
34 0.04316445250215185
35 0.04248362994983003
36 0.04183262747532614
37 0.0412037429531749
38 0.04059294224542183
39 0.040016154126255296
40 0.03945396290133864
41 0.03892002005470285
42 0.03840443107046075
43 0.037910

In [28]:
#view(geometries=[mesh_result,target_mesh])

## Resample From Target

A common procedure for comparing correspondences across samples is to register an atlas to each mesh sample and then deform the template to align with points on the target surface. The `MeshToMeshRegistrar` class provides an interface to use ITK's [`KdTree`](https://itk.org/Doxygen/html/classitk_1_1Statistics_1_1KdTree.html) to set each template point to its nearest neighbor on the target mesh.

In [None]:
template_resampled = \
    registrar.resample_template_from_target(mesh_result, target_mesh)

itk.meshwrite(template_resampled,POINTSET_RESAMPLED_OUTPUT_FILE)

In [None]:
#view(geometries=[target_mesh,mesh_result,template_resampled])

In [None]:
# Clean up file output
os.remove(MEANSQUARES_OUTPUT_FILE)
os.remove(DIFFEO_OUTPUT_FILE)
os.remove(POINTSET_OUTPUT_FILE)
os.remove(POINTSET_RESAMPLED_OUTPUT_FILE)