# Getting Started with Shape Cohort Generator

## Before you start!

- This notebook assumes that shapeworks conda environment has been activated using `conda activate shapeworks` on the terminal.
- See [Setting Up ShapeWorks Environment](setting-up-shapeworks-environment.ipynb) to learn how to set up your environment to start using shapeworks library. Please note, the prerequisite steps will use the same code to setup the environment for this notebook and import `shapeworks` library.


## In this notebook, you will learn:

1. How to use the `ShapeCohortGenerator` package to generate meshes and segmentation images for <br>
   i. Ellipsoids <br>
   ii. Supershapes


We will also define modular/generic helper functions as we walk through these items to reuse functionalities without duplicating code.


## Prerequisites

- Setting up `shapeworks` environment. See [Setting Up ShapeWorks Environment](setting-up-shapeworks-environment.ipynb). To avoid code clutter, the `setup_shapeworks_env` function can found in `Examples/Python/setupenv.py` module. 
- Importing `shapeworks` library. See [Setting Up ShapeWorks Environment](setting-up-shapeworks-environment.ipynb).

## Note about shapeworks APIs

shapeworks functions are inplace, i.e., `<swObject>.<function>()` applies that function to the `swObject` data. To keep the original data unchanged, you have first to copy it to another variable before applying the function.

## Notebook keyboard shortcuts

- `Esc + H`: displays a complete list of keyboard shortcuts
- `Esc + A`: insert new cell above the current cell
- `Esc + B`: insert new cell below the current cell
- `Esc + D + D`: delete current cell
- `Esc + Z`: undo
- `Shift + enter`: run current cell and move to next
- To show a function's argument list (i.e., signature), use `(` then `shift-tab`
- Use `shift-tab-tab` to show more help for a function
- To show the help of a function, use `help(function)` or `function?`
- To show all functions supported by an object, use `dot-tab` after the variable name

## Prerequisites

### Setting up shapeworks environment 

Here, we will append both your `PYTHONPATH` and your system `PATH` to setup shapeworks environment for this notebook. See [Setting Up ShapeWorks Environment](setting-up-shapeworks-environment.ipynb) for more details.

In this notebook, we assume the following.

- This notebook is located in `Examples/Python/notebooks/tutorials`
- You have built shapeworks in `build` directory within the shapeworks code directory
- You have built shapeworks dependencies (using `build_dependencies.sh`) in the same parent directory of shapeworks code

In [None]:
%%capture
# import relevant libraries 
import sys 
import os
import itkwidgets as itkw
import pyvista as pv
import numpy
import itk
# # add parent-parent directory (where setupenv.py is) to python path
sys.path.insert(0,'../..')

# importing setupenv from Examples/Python

import setupenv

# indicate the bin directories for shapeworks and its dependencies
shapeworks_bin_dir   = "../../../../build_pybind/bin/"
dependencies_bin_dir = "../../../../dependencies/install/bin/"

# set up shapeworks environment
setupenv.setup_shapeworks_env(shapeworks_bin_dir,  
                              dependencies_bin_dir, 
                              verbose = True)

### Importing `shapeworks` library

In [None]:
# let's import shapeworks library to test whether shapeworks is now set
try:
    import shapeworks as sw
except ImportError:
    print('ERROR: shapeworks library failed to import')
else:
    print('SUCCESS: shapeworks library is successfully imported!!!')

### Importing `ShapeCohortGen` library

In [None]:
# let's import ShapeCohortGen library 
try:
    import ShapeCohortGen
except ImportError:
    print('ERROR: ShapeCohortGen library failed to import')
else:
    print('SUCCESS: ShapeCohortGen library is successfully imported!!!')


In [None]:
'''
Get files with specific extensions
'''
def get_file_with_ext(file_list,extension):

    extList =[]
    for file in file_list:
        ext = file.split(".")[-1]
        if(ext==extension):
            extList.append(file)
    extList = sorted(extList)
    return extList

### Shape Cohort Generation Overview
This shape cohort generator package is capable of generating ellipsoids and supershapes. To use this package, first a generator is defined, then `generate()` is called which creates shapes in mesh format (both .ply and .vtk). After this has been run segmentations and images can be created from those meshes.

Each generator has three functions:
1. `generate()` - mesh generation (function specific to generator type)
2. `generate_segmentations()` - segmentation generation based on meshes (general function shared by all generator types)
3. `generate_images()` - image generation based on segmentations (general function shared by all generator types)



