In [18]:
%load_ext autoreload
%autoreload 2

# Shape Analysis in 3D

0. Why analyze shape in biology?
1. Fourier approximation 
2. Spherical harmonics
3. Applied spherical harmonics

# Why do shape analysis in biology? 

<img src="resources/shapes_in_biology.png"/>

## Shape "type" determines appropriate tool 

1. **2D contours** 
2. **Simple 3D shapes**
3. Multi-component 3D shapes
4. Shapes with underlying network topology
5. Complex 3D shapes

**What is shape?**

D.G. Kendall (1984): "what is left when the differences which can be attributed to translations, rotations, and dilatations have been quotiented out"

In [None]:
import pyvista as pv
pv.start_xvfb()

In [None]:
sample_nuc_mesh_1 = pv.read("resources/sample_nuc.vtk")
sample_nuc_mesh_2 = sample_nuc_mesh_1.copy()
sample_nuc_mesh_2 = sample_nuc_mesh_2.translate((-150,0,20), inplace=True)
sample_nuc_mesh_2 = sample_nuc_mesh_2.rotate_x(-30, inplace=True)
sample_nuc_mesh_2 = sample_nuc_mesh_2.scale([1.5,1.5,1.5], inplace=False)

In [None]:
sample_nuc_mesh_1.volume

In [None]:
sample_nuc_mesh_2.volume

In [None]:
plotter = pv.Plotter(window_size=[900,400],) 
plotter.add_mesh(sample_nuc_mesh_1, color='lightgray')
plotter.add_mesh(sample_nuc_mesh_2, color='lightgray')
plotter.add_bounding_box(line_width=5, color='black')
plotter.view_xz()
plotter.camera.zoom(1.8)
plotter.set_background('white')
plotter.show(jupyter_backend='pythreejs')

By Kendall's definition of shape, these two nuclei are equivalent!

TODO: explanation of preprocessing steps

## Shape parameterization

In short, how we obtain numbers representing shapes such that we can do further analyses on them

<img src="resources/param_workflow.png"/>

TODO: explanation of desirable properties (generative). perhaps examples of "statistics"

## Introduction to Fourier Approximation: parameterization in 1D/2D:

Fourier theory states that *any function can be represented by an infinite sum of sine and cosine terms*. In practice, we use a finite number of terms and obtain an approximation of our original function. 

When we approximate periodic functions, we *expand* the function into a Fourier series which looks like this: 

$y=A_o+A_1\cos(\frac{2πx}{L})+B_1 \sin(\frac{2πx}{L})+A_2 \cos(\frac{4πx}{L})+B_2 \sin(\frac{4πx}{L})+ \space ...$

Equivalently:

$y= \sum_{n=0}^N A_n \cos(\frac{2 \pi n x}{L}) + \sum_{n=0}^N B_n \sin(\frac{2 \pi n x}{L})$

- $L$ is half of the period of the function
- $A_n$ and $B_n$ are coefficients we must calculate 

**How do we compute $A_n$ and $B_n$?**

Given a function or set of data $f(x)$ which we want to approximate, we set up a minimization scheme by deriving $y$ w.r.t. $A_n$ and then w.r.t. $B_n$, setting $\frac{\partial E}{\partial A_n}$ and $\frac{\partial E}{\partial B_n}$ to zero. We arrive at the following equations:

$A_n = \frac{2}{L} \int_{x_1}^{x_2} f(x) \cos(\frac{2 \pi n x}{L})dx$

$B_n = \frac{2}{L} \int_{x_1}^{x_2} f(x) \sin(\frac{2 \pi n x}{L})dx$

In [5]:
import numpy as np
from scipy.signal import square 
from scipy.integrate import simps
from scipy import interpolate
import matplotlib.pyplot as plt

In [6]:
L=4
n_points=256

In [7]:
x=np.linspace(0,L,n_points,endpoint=False)
y=square(np.pi*x, duty=0.5) # Define simple square waveform

In [9]:
# Functions for computing Fourier coefficients using Simpson's integration technique
an=lambda n:2.0/L*simps(y*np.cos(2.*np.pi*n*x/L),x)
bn=lambda n:2.0/L*simps(y*np.sin(2.*np.pi*n*x/L),x)

In [10]:
# Interactive plot demonstrating Fourier series approximation for periodic functions

from viz import get_square_wave_fig

fig = get_square_wave_fig(x=x,
                          y=y,
                          L=L,
                          an=an,
                          bn=bn)
fig.show()

We can similarly use Fourier-based techniques to model closed 2D contours. For instance, consider this simple 2D closed contour.

In [15]:
n_points = 64

