In [None]:
import os
os.environ['JAX_ENABLE_X64'] = '1'

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
from temgym_core.components import Lens, Detector
from temgym_core.ray import Ray
from temgym_core.gaussian import get_image, q_inv
from temgym_core.propagator import FreeSpaceParaxial


In [None]:
%matplotlib widget
import matplotlib.pyplot as plt

In [None]:
# Detector setup
pixel_size = 5e-6
shape = (2048, 2048)
wavelength = 1e-4
wo = 0.5e-1
k = 2 * jnp.pi / wavelength

In [None]:
f=5e-3
z2 = (1 / f) ** -1 + 2e-3
model = [Lens(z=0.0, focal_length=f), Detector(z=z2, pixel_size=(pixel_size, pixel_size), shape=shape)]

In [None]:
xs = jnp.array([0.0, 0.5e-3])
ys = jnp.array([0.0, 0.5e-3])
dxs = jnp.array([0.0, 0.0])
dys = jnp.array([0.0, 0.0])
zs = jnp.array([0.0, 0.0])
pathlengths = jnp.array([0.0, 0.0])
ones = jnp.array([1.0, 1.0])

rays = Ray(x=xs, y=ys, dx=dxs, dy=dys, z=zs, pathlength=pathlengths, _one=ones)

In [None]:
amplitudes = jnp.array([1.0, 1.0])
q1_invs = q_inv(jnp.array([0.0, 0.0]), jnp.array([wo, wo]), wavelength)

Q1_invs = jnp.empty((0, 2, 2))
for q1_inv in q1_invs:
    Q1_inv = jnp.eye(2) * q1_inv
    Q1_invs = jnp.append(Q1_invs, Q1_inv[None, ...], axis=0)

det_image = get_image(rays, model, amplitudes, Q1_invs, wavelength)

In [None]:
fig, ax = plt.subplots()
ax.imshow(np.abs(det_image), cmap="gray")