# Iterative Template Generation

This notebook exemplifies one way in which a template mesh atlas can be generated from a collection of segmented binary images. Each binary image of a mouse femur is downsampled to reduce pixel density prior to applying marching cubes to generate a mesh from the binary image. One arbitrary mesh is selected as the template and then registered to and resampled from each original mesh to get a full set of meshes with correspondence points. The meshes are then groupwise registered via procrustes alignment and the mean mesh is taken as the new template. This process is repeated for a fixed number of iterations to get a template mesh atlas that represents the average case of all meshes.

In [1]:
import os
import glob

import numpy as np
import itk
import sys
from itkwidgets import view, checkerboard

module_path = os.path.abspath(os.path.join('..'))

if module_path not in sys.path:
    sys.path.append(module_path)

from src.hasi.hasi.pointsetentropyregistrar import PointSetEntropyRegistrar

### Read images

Input images represent the results of automatic binary segmentations of mouse femur data. Each image contains only the femur object and represents a different possible region in space.

In [2]:
# TODO download femurs if not locally available
IMAGE_FOLDER = 'Input/femurs/'
MESH_OUTPUT_FOLDER = 'Output/femurs/'
MEAN_OUTPUT_FOLDER = 'Output/mean/'

In [3]:
for folder in [IMAGE_FOLDER, MESH_OUTPUT_FOLDER, MEAN_OUTPUT_FOLDER]:
    os.makedirs(folder, exist_ok=True)

In [4]:
input_paths = glob.glob(IMAGE_FOLDER + '*')
assert(len(input_paths) == 28)

# FIXME remove 901-R for now because it is misaligned
del input_paths[1]

In [5]:
MESH_FILENAMES = [os.path.basename(file).replace('.nrrd','.obj') for file in input_paths]

In [6]:
images = list()

for path in input_paths:
    images.append(itk.imread(path, itk.UC))

### Paste images into same space
For viewing convenience.

In [7]:
# Verify image spacing is equivalent
TOLERANCE = 0.0000001
assert(all([itk.spacing(images[i])[j] - itk.spacing(images[i+1])[j] < TOLERANCE
            for j in range(0,images[0].GetImageDimension())
            for i in range(0,len(images) - 1)]))

In [8]:
# Get largest range containing images
max_size = itk.size(images[0])
for image in images:
    for i in range(0,len(max_size)):
        max_size[i] = max(max_size[i], itk.size(image)[i])
print(max_size)

itkSize3 ([1392, 983, 1247])


In [9]:
# Make image of common size
def paste_common_size(orig_image, size, spacing):
    dimension = orig_image.GetImageDimension()
    region = itk.ImageRegion[dimension]()
    region.SetSize(size)
    region.SetIndex([0] * dimension)

    new_image = type(orig_image).New()
    new_image.SetRegions(region)
    new_image.SetSpacing(spacing)
    new_image.Allocate()
    
    filter = itk.PasteImageFilter[type(orig_image)].New()
    filter.SetSourceImage(orig_image)
    filter.SetSourceRegion(orig_image.GetLargestPossibleRegion())
    filter.SetDestinationImage(new_image)
    filter.SetDestinationIndex([0,0,0])
    filter.Update()
    
    return filter.GetOutput()

In [10]:
# Make images common size
for i in range(0, len(images)):
    images[i] = paste_common_size(images[i], max_size, images[0].GetSpacing())

In [11]:
view(images[1])

Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageUC3; pr…

### Downsample images
The marching cubes algorithm returns a mesh with vertex density related to the pixel density of the original image. In this case marching cubes on the default images would produce meshes of approximately 800,000 points each, but the template mesh atlas is desired to contain only approximately 5,000 points. Each image is downsampled in order to yield a less dense mesh output.

In [12]:
DOWNSAMPLE_RATIO = 14

In [13]:
downsampled_images = list()

In [14]:
for image in images:
    output_spacing = [spacing * DOWNSAMPLE_RATIO for spacing in itk.spacing(image)]
    output_size = [int(size / DOWNSAMPLE_RATIO) for size in itk.size(image)]

    downsample = itk.resample_image_filter(Input=image,
                                           Size=output_size,
                                           OutputOrigin=itk.origin(image),
                                           OutputSpacing=output_spacing,)
    downsampled_images.append(downsample)

In [15]:
view(downsample)

Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageUC3; pr…

### Generate Meshes
The `itk.BinaryMask3DMeshSource` class makes use of the Marching Cubes algorithm to generate a mesh from a given object. Each binary image here uses the value '1' to indicate the femur is present at a pixel and '0' to indicate the femur is not present. Marching Cubes rapidly fills the femur space and generates surfaces at pixel region boundaries.

