# Getting Started with Grooming Segmentations 

## Before you start!

- This [notebook](getting-started-with-grooming-segmentations.ipynb) assumes that shapeworks conda environment has been activated using `conda activate shapeworks` on the terminal.
- See [Setting Up ShapeWorks Environment](setting-up-shapeworks-environment.ipynb) to learn how to set up your environment to start using shapeworks library. Please note, the prerequisite steps will use the same code to setup the environment for this notebook and import `shapeworks` library.
- See [Getting Started with Segmentations](getting-started-with-segmentations.ipynb) to learn how to load and visualize binary segmentations.
- See [Getting Started with Exploring Segmentations](getting-started-with-exploring-segmentations.ipynb) to learn how to decide the grooming pipeline needed for your dataset.


## In this notebook, you will learn:

1. How to resample segmentations to have smaller and isotropic voxel spacing   
2. How to rigidly align segmentations to factor out global transformations   
3. How to crop segmentations to remove unnecessary background voxels that might increase the memory footprint when optimizing the shape model and pad segmentations to create more room along each dimension for correspondences optimization   
4. How to convert segmentations to smooth signed distance transforms as numerically stable inputs for correspondences optimization  
5. How to run `ShapeWorksStudio` and use the groomed segmentations to optimize its shape model


We will also define modular/generic helper functions as we walk through these items to reuse functionalities without duplicating code.

## Prerequisites

- Setting up `shapeworks` environment. See [Setting Up ShapeWorks Environment](setting-up-shapeworks-environment.ipynb). To avoid code clutter, the `setup_shapeworks_env` function can found in `Examples/Python/setupenv.py` module.
- Importing `shapeworks` library. See [Setting Up ShapeWorks Environment](setting-up-shapeworks-environment.ipynb).
- Helper functions for segmentations. See [Getting Started with Segmentations](getting-started-with-segmentations.ipynb) and [Getting Started with Exploring Segmentations](getting-started-with-exploring-segmentations.ipynb).
- Helper functions for meshes. See [Getting Started with Meshes](getting-started-with-meshes.ipynb).
- Helper functions for visualization. See [Getting Started with Segmentations](getting-started-with-segmentations.ipynb), [Getting Started with Meshes](getting-started-with-meshes.ipynb), and [Getting Started with Exploring Segmentations](getting-started-with-exploring-segmentations.ipynb).
- Defining your dataset location. See [Getting Started with Exploring Segmentations](getting-started-with-exploring-segmentations.ipynb).
- Loading your dataset. See [Getting Started with Exploring Segmentations](getting-started-with-exploring-segmentations.ipynb).
- Defining parameters for `pyvista` plotter. See [Getting Started with Exploring Segmentations](getting-started-with-exploring-segmentations.ipynb).
- Tentative grooming pipeline. See [Getting Started with Exploring Segmentations](getting-started-with-exploring-segmentations.ipynb). 

## Note about `shapeworks` APIs

shapeworks functions are inplace, i.e., `<swObject>.<function>()` applies that function to the `swObject` data. To keep the original data unchanged, you have first to copy it to another variable before applying the function.

## Notebook keyboard shortcuts

- `Esc + H`: displays a complete list of keyboard shortcuts
- `Esc + A`: insert new cell above the current cell
- `Esc + B`: insert new cell below the current cell
- `Esc + D + D`: delete current cell
- `Esc + Z`: undo
- `Shift + enter`: run current cell and move to next
- To show a function's argument list (i.e., signature), use `(` then `shift-tab`
- Use `shift-tab-tab` to show more help for a function
- To show the help of a function, use `help(function)` or `function?`
- To show all functions supported by an object, use `dot-tab` after the variable name



## Prerequisites

### Setting up `shapeworks` environment 

Here, we will append both your `PYTHONPATH` and your system `PATH` to setup shapeworks environment for this notebook. See [Setting Up ShapeWorks Environment](setting-up-shapeworks-environment.ipynb) for more details.

In this notebook, we assume the following.

- This notebook is located in `Examples/Python/notebooks/tutorials`
- You have built shapeworks from source in `build` directory within the shapeworks code directory
- You have built shapeworks dependencies (using `build_dependencies.sh`) in the same parent directory of shapeworks code

**Note:** If you run from a ShapeWorks installation, you don't need to set the dependencies path and the `shapeworks_bin_dir` would be set as `../../../../bin`.

In [None]:
# import relevant libraries 
import sys 

# add parent-parent directory (where setupenv.py is) to python path
sys.path.insert(0,'../..')

# importing setupenv from Examples/Python
import setupenv

# indicate the bin directories for shapeworks and its dependencies
shapeworks_bin_dir   = "../../../../build/bin"
dependencies_bin_dir = "../../../../../shapeworks-dependencies/bin"

# set up shapeworks environment
setupenv.setup_shapeworks_env(shapeworks_bin_dir,  
                              dependencies_bin_dir, 
                              verbose = False)

### Importing `shapeworks` library

In [None]:
# let's import shapeworks library to test whether shapeworks is now set
try:
    import shapeworks as sw
except ImportError:
    print('ERROR: shapeworks library failed to import')
else:
    print('SUCCESS: shapeworks library is successfully imported!!!')

### Helper functions for IO

In [None]:
# a helper function that saves a list of shapeworks images in a directory
# this could be used to save final and intermediate results (if needed)
def save_images(outDir,        # path to the directory where we want to save the images
                swImageList,   # list of shapeworks images to be saved
                swImageNames,  # list of image names to be used as filenames
                extension        = 'nrrd',
                compressed       = False, # use false to load in paraview 
                verbose          = True):

    if (len(swImageList) != len(swImageNames)):
        print('swImageNames list is not consistent with number of images in swImageList')
        return

    # create the output directory in case it does not exist
    if not os.path.exists(outDir):
        os.makedirs(outDir)

    for curImg, curName in zip(swImageList, swImageNames):
        filename = outDir + curName + '.' + extension
        if verbose:
            print('Writing: ' + filename)
        curImg.write(filename, compressed=compressed) 

### Helper functions for segmentations

In [None]:
# importing relevant libraries
import pyvista as pv
import numpy as np

# a helper function that converts shapeworks Image object to vtk image
def sw2vtkImage(swImg, verbose = False):
            
    # get the numpy array of the shapeworks image
    array  = swImg.toArray()
    
    # the numpy array needs to be permuted to match the shapeworks image dimensions
    array = np.transpose(array,(2,1,0))
    
    # converting a numpy array to a vtk image using pyvista's wrap function
    vtkImg = pv.wrap(array)
    
    if verbose:
        print('shapeworks image header information: ')
        print(swImg)

        print('\nvtk image header information: ')
        print(vtkImg) 
    
    return vtkImg

### Helper functions for meshes

In [None]:
# importing relevant libraries
import os

# a helper function that converts shapeworks Mesh object to vtk mesh 
# TODO: to be modifed when #825 is addressed
def sw2vtkMesh(swMesh, verbose = False):
    
    if verbose:
        print('Header information: ')
        print(swMesh)

    # save mesh
    swMesh.write('temp.vtk')

    # read mesh into an itk mesh data
    vtkMesh = pv.read('temp.vtk')
    
    # remove the temp mesh file
    os.remove('temp.vtk')
    
    return vtkMesh

### Helper functions for visualization

In [None]:
# importing itkwidgets to visualize single segmentations
import itkwidgets as itkw

# itkwidgets.view returns a Viewer object. And, the IPython Jupyter kernel 
# displays the last return value of a cell by default. So we have to use the display function
# to be able to call itkwidgets within a function and if statements
from IPython.display import display

# enable use_ipyvtk by default for interactive plots
pv.rcParams['use_ipyvtk'] = True 
    
# a helper function that addes a vtk image to a pyvista plotter
def add_volume_to_plotter( pvPlotter,      # pyvista plotter
                           vtkImg,         # vtk image to be added
                           rowIdx, colIdx, # subplot row and column index
                           title = None,   # text to be added to the subplot, use None to not show text 
                           shade_volumes  = True,  # use shading when performing volume rendering
                           color_map      = "coolwarm", # color map for volume rendering, e.g., 'bone', 'coolwarm', 'cool', 'viridis', 'magma'
                           show_axes      = True,  # show a vtk axes widget for each rendering window
                           show_bounds    = False, # show volume bounding box
                           show_all_edges = True,  # add an unlabeled and unticked box at the boundaries of plot. 
                           font_size      = 10     # text font size for windows
                         ):
    
    # which subplot to add the volume to
    pvPlotter.subplot(rowIdx, colIdx)
    
    # add the volume
    pvPlotter.add_volume(vtkImg, 
                         shade   = shade_volumes, 
                         cmap    = color_map)

    if show_axes:
        pvPlotter.show_axes()

    if show_bounds:
        pvPlotter.show_bounds(all_edges = show_all_edges)

    # add a text to this subplot to indicate which volume is being visualized
    if title is not None:
        pvPlotter.add_text(title, font_size = font_size)
        