# Define square contour
rq = np.linspace(-1,1-2.0/n_points, n_points).tolist()
lq = (-np.linspace(-1,1-2.0/n_points, n_points)).tolist()
x = rq + [1]*n_points + lq + [-1]*n_points
y = [-1]*n_points + rq + [1]*n_points + lq
x = np.array(x + [x[0]])
y = np.array(y + [y[0]])

<img src="resources/square_cartesian_vs_polar.png" width="70%"/>

**By mapping the square into polar coordinates, we can represent it with a 1-parameter function $r(\theta)$ w.r.t. $\theta$ rather than w.r.t. $x-$ and $y-$, which requires a 2-parameter function**
- $r(\theta)$ maps the distance from each point on the contour to the origin. 

Thus we can do a Fourier expansion of $r(\theta)$:

$r(\theta) =  \frac{a_0}{2} + \sum_{n=1}^{\infty}(A_n \cos n \theta + B_n \sin n \theta)$

where the Fourier coeffcients are:

$A_n = \frac{1}{\pi} \int_{-\pi}^{\pi} r(\theta) \cos n \theta d \theta$

$B_n = \frac{1}{\pi} \int_{-\pi}^{\pi} r(\theta) \sin n \theta d \theta$

In [16]:
# Convert to polar coordinates
r = np.sqrt(x**2+y**2)
theta = np.arctan2(y,x)

In [53]:
from viz import get_one_param_polar_fig

fig = get_one_param_polar_fig(theta, r, x, y)
fig.show()

This Fourier basis representation is not without limitations. Shown below is a simple 2D "C" contour that cannot be described by Fourier decomposition as in above. 

<img src="resources/C_contour_approximation.png"/>

**Why?**
- the radius does not cross the contour for some $\theta$ 
- some $\theta$ values map to more than one $r$ value. 

Note: in this case, these issues cannot be resolved by moving the origin

As a solution, we can use an extended Fourier method by Kuhl and Giardina (1982): the *elliptical Fourier variant*. Now we **use two parametric functions $x(t)$ and $y(t)$** s.t. $t$ is *arc length* relative to an origin on the contour rather than an angle relative to the origin ($\theta$), as in above. The simple idea is that $x(t)$ and $y(t)$ correspond to the $x-$ and $y-$Cartesian coordinates of the 2D contour.

Now we have 4 sets of coefficients:

$x(t) = \frac{A_0}{2} + \sum_{n=1}^{\infty}(A_n \cos n t + B_n \sin n t)$

$y(t) = \frac{C_0}{2} + \sum_{n=1}^{\infty}(C_n \cos n t + D_n \sin n t)$

In [116]:
# Define "C" shape in cartesian coordinates
n_terms = 50
n_points = 100

xy = np.array([[-0.5,0.5], [0,0.5], [0.5,0.5], 
               [0.5,0.75], [0,0.75], [-0.75,0.75], 
               [-0.75,0], [-0.75,-0.75], [0,-0.75], 
               [0.5,-0.75], [0.5,-0.5], [0,-0.5],
               [-0.5,-0.5], [-0.5,0], [-0.5,0.5]])

In [117]:
import pyefd
coeffs = pyefd.elliptic_fourier_descriptors(xy, order=n_terms)
a0, c0 = pyefd.calculate_dc_coefficients(xy)

In [125]:
from viz import get_two_param_2d_fig

fig = get_two_param_2d_fig(coeffs, a0, c0, 
                           xy, 
                           n_points, 
                           n_terms)
fig.show()

In [119]:
from viz import get_two_param_coeff_table

coeff_table = get_two_param_coeff_table(xy)
coeff_table

Unnamed: 0_level_0,a0,b0,c0,d0,a1,b1,c1,d1,a2,b2,...,c21,d21,a22,b22,c22,d22,a23,b23,c23,d23
n_terms,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1 term,0.215,-0.096,0.321,0.721,,,,,,,...,,,,,,,,,,
8 terms,0.215,-0.096,0.321,0.721,-0.352,0.39,-0.083,-0.075,-0.045,0.139,...,,,,,,,,,,
16 terms,0.215,-0.096,0.321,0.721,-0.352,0.39,-0.083,-0.075,-0.045,0.139,...,,,,,,,,,,
24 terms,0.215,-0.096,0.321,0.721,-0.352,0.39,-0.083,-0.075,-0.045,0.139,...,-0.0,0.001,0.001,-0.0,-0.001,-0.003,-0.001,0.001,0.001,0.001


## COOL! So how can we use this in the real world?

<img src="resources/fourier_workflow.png"/>

In [127]:
cell_2d_contour = np.load("resources/cell_contour.npy")
cell_2d_contour.shape

