# SerialTomo Quickstart
**SerialTomo** is a python library designed to aid research in serial section tomography. The key functionality is built on top of *Jax*, a high-performance numpy-like library that takes advantage of GPU acceleration in addition to providing automatic differentiation capabilities.

In this guide we'll use SerialTomo to simulate and reconstruct one section from a dual-tilt series. This will allow us to highlight 4 key functions
- ```project``` Applies a radon transform to a 3D volume. The output is differentiable with respect to volume, tilt angles, and tilt axes
-  ```alignstacks``` Coarsely aligns multiple adjacent tilt series into a *linogram* representation. Uses SIFT to find correspondences and RANSAC to estimate a projective transformation between pairs.
-  ```minimize``` Performs gradient descent with auto learning rate tuning
- ```viewstack``` Creates a widget that allows for scrolling through 3D stacks in a Jupyter Notebook

Let's get started!

## Setup

In [None]:
%matplotlib widget
%load_ext autoreload
%autoreload 2

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

# cuda functionality

# imports
import jax 
import jax.numpy as jnp
from jax import value_and_grad, jit

import tifffile
import numpy as np
import matplotlib.pyplot as plt

# serialtomo functions
from serialtomo.align import alignstacks
from serialtomo.project import project
from serialtomo.minimize import minimize
from serialtomo.visualize import viewstack

In [None]:
# first we need a dataset for our simulation. We'll load a FIBSEM volume of fly brain from Janelia
volume = tifffile.imread('../example_data/density.tif').transpose(2,1,0).astype('float32')
volume = volume / volume.std()
print(f'{volume.shape=}')

## ```viewstack```

We'll show the volume using the `viewstack` utility. This is built on top of the python library *stackview*.  
The **SerialTomo** interface adds a few convenience utilities to easily adjust the size, change the view and adjust brightness & contrast

In [None]:
viewstack(volume,view='xy',size=1/3,pmin=1.0,pmax=99.)

## ```project```

This function applies a Radon transform to a 3D volume. The output can be differentiated with respect to the volume, the tilt angle and the tilt axis.  
**Warning** This Radon transform preserves the size and global rotation for all tilt angles/axes, unlike conventional implementations of discrete radon transforms.

We diagram this "stretched" Radon transform below

<!-- ![image-formation](../images/tutorial/image-formation.png =250x250) -->

<img src="../images/quickstart/image-formation.png" width="900">
To generate a tilt image, the volume is held fixed fixed and the detector is displaced by amount determined by the tilt angle $\theta$ and tilt axis $\phi$. (In particular by $L \tan(\theta) cos(\phi)$ in the $x$ direction, and $L \tan(\theta) sin(\phi)$ in the y direction, where $L$ is the distance between the center of the volume and the detector plane).

Rays are cast from the center (in the xy plane) of each central (in the depth direction) voxel down to the center of each detector pixel. The volume is interpolated and summed along each ray to generate a tilt image. This operation preserves the size of the image for all tilt angles. Note that because the rays are always spaced 1 pixel apart in the volume, this operation is a convolution (and it is in fact implemented as a sparse convolution in the `project` function).

<!-- $$ T[i,j] = \sum_k x(r[i,j,k,\theta,\phi])$$ -->

Now let's use this operator to simulate a dual tilt series from the FIBSEM volume, with one tilt axis $\phi=0^\circ$ and the other at $\phi=90^\circ$

In [None]:
# choose tilt-angles 
tilt_angles = np.linspace(-45,45,45)

# apply the projection operator over two axes
tilts_a = project(volume, tilt_angles, tilt_axes=0.0)
tilts_b = project(volume, tilt_angles, tilt_axes=90.0)

# The output of project is a jax array of size n_tilts x height x width
print(f'{tilts_a.shape=}, {type(tilts_a)=}')

In [None]:
# concatenate and view the tilt series
viewstack(np.concatenate([tilts_a,tilts_b]),size=1/3,pmin=0,pmax=100)

