## Opti-jax DPC example.

* This example shows how to use opti-jax to fit DPC/QPM microscopy data.
* This example also fits for the pupil function.


In [None]:
import jax.numpy as jnp
import math
import matplotlib.pyplot as plt
import numpy as np
import tifffile

import opti_jax.optics_dpc as optics_dpc
import opti_jax.zernike as zernike


In [None]:
# Create opti-jax DPC solver object.
#

# NA is the objective numerical aperture.
# pixelsize is the camera pixel size at the image plane of the microscope objective.
# shape (X,Y) is the image size.
# wavelength is the center wavelength of the detection.

odpc = optics_dpc.OpticsDPCVp(NA = 0.8, pixelSize = 0.1, shape = (128, 128), wavelength = 0.514)

rng = np.random.default_rng()

### Load test image

In [None]:
# We will use a USAF 1951 test target example image.
#
img = tifffile.imread("1951usaf_test_target.tif").astype(float)

c0 = img.shape[0]//2 - 2
c1 = img.shape[1]//2 + 22
hw = 64
slc1 = slice(c0-hw,c0+hw)
slc2 = slice(c1-hw,c1+hw)

# Phase image.
img = 0.8*np.exp(1j*0.5*np.pi*img/330.0)

img = img[slc1,slc2]

fig, axs = plt.subplots(1, 2, figsize = (10, 5))
axs[0].imshow(odpc.intensity(img), cmap = "gray", vmin = 0.0, vmax = 1.0)
axs[1].imshow(np.angle(img), cmap = "gray", vmin = -1.0, vmax = 1.0)
plt.show()


### Create DPC illumination patterns and check patterned illumination results.

In [None]:
# Figure out k space shift values to use.
#
# As an optimization, and also to reduce edge effects, this solver shifts the 
# current best fit image in k space by integer amounts. It is possible to use
# arbitrary k values but this is slower and you will see ringing at the edge
# of the image.
#

# Print the maximum k values in X/Y based on the objective numerical aperture
#
maxkv = odpc.kvalue_range()
print(maxkv)

# Create an array of integers that cover the k space range.
#
ikv = np.arange(-int(math.ceil(maxkv[0])), int(math.ceil(maxkv[0])+1), 5)
print(ikv, len(ikv))


In [None]:
# DPC illumination patterns.
#
# Depending on the DPC data sequence this may need to be adjusted.
# For non-square images you define the k values for each axis separately.
#
pats = odpc.make_dpc_patterns(odpc.NA, ikv, ikv)

odpc.plot_patterns(pats)


In [None]:
# Spherical aberration pupil function

pfsa = zernike.zern_poly(odpc, 10)

plt.imshow(pfsa, cmap = "gray")
plt.show()


In [None]:
# Use the test image to generate the DPC images that we will fit.

xrc = [jnp.array(img.real), jnp.array(img.imag), 0.5*jnp.array(pfsa)]
Y = odpc.y_pred(xrc, pats)

fig, axs = plt.subplots(1, 4, figsize = (20, 5))
for i in range(4):
    axs[i].imshow(Y[i], cmap = "gray")
plt.show()


In [None]:
# Add noise for additional realism, noise is at aboue the 1% level.
Yn = Y + rng.normal(scale = 0.006, size = Y.shape)


### Solve and check results

In [None]:
# Solve for the object whose illumination best matches the DPC images.
#
# This solves with total variation regularization.
# lval is the weight of the TV term.
# order is the order of the TV term (1 = first derivative, 2 = second derivative).
#
x, stats = odpc.solve_tv(Yn, pats, lval = 1.0e-5, lvalp = 1.0e-3, order = 2, learningRate = 1.0e-1)


In [None]:
# Plot convergence.
odpc.plot_stats(stats)


In [None]:
# Plot the objects pixel intensities in the complex plane.
odpc.plot_x(x)


In [None]:
# Compare the objects amplitude and phase to the ground truth image.
#
fig, axs = plt.subplots(2, 3, figsize = (12, 8))
axs[0,0].imshow(jnp.abs(img), cmap = "gray", vmin = 0.0, vmax = 1.0)
axs[0,1].imshow(jnp.angle(img), cmap = "gray")
axs[0,2].imshow(pfsa - pfsa[0,0], cmap = "gray")

axs[1,0].imshow(jnp.abs(x[0] + 1j*x[1]), cmap = "gray", vmin = 0.0, vmax = 1.0)
axs[1,1].imshow(jnp.angle(x[0] + 1j*x[1]), cmap = "gray")
axs[1,2].imshow(x[2] - x[2][0,0], cmap = "gray")

plt.show()


In [None]:
# Check how well the DPC images of the object match the target images.
#
odpc.plot_fit_images(Y, x, pats)
