In [4]:
import temgymbasic.jax_components as comp
from temgymbasic.jax_ray import Ray, propagate
import jax.numpy as jnp
import jax

In [5]:

# Run the model
@jax.jit #Unsure if I need jit here also????
def run_model(ray, model):
    for component in model:

        distance = component.z - ray.z
        ray = propagate(distance, ray)
        ray = component.step(ray)

    return ray.matrix

model = [comp.Lens(z= 0.5, focal_length=0.1),
         comp.Detector(z=1.0, pixel_size=0.01, shape=(100, 100))]

n_rays = 4
n_rays_dim = int(jnp.sqrt(n_rays))

input_amplitude = 1.0
input_pathlength = 0.0
input_wavelength = 1.0

# Prepare multiple input rays (each row is a separate ray)
r1mx = jnp.linspace(-0.1, 0.1, n_rays_dim) #x off set
r1my = jnp.linspace(-0.1, 0.1, n_rays_dim) #y off set
r1m = jnp.stack(jnp.meshgrid(r1mx, r1my), axis=-1).reshape(2, -1)

theta1mx = jnp.linspace(-0.0, 0.0, n_rays_dim) #dx off set
theta1my = jnp.linspace(-0.0, 0.0, n_rays_dim) #dy off set
theta1m = jnp.stack(jnp.meshgrid(theta1mx, theta1my), axis=-1).reshape(2, -1)

#Create ray input
rays_input = jnp.vstack([r1m[0, :], r1m[1, :], theta1m[0, :], theta1m[1, :], jnp.ones((n_rays))]).T

def create_ray(ray_data):
    return Ray(z=0.0, 
               matrix=ray_data,
               amplitude=input_amplitude,
               pathlength=input_pathlength,
               wavelength=input_wavelength)

# vmap over run_model to process each ray in rays_input
batched_run_model = jax.jit(jax.vmap(lambda r: run_model(create_ray(r), model)))
ray_outputs = batched_run_model(rays_input)

# # Partial derivatives of ray output with respect to ray input
batched_dro_dri = jax.jit(jax.vmap(lambda r: jax.jacobian(run_model, argnums=0)(create_ray(r), model)))
dro_dri = batched_dro_dri(rays_input)

ABCD = dro_dri.matrix

# Extract ABCD matrices
A = ABCD[:, :2, :2]
B = ABCD[:, :2, 2:4]
C = ABCD[:, 2:4, :2]
D = ABCD[:, 2:4, 2:4]

In [6]:
focal_length = model[0].focal_length
expected_C = -1 / focal_length * jnp.eye(2)
expected_C = jnp.repeat(expected_C[None, :, :], C.shape[0], axis=0)

# Verify that the 2x2 matric C is equal to -1/f on each diagonal element
assert jnp.allclose(C, expected_C), "C does not match -1/f * I"
print("Test passed: C is equal to -1/f * I.")

Test passed: C is equal to -1/f * I.


In [7]:
# # Partial derivatives of ray output with respect to model parameters
batched_dro_dri = jax.jit(jax.vmap(lambda r: jax.jacobian(run_model, argnums=1)(create_ray(r), model)))
dro_dmodel = batched_dro_dri(rays_input)

print(dro_dmodel)


[Lens(z=Array([[-1., -1.,  0.,  0.,  0.],
       [-1.,  1.,  0.,  0.,  0.],
       [ 1.,  1.,  0.,  0.,  0.],
       [-1.,  1.,  0.,  0.,  0.]], dtype=float32, weak_type=True), focal_length=Array([[-4.9999995, -4.9999995, -9.999999 , -9.999999 , -0.       ],
       [-4.9999995,  4.9999995, -9.999999 ,  9.999999 ,  0.       ],
       [ 4.9999995,  4.9999995,  9.999999 ,  9.999999 ,  0.       ],
       [-4.9999995,  4.9999995, -9.999999 ,  9.999999 ,  0.       ]],      dtype=float32, weak_type=True)), Detector(z=Array([[ 1.,  1.,  0.,  0.,  0.],
       [ 1., -1.,  0.,  0.,  0.],
       [-1., -1., -0., -0., -0.],
       [ 1., -1.,  0.,  0.,  0.]], dtype=float32, weak_type=True), pixel_size=Array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32, weak_type=True), shape=(100, 100), rotation=Array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float