In [16]:
FEMUR_OBJECT_PIXEL_VALUE = 1

Dimension = itk.template(downsampled_images[0])[1][1]
MeshType = itk.Mesh[itk.F,Dimension]
MeshSourceType = itk.BinaryMask3DMeshSource[type(downsampled_images[0]), MeshType]

In [17]:
meshes = list()

for image in downsampled_images:
    mesh_source = MeshSourceType.New()
    mesh_source.SetObjectValue(FEMUR_OBJECT_PIXEL_VALUE)
    mesh_source.SetInput(image)
    mesh_source.Update()
    
    mesh_output = mesh_source.GetOutput()
    meshes.append(mesh_output)
    
print('Average mesh points: ' +
      str(sum([mesh.GetNumberOfPoints() for mesh in meshes]) / len(meshes)))

Average mesh points: 4822.518518518518


In [18]:
# Write out each mesh to disk
for i in range(0,len(meshes)):
    itk.meshwrite(meshes[i], f'{MESH_OUTPUT_FOLDER}{MESH_FILENAMES[i]}')

In [19]:
# TODO visualize with itkwidgets
# view(geometries=meshes)

### Select Template Mesh
It is now necessary to select a mesh that will act as the 'standard' for correspondence point updates from this point forward. The first mesh in the list is arbitrarily selected as this standard template.

In [20]:
TEMPLATE_MESH_INDEX = 0
template_mesh = meshes[TEMPLATE_MESH_INDEX]

In [21]:
# FIXME visualize template
# view(geometries=[template_mesh])

### Define Iterative Registration
It is desired to find correspondence points between the template and each sample mesh. In order to get correspondence, two steps are employed:
- First, a copy of the template mesh is registered to the target sample;
- Second, each mesh point is updated to its nearest neighbor on the target sample.

The result of this process is a full collection of meshes having the same number of points and correspondence between each point such that it represents the same approximate feature on each femur.

In [22]:
#FIXME set mesh1 to better inital transform for registration
#transform = itk.Euler3DTransform[itk.D].New()
#params = transform.GetParameters()
#params.SetElement(0,3.14)
#params.SetElement(1,3.14)
#params.SetElement(2,3.14 / 2)
#transform.SetParameters(params)

In [23]:
def register_template_to_sample(template_mesh, 
                                sample_mesh,
                                learning_rate=1.0,
                                max_iterations=2000):
    
    registrar = PointSetEntropyRegistrar()
    metric = itk.EuclideanDistancePointSetToPointSetMetricv4[itk.PointSet[itk.F,3]].New()
    transform = itk.Euler3DTransform[itk.D].New()
    
    # Make a deep copy of the template point set to resample from the target
    template_copy = itk.Mesh[itk.F,3].New()
    for i in range(0, template_mesh.GetNumberOfPoints()):
        template_copy.SetPoint(i, template_mesh.GetPoint(i))
    template_copy.SetCells(template_mesh.GetCells())
    
    # Run registration and resample from target
    (transform, deformed_mesh) = registrar.register(template_mesh=template_copy,
                                                    target_mesh=sample_mesh,
                                                    metric=metric,
                                                    transform=transform,
                                                    learning_rate=learning_rate,
                                                    max_iterations=max_iterations,
                                                    resample_from_target=True,
                                                    verbose=False)
    return deformed_mesh

### Define Procrustes Alignment Parameters
Now that template meshes have been aligned to represent each input mesh with correspondence, run Procrustes alignment and get out a mean mesh as the new template.

In [24]:
def align_procrustes(deformed_meshes, 
                     template_index, 
                     verbose=False,
                     convergence=0.08):
    
    procrustes_filter = itk.MeshProcrustesAlignFilter[type(deformed_meshes[0]), type(deformed_meshes[0])].New()
    
    procrustes_filter.SetUseInitialAverageOff()
    procrustes_filter.SetUseNormalizationOff()
    procrustes_filter.SetUseScalingOff()
    procrustes_filter.SetConvergence(convergence)  # Minimum threshold to exit alignment
    
    # Set mesh correspondence inputs
    procrustes_filter.SetNumberOfInputs(len(deformed_meshes))
    for i in range(0, len(deformed_meshes)):
        procrustes_filter.SetInput(i, deformed_meshes[i])
    
    # Run filter
    procrustes_filter.Update()
    
    if(verbose):
        print(f'Alignment converged at {procrustes_filter.GetMeanPointsDifference()}')
    
    mean_result = procrustes_filter.GetMean()
    mean_result.SetCells(deformed_meshes[template_index].GetCells())
    return mean_result

