## Opti-jax Fourier Phytography example.

* This example shows how to use opti-jax to fit Fourier Phytography microscopy data.


In [None]:
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import tifffile

import opti_jax.optics_fpty as optics_fpty


In [None]:
# Create opti-jax DPC solver object.
#

# NA is the objective numerical aperture.
# pixelsize is the camera pixel size at the image plane of the microscope objective.
# shape (X,Y) is the image size.
# wavelength is the center wavelength of the detection.

ofpty = optics_fpty.OpticsFPty(NA = 0.3, pixelSize = 0.1, shape = (128, 128), wavelength = 0.514)


### Load test image

In [None]:
# We will use a USAF 1951 test target example image.
#
img = tifffile.imread("1951usaf_test_target.tif").astype(float)

c0 = img.shape[0]//2 - 2
c1 = img.shape[1]//2 + 22
hw = 64
slc1 = slice(c0-hw,c0+hw)
slc2 = slice(c1-hw,c1+hw)

# Phase image.
img = 0.8*np.exp(1j*0.5*np.pi*img/330.0)

img = img[slc1,slc2]

fig, axs = plt.subplots(1, 2, figsize = (10, 5))
axs[0].imshow(ofpty.intensity(img), cmap = "gray", vmin = 0.0, vmax = 1.0)
axs[1].imshow(np.angle(img), cmap = "gray", vmin = -1.0, vmax = 1.0)
plt.show()


### Create FPty illumination vectors and plot.

In [None]:
# Figure out k space shift values to use.
#
# As an optimization, and also to reduce edge effects, this solver shifts the 
# current best fit image in k space by integer amounts. It is possible to use
# arbitrary k values but this is slower and you will see ringing at the edge
# of the image.
#

# Print the maximum k values in X/Y based on the objective numerical aperture
#
print(ofpty.kvalue_range())

# Create illumination vectors.
#
ikv = [-18,-12,-6,0,6,12,18]
print(ikv)
print(np.array(ikv)*ofpty.dk0)

rxy = []
for v0 in ikv:
    for v1 in ikv:
        rxy.append([v0,v1])
rxy = jnp.array(rxy)

# Plot vectors in k space.
ofpty.plot_illumination_vectors(rxy)

In [None]:
# Use the test image to generate the images that we will fit.

xrc = [jnp.array(img.real), jnp.array(img.imag)]
Y = ofpty.fpty_illumination(xrc, rxy)

print(Y.shape)
nx = 7
ny = 7
fig, axs = plt.subplots(7, 7, figsize = (nx*3, ny*3))
for i in range(nx):
    for j in range(ny):
        axs[i,j].imshow(Y[i*nx+j], cmap = "gray")
        axs[i,j].set_xticks([])
        axs[i,j].set_yticks([])

plt.show()


### Solve and check results

In [None]:
# Solve for the object whose illumination best matches the FT images.
#
# This solves with total variation regularization.
# lval is the weight of the TV term.
# order is the order of the TV term (1 = first derivative, 2 = second derivative).
#
x, stats = ofpty.tv_solve(Y, rxy, lval = 1.0e-6, order = 2, learningRate = 1.0e-1)


In [None]:
# Plot convergence.
ofpty.plot_stats(stats)


In [None]:
# Plot the objects pixel intensities in the complex plane.
ofpty.plot_x(x)


In [None]:
# Compare the objects amplitude and phase to the ground truth image.
#
fig, axs = plt.subplots(2, 3, figsize = (12, 8))
axs[0,1].imshow(jnp.abs(img), cmap = "gray", vmin = 0.0, vmax = 1.0)
axs[0,2].imshow(jnp.angle(img), cmap = "gray")
axs[1,0].imshow(jnp.mean(Y, axis = 0), cmap = "gray")
axs[1,1].imshow(jnp.abs(x[0] + 1j*x[1]), cmap = "gray", vmin = 0.0, vmax = 1.0)
axs[1,2].imshow(jnp.angle(x[0] + 1j*x[1]), cmap = "gray")
plt.show()