# a helper function that adds a mesh to a `pyvista` plotter.
def add_mesh_to_plotter( pvPlotter,      # pyvista plotter
                         vtkMesh,         # vtk mesh to be added
                         rowIdx, colIdx, # subplot row and column index
                         title = None,    # text to be added to the subplot, use None to not show text 
                         mesh_color      = "tan",  # string or 3 item list
                         mesh_style      = "surface", # visualization style of the mesh. style='surface', style='wireframe', style='points'. 
                         show_mesh_edges = False, # show mesh edges
                         opacity         = 1,
                         show_axes       = True,  # show a vtk axes widget for each rendering window
                         show_bounds     = False, # show volume bounding box
                         show_all_edges  = True,  # add an unlabeled and unticked box at the boundaries of plot. 
                         font_size       = 10     # text font size for windows
                         ):
    
    # which subplot to add the mesh to
    pvPlotter.subplot(rowIdx, colIdx)

    # add the surface mesh
    pvPlotter.add_mesh(vtkMesh, 
                       color      = mesh_color, 
                       style      = mesh_style,
                       show_edges = show_mesh_edges,
                       opacity    = opacity)

    if show_axes:
        pvPlotter.show_axes()

    if show_bounds:
        pvPlotter.show_bounds(all_edges = show_all_edges)

    # add a text to this subplot to indicate which volume is being visualized
    if title is not None:
        pvPlotter.add_text(title, font_size = font_size)
        
# helper functions to define the best grid size for subplots
def postive_factors(num_samples):
    factors = []
    
    for whole_number in range(1, num_samples + 1):
        if num_samples % whole_number == 0:
            factors.append(whole_number)
    
    return factors

def num_subplots(num_samples):
    factors = postive_factors(num_samples)
    cols    = min(int(np.ceil(np.sqrt(num_samples))),max(factors))
    rows    = int(np.ceil(num_samples/cols))
    
    return rows, cols

# helper function to add and plot a list of volumes
def plot_volumes(volumeList,           # list of shapeworks images to be visualized
                 volumeNames     = None,  # list of strings of same size as shape list used to add text for each plot window, use None to not show text per window 
                 use_same_window = False, # plot using multiple rendering windows if false
                 is_interactive  = True,  # to enable interactive plots
                 show_borders    = True,  # show borders for each rendering window
                 shade_volumes   = True,  # use shading when performing volume rendering
                 color_map       = "coolwarm", # color map for volume rendering, e.g., 'bone', 'coolwarm', 'cool', 'viridis', 'magma'
                 show_axes       = True,  # show a vtk axes widget for each rendering window
                 show_bounds     = True,  # show volume bounding box
                 show_all_edges  = True,  # add an unlabeled and unticked box at the boundaries of plot. 
                 font_size       = 10,    # text font size for windows
                 link_views      = True   # link all rendering windows so that they share same camera and axes boundaries
                ):
    
    num_samples = len(volumeList)
    
    if volumeNames is not None:
        if use_same_window and (len(volumeNames) > 1):
            print('A single title needed when all volumes are to be displayed on the same window')
            return
        elif (not use_same_window) and (len(volumeNames) != num_samples):
            print('volumeNames list is not consistent with number of samples')
            return
        
    if use_same_window:
        grid_rows, grid_cols = 1, 1
    else:
        # define grid size for the given number of samples
        grid_rows, grid_cols  = num_subplots(num_samples)

    # define the plotter
    plotter = pv.Plotter(shape    = (grid_rows, grid_cols),
                         notebook = is_interactive, 
                         border   = show_borders) 
    
    # add the given volume list (one at a time) to the plotter
    for volumeIdx in range(num_samples):
        
        # which window to add the current volume
        if use_same_window:
            rowIdx, colIdx = 0, 0
            titleIdx       = 0
        else:
            idUnraveled     = np.unravel_index(volumeIdx, (grid_rows, grid_cols))
            rowIdx, colIdx  = idUnraveled[0], idUnraveled[1]
            titleIdx        = volumeIdx
        
        # which title to use
        if volumeNames is not None:
            volumeName = volumeNames[titleIdx]
        else:
            volumeName = None

        # convert sw image to vtk image
        if type(volumeList[volumeIdx]) == sw.Image:
            volume_vtk = sw2vtkImage(volumeList[volumeIdx], 
                                       verbose = False)
        else:
            volume_vtk = volumeList[volumeIdx]

        # add the current volume
        add_volume_to_plotter( plotter, volume_vtk,   
                               rowIdx = rowIdx, colIdx = colIdx, 
                               title          = volumeName,
                               shade_volumes  = shade_volumes, 
                               color_map      = color_map,
                               show_axes      = show_axes, 
                               show_bounds    = show_bounds, 
                               show_all_edges = show_all_edges, 
                               font_size      = font_size)
    # link views
    if link_views and (not use_same_window):
        plotter.link_views()  

    # now, time to render our volumes
    plotter.show(use_ipyvtk = is_interactive)
        
# helper function to add and plot a list of meshes
def plot_meshes(meshList,           # list of shapeworks meshes to be visualized
                meshNames       = None,  # list of strings of same size as shape list used to add text for each plot window, use None to not show text per window 
                use_same_window = False, # plot using multiple rendering windows if false
                is_interactive  = True,  # to enable interactive plots
                show_borders    = True,  # show borders for each rendering window
                meshes_color    = 'tan', # color to be used for meshes (can be a list with the same size as meshList if different colors are needed)
                mesh_style      = "surface", # visualization style of the mesh. style='surface', style='wireframe', style='points'. 
                show_mesh_edges = False, # show mesh edges
                opacities       = 1,     # opacity to be used for meshes (can be a list with the same size as meshList if different opacities are needed) 
                show_axes       = True,  # show a vtk axes widget for each rendering window
                show_bounds     = True,  # show volume bounding box
                show_all_edges  = True,  # add an unlabeled and unticked box at the boundaries of plot. 
                font_size       = 10,    # text font size for windows
                link_views      = True   # link all rendering windows so that they share same camera and axes boundaries
               ):
    
    num_samples = len(meshList)
    
    if meshNames is not None:
        if use_same_window and (len(meshNames) > 1):
            print('A single title needed when all meshes are to be displayed on the same window')
            return
        elif (not use_same_window) and  (len(meshNames) != num_samples):
            print('meshNames list is not consistent with number of samples')
            return
            
    if type(meshes_color) is not list: # single color given
        meshes_color = [meshes_color] * num_samples
        
    if type(opacities) is not list: # single opacity given
        opacities = [opacities] * num_samples
        
    if use_same_window:
        grid_rows, grid_cols = 1, 1
    else:
        # define grid size for the given number of samples
        grid_rows, grid_cols  = num_subplots(num_samples)

    # define the plotter
    plotter = pv.Plotter(shape    = (grid_rows, grid_cols),
                         notebook = is_interactive, 
                         border   = show_borders) 
    
    # add the given volume list (one at a time) to the plotter
    for meshIdx in range(num_samples):
        
        # which window to add the current mesh
        if use_same_window:
            rowIdx, colIdx = 0, 0
            titleIdx       = 0
        else:
            idUnraveled     = np.unravel_index(meshIdx, (grid_rows, grid_cols))
            rowIdx, colIdx  = idUnraveled[0], idUnraveled[1]
            titleIdx        = meshIdx
        
        # which title to use
        if meshNames is not None:
            meshName = meshNames[titleIdx]
        else:
            meshName = None

        # convert sw mesh to vtk mesh
        if type(meshList[meshIdx]) == sw.Mesh:
            mesh_vtk = sw2vtkMesh(meshList[meshIdx], 
                                  verbose = False)
        else:
            mesh_vtk = meshList[meshIdx]

        # add the current mesh
        add_mesh_to_plotter( plotter, mesh_vtk,   
                             rowIdx = rowIdx, colIdx = colIdx, 
                             title           = meshName,
                             mesh_color      = meshes_color[meshIdx],
                             mesh_style      = mesh_style,
                             show_mesh_edges = show_mesh_edges,
                             opacity         = opacities[meshIdx],
                             show_axes       = show_axes, 
                             show_bounds     = show_bounds, 
                             show_all_edges  = show_all_edges, 
                             font_size       = font_size)
        
    # link views
    if link_views and (not use_same_window):
        plotter.link_views()  

    # now, time to render our meshes
    plotter.show(use_ipyvtk = is_interactive)
    
    