### Run Iterative Refinement

Run registration and alignment on every mesh, choosing whether to iteratively update meshes to take advantage of mean alignment in each iteration or to ignore previous changes so that each alignment procedure is independent of the order of meshes in the list.

In [26]:
NUM_ITERATIONS = 1
UPDATE_CONTINUOUSLY = True
aligned_templates = list()

# Write out mean meshes by iteration
# ex. 'Output/mean/0/901-L-femur-label.obj'
for i in range(0,NUM_ITERATIONS):
    os.makedirs(MEAN_OUTPUT_FOLDER + str(i) + '\\', exist_ok=True)

In [29]:
for iteration in range(0,NUM_ITERATIONS):
    print(f'Now at iteration {iteration}')
    
    for template_mesh_index in range(0,len(meshes)):
        print(f'Now at template mesh {template_mesh_index}')
        template_mesh = meshes[template_mesh_index]

        deformed_templates = list()
        for sample_index in range(0,len(meshes)):
            if sample_index == template_mesh_index:
                deformed_templates.append(template_mesh)
            else:
                print(f'Resampling from mesh {sample_index}')
                deformed_template = register_template_to_sample(template_mesh, meshes[sample_index])
                deformed_templates.append(deformed_template)

        print('Running alignment')
        mesh_result = align_procrustes(deformed_templates, template_mesh_index, verbose=True)

        # Save intermediate results to disk
        # These meshes are very small so this is not a significant expense (~400 KB/mesh)
        output_path = f'{MEAN_OUTPUT_FOLDER}/{iteration}/{MESH_FILENAMES[template_mesh_index]}'
        print(f'Writing mean to {output_path}')
        itk.meshwrite(mesh_result, output_path)
        
        # Optionally update current mesh in place for use in subsequent alignments
        if UPDATE_IN_PLACE:
            meshes[template_mesh_index] = mesh_result
        else:
            aligned_templates.append(mesh_result)
    
    # Optionally update templates only between iterations
    if not UPDATE_CONTINUOUSLY:
        meshes = aligned_templates

Now at iteration 0
Now at template mesh 0
Resampling from mesh 1
Resampling from mesh 2
Resampling from mesh 3
Resampling from mesh 4
Resampling from mesh 5
Resampling from mesh 6
Resampling from mesh 7
Resampling from mesh 8
Resampling from mesh 9
Resampling from mesh 10
Resampling from mesh 11
Resampling from mesh 12
Resampling from mesh 13
Resampling from mesh 14
Resampling from mesh 15
Resampling from mesh 16
Resampling from mesh 17
Resampling from mesh 18
Resampling from mesh 19
Resampling from mesh 20
Resampling from mesh 21
Resampling from mesh 22
Resampling from mesh 23
Resampling from mesh 24
Resampling from mesh 25
Resampling from mesh 26
Running alignment
Alignment converged at 0.0013539076418325581
Writing mean to Output/mean//0/901-L-femur-label.obj
Now at template mesh 1
Resampling from mesh 0
Resampling from mesh 2
Resampling from mesh 3
Resampling from mesh 4
Resampling from mesh 5
Resampling from mesh 6
Resampling from mesh 7
Resampling from mesh 8
Resampling from mesh

Running alignment
Alignment converged at 0.0007549640420944888
Writing mean to Output/mean//0/915-R-femur-label.obj
Now at template mesh 11
Resampling from mesh 0
Resampling from mesh 1
Resampling from mesh 2
Resampling from mesh 3
Resampling from mesh 4
Resampling from mesh 5
Resampling from mesh 6
Resampling from mesh 7
Resampling from mesh 8
Resampling from mesh 9
Resampling from mesh 10
Resampling from mesh 12
Resampling from mesh 13
Resampling from mesh 14
Resampling from mesh 15
Resampling from mesh 16
Resampling from mesh 17
Resampling from mesh 18
Resampling from mesh 19
Resampling from mesh 20
Resampling from mesh 21
Resampling from mesh 22
Resampling from mesh 23
Resampling from mesh 24
Resampling from mesh 25
Resampling from mesh 26
Running alignment
Alignment converged at 0.0008808021989129835
Writing mean to Output/mean//0/916-L-femur-label.obj
Now at template mesh 12
Resampling from mesh 0
Resampling from mesh 1
Resampling from mesh 2
Resampling from mesh 3
Resampling fro

KeyboardInterrupt: 

In [None]:
# FIXME visualize template
# view(geometries=[meshes])