Observe that borders are darker, because the rays are not totally contained in the volume near the edges. 

In [None]:
# crop the borders because border rays are not perfectly perfectly contained in the volume if they are obliquely oriented
tilts_a = tilts_a[:,100:-100,100:-100]
tilts_b = tilts_b[:,100:-100,100:-100]

Seemingly a bigger issue is that because our project operator does not stretch or rotate, these images do not actually look like the ones that would be observe din a typical real-life experiment.

This is not a problem. If one wishes to implement stretching and rotation, the images can just be rotated and stretched after the project function

In [None]:
from jax.scipy.ndimage import map_coordinates

def foreshorten_rotate(image, tilt_angle, tilt_axis):
    """ Rotate the image so the tilt_axis is horizontal and foreshorten by cos(tilt_angle)"""
    coords = jnp.mgrid[:image.shape[0],:image.shape[1]].astype('float32')
    center = coords.mean(axis=(1,2),keepdims=True)
    coords -= center
    
    # stretch the coordinates about the center of the image
    theta = np.pi / 180 * tilt_angle
    stretch = np.array([1,jnp.cos(theta)]).reshape((2,1,1))
    coords /= stretch

    # rotate the coordinates
    phi = np.pi / 180 * tilt_axis
    rot_mat = jnp.array([[jnp.cos(phi), jnp.sin(phi)],[-jnp.sin(phi), jnp.cos(phi)]])
    coords = jnp.einsum('ij,jyx->iyx', rot_mat, coords)

    # uncenter
    coords += center
    
    # interpolate via bilinear interpolation
    return map_coordinates(image, coords, order=1)

# rotate and foreshorten tilts
tilts_a = np.array([foreshorten_rotate(tilt, angle, 0.0) for tilt, angle in zip(tilts_a, tilt_angles)])
tilts_b = np.array([foreshorten_rotate(tilt, angle, 90.0) for tilt, angle in zip(tilts_b, tilt_angles)])

In [None]:
# concatenate and view the tilt series
viewstack(np.concatenate([tilts_a,tilts_b]),size=1/3,pmin=0,pmax=100)

Now we have a more realistic simulation so its time to move on to the reconstruction

## ```alignstacks```

This function align multiple stack to each other. These stacks can come from either from multiple axes or multiple adjacent sections, though aligning adjacent section is very challenging for thick sections.  
**Warning** This function is likely to undergo signifcant implementation (and possibly API) changes, as robust alignment is a challenging task especially across thick sections, and we are actively working to improve registration quality.  Additionally we're not using using GPU for this

This method uses SIFT keypoints to register pairs of sections. It then uses RANSAC to estimate a projective transform (characterized by a 3x3 coordinate transformation matrix) from the SIFT keypoints. Within each stack sections are registered sequentially, and the transformation matrices are composed so that each section is registered to the central section (the presumed $0^\circ$ tilt). These central sections are in turn registered to stack specified by the `ref_idx` parameter of alignstacks.

<img src="../images/quickstart/alignment-strategy.png" width="600" height="500">


In [None]:
[aligned_a, aligned_b], info = alignstacks([tilts_a,tilts_b], downsample=4, ref_idx=0)
# downsampling is used to speed up registration, the transformed stack is still full-res
# 

In [None]:
viewstack(np.concatenate([aligned_a,aligned_b]), size=1/3, pmin=0, pmax=100)

#### Linogram vs Sinogram

This alignment has unstretched and derotated the image. In other words it has undone the `foreshorten_rotate` function we applied after `project`-ing. It has given us what is known as a *linogram* representation. We can understand the origins of the term by looking at an 'xz' slice of the aligned stack.

In [None]:
# linogram representation
viewstack(aligned_a, view='xz',size=2)

Contrast the linogram with the sinogram, the traditional representation of a tilt series, where features take large sinusoidal trajectories through the tilt series

In [None]:
# sinogram representation
viewstack(tilts_a, view='xz',size=2)

## ```minimize```

