In [None]:
import os
os.environ['JAX_ENABLE_X64'] = '1'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.5'

In [None]:
import numpy as np
import jax.numpy as jnp
from temgym_core.components import Detector
from temgym_core.gaussian import make_gaussian_image, GaussianRay
from temgym_core.source import ParallelBeam
from temgym_core.components import Component
import jax_dataclasses as jdc
from temgym_core.ray import Ray

In [None]:
%matplotlib widget
import matplotlib.pyplot as plt

In [None]:
# Detector setup
pixel_size = 5e-4
shape = (512, 512)

@jdc.pytree_dataclass
class CrazyLens(Component):
    z: float
    focal_length: float

    def __call__(self, ray: Ray):
        f = self.focal_length

        x, y, dx, dy = ray.x, ray.y, ray.dx, ray.dy

        new_dx = (-x ** 2) / f + dx
        new_dy = (-y ** 2) / f + dy

        pathlength = ray.pathlength - (x**2 + y**2) / (2 * f)
        one = ray._one * 1.0

        return Ray(
            x=x, y=y, dx=new_dx, dy=new_dy, _one=one, pathlength=pathlength, z=ray.z
        )

In [None]:
ray = Ray(x=0.000, y=0.00, dx=0., dy=0., _one=1.0, pathlength=0.0, z=0.0)
new_ray = CrazyLens(z=0.0, focal_length=5e-3)(ray)

import jax
from temgym_core.utils import custom_jacobian_matrix
new_ray_abcd = jax.jacobian(CrazyLens(z=0.0, focal_length=5e-3))(ray)
new_ray_abcd = custom_jacobian_matrix(new_ray_abcd)

print(new_ray_abcd)


In [None]:
# Model Creation
f = 5e-3
defocus = 2e-1
z2 = (1 / f) ** -1 + defocus
model = [CrazyLens(z=0.0, focal_length=f), Detector(z=z2, pixel_size=(pixel_size, pixel_size), shape=shape)]

In [None]:
num_rays = 1
wavelength = 1e-8
wo = 0.5e-3
k = 2 * jnp.pi / wavelength

beam = ParallelBeam(0., 0.5e-3)
base_rays = beam.make_rays(num_rays, random=False)
num_rays = jnp.size(base_rays.x)

# Gaussian Beam Input
xs = jnp.array(np.asarray([base_rays.x]))
ys = jnp.array(np.asarray([base_rays.y]))
dxs = jnp.array(np.asarray([base_rays.dx]))
dys = jnp.array(np.asarray([base_rays.dy]))
zs = jnp.array(np.zeros(num_rays))
pathlengths = jnp.array(np.zeros(num_rays))
ones = jnp.array(np.ones(num_rays))
amplitudes = jnp.array(np.ones(num_rays))
radii_of_curv = jnp.array(np.full((num_rays, 2), np.inf))
theta = jnp.array(np.zeros(num_rays))
wavelength = wavelength
wavelengths = jnp.array(np.full((num_rays,), wavelength))
wo = wo
waist_xy = jnp.array(np.full((num_rays, 2), wo))


rays = GaussianRay(
    x=xs, 
    y=ys, 
    dx=dxs,
    dy=dys, 
    z=zs,
    pathlength=pathlengths, 
    _one=ones, 
    amplitude=amplitudes, 
    waist_xy=waist_xy,  # 1x2 per Gaussian Ray
    radii_of_curv=radii_of_curv,  # 1x2 per Gaussian Ray
    wavelength=wavelengths, 
    theta=theta,
)

In [None]:
det_image = make_gaussian_image(rays, model, batch_size=1000)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
img = np.abs(det_image)
ax1.imshow(img, cmap="gray")
ax2.plot(img[img.shape[0]//2, :])

fig, ax = plt.subplots()
ax.imshow(np.angle(det_image), cmap="viridis", vmin=-np.pi, vmax=np.pi)