In [None]:
import os
os.environ['JAX_ENABLE_X64'] = '1'

In [None]:
import numpy as np
import jax.numpy as jnp
from temgym_core.components import Lens, Detector
from temgym_core.ray import GaussianRay
from temgym_core.gaussian import get_image

In [None]:
%matplotlib widget
import matplotlib.pyplot as plt

In [None]:
# Detector setup
pixel_size = 5e-6
shape = (2048, 2048)

In [None]:
# Model Creation
f=5e-3
z2 = (1 / f) ** -1 + 2e-3
model = [Lens(z=0.0, focal_length=f), Detector(z=z2, pixel_size=(pixel_size, pixel_size), shape=shape)]

In [None]:
num_rays = 1024
wavelength = 1e-4
wo = 0.5e-3
k = 2 * jnp.pi / wavelength

# Gaussian Beam Input
xs = jnp.array(np.random.uniform(-0.5e-3, 0.5e-3, num_rays))
ys = jnp.array(np.random.uniform(-0.5e-3, 0.5e-3, num_rays))
dxs = jnp.array(np.random.uniform(-0.5e-3, 0.5e-3, num_rays))
dys = jnp.array(np.random.uniform(-0.5e-3, 0.5e-3, num_rays))
zs = jnp.array(np.zeros(num_rays))
pathlengths = jnp.array(np.zeros(num_rays))
ones = jnp.array(np.ones(num_rays))
amplitudes = jnp.array(np.ones(num_rays))
radii_of_curv = jnp.array(np.full((num_rays, 2), np.inf))
theta = jnp.array(np.zeros(num_rays))
waist_xy = jnp.array(np.full((num_rays, 2), wo))

rays = GaussianRay(
    x=xs, 
    y=ys, 
    dx=dxs,
    dy=dys, 
    z=zs,
    pathlength=pathlengths, 
    _one=ones, 
    amplitude=amplitudes, 
    waist_xy=waist_xy,  # 1x2 per Gaussian Ray
    radii_of_curv=radii_of_curv,  # 1x2 per Gaussian Ray
    wavelength=wavelength, 
    theta=theta,
)

In [None]:
det_image = get_image(rays, model)

In [None]:
fig, ax = plt.subplots()
ax.imshow(np.abs(det_image), cmap="gray")

fig, ax = plt.subplots()
ax.imshow(np.angle(det_image), cmap="gray")