# helper function to add and plot a list of meshes/volumes mix
def plot_meshes_volumes_mix(objectList,    # list of shapeworks meshes to be visualized
                            objectsType, # list of 'vol', 'mesh' of same size as objectList
                objectNames     = None,  # list of strings of same size as shape list used to add text for each plot window, use None to not show text per window 
                use_same_window = False, # plot using multiple rendering windows if false
                is_interactive  = True,  # to enable interactive plots
                show_borders    = True,  # show borders for each rendering window
                meshes_color    = 'tan', # color to be used for meshes (can be a list with the same size as meshList if different colors are needed)
                mesh_style      = "surface", # visualization style of the mesh. style='surface', style='wireframe', style='points'. 
                shade_volumes   = True,  # use shading when performing volume rendering
                color_map       = "coolwarm", # color map for volume rendering, e.g., 'bone', 'coolwarm', 'cool', 'viridis', 'magma' 
                show_mesh_edges = False, # show mesh edges
                opacities       = 1,     # opacity to be used for meshes (can be a list with the same size as meshList if different opacities are needed) 
                show_axes       = True,  # show a vtk axes widget for each rendering window
                show_bounds     = True,  # show volume bounding box
                show_all_edges  = True,  # add an unlabeled and unticked box at the boundaries of plot. 
                font_size       = 10,    # text font size for windows
                link_views      = True   # link all rendering windows so that they share same camera and axes boundaries
               ):
    
    num_samples = len(objectList)
    
    if objectNames is not None:
        if use_same_window and (len(objectNames) > 1):
            print('A single title needed when all objects are to be displayed on the same window')
            return
        elif (not use_same_window) and  (len(objectNames) != num_samples):
            print('objectNames list is not consistent with number of samples')
            return
            
    if type(meshes_color) is not list: # single color given
        meshes_color = [meshes_color] * num_samples
        
    if type(opacities) is not list: # single opacity given
        opacities = [opacities] * num_samples
        
    if use_same_window:
        grid_rows, grid_cols = 1, 1
    else:
        # define grid size for the given number of samples
        grid_rows, grid_cols  = num_subplots(num_samples)

    # define the plotter
    plotter = pv.Plotter(shape    = (grid_rows, grid_cols),
                         notebook = is_interactive, 
                         border   = show_borders) 
    
    # add the given volume list (one at a time) to the plotter
    for objectIdx in range(num_samples):
        
        # which window to add the current mesh
        if use_same_window:
            rowIdx, colIdx = 0, 0
            titleIdx       = 0
        else:
            idUnraveled     = np.unravel_index(objectIdx, (grid_rows, grid_cols))
            rowIdx, colIdx  = idUnraveled[0], idUnraveled[1]
            titleIdx        = objectIdx
        
        # which title to use
        if objectNames is not None:
            objectName = objectNames[titleIdx]
        else:
            objectName = None

        if objectsType[objectIdx] == 'vol':
            
            # convert sw image to vtk image
            if type(objectList[objectIdx]) == sw.Image:
                object_vtk = sw2vtkImage(objectList[objectIdx], 
                                         verbose = False)
            else:
                object_vtk = objectList[objectIdx]
            
            # add the current volume
            add_volume_to_plotter( plotter, object_vtk,   
                                   rowIdx = rowIdx, colIdx = colIdx, 
                                   title          = objectName,
                                   shade_volumes  = shade_volumes, 
                                   color_map      = color_map,
                                   show_axes      = show_axes, 
                                   show_bounds    = show_bounds, 
                                   show_all_edges = show_all_edges, 
                                   font_size      = font_size)

        else: # 'mesh'
            # convert sw mesh to vtk image
            if type(objectList[objectIdx]) == sw.Mesh:
                object_vtk = sw2vtkMesh(objectList[objectIdx], 
                                        verbose = False)
            else:
                object_vtk = objectList[objectIdx]

            # add the current mesh
            add_mesh_to_plotter( plotter, object_vtk,   
                                 rowIdx = rowIdx, colIdx = colIdx, 
                                 title           = objectName,
                                 mesh_color      = meshes_color[objectIdx],
                                 mesh_style      = mesh_style,
                                 show_mesh_edges = show_mesh_edges,
                                 opacity         = opacities[objectIdx],
                                 show_axes       = show_axes, 
                                 show_bounds     = show_bounds, 
                                 show_all_edges  = show_all_edges, 
                                 font_size       = font_size)
        
    # link views
    if link_views and (not use_same_window):
        plotter.link_views()  

    # now, time to render our mesh/volume mix
    plotter.show(use_ipyvtk = is_interactive)

### Defining dataset location

In [None]:
import glob # for paths and file-directory search

# dataset name is the folder name for your dataset
datasetName  = 'ellipsoid-v2'

# file extension for the shape data
shapeExtention = '.nrrd'

# path to the dataset where we can find shape data 
# here we assume shape data are given as binary segmentations
shapeDir      = '../../Data/' + datasetName + '/segmentations/'

# name for this notebook to use for output directory
notebookName = 'getting-started-groom-segmentations'

# path to the directory where we want to save results from this notebook
outDir       = '../../Output/Notebooks/' + notebookName + '/' + datasetName + '/'

# create the output directory in case it does not exist
if not os.path.exists(outDir):
    os.makedirs(outDir)

# let's get a list of files for available segmentations in this dataset
# * here is a wild character used to retrieve all filenames 
# in the shape directory with the file extensnion
shapeFilenames = sorted(glob.glob(shapeDir + '*' + shapeExtention)) 

print('Dataset Name:     ' + datasetName)
print('Shape Directory:  ' + shapeDir)
print('Output Directory: ' + outDir)
print('Number of shapes: ' + str(len(shapeFilenames)))
print('Shape files found:')
for shapeFilename in shapeFilenames:
    print('\t' + shapeFilename)

### Loading your dataset

In [None]:
# list of shape segmentations
shapeSegList = []

# list of shape names (shape files prefixes) to be used 
# for saving outputs and visualizations
shapeNames   = [] 

# loop over all shape files and load individual segmentations
for shapeFilename in shapeFilenames:
    print('Loading: ' + shapeFilename)
    
    # current shape name
    segFilename = shapeFilename.split('/')[-1] 
    shapeName   = segFilename[:-len(shapeExtention)]
    shapeNames.append(shapeName)
    
    # load segmentation
    shapeSeg = sw.Image(shapeFilename)
    
    # append to the shape list
    shapeSegList.append(shapeSeg)

num_samples = len(shapeSegList)
print('\n' + str(num_samples) + 
      ' segmentations are loaded for the ' + datasetName + ' dataset ...')

### Defining parameters for `pyvista` plotter

In [None]:
# define parameters that controls the plotter

# common for volumes and meshes visualization
is_interactive = True  # to enable interactive plots
show_borders   = True  # show borders for each rendering window
show_axes      = True  # show a vtk axes widget for each rendering window
show_bounds    = True  # show volume bounding box
show_all_edges = True  # add an unlabeled and unticked box at the boundaries of plot. 
font_size      = 10    # text font size for windows
link_views     = True  # link all rendering windows so that they share same camera and axes boundaries

# for volumes
shade_volumes  = True  # use shading when performing volume rendering
color_map       = 'coolwarm' # color map for volume rendering, e.g., 'bone', 'coolwarm', 'cool', 'viridis', 'magma'

# for meshes
meshes_color    = 'tan' # color to be used for meshes (can be a list with the same size as meshList if different colors are needed)
mesh_style      = 'surface' # visualization style of the mesh. style='surface', style='wireframe', style='points'. 
show_mesh_edges = False  # show mesh edges 
                

### Tentative grooming

Hence, a tentative grooming pipeline entails the following steps:   

1. Resampling segmentations to have smaller and isotropic voxel spacing   
2. Rigidly aligning shapes   
3. Cropping and padding segmentations   
4. Converting segmentations to smooth signed distance transforms   


Let the fun begins!!!

## 1. Resampling segmentations


This grooming step resamples all the binary volumes which in a raw setting could be in different physical spaces (different dimensions and voxel spacing). This brings all segmentations to the same voxel spacing. 

If a smaller voxel spacing is used, this improves the resolution of the segmentations and reduce the staircase effect seen in the volume rendering.

Since image resampling entails interpolation, directly resampling binary segmentations will not result in a binary segmentation, but rather an interpolated version that does not have two distinct labels (i.e., foreground and background).

To mitigate this behavior, we need first to convert the binary segmentations (with zero-one voxels) to a continuous-valued (gray-scale) image. This can be done by antialiasing the segmentations, which smoothes the foreground-background interface. 

Hence, the resampling pipeline for a binary segmentation includes the following steps:

- Antialiasing the binary segmentation to convert it to a smooth continuous-valued image
- Resampling the antialiased image using the same (and possible smaller) voxel spacing for all dimensions
- Binarizing (aka thresholding) the resampled image to results in a binary segmentation with an isotropic voxel spacing



### Antialiasing: which `shapeworks` API to use?

In [None]:
# let's see if there's a function that antialias a shapeworks image
# use dot-tap to get a list of functions/apis available for shapeSeg

shapeIdx = 6
shapeSeg = shapeSegList[shapeIdx]

# found it - resample, let's see its help
help(shapeSeg.antialias)
#shapeSeg.antialias?


### Exploring antialiasing for a single segmentation

In [None]:
# parameters for antialiasing
antialiasIterations  = 50 # number of iterations to perform antialiasing

# note that shapeworks APIs use inplace computations, so if we do shapeSeg.antialias, 
# shapeSeg will have the antialiased segmentation

# to visualize the effect of padding, let's keep the original segmentation unchanged
shapeSegAntialiased = sw.Image(shapeSeg)