(713, 2)

In [128]:
coeffs = pyefd.elliptic_fourier_descriptors(cell_2d_contour, order=n_terms)
a0, c0 = pyefd.calculate_dc_coefficients(cell_2d_contour)

How do we know when our reconstruction is "good"? We use **reconstruction error**! This measures the difference between the reconstructed contour and the original. Here we match nearest points and compute a mean squared error. There are many other options. 

In [129]:
from viz import get_two_param_2d_fig

fig = get_two_param_2d_fig(coeffs, a0, c0, 
                           cell_2d_contour, 
                           n_points, 
                           n_terms, 
                           show_recon_err=True,
                           set_aspect_ratio=True)
fig.show()

# Intro to Spherical Harmonics

Spherical harmonics are special functions defined on the sphere. Using spherical harmonics is one of a few techniques we can use to efficiently represent 3D shapes (surfaces). **The way we can use spherical harmonics is analogous to the way we used the Fourier Transform to approximate functions**. The theory of spherical harmonics states *any spherical function $f(\theta, \phi)$ can be decomposed as the sum of its harmonics*:

$f(\theta, \phi) = \sum_{l=0}^{\infty} \sum_{m=-l}^{m=l} a_{lm} Y_l^m(\theta, \phi)$

A spherical harmonic representation is composed of the coefficients associated with these functions. 


<img src="resources/shcoeff_workflow.png"/>

Practical notes
- $\text{L}_{\text{max}}$ is analogous to "number of terms"
- Spherical harmonics are most appropriate in our domain to describe relatively simple, closed forms

In [None]:
from aicsimageio import AICSImage

sample_cell_img = AICSImage("resources/416089.tiff").data.squeeze()

In [None]:
%%time
from utils import get_mesh_from_series
from aicsshparam import shparam, shtools

MAX_LMAX = 16
recon_errors = []
recon_meshes = []
for l in range(1,MAX_LMAX+1):
    (coeffs, grid_rec), (image_, mesh, grid, transform) = shparam.get_shcoeffs(image=sample_cell_img[0,:,:,:], 
                                                                              lmax=l)
    shcoeffs_mesh = get_mesh_from_series(coeffs,l)
    mse = shtools.get_reconstruction_error(grid, grid_rec)
    recon_errors.append(mse)
    pv.wrap(shcoeffs_mesh).save(f"output/recon-0{l}.vtk")
    recon_meshes.append(shcoeffs_mesh)

In [None]:
gt_mesh, _, _ = shtools.get_mesh_from_image(sample_cell_img[0,:,:,:])
gt_mesh = pv.wrap(gt_mesh)
gt_mesh = gt_mesh.translate((-200, 0, 0), inplace=True)

In [None]:
from viz import interactive_reconstruction_plot

interactive_reconstruction_plot(recon_errors, recon_meshes)

# Using spherical harmonics on toy dataset

In [None]:
import pandas as pd
import vtk

In [None]:
# Create toy shape dataset s.t. each shape has volume~=1.0
base_cube = pv.Cube()
base_cylinder = pv.Cylinder(radius=0.564)
base_cone = pv.Cone(height=2.0, radius=0.7596)

In [None]:
base_cone.volume

In [None]:
base_cube.volume

In [None]:
base_cylinder.volume

In [None]:
plotter = pv.Plotter(window_size=[900,400], shape=(1,3)) 
plotter.subplot(0,0)
plotter.add_mesh(base_cube, color='lightgray', show_edges=True)
plotter.subplot(0,1)
plotter.add_mesh(base_cylinder, color='lightgray', show_edges=True)
plotter.set_background('white')
plotter.subplot(0,2)
plotter.add_mesh(base_cone, color='lightgray', show_edges=True)
plotter.set_background('white')
plotter.show(jupyter_backend='pythreejs')

In [None]:
from aicsshparam import shparam, shtools
from utils import get_image_from_polydata, get_mesh_from_series

LMAX = 16

base_cube_im = get_image_from_polydata(base_cube)
(base_cube_shcoeffs, _), (_, _, _, _) = shparam.get_shcoeffs(base_cube_im, \
                                                             LMAX)
base_cube_shcoeffs_mesh = get_mesh_from_series(base_cube_shcoeffs,LMAX)

base_cyl_im = get_image_from_polydata(base_cylinder)
(base_cyl_shcoeffs, _), (_, _, _, _) = shparam.get_shcoeffs(base_cyl_im, \
                                                             LMAX)
base_cyl_shcoeffs_mesh = get_mesh_from_series(base_cyl_shcoeffs,LMAX)

