# Microlens array playground

In [None]:
import chromatix.functional as cf
import jax.numpy as jnp
import matplotlib.pyplot as plt

# # Import the microlens array element
# from chromatix.elements import ThinLens, MicroLensArray

### Create a plane wave field

In [None]:
field = cf.plane_wave(
        (512, 512),
        1.0,
        0.532,
        1.0,
        amplitude=cf.linear(1 / 2 * jnp.pi),
        scalar=False,
    )

### Define helpful variables and functions

In [None]:
# Functions
def retrieve_micron_dims(field):
    """Adjust from pixel units to micron units."""
    um_length_x = field.shape[-3] * field.dx.squeeze()[0]
    um_length_y = field.shape[-4] * field.dx.squeeze()[0]
    return [-um_length_x / 2, um_length_x / 2, -um_length_y / 2, um_length_y / 2]


def add_intensity_to_axes(field, ax, extent):
    um_length_x = field.shape[-3] * field.dx.squeeze()[0]
    um_length_y = field.shape[-4] * field.dx.squeeze()[0]
    intensity = field.intensity.squeeze()
    ax.imshow(intensity, extent=extent)
    ax.set_xlabel("microns")
    ax.set_ylabel("microns")
    ax.set_xlim([-um_length_x / 2, um_length_x / 2])
    ax.set_ylim([-um_length_y / 2, um_length_y / 2])
    return ax

# z = 100
# spectrum = 0.532
# n = 1.33
# Nf = (D / 2) ** 2 / (spectrum / n * z)
# source_field = VectorPlaneWave(shape=(512, 512), dx = 0.0001, n = 1, spectrum=spectrum, spectral_density=1.0, k = k, Ep = Ep)
# field = cx.empty_field((N, N), dxi, 0.532, 1.0, polarized=True)
# plane_wave_field = cx.plane_wave(field, pupil=lambda field: cx.square_pupil(field, dxi * N))

wavelength = 0.532
n_medium = 1.33
vectorial = False


D = 256 # 40
N = 256
# dxi = D / N
Q = 5
N_pad = Q * N

phi = 0  # angle between z axis and xy plane
theta = 0  # angle between x and y
k = (
    n_medium
    * 2
    * jnp.pi
    / wavelength
    * jnp.array([jnp.sin(phi) * jnp.sin(theta), jnp.sin(phi) * jnp.cos(theta)])
)  # y and x
Ep = jnp.array((1, 0, 0))

spacing = D / N

### Create a plane wave field that has passed through a circular aperature

In [None]:
field = cf.plane_wave(
        (256, 256),
        1.0,
        0.532,
        1.0,
        amplitude=cf.linear(1 / 2 * jnp.pi),
        scalar=False,
        pupil=lambda field: cf.circular_pupil(field, spacing * N)
    )

In [None]:
source_field = field
extent = retrieve_micron_dims(source_field)
fig, ax = plt.subplots(1, 1)
ax = add_intensity_to_axes(source_field, ax, extent)

In [None]:
## This cell will not run for some reason, which is okay.
## Propogate through MicroLensArray
# model = ThinLens(10.0, 1.33, 0.8)
# field_after_lens = model(field)

In [None]:
scalar_field = cf.plane_wave(
        (256, 256),
        1.0,
        0.532,
        1.0,
        amplitude=cf.linear(1 / 2 * jnp.pi),
        scalar=True,
        pupil=lambda field: cf.circular_pupil(field, spacing * N)
    )

In [None]:
scalar_field.shape

Note: The following microlens array process has a shape incompatibility error. The dimension that is 3 likely indicates that the field is a vector field.

In [None]:
field_after_lens = cf.lenses.microlens_array(
    scalar_field,
    centers=jnp.array([[0], [0]]),
    fs=jnp.array([2500]),
    ns=[1],
    NAs=[0.02]
    )

In [None]:
field_after_first_lens = cf.objective_point_source(
    (128, 128), 0.3, 0.532, 1.0, 0, f=10.0, n=1.0, NA=0.8
)

In [None]:
field_after_first_lens.shape

In [None]:
field_after_mla = cf.lenses.microlens_array(
    field_after_first_lens,
    centers=jnp.array([[0], [0]]),
    fs=jnp.array([2500]),
    ns=[1],
    NAs=[0.02]
)

In [None]:
plt.imshow(field_after_mla.intensity.squeeze())

In [None]:
field_after_mla.phase[0, 64, 64, 0, 0]

In [None]:
plt.imshow(field_after_first_lens.phase.squeeze())
plt.colorbar()

In [None]:
my_phase = field_after_mla.phase[0, :, :, 0, 0]
my_phase = field_after_mla.phase[0, 0, 0, :, :]

In [None]:
field_after_mla.phase.shape

In [None]:
plt.imshow(my_phase.squeeze())
plt.colorbar()

In [None]:
plt.imshow(field.phase.squeeze())

In [None]:
plt.imshow(field_after_lens.intensity.squeeze())

In [None]:
field_after_lens = cf.square_pupil(scalar_field, w=100.0)

In [None]:
field_pt_source = cf.objective_point_source(
    (15, 15), 0.3, 0.532, 1.0, 0, f=10.0, n=1.0, NA=0.8
)

In [None]:
field = field_pt_source

## Explicitly propagating through a microlens array

In [None]:
f = 2500
D = 100
ns = 5
nu = 3


# assert lenslet_len == nu

wavelength = field._spectrum
(_, M, N, _, _) = field.shape
lenslet_len = int(M / ns)
nu = int(nu)
dx = field._dx[0].item()
k = 2 * jnp.pi / wavelength
x = jnp.linspace(-D/2 + dx/2, D/2 - dx/2, lenslet_len)
X, Y = jnp.meshgrid(x, x)
phase = jnp.exp(-1j * k / (2 * f) * (jnp.square(X) + jnp.square(Y)))

In [None]:
phase_dim_ext = jnp.expand_dims(phase, axis=(0, -2, -1))

In [None]:
for s in jnp.arange(ns):
    for t in jnp.arange(ns):
        cell = field.u[:, s*lenslet_len:(s+1)*lenslet_len, t*lenslet_len:(t+1)*lenslet_len]
        # modified_field = (cell * phase).reshape(1, nu, nu, 1, 1)
        # field.u = field.u.at[:, s*lenslet_len:(s+1)*lenslet_len, t*lenslet_len:(t+1)*lenslet_len, :, :].set(cell * phase_dim_ext)
        field.u.at[:, s*lenslet_len:(s+1)*lenslet_len, t*lenslet_len:(t+1)*lenslet_len, :, :].set(cell * phase_dim_ext)
        # field.replace(u=u)

In [None]:
plt.imshow(field.intensity.squeeze())

In [None]:
field_pt_source_plain = cf.objective_point_source(
    (128, 128), 0.3, 0.532, 1.0, 0, f=10.0, n=1.0, NA=0.8
)

In [None]:
plt.imshow(field_pt_source_plain.intensity.squeeze())
plt.colorbar()