## Generating Ellipsoid Dataset

### 1. Mesh Generation

Initialize an ellipsoid cohort generator. The output directory needs to be specified otherwise an output directory will automatically be generated.

Arguments:
1. `out_dir` - path where the dataset should be saved<br>
    Datatype : `string`<br> 
    Default value : 'current_directory/generated_ellipsoid_cohort/' <br>
    <br>

In [None]:
out_dir = "../Output/Generated_Ellipsoids/"
ellipsoid_generator = ShapeCohortGen.EllipsoidCohortGenerator(out_dir)

For the ellipsoid mesh generation, you can specify the following arguments:
1. `num_samples` - number of samples in the cohort(dataset)<br>
    Datatype : `int` <br>
    Default value : 3 <br>
    <br>
2. `randomize_center` - randomizes the centers for ellipsoid mesh generation if set to `True`<br>
    Datatype : `bool` <br> 
    Defaut value : `True` <br>
    <br>   
3. `randomize_rotation` - randomizes the orientation of the ellispoids if set to `True` <br>
    Datatype : `bool` <br> 
    Defaut value : `True`
    <br>

In [None]:
num_samples = 8
meshFiles = ellipsoid_generator.generate(num_samples)

In [None]:
# get all the .vtk files for plotting
VTKmeshFiles = get_file_with_ext(meshFiles,'vtk')
print(VTKmeshFiles)

### 2. Read the  meshes and display them
We will then use `shapeworks` Mesh class to load this surface mesh

In [None]:
%%capture
swMeshList = []
for i in range(len(VTKmeshFiles)):
    shapeMesh = sw.Mesh(VTKmeshFiles[i])
    swMeshList.append(shapeMesh)

### 3. Visualizing surface mesh using `itkwidgets`

