# Iterative Template Generation

This notebook exemplifies one way in which a template mesh atlas can be generated from a collection of segmented binary images. Each binary image of a mouse femur is downsampled to reduce pixel density prior to applying marching cubes to generate a mesh from the binary image. One arbitrary mesh is selected as the template and then registered to and resampled from each original mesh to get a full set of meshes with correspondence points. The meshes are then groupwise registered via procrustes alignment and the mean mesh is taken as the new template. This process is repeated for a fixed number of iterations to get a template mesh atlas that represents the average case of all meshes.

This pipeline uses the [ITKShape](https://github.com/slicersalt/ITKShape) module for shape analysis.

In [1]:
import os
import glob
import time

import numpy as np
import itk
import sys
from itkwidgets import view, checkerboard

module_path = os.path.abspath(os.path.join('..'))

if module_path not in sys.path:
    sys.path.append(module_path)

from src.hasi.hasi.pointsetentropyregistrar import PointSetEntropyRegistrar

### Read images

Input images represent the results of automatic binary segmentations of mouse femur data. Each image contains only the femur object and potentially represents a different region in space.

In [2]:
IMAGE_FOLDER = 'Input/femurs/'
DENSE_MESH_OUTPUT_FOLDER = 'Output/femurs/'
TEMPLATE_OUTPUT_FOLDER = 'Output/templates/'
MEAN_OUTPUT_FOLDER = 'Output/mean/'

for folder in [IMAGE_FOLDER, DENSE_MESH_OUTPUT_FOLDER, TEMPLATE_OUTPUT_FOLDER, MEAN_OUTPUT_FOLDER]:
    os.makedirs(folder, exist_ok=True)

In [3]:
# Get healthy femur segmentation binary images at 
# https://data.kitware.com/#collection/5dcc6691e3566bda4b802172/folder/5e0b8d6baf2e2eed35c326f7

input_paths = glob.glob(IMAGE_FOLDER + '*-R-*')
assert(len(input_paths) == 14)

print(input_paths)

['Input/femurs\\901-R-femur-label.nrrd', 'Input/femurs\\902-R-femur-label.nrrd', 'Input/femurs\\906-R-femur-label.nrrd', 'Input/femurs\\907-R-femur-label.nrrd', 'Input/femurs\\908-R-femur-label.nrrd', 'Input/femurs\\915-R-femur-label.nrrd', 'Input/femurs\\916-R-femur-label.nrrd', 'Input/femurs\\917-R-femur-label.nrrd', 'Input/femurs\\918-R-femur-label.nrrd', 'Input/femurs\\F9-3wk-01-R-femur-label.nrrd', 'Input/femurs\\F9-3wk-02-R-femur-label.nrrd', 'Input/femurs\\F9-3wk-03-R-femur-label.nrrd', 'Input/femurs\\F9-8wk-01-R-femur-label.nrrd', 'Input/femurs\\F9-8wk-02-R-femur-label.nrrd']


In [4]:
MESH_FILENAMES = [os.path.basename(file).replace('.nrrd','.obj') for file in input_paths]

In [5]:
images = list()
for path in input_paths:
    images.append(itk.imread(path, itk.UC))

### Paste images into same space

Here we standardize physical space across the femur images. This is primarily intended to assist in viewing convenience with `itkwidgets` which expects a standard viewing region, but could also be helpful to standardize output from subsequent image downsampling and mesh conversion operations.

In [6]:
# Verify image spacing is equivalent
TOLERANCE = 0.0000001
assert(all([itk.spacing(images[i])[j] - itk.spacing(images[i+1])[j] < TOLERANCE
            for j in range(0,images[0].GetImageDimension())
            for i in range(0,len(images) - 1)]))

In [7]:
# Get largest range containing images
max_size = itk.size(images[0])
for image in images:
    for i in range(0,len(max_size)):
        max_size[i] = max(max_size[i], itk.size(image)[i])
print(max_size)

itkSize3 ([1279, 954, 1039])


In [8]:
# Make image of common size
def paste_common_size(orig_image, size, spacing):
    dimension = orig_image.GetImageDimension()
    region = itk.ImageRegion[dimension]()
    region.SetSize(size)
    region.SetIndex([0] * dimension)

    new_image = type(orig_image).New()
    new_image.SetRegions(region)
    new_image.SetSpacing(spacing)
    new_image.Allocate()
    
    paste_filter = itk.PasteImageFilter[type(orig_image)].New()
    paste_filter.SetSourceImage(orig_image)
    paste_filter.SetSourceRegion(orig_image.GetLargestPossibleRegion())
    paste_filter.SetDestinationImage(new_image)
    paste_filter.SetDestinationIndex([0,0,0])
    paste_filter.Update()
    
    return paste_filter.GetOutput()

In [9]:
# Make images common size
for i in range(0, len(images)):
    images[i] = paste_common_size(images[i], max_size, images[0].GetSpacing())

In [10]:
# View a 3D image with itkwidgets
view(images[1])

Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageUC3; pr…

### Downsample images

The marching cubes algorithm returns a mesh with vertex density related to the pixel density of the original image. Here we downsample each image twice, once to get a "dense" image retaining most information density and a second time to get a "sparse" image more easily applied to get correspondence points.

Meshes generated later from the dense images have approximately 600,000 vertices each while meshes generated from the sparse images have approximately 4,000 vertices. We will use the dense meshes to sample feature information and iteratively refine a the atlas to generalize the shape population. We can select a single sparse mesh to act as the initial atlas or carry out iterative refinement on multiple sparse meshes and compare to determine which result "best" reflects the population.

In [11]:
SPARSE_DOWNSAMPLE_RATIO = 14
DENSE_DOWNSAMPLE_RATIO = 2

In [12]:
def downsample_images(image_list, ratio) -> list:
    downsamples = list()
    for image in image_list:
        output_spacing = [spacing * ratio for spacing in itk.spacing(image)]
        output_size = [int(size / ratio) for size in itk.size(image)]

        downsample = itk.resample_image_filter(Input=image,
                                               Size=output_size,
                                               OutputOrigin=itk.origin(image),
                                               OutputSpacing=output_spacing,)
        downsamples.append(downsample)
    return downsamples

In [15]:
sparse_downsampled_images = downsample_images(images,SPARSE_DOWNSAMPLE_RATIO)
dense_downsampled_images = downsample_images(images,DENSE_DOWNSAMPLE_RATIO)

print(itk.size(dense_downsampled_images[0]))
print(itk.size(sparse_downsampled_images[0]))

itkSize3 ([639, 477, 519])
itkSize3 ([91, 68, 74])


In [16]:
view(dense_downsampled_images[0])

Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageUC3; pr…

### Generate Meshes
The `itk.BinaryMask3DMeshSource` class makes use of the Marching Cubes algorithm to generate a mesh from a given object. Each binary image here uses the value '1' to indicate the femur is present at a pixel and '0' to indicate the femur is not present. Marching Cubes rapidly fills the femur space and generates surfaces at pixel region boundaries.

Note that it may be useful to visually examine intermediate results. In the case where a mesh is not well aligned with others it is useful to correct the transformation in an external mesh editor.

In [17]:
# Here each pixel interior to the femur has value "1" and exterior has value "0".
# This may change for a different segmentation image.
FEMUR_OBJECT_PIXEL_VALUE = 1

Dimension = itk.template(sparse_downsampled_images[0])[1][1]
MeshType = itk.Mesh[itk.F,Dimension]

In [16]:
def generate_meshes(images:list, mesh_type=itk.Mesh[itk.F,3]) -> list:
    meshes = list()
    for image in images:
        mesh = itk.binary_mask3_d_mesh_source(image, 
                               object_value=1, 
                               ttype=[type(images[0]),mesh_type])
        meshes.append(mesh)
    return meshes    

In [17]:
dense_meshes = generate_meshes(dense_downsampled_images)

# Expect ~200K vertices
print('Average dense mesh points: ' +
      str(sum([mesh.GetNumberOfPoints() for mesh in dense_meshes]) / len(dense_meshes)))

Average dense mesh points: 215290.5


In [18]:
sparse_meshes = generate_meshes(sparse_downsampled_images)

# Expect ~5K vertices
print('Average sparse mesh points: ' +
      str(sum([mesh.GetNumberOfPoints() for mesh in sparse_meshes]) / len(sparse_meshes)))

Average sparse mesh points: 4373.357142857143


In [None]:
# Write out each mesh to disk
for i in range(0,len(dense_meshes)):
    itk.meshwrite(dense_meshes[i], f'{DENSE_MESH_OUTPUT_FOLDER}{MESH_FILENAMES[i]}')

for i in range(0,len(sparse_meshes)):
    itk.meshwrite(sparse_meshes[i], f'{TEMPLATE_OUTPUT_FOLDER}{MESH_FILENAMES[i]}')

In [None]:
# visualize with itkwidgets
#view(geometries=sparse_meshes)

### Define Iterative Registration
It is desired to find correspondence points between the template and each sample mesh. In order to get correspondence, two steps are employed:
- First, a copy of the template mesh is registered to the target sample;
- Second, each mesh point is repositioned at its nearest neighbor on the target sample.

The result of this process is a full collection of meshes having the same number of points and correspondence between each point such that it represents the same approximate feature on each femur.

In [20]:
def register_template_to_sample(template_mesh, 
                                sample_mesh,
                                learning_rate=1.0,
                                max_iterations=500):
    
    registrar = PointSetEntropyRegistrar()
    metric = itk.EuclideanDistancePointSetToPointSetMetricv4[itk.PointSet[itk.F,3]].New()
    transform = itk.Euler3DTransform[itk.D].New()
    
    # Make a deep copy of the template point set to register to the target
    template_copy = itk.Mesh[itk.F,3].New()
    for i in range(0, template_mesh.GetNumberOfPoints()):
        template_copy.SetPoint(i, template_mesh.GetPoint(i))
    template_copy.SetCells(template_mesh.GetCells())
    
    # Run registration and transform points to target
    (transform, deformed_mesh) = registrar.register(template_mesh=template_copy,
                                                    target_mesh=sample_mesh,
                                                    metric=metric,
                                                    transform=transform,
                                                    learning_rate=learning_rate,
                                                    max_iterations=max_iterations,
                                                    resample_from_target=True,
                                                    verbose=False)
    return deformed_mesh

### Define Procrustes Alignment Parameters
Once template meshes have been aligned to represent each input mesh with correspondence we can run Procrustes alignment and get out a mean mesh as the new template.

In [21]:
def align_procrustes(deformed_meshes, 
                     template_index, 
                     verbose=False,
                     convergence=0.08):
    
    procrustes_filter = itk.MeshProcrustesAlignFilter[type(deformed_meshes[0]), type(deformed_meshes[0])].New()
    
    procrustes_filter.SetUseInitialAverageOff()
    procrustes_filter.SetUseNormalizationOff()
    procrustes_filter.SetUseScalingOff()
    procrustes_filter.SetConvergence(convergence)  # Minimum threshold to exit alignment
    
    # Set mesh correspondence inputs
    procrustes_filter.SetNumberOfInputs(len(deformed_meshes))
    for i in range(0, len(deformed_meshes)):
        procrustes_filter.SetInput(i, deformed_meshes[i])
    
    # Run filter
    procrustes_filter.Update()
    
    if(verbose):
        print(f'Alignment converged at {procrustes_filter.GetMeanPointsDifference()}')
    
    mean_result = procrustes_filter.GetMean()
    mean_result.SetCells(deformed_meshes[template_index].GetCells())
    return mean_result

### Compute Hausdorff Distance

We can calculate the Hausdorff distance from the current to previous mesh atlas at each iterative refinement to quantify the amount of change between iterations. In this case the meshes are in correspondence so we get the largest Euclidean distance between any pair of correspondence points.

In [22]:
def calculate_hausdorff_distance(mesh1, mesh2):
    assert(mesh1.GetNumberOfPoints() == mesh2.GetNumberOfPoints())
    max_dist = 0.0
    
    for pt_idx in range(mesh1.GetNumberOfPoints()):
        pt1 = mesh1.GetPoint(pt_idx)
        pt2 = mesh2.GetPoint(pt_idx)
        dist = sum((pt1[dim] - pt2[dim]) ** 2 for dim in range(0,3)) ** 0.5
        max_dist = max(max_dist, dist)
    return max_dist

### Run Iterative Refinement

Run registration and alignment on every mesh, choosing whether to iteratively update meshes to take advantage of mean alignment in each iteration or to ignore previous changes so that each alignment procedure is independent of the order of meshes in the list.

In [23]:
# Fix the number of iterative refinements to run for each atlas.
# One iteration includes registration to each dense mesh, followed by
# subsequent Procrustes alignment of all correspondence meshes.
NUM_ITERATIONS = 1

# Select indices of atlas templates to refine.
TEMPLATES_TO_ALIGN = [0]

# Prepare directory to write out atlas iterations.
# ex. 'Output/mean/0/901-L-femur-label.obj'
for i in range(0,NUM_ITERATIONS):
    os.makedirs(MEAN_OUTPUT_FOLDER + str(i) + '\\', exist_ok=True)

In [24]:
aligned_templates = dict()
for iteration in range(0,NUM_ITERATIONS):
    print(f'Now at iteration {iteration}')
    
    for template_mesh_index in TEMPLATES_TO_ALIGN:
        starttime = time.time()
        print(f'Now at template mesh {template_mesh_index}')
        template_mesh = sparse_meshes[template_mesh_index]

        deformed_templates = list()
        for sample_index in range(0,len(dense_meshes)):
            print(f'Resampling from mesh {sample_index}')
            deformed_template = register_template_to_sample(template_mesh, 
                                                            dense_meshes[sample_index],
                                                            max_iterations=50)
            deformed_templates.append(deformed_template)

        print('Running alignment')
        mesh_result = align_procrustes(deformed_templates, template_mesh_index, verbose=True)

        # Save intermediate results to disk
        # These meshes are very small so this is not a significant expense (~400 KB/mesh)
        output_path = f'{MEAN_OUTPUT_FOLDER}/{iteration}/{MESH_FILENAMES[template_mesh_index]}'
        print(f'Writing mean to {output_path}')
        #itk.meshwrite(mesh_result, output_path)
        
        # Optionally update current mesh in place for use in subsequent alignments
        aligned_templates[template_mesh_index] = mesh_result
        
        endtime = time.time()
        print(f'Elapsed: {endtime - starttime}')
    
    # Optionally update templates only between iterations
    for template_mesh_index in TEMPLATES_TO_ALIGN:
        dist = calculate_hausdorff_distance(sparse_meshes[template_mesh_index],
                                            aligned_templates[template_mesh_index])
        print(f'Mesh {template_mesh_index} Hausdorff distance from previous iteration: {dist}')
        sparse_meshes[template_mesh_index] = aligned_templates[template_mesh_index]

Now at iteration 0
Now at template mesh 0
Resampling from mesh 0
Resampling from mesh 1
Resampling from mesh 2
Resampling from mesh 3
Resampling from mesh 4
Resampling from mesh 5
Resampling from mesh 6
Resampling from mesh 7
Resampling from mesh 8
Resampling from mesh 9
Resampling from mesh 10
Resampling from mesh 11
Resampling from mesh 12
Resampling from mesh 13
Running alignment
Alignment converged at 0.0034345665327492492
Writing mean to Output/mean//0/901-R-femur-label.obj
Elapsed: 137.288982629776
Mesh 0 Hausdorff distance from previous iteration: 2.228128978646963


In [None]:
# visualize template
# view(geometries=sparse_meshes)