## Opti-jax z (focus) stack example.

* This example shows how to use opti-jax to fit a z (focus) stack dataset.


In [None]:
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import tifffile

import opti_jax.optics_zstack as optics_zs


In [None]:
# Create opti-jax DPC solver object.
#

# NA is the objective numerical aperture.
# pixelsize is the camera pixel size at the image plane of the microscope objective.
# shape (X,Y) is the image size.
# wavelength is the center wavelength of the detection.

ozs = optics_zs.OpticsZStack(NA = 0.5, pixelSize = 0.1, shape = (128, 128), wavelength = 0.514)

rng = np.random.default_rng()

### Load test image

In [None]:
# We will use a USAF 1951 test target example image.
#
img = tifffile.imread("1951usaf_test_target.tif").astype(float)

c0 = img.shape[0]//2 - 2
c1 = img.shape[1]//2 + 22
hw = 64
slc1 = slice(c0-hw,c0+hw)
slc2 = slice(c1-hw,c1+hw)

# Amplitude image.
#img = 0.8*((255 - 0.5*img)/255)

# Phase image.
img = 0.8*np.exp(1j*0.5*np.pi*img/330.0)

img = img[slc1,slc2]

fig, axs = plt.subplots(1, 2, figsize = (10, 5))
axs[0].imshow(ozs.intensity(img), cmap = "gray", vmin = 0.0, vmax = 1.0)
axs[1].imshow(np.angle(img), cmap = "gray", vmin = -1.0, vmax = 1.0)
plt.show()


### Create brightfield illumination vectors and plot.

In [None]:
# Figure out k space shift values to use.
#
# As an optimization, and also to reduce edge effects, this solver shifts the 
# current best fit image in k space by integer amounts. It is possible to use
# arbitrary k values but this is slower and you will see ringing at the edge
# of the image.
#

# Print the maximum k values in X/Y based on the objective numerical aperture
#
print(ozs.kvalue_range())

# Create illumination vectors.
#
ikv = np.arange(-12,13,3)
print(ikv)
print(np.array(ikv)*ozs.dk0)
pat, intens = ozs.make_bf_pattern(0.5, ikv, ikv)

# Plot vectors in k space.
ozs.plot_pattern(pat, intens, mscale = 20)

In [None]:
# Use the test image to generate the images that we will fit.

zvals = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
xrc = [jnp.array(img.real), jnp.array(img.imag)]

Y = ozs.y_pred(xrc, [pat, intens, zvals])
print(Y.shape)
fig, axs = plt.subplots(1, 5, figsize = (20, 4))
for i in range(5):
    axs[i].imshow(Y[i], cmap = "gray")
    axs[i].set_xticks([])
    axs[i].set_yticks([])
plt.show()


In [None]:
# Add noise for additional realism, noise is at aboue the 1% level.
Yn = Y + rng.normal(scale = 0.006, size = Y.shape)


### Solve and check results

In [None]:
# Solve for the object whose illumination best matches the Z stack images.
#
# This solves with total variation regularization.
# lval is the weight of the TV term.
# order is the order of the TV term (1 = first derivative, 2 = second derivative).
#
x, stats = ozs.solve_tv(Yn, [pat, intens, zvals], lval = 1.0e-5, order = 2, learningRate = 1.0e-1)


In [None]:
# Plot convergence.
ozs.plot_stats(stats)


In [None]:
# Plot the objects pixel intensities in the complex plane.
ozs.plot_x(x)


In [None]:
# Compare the objects amplitude and phase to the ground truth image.
#
fig, axs = plt.subplots(2, 3, figsize = (12, 8))
axs[0,1].imshow(jnp.abs(img), cmap = "gray", vmin = 0.0, vmax = 1.0)
axs[0,2].imshow(jnp.angle(img), cmap = "gray")
axs[1,0].imshow(jnp.mean(Y, axis = 0), cmap = "gray")
axs[1,1].imshow(jnp.abs(x[0] + 1j*x[1]), cmap = "gray", vmin = 0.0, vmax = 1.0)
axs[1,2].imshow(jnp.angle(x[0] + 1j*x[1]), cmap = "gray")
plt.show()