Now we're ready to generate a reconstruction. To do so we will create an energy function which is just the sum-of-squared errors between predicted and "measured" (i.e. simulated) tilts. We will minimize this function with the `minimize` utility provided by SerialTomo.

This function performs gradient descent using backtracking line search to estimate a good step size at each iteration. Constraints can be incorporated (such as non-negativity) and multiple parameter groups can be simulateously optimized over, each with their own learning rates.

In [None]:
# construct the energy function
def energy(volume):
    # generate predictions for each stretched tilt series
    pred_a = project(volume, tilt_angles, tilt_axes=0.0)
    pred_b = project(volume, tilt_angles, tilt_axes=90.0)
    
    # squared error 
    err = ((pred_a - aligned_a)**2).sum() + ((pred_b - aligned_b)**2).sum()

    return err

func = value_and_grad(energy) # func(volume) returns (energy, d_energy/d_volume)
func = jit(func) # just-in-time compilation for improved speed and reduced memory consumption, though the first call may be quite slow

In [None]:
# initialize the reconstruction to be all zeros
init_reconstruction = jnp.zeros((64,1000,1000))

# perform the minimization. 
reconstruction, info = minimize(func, init_reconstruction, maxiter=20)
# regularization is provided by early stopping (i.e. by only performing at most maxiter=20 updates)

Before analyzing the reconstruction, its worth mentioning that there are many methods that can be used to optimize the energy.
- hand implementing gradient descent or stochastic gradient descent with hand tuned step sizes
- Tensorflow L-BFGS
- Optax for SGD and variants

These may be worth checking out, but the `minimize` function has proven quite robust in experiments thus far.

#### Analysis of reconstructions

Now let's visualize the reconstructions

In [None]:
# XY slices from the reconstruction
viewstack(reconstruction,size=1/2)

In [None]:
# YZ slices
viewstack(reconstruction, view='yz')

In [None]:
# Compare the real volume and the reconstructed volume
# if one provides viewstack with two stacks, it overlays them in a curtain fashion.
viewstack(volume[:,100:-100,100:-100], reconstruction, size=1/2, view='xy')

#### Analysis of errors

We can compare the predicted tilts to the reconstructed tilts

In [None]:
pred_a = project(reconstruction, tilt_angles, tilt_axes=0.0)
viewstack(pred_a, aligned_a, size=1/2, view='xy')

#### Solve for volume and tilt axes
Before concluding, let's do something more complicated. What if we guessed the initial tilt axes incorrectly? Let's just throw them into the energy and let the minimize solve for them

In [None]:
# construct the energy function
def energy(params):
    # unpack params. params can be any pytree of jax arrays
    volume, tilt_axes = params
    
    # generate predictions for each stretched tilt series
    pred_a = project(volume, tilt_angles, tilt_axes[0])
    pred_b = project(volume, tilt_angles, tilt_axes[1])
    
    # squared error 
    err = ((pred_a - aligned_a)**2).sum() + ((pred_b - aligned_b)**2).sum()

    return err

func = value_and_grad(energy) # func(volume) returns (energy, (d_energy/param for param in params))
func = jit(func) # just-in-time compilation for improved speed and reduced memory consumption, though the first call may be quite slow

In [None]:
# initialize the reconstruction to be all zeros
init_reconstruction = jnp.zeros((64,1000,1000))
init_angles = jnp.array([0.0,80.0])
params = (init_reconstruction, init_angles)

# perform the minimization. 
params, info = minimize(func, params, maxiter=30)
# regularization is provided by early stopping (i.e. by only performing at most maxiter=20 updates)

This example better illustrates the utility of the `minimize` function. The step sizes for the tilt angles are nearly 4 orders of magnitude smaller than for the volume. We could have eventually found this by hand, but this method automatically determined them.

In [None]:
tilt_axes = params[1]
print(f're-estimated axis #1: {tilt_axes[0]} degrees')
print(f're-estimated axis #2: {tilt_axes[1]} degrees')

In [None]:
viewstack(params[0],size=1/2)