base_cone_im = get_image_from_polydata(base_cone)
(base_cone_shcoeffs, _), (_, _, _, _) = shparam.get_shcoeffs(base_cone_im, \
                                                             LMAX)
base_cone_shcoeffs_mesh = get_mesh_from_series(base_cone_shcoeffs,LMAX)

In [None]:
from viz import get_recon_mesh_plotter

In [None]:
pl = get_recon_mesh_plotter(base_cube, base_cube_shcoeffs_mesh)
pl.show(jupyter_backend='pythreejs')

In [None]:
pl = get_recon_mesh_plotter(base_cylinder, base_cyl_shcoeffs_mesh)
pl.show(jupyter_backend='pythreejs')

In [None]:
pl = get_recon_mesh_plotter(base_cone, base_cone_shcoeffs_mesh)
pl.show(jupyter_backend='pythreejs')

In [None]:
base_cube_im.shape

In [None]:
np.power(base_cube_im.shape[0],3)

In [None]:
N_EXAMPLES = 20
N_COEFFS = len(pd.Series(base_cube_shcoeffs))

np.random.seed(20)
cube_noise = np.random.normal(0.0, 0.1, [N_COEFFS * 20])

np.random.seed(21)
cyl_noise = np.random.normal(0.0, 0.1, [N_COEFFS * 20])

np.random.seed(22)
cone_noise = np.random.normal(0.0, 0.1, [N_COEFFS * 20])

In [None]:
cubes = [(cube_noise[i*N_COEFFS:i*N_COEFFS+N_COEFFS] + pd.Series(base_cube_shcoeffs)).to_numpy() for i in range(0,N_EXAMPLES)]
cylinders = [(cyl_noise[i*N_COEFFS:i*N_COEFFS+N_COEFFS] + pd.Series(base_cyl_shcoeffs)).to_numpy() for i in range(0,N_EXAMPLES)]
cones = [(cone_noise[i*N_COEFFS:i*N_COEFFS+N_COEFFS] + pd.Series(base_cone_shcoeffs)).to_numpy() for i in range(0,N_EXAMPLES)]

In [None]:
all_shapes = np.vstack([cubes, cylinders, cones])
labels = ["cube"] * 20 + ["cylinder"] * 20 + ["cone"] * 20

In [None]:
all_shapes.shape

In [None]:
from sklearn.decomposition import PCA

pca = PCA(2)
pca = pca.fit(all_shapes)
axes = pca.transform(all_shapes)

In [None]:
from viz import get_pca_result_fig

fig = get_pca_result_fig(axes, labels)

fig.show()

In [None]:
pca_df = pd.DataFrame({"PC1":axes[:,0], "PC2":axes[:,1], "shape":labels})
pca_df.head()

In [None]:
cube_centroid = pca_df[pca_df["shape"] == "cube"].mean().values
cone_centroid = pca_df[pca_df["shape"] == "cone"].mean().values

x = np.vstack([cube_centroid, cone_centroid])[:,0]
y = np.vstack([cube_centroid, cone_centroid])[:,1]

In [None]:
from scipy.interpolate import interp1d

# Latent walk in PCA space: get equally spaced points along a line connecting our centroids
n_steps = 5
distance = np.cumsum(np.sqrt(np.ediff1d(x, to_begin=0)**2 + np.ediff1d(y, to_begin=0)**2))
distance = distance/distance[-1]
fx, fy = interp1d(distance, x), interp1d(distance, y)
alpha = np.linspace(0, 1, n_steps)
latent_line_x, latent_line_y = fx(alpha), fy(alpha)

In [None]:
from viz import get_pca_clust_latent_walk_fig

fig = get_pca_clust_latent_walk_fig(axes, latent_line_x, latent_line_y, labels)
fig.show()

In [None]:
latent_walk_meshes = []
for i in range(n_steps):
    shcoeffs_i = pca.inverse_transform(np.array([latent_line_x[i],latent_line_y[i]]).reshape(1,2))
    shcoeffs_dict = dict(zip(list(base_cube_shcoeffs.keys()),list(shcoeffs_i.squeeze())))
    recon_mesh = get_mesh_from_series(shcoeffs_dict,16)
    latent_walk_meshes.append(recon_mesh)

In [None]:
pl = pv.Plotter(window_size=[900,300], shape=(1,5))
pl.set_background("white")

for i in range(n_steps):
    pl.subplot(0,i)
    pl.add_mesh(latent_walk_meshes[i], color="lightgrey")
    pl.add_title(f"Latent index {i}", font_size=8)
pl.show()

## Conclusion

<img src="resources/variance_paper_fig.png" width="80%"/>
TODO
- link to resources
- possibly mention additional notebook on real data