# let's antialias this segmentation
shapeSegAntialiased.antialias(antialiasIterations)

# note antialiasing does not change header information
print('\nHeader information before antialiasing: ')
print(shapeSeg)

print('\nHeader information after antialiasing: ')
print(shapeSegAntialiased)
   
plot_volumes([shapeSeg, shapeSegAntialiased], 
             ['before antialiasing','after antialiasing'],
             use_same_window = False,
             is_interactive  = is_interactive, 
             show_borders    = show_borders,  
             shade_volumes   = shade_volumes,  
             color_map       = color_map,
             show_axes       = show_axes,  
             show_bounds     = show_bounds,
             show_all_edges  = show_all_edges, 
             font_size       = font_size,   
             link_views      = link_views)

<p><video src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-antialias.mp4" autoplay muted loop controls style="width:100%"></p>

Note that the antialiasing operation has converted a 0-1 binary volume to a continous-valued volume with voxel values ranging from -4.0 to 4.0. The zero level set (i.e., voxels with zero value) is the foreground-background interface.

Let's extract the isosurface of this continous-valued volume and compare it with the segmentation isosurface to make sure that the shape information is not lost during antialiasing.

In [None]:
# extract the segmentation isosurface
shapeSegIso     = shapeSeg.toMesh(isovalue = 0.5)

# extract the antialiased isosurface
shapeSegAntialiasedIso = shapeSegAntialiased.toMesh(isovalue = 0.0)


plot_meshes([shapeSegIso, shapeSegAntialiasedIso], 
             ['before (red) and after (tan) antialiasing'],
             use_same_window = True,
             is_interactive  = is_interactive, 
             show_borders    = show_borders,  
             meshes_color    = ['red', 'tan'], 
             opacities       = [0.5, 1],
             mesh_style      = mesh_style,
             show_mesh_edges = show_mesh_edges,
             show_axes       = show_axes,  
             show_bounds     = show_bounds,
             show_all_edges  = show_all_edges, 
             font_size       = font_size,   
             link_views      = link_views)

<p><video src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-antialias-iso.mp4" autoplay muted loop controls style="width:100%"></p>

### Resampling: which `shapeworks` API to use?

In [None]:
# let's see if there's a function that resamples a shapeworks image
# use dot-tap to get a list of functions/apis available for shapeSeg

# found it - resample, let's see its help
help(shapeSegAntialiased.resample) # issue #827

### Exploring resampling for a single segmentation

There are three options to perform resampling:
- Isotropic resampling with all dimensions having the same voxel size (set isSpacing parameter)
- Anisotropic resampling by setting spacex, spacey, and spacez parameters to indicate a different voxel size for each dimension
- Resampling given the output image dimensions by setting sizex, sizey, and sizez parameters. Voxel spacing along each dimension is thus computed based on the output image size.

In this case, we will perform isotropic resampling. 

In [None]:
# define the voxel spacing for resampling

# let's checkout the spacing of the original segmentation
spacing = shapeSegAntialiased.spacing()

print(spacing) # voxel spacing in the z-dimension is double that of x- and y-dimension

In [None]:
# let's define the isospacing to be the minimum  of the spacing of all dimensions
# this will improve the segmentation resolution along 
# the z-dimension (i.e., reduce slice thickness)
# #830 - add toArray to swVectors
spacing_array = [spacing[d] for d in range(3)]
isoSpacing    = min(spacing_array) # The isotropic spacing in all dimensions.

print('isoSpacing = ' + str(isoSpacing))

In [None]:
# to visualize the effect of resampling, let's keep the antialiased segmentation unchanged
shapeSegResampled = sw.Image(shapeSegAntialiased)

# define the interpolation type
interp = sw.InterpolationType.Linear 

# perform image resampling
#shapeSegResampled.resample(voxelSpacing, interp)
shapeSegResampled.resample([isoSpacing,isoSpacing,isoSpacing], interp)

# or using image utils
#sw.ImageUtils.isoresample(shapeSegResampled, isoSpacing)

print('\nHeader information before resampling: ')
print(shapeSegAntialiased)

print('\nHeader information after resampling: ')
print(shapeSegResampled)
  
# note pyvista's plots don't reflect physical coordinates, you are seeing a larger
# 3D image (matrix)
plot_volumes([shapeSegAntialiased, shapeSegResampled], 
             ['before resampling','after resampling'],
             is_interactive = is_interactive, 
             show_borders   = show_borders,  
             shade_volumes  = shade_volumes,
             show_axes      = show_axes,  
             show_bounds    = show_bounds,
             show_all_edges = show_all_edges, 
             font_size      = font_size,   
             link_views     = link_views)

<p><video src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-resample.mp4" autoplay muted loop controls style="width:100%"></p>

In [None]:
# visualizing isosurfaces rather respects the physical coordinates

# extract the resampled isosurface 
shapeSegResampledIso = shapeSegResampled.toMesh(isovalue = 1e-20) # Issue #833

plot_meshes([shapeSegIso, shapeSegResampledIso], 
             ['segmentation (red) and resampled (tan)'],
             use_same_window = True,
             is_interactive  = is_interactive, 
             show_borders    = show_borders,  
             meshes_color    = ['red', 'tan'], 
             opacities       = [0.5, 1],
             mesh_style      = mesh_style,
             show_mesh_edges = show_mesh_edges,
             show_axes       = show_axes,  
             show_bounds     = show_bounds,
             show_all_edges  = show_all_edges, 
             font_size       = font_size,   
             link_views      = link_views)


<p><video src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-resample-iso.mp4" autoplay muted loop controls style="width:100%"></p>

### Binarizing: which shapeworks API to use?

In [None]:
# let's see if there's a function that resamples a shapeworks image
# use dot-tap to get a list of functions/apis available for shapeSeg

# found it - resample, let's see its help
help(shapeSegResampled.binarize)

### Exploring binarization for a single segmentation

In [None]:
# parameters for binarization
# all voxels between minVal and maxVal are set to innerVal
minVal   = 0
maxVal   = 5
innerVal = 1
outerVal = 0

# to visualize the effect of padding, let's keep the original segmentation unchanged
shapeSegResampledBin = sw.Image(shapeSegResampled)

shapeSegResampledBin.binarize(minVal, maxVal, innerVal, outerVal)
   
plot_volumes([shapeSegResampled, shapeSegResampledBin], 
             ['before binarization','after binarization'],
             is_interactive = is_interactive, 
             show_borders   = show_borders,  
             shade_volumes  = shade_volumes,  
             show_axes      = show_axes,  
             show_bounds    = show_bounds,
             show_all_edges = show_all_edges, 
             font_size      = font_size,   
             link_views     = link_views)

<p><video src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-bin.mp4" autoplay muted loop controls style="width:100%"></p>

### Comparing the original segmentation and the resampled one

In [None]:
plot_meshes([shapeSegIso, shapeSegResampledBin.toMesh(0.5)], 
             ['segmentation (red) and resampled (tan)'],
             use_same_window = True,
             is_interactive  = is_interactive, 
             show_borders    = show_borders,  
             meshes_color    = ['red', 'tan'], 
             opacities       = [0.5, 1],
             mesh_style      = mesh_style,
             show_mesh_edges = show_mesh_edges,
             show_axes       = show_axes,  
             show_bounds     = show_bounds,
             show_all_edges  = show_all_edges, 
             font_size       = font_size,   
             link_views      = link_views)

<p><video src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-resample-final.mp4" autoplay muted loop controls style="width:100%"></p>

### Resampling all segmentations

In [None]:
# let's now resample all segmentations, here we will use inplace update for shapeworks image
# this step will bring all segmentation to the same voxel spacing

# parameters for resampling
iso_spacing           = 1  # use 1 for memory constraints on notebooks
antialias_iterations  = 50

# all voxels between minVal and maxVal are set to innerVal
minVal   = 0
maxVal   = 5
innerVal = 1
outerVal = 0

for shapeSeg, shapeName in zip(shapeSegList, shapeNames):
    print('Resampling segmentation: ' + shapeName)
    
    # we can chain aliasing then resample then binarize
        
    # antialiasing binary segmentation
    shapeSeg.antialias(antialias_iterations)
    
    # perform image resampling
    shapeSeg.resample([iso_spacing, iso_spacing, iso_spacing], sw.InterpolationType.Linear)
    #sw.ImageUtils.isoresample(shapeSeg, iso_spacing)

    # binarize resampled image
    # all voxels between minVal and maxVal are set to innerVal
    shapeSeg.binarize(minVal, maxVal, innerVal, outerVal)


In [None]:
shapeSeg = shapeSegList[10]
itkw.view( image          = sw2vtkImage(shapeSeg),  
           slicing_planes = True,
           axes           = True,
           rotate         = True, 
           interpolation  = True)

<p><video src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-resample-final-itkw.mp4" autoplay muted loop controls style="width:100%"></p>

### Visualizing resampled segmentations