[`itkwidgets`](https://github.com/InsightSoftwareConsortium/itkwidgets) is a python library that supports interactive Jupyter widgets to visualize images, point sets, and meshes. 

`itkwidgets` supports `itk`, `vtk`, and `pyvista` data structures. Hence, to visualize a `shapeworks` mesh, we need first to convert it to a `vtk` mesh. First we will define a helper function for the conversion.

In [None]:
# a helper function that converts shapeworks Mesh object to vtk mesh 
# TODO: to be modifed when #825 is addressed

def sw2vtkMesh(swMesh, verbose = True):
    
    if verbose:
        print('Header information: ')
        print(swMesh)

    # save mesh
    swMesh.write('temp.vtk')

    # read mesh into an itk mesh data
    vtkMesh = pv.read('temp.vtk')
    
    # remove the temp mesh file
    os.remove('temp.vtk')
    
    return vtkMesh

In [None]:
%%capture
vtkMeshList = []
for i in range(len(swMeshList)):
    shapeMesh_vtk = sw2vtkMesh(swMeshList[i],verbose=False)
    vtkMeshList.append(shapeMesh_vtk)

In [None]:
# visualize with axes and auto rotation
itkw.view(  geometries       = vtkMeshList, 
            rotate           = True, # enable auto rotation
            axes             = True)

### 4. Segmentation Generation
For segmentation generation, you can specify the following arguments:

1. `randomize_size` - randomize the size of the images to include more background if set to `True`<br>
    Datatype : `bool` <br> 
    Defaut value : `True`
    <br>
2. `spacing` - set the spacing of the segmentation image <br>
    Datatype: `list` <br>
    Default value: `[1,1,1]` <br>
    <br>
3. `allow_on_boundary` - If set to `True`,randomly selects 20% samples and ensure that the shapes are touching two random selected axes out of `[x,y,z]`<br>
    Datatype : `bool` <br> 
    Defaut value : `True`
    <br>

In [None]:
segFiles = ellipsoid_generator.generate_segmentations()

### 5. Read segmentation images
### Segmentation is just an image data, we will use `shapeworks` Image class to load it

In [None]:
shapeSeg = sw.Image(segFiles[1])

# let's print out header information of this segmentation 
print('Header information: ')
print(shapeSeg)

### 6. Visualize Segmentation Image
### To visualize our `shapeworks` image, we need first to convert it to a `vtk` data structure., let's add a helper function for this purpose.



In [None]:
# a helper function that converts shapeworks Image object to vtk image
def sw2vtkImage(swImg, verbose = True):
            
    # get the numpy array of the shapeworks image
    array  = swImg.toArray()
    
    # the numpy array needs to be permuted to match the shapeworks image dimensions
    array = numpy.transpose(array,(2,1,0))
    
    # converting a numpy array to a vtk image using pyvista's wrap function
    vtkImg = pv.wrap(array)
    
    if verbose:
        print('shapeworks image header information: ')
        print(swImg)

        print('\nvtk image header information: ')
        print(vtkImg) 
    
    return vtkImg

In [None]:
def getvtkImages(segs):
    vtkImage = []
    for i in range(len(segFiles)):
        swImg = sw.Image(segs[i])
        vtkImage.append(sw2vtkImage(swImg))
        # define grid size for two segmentations
        grid_rows  = 2
        grid_cols  = 4
    return vtkImage
vtkImages = getvtkImages(segFiles)

### Defining `pyvista` plotter

Next, we will define a `pyvista` plotter to render multiple windows. The multiple rendering windows will be visualized as a grid of plots.

In [None]:
def add_volume_to_plotter( pvPlotter,      # pyvista plotter
                           vtkImg,         # vtk image to be added
                           rowIdx, colIdx, # subplot row and column index
                           title = None,   # text to be added to the subplot, use None to not show text 
                           shade_volumes  = True,  # use shading when performing volume rendering
                           color_map      = "coolwarm", # color map for volume rendering, e.g., 'bone', 'coolwarm', 'cool', 'viridis', 'magma'
                           show_axes      = True,  # show a vtk axes widget for each rendering window
                           show_bounds    = False, # show volume bounding box
                           show_all_edges = True,  # add an unlabeled and unticked box at the boundaries of plot. 
                           font_size      = 10     # text font size for windows
                         ):
    
    # which subplot to add the volume to
    pvPlotter.subplot(rowIdx, colIdx)
    
    # add the volume
    pvPlotter.add_volume(vtkImg, shade = shade_volumes, cmap = color_map)

    if show_axes:
        pvPlotter.show_axes()

    if show_bounds:
        pvPlotter.show_bounds(all_edges = show_all_edges)

    # add a text to this subplot to indicate which volume is being visualized
    if title is not None:
        pvPlotter.add_text(title, font_size = font_size)

### Defining a helper function

Let's define a helper function that adds a segmentation to a `pyvista` plotter.

In [None]:
def plot8segs(segs, vtkImages):
    grid_rows = 2
    grid_cols = 4
    # define parameters that controls the plotter
    is_interactive = True  # to enable interactive plots
    show_borders   = True  # show borders for each rendering window
    shade_volumes  = True  # use shading when performing volume rendering
    color_map      = "coolwarm" # color map for volume rendering, e.g., 'bone', 'coolwarm', 'cool', 'viridis', 'magma'
    show_axes      = True  # show a vtk axes widget for each rendering window
    show_bounds    = False # show volume bounding box
    show_all_edges = True  # add an unlabeled and unticked box at the boundaries of plot. 
    font_size      = 10    # text font size for windows
    link_views     = True  # link all rendering windows so that they share same camera and axes boundaries
    # define the plotter
    plotter = pv.Plotter(shape    = (grid_rows, grid_cols),
                         notebook = is_interactive, 
                         border   = show_borders) 
    for r in range(grid_rows):
        for c in range(grid_cols):
            index = grid_cols * r + c
            index = int(index)

            shapeName = segs[index].split("/")[-1].split(".")[0]

            add_volume_to_plotter( plotter, vtkImages[index],   
                           r,c, 
                           title          = shapeName,
                           shade_volumes  = shade_volumes, 
                           color_map      = color_map,
                           show_axes      = show_axes, 
                           show_bounds    = show_bounds, 
                           show_all_edges = show_all_edges, 
                           font_size      = font_size)
    #         plotter.show_axes()
            plotter.show_bounds()
    plotter.link_views() 

    plotter.show()

In [None]:
plot8segs(segFiles, vtkImages)

### 7. Image Generation - Turning segmentations into non-binary images

For the image generation, a Gaussian distribution is used to define foreground and background pixels values and a blur factor is used to blur the boundary with a Gaussian filter. You can specify the following arguments:

1. `blur_factor` - size of Gaussian filter to use for boundary blurring <br>
    Datatype : `int` <br> 
    Defaut value : `1`
    <br>
2. `foreground_mean` - mean of the foreground pixel value distribution <br>
    Datatype: `int` <br>
    Default value: `180` <br>
    <br>
3. `foreground_var` - variance of the foreground pixel value distribution <br>
    Datatype : `int` <br> 
    Defaut value : `30`
    <br>
4. `background_mean` - mean of the background pixel value distribution <br>
    Datatype: `int` <br>
    Default value: `80` <br>
    <br>
5. `background_var` - variance of the foreground pixel value distribution <br>
    Datatype : `int` <br> 
    Defaut value : `30`
    <br>

In [None]:
imageFiles = ellipsoid_generator.generate_images()

### 8. Visualize generated image
Let's compare a segmentation to it's corresponding image.

In [None]:
def sw2itkImage(swImg):
    print(swImg)
    array = swImg.toArray()
    itkImg = itk.GetImageFromArray(array)
    return itkImg

In [None]:
print("Segmentation:")
seg0 = sw.Image(segFiles[0])
itkw.view(sw2itkImage(seg0))

In [None]:
print("Image:")
img0 = sw.Image(imageFiles[0])
itkw.view(sw2itkImage(img0))

## Generate Supershapes Dataset
SuperShapes are parameterized shapes which have geometry based on a given number of lobes, 'm'.

### 1. Generate meshes

Initialize  SuperShapes cohort generator. The output directory needs to be specified otherwise an output directory will automatically be generated.

Argument:
1. `out_dir` - path where the dataset should be saved<br>
    Datatype : `string`<br> 
    Default value : 'current_directory/generated_supershapes_cohort/' <br>
    <br>

Each generator has three functions:
1. `generate()` - mesh generation
2. `generate_segmentations()` - segmentation generation (based on meshes)
3. `generate_images()` - image generation (based on segmentations)

In [None]:
out_dir = "../Output/Generated_Supershapes/"
ss_generator = ShapeCohortGen.SupershapesCohortGenerator(out_dir)

For the supershapes mesh generation, you can specify the following arguments:
1. `num_samples` - number of samples in the cohort(dataset)<br>
    Datatype : `int` <br>
    Default value : 3 <br>
    <br>
2. `randomize_center` - randomizes the centers for ellipsoid mesh generation if set to `True`<br>
    Datatype : `bool` <br> 
    Defaut value : `True` <br>
    <br>   
3. `randomize_rotation` - randomizes the orientation of the ellispoids if set to `True` <br>
    Datatype : `bool` <br> 
    Defaut value : `True`
    <br>
4. `m` - number of lobes supershapes should have <br>
    Datatype : `int` <br>
    Default value: `3` <br>
    <br>
5. `size` - size of meshes (won't be more than 'size' away from center in any direction) <br>
    Datatype: `int` <br>
    Default value: `20` <br>
    <br>

In [None]:
num_samples = 8
meshFiles = ss_generator.generate(num_samples)

In [None]:
# get all the .vtk files for plotting
VTKmeshFiles = get_file_with_ext(meshFiles,'vtk')
print(VTKmeshFiles)

### 2. Visualize meshes
We will then use `shapeworks` Mesh class to load this surface mesh

In [None]:
swMeshList = []
for i in range(len(VTKmeshFiles)):
    shapeMesh = sw.Mesh(VTKmeshFiles[i])
    swMeshList.append(shapeMesh)

In [None]:
%%capture
vtkMeshList = []
for i in range(len(swMeshList)):
    shapeMesh_vtk = sw2vtkMesh(swMeshList[i],verbose=False)
    vtkMeshList.append(shapeMesh_vtk)

In [None]:
# visualize with axes and auto rotation
itkw.view(  geometries       = vtkMeshList,
            rotate           = True, # enable auto rotation
            axes             = True)

### 3. Segmentation Generation
This is data type independent, the options are the same as they were for the ellipsoid.

In [None]:
segFiles = ss_generator.generate_segmentations()

### 4. Visualize segmentations

In [None]:
vtkImages = getvtkImages(segFiles)

In [None]:
plot8segs(segFiles, vtkImages)

### 5. Image generation
This is also a standard function and has all the same options as listed before.

In [None]:
imageFiles = ss_generator.generate_images()

### 6. Visualize image
Let's again compare one segmentation to it's corresponding image.

In [None]:
print("Segmentation:")
seg0 = sw.Image(segFiles[0])
itkw.view(sw2itkImage(seg0))

In [None]:
print("Image:")
img0 = sw.Image(imageFiles[0])
itkw.view(sw2itkImage(img0))