In [None]:
# plot all segmentations in the shape list as surfaces
shapeSegIsoList = []
for shapeSeg, shapeName in zip(shapeSegList, shapeNames):
    print('Isosurfacing segmentation: ' + shapeName)
    shapeSegIsoList.append(shapeSeg.toMesh(1e-10))

In [None]:
plot_meshes(shapeSegIsoList, 
            shapeNames,
            use_same_window = False,
            is_interactive  = is_interactive, 
            show_borders    = show_borders,  
            meshes_color    = meshes_color, 
            mesh_style      = mesh_style,
            show_mesh_edges = show_mesh_edges,
            show_axes       = show_axes,  
            show_bounds     = False,
            show_all_edges  = show_all_edges, 
            font_size       = font_size,   
            link_views      = False) # try out True to link views


<p><video src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-resample-final-pv.mp4" autoplay muted loop controls style="width:100%"></p>

### Saving resampled segmentations

In [None]:
save_images(outDir + 'resampled/',     
            shapeSegList,  
            shapeNames,  
            extension        = 'nrrd',
            compressed       = False, # use false to load in paraview 
            verbose          = True)

## 2. Aligning shapes  

Rigidly aligning a cohort of shapes entails removing differences across these shapes pertaining to global transformations, i.e., translation and rotation. This step requires a reference coordinate frame to align all shapes to, where one of the shapes can be selected as a reference. 

Rigid alignment (aka registration) is an optimization process that might get stuck in a bad local minima if shapes are significantly out of alignment. To bring shapes closer, we can remove translation differences using center-of-mass alignment.  This factors out translations to reduce the risk of misalignment and allow for a medoid sample to be automatically selected as the reference for subsequent rigid alignment.

Hence, the shapes alignment pipeline includes the following steps:
- Center-of-mass alignment
- Reference shape selection
- Rigid alignment



### Center-of-mass alignment

This step takes in a binary volume and translates the center of mass of the shape to the center of the 3D volume space.

#### Exploring center-of-mass alignment for a single segmentation

In [None]:
shapeIdx = 3
shapeSeg = sw.Image(shapeSegList[shapeIdx])

# computing the translation vector for this registration step, we need to
# get the image's center and the shape's center of mass.
# shapeworks image class has apis to provide both (returning shapeworks point)
help(shapeSeg.centerOfMass)
help(shapeSeg.center)

itkw.view( image          = sw2vtkImage(shapeSeg),  
           slicing_planes = True,
           axes           = True,
           rotate         = True, 
           interpolation  = True)

<p><video src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-com-itkw.mp4" autoplay muted loop controls style="width:100%"></p>

In [None]:
# compute the center of mass of this segmentation
shapeCenter = shapeSeg.centerOfMass() # retruns shapeworks point
print('shapeCenter: ' + str(shapeCenter))

# get the center of the image domain - physical coordinates of the "image center" (middle voxel) 
imageCenter = shapeSeg.center() # retruns shapeworks point
print('imageCenter: ' + str(imageCenter))

# image origin - physical coordinates of voxel(0,0,0)
imageOrigin = shapeSeg.origin() # retruns shapeworks point
print('imageOrigin: ' + str(imageOrigin))

# physical coordinates of voxel (row-1,cols-1,slices-1) - end point of the image domain diagonal
imageEnd = imageOrigin + shapeSeg.size()
print('imageEnd: ' + str(imageEnd))

In [None]:
# let' visualize those centers in relation to the image domain and shape/segmentation
# to better understand this relation, we need to visualize meshes and volumes in one view
shapeCenter_vtk = pv.Sphere(radius = 2, 
                            center = (shapeCenter[0], shapeCenter[1], shapeCenter[2]))
imageCenter_vtk = pv.Sphere(radius = 2, 
                            center = (imageCenter[0], imageCenter[1], imageCenter[2]))

origin_vtk = pv.Sphere(radius = 2, 
                       center = (imageOrigin[0], imageOrigin[1], imageOrigin[2]))

end_vtk = pv.Sphere(radius = 2, 
                    center = (imageEnd[0], imageEnd[1], imageEnd[2]))


plot_meshes([shapeSeg.toMesh(0.5), shapeCenter_vtk, imageCenter_vtk, origin_vtk, end_vtk], 
             ['seg (tan), shape center (red), image center (green), image origin (blue), image end (cyan)'],
             use_same_window = True,
             is_interactive  = is_interactive, 
             show_borders    = show_borders,  
             meshes_color    = ['tan', 'red', 'green', 'blue', 'cyan'], 
             opacities       = [0.5, 1, 1, 1, 1],
             mesh_style      = mesh_style,
             show_mesh_edges = show_mesh_edges,
             show_axes       = show_axes,  
             show_bounds     = show_bounds,
             show_all_edges  = show_all_edges, 
             font_size       = font_size,   
             link_views      = link_views)

<p><video src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-com-pv.mp4" autoplay muted loop controls style="width:100%"></p>

In [None]:
# now define the translation to move the shape center to the image center
translationVector =  imageCenter - shapeCenter
print('translationVector: ' + str(translationVector))

Applying this translation to segmentations entails interpolation due to image resampling in the new coordinate frame. Directly resampling binary segmentations will not result in a binary segmentation, but rather an interpolated version that does not have two distinct labels (i.e., foreground and background).

To mitigate this behavior, similar to the resampling workflow, we will first antialias the segmentation to convert it to a continuous-valued image with a smooth foreground-background interface, then apply the translation, and finally binarize the translated image.

In [None]:
# let's first inspect the api for image translation
help(shapeSeg.translate)

In [None]:
# let's keep shapeSeg unchanged
shapeSegCentered = sw.Image(shapeSeg)

translationVector
# here we will perform antialias-translate-binarize by chaining the apis
#shapeSegCentered.antialias(antialias_iterations).translate(translationVector).binarize()
shapeSegCentered.antialias(antialias_iterations).translate(translationVector).binarize()

# notice no change in header information
print('\nHeader information before center of mass alignment: ')
print(shapeSeg)

print('\nHeader information after center of mass alignment: ')
print(shapeSegCentered)

In [None]:
# let's visualize the segmentation before and after center of mass alignment
# notice how the shape is moved (i.e., translated) to the center of the image

print('shape segmentation before center of mass alignment')
display(itkw.view( image          = sw2vtkImage(shapeSeg), # for orthoginal image plane
                   label_image    = sw2vtkImage(shapeSeg),  # for volume rendering segmentation
                   slicing_planes = True, 
                   axes           = True,
                   rotate         = True, # enable auto rotation
                   interpolation  = True)
       )

print('shape segmentation after center of mass alignment')
display(itkw.view( image          = sw2vtkImage(shapeSegCentered), # for orthoginal image plane
                   label_image    = sw2vtkImage(shapeSegCentered),  # for volume rendering segmentation
                   slicing_planes = True, 
                   axes           = True,
                   rotate         = True, # enable auto rotation
                   interpolation  = True)
       )

<p><video src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-com-res.mp4" autoplay muted loop controls style="width:100%"></p>

#### Center-of-mass aligning all shapes

In [None]:
# let's apply center of mass alignment to all shapes

for shapeSeg, shapeName in zip(shapeSegList, shapeNames):
    print('Center of mass alignment: ' + shapeName)    
    
    # compute the center of mass of this segmentation
    shapeCenter = shapeSeg.centerOfMass() # retruns shapeworks point
    #print('\tshapeCenter: ' + str(shapeCenter))

    # get the center of the image domain
    imageCenter = shapeSeg.center() # retruns shapeworks point
    #print('\timageCenter: ' + str(imageCenter))

    # now define the translation to move the shape to its center
    translationVector =  imageCenter - shapeCenter
    #print('\ttranslationVector: ' + str(translationVector))
    
    # perform antialias-translate-binarize by chaining the apis
    shapeSeg.antialias(antialias_iterations).translate(translationVector).binarize()

#### Visualizing center-of-mass aligned segmentations

In [None]:
# .. and visualize center of mass aligned dataset
plot_volumes(shapeSegList,    
             volumeNames     = shapeNames, 
             is_interactive = is_interactive, 
             show_borders   = show_borders,  
             shade_volumes  = shade_volumes,  
             show_axes      = show_axes,  
             show_bounds    = show_bounds,
             show_all_edges = show_all_edges, 
             font_size      = font_size,   
             link_views     = True ) #link_views)

<p><video src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-com-res-pv.mp4" autoplay muted loop controls style="width:100%"></p>

#### Saving center-of-mass aligned segmentations

In [None]:
save_images(outDir + 'com/',     
            shapeSegList,  
            shapeNames,  
            extension        = 'nrrd',
            compressed       = False, # use false to load in paraview 
            verbose          = True)

### Reference shape selection

One option for a reference is to select the shape that is closest to the mean shape, i.e., the medoid shape. Note that computing the mean shape and distance to the mean shape require that all segmentations to have the same image dimensions.

Select the medoid shape entails the following steps:

- Bring all segmentations to the same image size if needed 
- Compute mean shape
- Compute distance to mean shape
- Select the shape sample that is closest to the mean shape

#### Bring all segmentations to the same image size if needed 

In [None]:
# compute maximum image size in each dimension
x_dims, y_dims, z_dims = [], [], []
for shapeSeg in shapeSegList:
    dims = shapeSeg.dims()
    x_dims.append(dims[0])
    y_dims.append(dims[1])
    z_dims.append(dims[2])

print('x_dims: ' + str(x_dims))
print('y_dims: ' + str(y_dims))
print('z_dims: ' + str(z_dims))

# define the maximum image size in the given dataset - notice different images have different sizes
x_max_dim = max(x_dims)
y_max_dim = max(y_dims)
z_max_dim = max(z_dims) 

# this is the common size that we need to pad individual segmentations 
# to compute the mean shape and distance to the mean shape
print('x_max_dim = ' + str(x_max_dim))
print('y_max_dim = ' + str(y_max_dim))
print('z_max_dim = ' + str(z_max_dim))

#### Compute mean shape

In [None]:
# now let's compute the mean shape
meanShape = np.zeros((x_max_dim, y_max_dim, z_max_dim))
padValue  = 0

# let's keep a list for the padded ones to compute distances to the mean shape next
paddedList = [] 
for shapeSeg in shapeSegList:
    dims   = shapeSeg.dims()

    # convert to numpy array
    shapeSeg_array = shapeSeg.toArray()
    
    # the numpy array needs to be permuted to match the shapeworks image dimensions
    shapeSeg_array = np.transpose(shapeSeg_array,(2,1,0))
    
    # now let's pad and still maintain the shape in the center of the image domain
    x_diff   = x_max_dim - dims[0]
    before_x = x_diff // 2
    after_x  = x_max_dim - (dims[0] + before_x)
    
    y_diff   = y_max_dim - dims[1]
    before_y = y_diff // 2
    after_y  = y_max_dim - (dims[1] + before_y)
    
    z_diff   = z_max_dim - dims[2]
    before_z = z_diff // 2
    after_z  = z_max_dim - (dims[2] + before_z)
    
    paddedList.append( np.pad(shapeSeg_array, ( (before_x, after_x), 
                                                (before_y, after_y), 
                                                (before_z, after_z)
                                              )
                             )
                     )
    meanShape += paddedList[-1]
meanShape /= len(shapeSegList)


In [None]:
# ... and visualize the mean shape
meanShape_vtk = pv.wrap(meanShape)

print('mean shape')
itkw.view( #image          = meanShape_vtk, # for orthoginal image plane
           label_image    = meanShape_vtk,  # for volume rendering segmentation
           slicing_planes = True, 
           axes           = True,
           rotate         = True, # enable auto rotation
           interpolation  = True
         )

<p><video src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-mean-b4-align.mp4" autoplay muted loop controls style="width:100%"></p>

#### Compute distance to mean shape

In [None]:
distances = np.zeros((num_samples,))
for ii, segPadded in enumerate(paddedList):
    distances[ii] = np.linalg.norm(meanShape - segPadded)

#### Select the shape sample that is closest to the mean shape

In [None]:
import matplotlib.pyplot as plt

plt.figure(dpi=150)
plt.bar(range(num_samples), distances)
plt.xlabel('shape #')
plt.ylabel('distance to mean shape')

referenceIdx = distances.argmin()
plt.bar(referenceIdx, distances[referenceIdx]);

<p><img src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-ref.png"></p>

### Rigid alignment

This step rigidly aligns each shape to the selected references. Rigid alignment involves interpolation, hence we need to convert binary segmentations to continuous-valued images.

#### Exploring alignment for a single segmentation

Aligning to a reference shape entails two steps:
- computing the rigid transformation parameters that would align a segmentation to the reference shape
- applying the rigid transformation to the segmentation

In [None]:
# this API takes in 
#      (1) source image  
#      (2) target/reference image 
#      (3) isovalue that defines the shape's surface in the given images
#      (4) number of iterations for ICP alignment
help(sw.ImageUtils.createRigidRegistrationTransform)

In [None]:
# first step is to compute rigid transformation parameters

shapeIdx = 5 #0 #referenceIdx #13
shapeSeg = sw.Image(shapeSegList[shapeIdx])     # need copy to avoid inplace antialiasing
refSeg   = sw.Image(shapeSegList[referenceIdx]) # need copy to avoid inplace antialiasing 

# notice the differences in origin, dims and size
print('\nHeader information before rigid alignment: ')
print(shapeSeg)

print('\nHeader information of reference: ')
print(refSeg)

In [None]:
# here we will antialias segmentations to convert to continuous-valued images
isoValue       = 1e-20
icp_iterations = 200

# compute rigid transformation
rigidTransform = sw.ImageUtils.createRigidRegistrationTransform(shapeSeg.antialias(antialias_iterations), 
                                                                refSeg.antialias(antialias_iterations), 
                                                                isoValue,
                                                                icp_iterations
                                                               )

print(rigidTransform)

In [None]:
# second we apply the computed transformation
# to avoid segmentations going out of image bounds when applying the transformation
# we need the aligned shape to have the same image information 
# of the reference shape for subsequent steps
shapeSegAligned = sw.Image(shapeSeg)
shapeSegAligned.antialias(antialias_iterations)
shapeSegAligned.applyTransform(rigidTransform, 
                               refSeg.origin(),  refSeg.dims(), 
                               refSeg.spacing(), refSeg.coordsys(), 
                               sw.InterpolationType.Linear)
shapeSegAligned.binarize()

# notice the change in header information
print('\nHeader information before rigid alignment: ')
print(shapeSeg)

print('\nHeader information after rigid alignment: ')
print(shapeSegAligned)

print('\nHeader information of reference: ')
print(refSeg)
                               

In [None]:
plot_meshes([shapeSeg.toMesh(0.5), shapeSegAligned.toMesh(0.5), refSeg.toMesh(0.5)], 
             ['seg (tan), aligned (red), reference (green)'],
             use_same_window = True,
             is_interactive  = is_interactive, 
             show_borders    = show_borders,  
             meshes_color    = ['tan', 'red', 'green'], 
             opacities       = [0.5, 0.5, 0.5],
             mesh_style      = mesh_style,
             show_mesh_edges = show_mesh_edges,
             show_axes       = show_axes,  
             show_bounds     = show_bounds,
             show_all_edges  = show_all_edges, 
             font_size       = font_size,   
             link_views      = link_views)

<p><video src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-align.mp4" autoplay muted loop controls style="width:100%"></p>

#### Aligning all segmentations to the reference shape



In [None]:
# get a copy for the reference segmentation to avoid inplace antialiasing 
refSeg   = sw.Image(shapeSegList[referenceIdx])

# antialias reference segmentation onnce
refSeg.antialias(antialias_iterations)

# alignment parameters
isoValue       = 1e-20
icp_iterations = 200

for shapeSeg, shapeName in zip(shapeSegList, shapeNames):
    print('Aligning ' + shapeName + ' to ' + shapeNames[referenceIdx]) 

    # compute rigid transformation
    rigidTransform = sw.ImageUtils.createRigidRegistrationTransform(shapeSeg.antialias(antialias_iterations), 
                                                                    refSeg,
                                                                    isoValue,
                                                                    icp_iterations
                                                                   )
    
    # second we apply the computed transformation, note that shapeSeg has 
    # already been antialiased, so we can directly apply the transformation 
    shapeSeg.applyTransform(rigidTransform, 
                            refSeg.origin(),  refSeg.dims(), 
                            refSeg.spacing(), refSeg.coordsys(), 
                            sw.InterpolationType.Linear)
    
    # then turn antialized-tranformed segmentation to a binary segmentation
    shapeSeg.binarize()

#### Visualizing aligned shapes

In [None]:
# .. and visualize  aligned dataset
plot_volumes(shapeSegList,    
             volumeNames     = shapeNames, 
             is_interactive = is_interactive, 
             show_borders   = show_borders,  
             shade_volumes  = shade_volumes,  
             show_axes      = show_axes,  
             show_bounds    = show_bounds,
             show_all_edges = show_all_edges, 
             font_size      = font_size,   
             link_views     = link_views)

<p><video src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-align-pv.mp4" autoplay muted loop controls style="width:100%"></p>

#### Comparing mean shapes before & after alignment

In [None]:
# let's compute the mean shape - note that all images have the same reference
# dimensions so no need to worry about different size images and no need for padding

dims      = refSeg.dims()
meanShapeAfterAlignment = sw.Image(np.zeros((dims[2], dims[1], dims[0]))) # note the flipped dims

for shapeSeg, shapeName in zip(shapeSegList, shapeNames):
    meanShapeAfterAlignment += shapeSeg
    
meanShapeAfterAlignment /= num_samples

In [None]:
# let's visualize the mean shape (segmentation) before and after alignment
# notice how the mean shape after alignment reflects a better ellipsoidal shape
meanShapeAfterAlignment_vtk = sw2vtkImage(meanShapeAfterAlignment)

print('mean shape before alignment')
display(
        itkw.view( #image          = meanShape_vtk, # for orthoginal image plane
                   label_image    = meanShape_vtk,  # for volume rendering segmentation
                   slicing_planes = True, 
                   axes           = True,
                   rotate         = True, # enable auto rotation
                   interpolation  = True
                 )
        )

print('mean shape after alignment')
display(
        itkw.view( #image          = meanShapeAfterAlignment_vtk, # for orthoginal image plane
                   label_image    = meanShapeAfterAlignment_vtk,  # for volume rendering segmentation
                   slicing_planes = True, 
                   axes           = True,
                   rotate         = True, # enable auto rotation
                   interpolation  = True
                 )
        )

<p><video src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-mean-align.mp4" autoplay muted loop controls style="width:100%"></p>

#### Saving aligned segmentations

In [None]:
save_images(outDir + 'aligned/',     
            shapeSegList,  
            shapeNames,  
            extension        = 'nrrd',
            compressed       = False, # use false to load in paraview 
            verbose          = True)

## 3. Cropping and padding segmentations   


As you can observed from the mean shape (after alignment), image boundaries are not tight around shapes. This leaves too much irrelevant background voxels that might increase the memory footprint when optimizing the shape model. Hence, let's learn how can we remove this irrelevant background while keeping our segmentations intact and avoid cropped segmentations to touch image boundaries.

This step entails:
- Finding the largest bounding box in which all segmentations are inscribed
- Cropping the segmentations using the computing bounding box
- Padding cropped segmentations

In [None]:
itkw.view( image          = meanShapeAfterAlignment_vtk, # for orthoginal image plane
           #label_image    = meanShapeAfterAlignment_vtk,  # for volume rendering segmentation
           slicing_planes = True, 
           axes           = True,
           rotate         = True, # enable auto rotation
           interpolation  = True
         )

### Finding the largest bounding box

As finding the bounding box is a process that needs to take into account all the segmentations in our dataset, the `ImageUtils` library in `shapeworks` is the right place to look for an API for this process.  

In [None]:
# finding the bounding box takes in a list of images and the isovalue that 
# defines the level set of shapes implicitly defined in these images 
help(sw.ImageUtils.boundingBox)

In [None]:
# note that the aligned segmentations are still binary images, so an good isovalue 
# that reflect the foreground-background interface would be 0.5
isoValue = 0.5
segsBoundingBox = sw.ImageUtils.boundingBox(shapeSegList, isoValue)

print('Computing bounding box:')
print(segsBoundingBox)
print(type(segsBoundingBox))


### Cropping segmentations

Cropping a segmentation is a process that can be applied for a single segmentation. So, let's find out the API from the `Image` library.

#### Exploring cropping for a single segmentation

In [None]:
# let's try this out first on a single segmentation
shapeIdx = 0

# keep a copy to avoid inplace update for now
shapeSeg = sw.Image(shapeSegList[shapeIdx]) 

# dot-tap magic
help(shapeSeg.crop)

In [None]:
# to see the impact of cropping, let's keep a copy
shapeSegCropped = sw.Image(shapeSeg)

shapeSegCropped.crop(segsBoundingBox)

print('Header information before cropping')
print(shapeSeg)

print('Header information after cropping')
print(shapeSegCropped)

In [None]:
# let's visualize them
plot_volumes([shapeSeg, shapeSegCropped], 
             ['before cropping','after cropping'],
             is_interactive = is_interactive, 
             show_borders   = show_borders,  
             shade_volumes  = shade_volumes,  
             show_axes      = show_axes,  
             show_bounds    = show_bounds,
             show_all_edges = show_all_edges, 
             font_size      = font_size,   
             link_views     = link_views)


<p><video src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-crop.mp4" autoplay muted loop controls style="width:100%"></p>

#### Inspecting how tight the cropping is

It is worth noting here that cropping here is a tight cropping without any extra padding. To see this effect, let's find the segmentation that has the largest volume and inspect the effect of cropping on it.

In [None]:
# compute the volume (number of voxels = 1) of each segmentation
volumes = np.zeros((num_samples,))
for ii, shapeSeg in enumerate(shapeSegList):
    volumes[ii] = np.sum(shapeSeg.toArray())
    
plt.figure(dpi=150)
plt.bar(range(num_samples), volumes)
plt.xlabel('shape #')
plt.ylabel('volume')

largestIdx = volumes.argmax()
plt.bar(largestIdx, volumes[largestIdx]);    

<p><img src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-vols.png"></p>

In [None]:
# let's inspect the cropping of the segmentation with the largest volume
# keep a copy to avoid inplace update for now
shapeSeg = sw.Image(shapeSegList[largestIdx]) 

# keep a copy to see the impact
shapeSegCropped = sw.Image(shapeSeg)

# apply cropping
shapeSegCropped.crop(segsBoundingBox)


shapeSeg_vtk        = sw2vtkImage(shapeSeg)
shapeSegCropped_vtk = sw2vtkImage(shapeSegCropped)

# to see how the segmentation relate to image boundaries before and after cropping
# let's visualize the volume with orthogonal image slices
# note how tight the cropping is
print('before cropping')
display(
        itkw.view( image          = shapeSeg_vtk, # for orthoginal image plane
                   label_image    = shapeSeg_vtk,  # for volume rendering segmentation
                   slicing_planes = True, 
                   axes           = True,
                   rotate         = True, # enable auto rotation
                   interpolation  = True
                 )
        )

print('after cropping')
display(
        itkw.view( image          = shapeSegCropped_vtk, # for orthoginal image plane
                   label_image    = shapeSegCropped_vtk,  # for volume rendering segmentation
                   slicing_planes = True, 
                   axes           = True,
                   rotate         = True, # enable auto rotation
                   interpolation  = True
                 )
        )

<p><video src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-crop-itkw.mp4" autoplay muted loop controls style="width:100%"></p>

Hence, to avoid cropped segmentations to touch the image boundaries, we will crop then pad the segmentations. 

### Padding segmentations

Given a segmentation volume, we would like to pad the volume in the three dimensions with a constant value. In this case, we will use zero padding, which will add more background voxels to all dimensions.

Let's start with one segmentation and figure out which `shapeworks` API to use and what parameters are expected.

#### Which `shapeworks` API to use?

In [None]:
# let's see if there's a function that pads a shapeworks image
# use dot-tap to get a list of functions/apis available for shapeSeg

# found it - pad, let's see its help
help(shapeSeg.pad)

#### Exploring padding for a single segmentation

In [None]:
# parameters for padding 
# The padding size and padding value are the two parameters for this step.
paddingSize  = 10 # number of voxels to pad for each dimension
paddingValue = 0  # the constant value used to pad the segmentations

# to visualize the effect of padding, let's keep the original segmentation unchanged
shapeSegPadded = sw.Image(shapeSegCropped)

shapeSegPadded.pad(paddingSize, paddingValue)

print('\nHeader information before padding: ')
print(shapeSegCropped)

print('\nHeader information after padding: ')
print(shapeSegPadded)
   

plot_volumes([shapeSegCropped, shapeSegPadded], 
             ['before padding','after padding'],
             is_interactive = is_interactive, 
             show_borders   = show_borders,  
             shade_volumes  = shade_volumes,  
             show_axes      = show_axes,  
             show_bounds    = show_bounds,
             show_all_edges = show_all_edges, 
             font_size      = font_size,   
             link_views     = link_views)

<p><video src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-pad.mp4" autoplay muted loop controls style="width:100%"></p>

### Cropping and padding all segmentations

In [None]:
# let's now crop-pad all segmentations, here we will use inplace update for shapeworks image

# parameters for padding 
paddingSize  = 10 # number of voxels to pad for each dimension
paddingValue = 0  # the constant value used to pad the segmentations

for shapeSeg, shapeName in zip(shapeSegList, shapeNames):
    print('Cropping & padding segmentation: ' + shapeName)    

    # crop-pad segmentation by chaining apis
    shapeSeg.crop(segsBoundingBox).pad(paddingSize, paddingValue)

### Visualizing cropped and padded segmentations

In [None]:
# .. and visualize  cropped dataset
plot_volumes(shapeSegList,    
             volumeNames     = shapeNames, 
             is_interactive = is_interactive, 
             show_borders   = show_borders,  
             shade_volumes  = shade_volumes,  
             show_axes      = show_axes,  
             show_bounds    = show_bounds,
             show_all_edges = show_all_edges, 
             font_size      = font_size,   
             link_views     = link_views)

<p><video src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-crop-pad.mp4" autoplay muted loop controls style="width:100%"></p>

### Saving cropped segmentations

In [None]:
save_images(outDir + 'cropped/',     
            shapeSegList,  
            shapeNames,  
            extension        = 'nrrd',
            compressed       = False, # use false to load in paraview 
            verbose          = True)

## 4. Converting segmentations to smooth signed distance transforms

We are now one-step away from optimizing our shape model. Stay tuned! 

For numerical computations for correspondences optimization, we need to convert binary segmentations to a continuous-valued image that satisfies the following requirements.
- smooth for gradient updates stability
- reflect the shape's surface (i.e., foreground-background) interface
- provide a signal for the particle to snap (move back) to the surface in case particles gets off the surface during optimization, which is a typical scenario when using gradient descent based optimization

So far, we have been antialiasing segmentations every time we need to convert binary segmentations to a continuous-valued image (e.g., resampling and alignment). An antialiased segmentation satisfies the first two requirements. However, if a particles leaves the surface (i.e., the zero-level set), it would be challenging to snap it back to the surface.

Another representation that satisfies all the requirements is the *signed distance transform*.

- A signed distance transform assigns to each voxel the physical distance to the closest point on the surface (i.e., the minimum distance from that voxel to nearest voxel on the foreground-background interface). 
- The sign is used to indicate whether that voxel is inside or outside the foreground object. 
- The zero-level set (zero-distance to the surface) indicates the foreground-background interface (i.e., the shape's surface).
- The gradient of a signed distance transform at a voxels indicats what direction to move in from that voxels to most rapidly increase the value of this distance. Hence, we can use the negative of this gradient as a signal to move a particle back to the surface.

### Exploring distance transforms for a single segmentation

In [None]:
shapeIdx = 0

# let's keep a copy to avoid inplace update
shapeSeg = sw.Image(shapeSegList[shapeIdx])

# computing distance transforms is a single image operation - let's use the dot-tap magic
help(shapeSeg.computeDT)

The `computeDT` API needs an isovalue that defines the foreground-background interface. For binary segmentations, this interface will not be smooth (notice the staircase effect) due to the aliasing effect of binarization. Hence, a smoother interface can be defined by first antialiasing the segmentation then compute the distance transform at the zero-level set.

In [None]:
isoValue = 0
shapeDT  = sw.Image(shapeSeg)
shapeDT.antialias(antialias_iterations).computeDT(isoValue)

In [None]:
# let's visualize them
plot_volumes([shapeSeg, shapeDT], 
             ['segmentation','signed distance transform'],
             is_interactive = is_interactive, 
             show_borders   = show_borders,  
             shade_volumes  = shade_volumes,  
             show_axes      = show_axes,  
             show_bounds    = show_bounds,
             show_all_edges = show_all_edges, 
             font_size      = font_size,   
             link_views     = link_views)

<p><video src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-dt.mp4" autoplay muted loop controls style="width:100%"></p>

In [None]:
# let's compare isosurfaces
plot_meshes([shapeSeg.toMesh(0.5), shapeDT.toMesh(0)], 
             ['seg (tan), dt (red)'],
             use_same_window = True,
             is_interactive  = is_interactive, 
             show_borders    = show_borders,  
             meshes_color    = ['yellow', 'red'], 
             opacities       = [0.7, 0.7],
             mesh_style      = mesh_style,
             show_mesh_edges = show_mesh_edges,
             show_axes       = show_axes,  
             show_bounds     = show_bounds,
             show_all_edges  = show_all_edges, 
             font_size       = font_size,   
             link_views      = link_views)

<p><video src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-dt-iso.mp4" autoplay muted loop controls style="width:100%"></p>

In [None]:
# let's have a closer look, is it smooth?
plot_meshes([shapeDT.toMesh(0)], 
             ['isosurface of dt'],
             use_same_window = True,
             is_interactive  = is_interactive, 
             show_borders    = show_borders,  
             meshes_color    = ['tan'], 
             opacities       = [1],
             mesh_style      = mesh_style,
             show_mesh_edges = show_mesh_edges,
             show_axes       = show_axes,  
             show_bounds     = show_bounds,
             show_all_edges  = show_all_edges, 
             font_size       = font_size,   
             link_views      = link_views)

<p><video src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-dt-iso2.mp4" autoplay muted loop controls style="width:100%"></p>

### Exploring smoothing the distance transform of a single segmentation

We can still see some leftovers from the aliasing effect of binarization. Let's try to smooth this out without significantly impacting the shape features that we want to model.

`shapeworks` library supports two methods of smoothing, gaussian-blur and topology-preserving smoothing. The gaussian blur method filters/convolves the image with a 3D gaussian filter with a given sigma (in physical coordinates). This method could be use for blobby-like structures. However, for shapes with thin features and high curvature regions, the gaussian blurring method could impact the underlying geometrical features. For these shapes, topology-preserving smoothing is recommended. 


**Note:**  topology-preserving smoothing is currently being debugged and tested. Issue #850

In [None]:
# let's try gaussian blur first
shapeDTgauss = sw.Image(shapeDT)
help(shapeDTgauss.gaussianBlur)

In [None]:
# gaussian blur needs the convolution fiter size (i.e., sigma)
blur_sigma = 2 # physcial coordinates 
shapeDTgauss.gaussianBlur(blur_sigma)

In [None]:
# let's have a look
plot_meshes([shapeDT.toMesh(0), shapeDTgauss.toMesh(0)], 
             ['distance tranform', 'gaussian blur'],
             use_same_window = False,
             is_interactive  = is_interactive, 
             show_borders    = show_borders,  
             meshes_color    = 'tan', 
             opacities       = 1,
             mesh_style      = mesh_style,
             show_mesh_edges = show_mesh_edges,
             show_axes       = show_axes,  
             show_bounds     = show_bounds,
             show_all_edges  = show_all_edges, 
             font_size       = font_size,   
             link_views      = link_views)

<p><video src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-dt-smooth.mp4" autoplay muted loop controls style="width:100%"></p>

In [None]:
# this seems to be a nicely smooth shape, let's compare isosurfaces
plot_meshes([shapeDT.toMesh(0), shapeDTgauss.toMesh(0)], 
             ['dt (tan), gaussian-blur (red)'],
             use_same_window = True,
             is_interactive  = is_interactive, 
             show_borders    = show_borders,  
             meshes_color    = ['yellow', 'red'], 
             opacities       = [0.7, 0.7],
             mesh_style      = mesh_style,
             show_mesh_edges = show_mesh_edges,
             show_axes       = show_axes,  
             show_bounds     = show_bounds,
             show_all_edges  = show_all_edges, 
             font_size       = font_size,   
             link_views      = link_views)

<p><video src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-dt-smooth2.mp4" autoplay muted loop controls style="width:100%"></p>

In [None]:
# this smoothing seems to impact more the major radius of the ellispoid
# a shape feature that we would like to model
# let's try to reduce the gaussian's sigma

blur_sigma   = 1.3 # physcial coordinates 
shapeDTgauss = sw.Image(shapeDT)
shapeDTgauss.gaussianBlur(blur_sigma)

plot_meshes([shapeDT.toMesh(0), shapeDTgauss.toMesh(0)], 
             ['dt (tan), gaussian-blur (red)'],
             use_same_window = True,
             is_interactive  = is_interactive, 
             show_borders    = show_borders,  
             meshes_color    = ['yellow', 'red'], 
             opacities       = [0.7, 0.7],
             mesh_style      = mesh_style,
             show_mesh_edges = show_mesh_edges,
             show_axes       = show_axes,  
             show_bounds     = show_bounds,
             show_all_edges  = show_all_edges, 
             font_size       = font_size,   
             link_views      = link_views)

<p><video src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-dt-smooth-small-sigma.mp4" autoplay muted loop controls style="width:100%"></p>

In [None]:
# not that bad
# let's have a look on how smooth it is
plot_meshes([shapeDT.toMesh(0), shapeDTgauss.toMesh(0)], 
             ['distance tranform', 'gaussian blur'],
             use_same_window = False,
             is_interactive  = is_interactive, 
             show_borders    = show_borders,  
             meshes_color    = 'tan', 
             opacities       = 1,
             mesh_style      = mesh_style,
             show_mesh_edges = show_mesh_edges,
             show_axes       = show_axes,  
             show_bounds     = show_bounds,
             show_all_edges  = show_all_edges, 
             font_size       = font_size,   
             link_views      = link_views)

<p><video src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/nb-groom-dt-smooth-small-sigma2.mp4" autoplay muted loop controls style="width:100%"></p>

### Computing smooth distance transforms for all segmentations

In [None]:
isoValue = 0
sigma    = 1.3
for shapeSeg, shapeName in zip(shapeSegList, shapeNames):
    print('Compute DT for segmentation: ' + shapeName)    

    # antialias-dt computing
    shapeSeg.antialias(antialias_iterations).computeDT(isoValue).gaussianBlur(sigma)

### Saving distance transforms

In [None]:
save_images(outDir + 'dts/',     
            shapeSegList,  
            shapeNames,  
            extension        = 'nrrd',
            compressed       = False, # use false to load in paraview 
            verbose          = True)

## 5. Optimizing correspondence: just scratching the surface

In [None]:
# here is the list of groomed segmentations
!echo {outDir}dts/*.nrrd

In [None]:
# let's launch studio to optimize our shape model 
# see the video in the next cell to an illustration

!ShapeWorksStudio {outDir}dts/*.nrrd

<p><video src="https://sci.utah.edu/~shapeworks/doc-resources/mp4s/studio_optimize2.mp4" autoplay muted loop controls style="width:100%"></p>

## Congrats! You have been able to groom your dataset ...