In [1]:
import temgymbasic.jax_components as comp
from temgymbasic.jax_model import run_model_to_end, run_model_iter
from temgymbasic.jax_ray import Ray
import jax.numpy as jnp
import jax
# from .utils import circular_beam, point_beam

In [12]:
model = [comp.Lens(z= 0.5, focal_length=0.1),
         comp.DoubleDeflector(z=0.325,
                              first=comp.Deflector(z=0.3, def_x=0.1, def_y=0.0), 
                              second=comp.Deflector(z=0.35, def_x=0.1, def_y=0.0)),
         comp.Biprism(z=0.75, deflection=0.001),
         comp.Aperture(z=0.8, radius=0.01),
         comp.Detector(z=1.0, pixel_size=0.01, shape=(100, 100))]

n_rays = 1024
n_rays_dim = int(jnp.sqrt(n_rays))

input_amplitude = 1.0
input_pathlength = 0.0
input_wavelength = 1.0

# Prepare multiple input rays (each row is a separate ray)
r1mx = jnp.linspace(-0.1, 0.1, n_rays_dim) #x off set
r1my = jnp.linspace(-0.1, 0.1, n_rays_dim) #y off set
r1m = jnp.stack(jnp.meshgrid(r1mx, r1my), axis=-1).reshape(2, -1)

theta1mx = jnp.linspace(-0.0, 0.0, n_rays_dim) #dx off set
theta1my = jnp.linspace(-0.0, 0.0, n_rays_dim) #dy off set
theta1m = jnp.stack(jnp.meshgrid(theta1mx, theta1my), axis=-1).reshape(2, -1)

#Create ray input
input_z = jnp.zeros((n_rays))
input_matrix = jnp.vstack([r1m[0, :], r1m[1, :], theta1m[0, :], theta1m[1, :], jnp.ones((n_rays))]).T
input_amplitude = jnp.ones((n_rays))
input_pathlength = jnp.zeros((n_rays))
input_wavelength = jnp.ones((n_rays))

Rays = Ray(z=input_z, 
           matrix=input_matrix, 
           amplitude=input_amplitude, 
           pathlength=input_pathlength, 
           wavelength=input_wavelength)


def create_ray(ray):
    return Ray(z=ray.z, 
               matrix=ray.matrix,
               amplitude=ray.amplitude,
               pathlength=ray.pathlength,
               wavelength=ray.wavelength)

# vmap over run_model to process each ray in rays_input

# Method 1
batched_run_model_one_ray = jax.jit(jax.vmap(lambda r: run_model_to_end(create_ray(r), model)))
all_rays = batched_run_model_one_ray(Rays)

# Method 2
batched_run_model_array_of_rays = jax.jit(jax.vmap(run_model_to_end, in_axes=(0, None)))
all_rays = batched_run_model_array_of_rays(Rays, model)

# Method 3
all_rays = jax.jit(run_model_to_end)(Ray(z=0.0, 
                                         matrix=input_matrix, 
                                         amplitude=1., 
                                         pathlength=0., 
                                         wavelength=1.),
                                         model)

In [13]:
%timeit batched_run_model_one_ray(Rays)

52.9 μs ± 2.71 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [14]:
%timeit batched_run_model_array_of_rays(Rays, model)

125 μs ± 9.46 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [15]:
func = jax.jit(run_model_to_end)
%timeit func(Ray(z=0.0, matrix=input_matrix, amplitude=1., pathlength=0., wavelength=1.), model)

136 μs ± 3.42 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [10]:
# Method 1
batched_dro_dri = jax.vmap(lambda ray: jax.jacobian(run_model_to_end, argnums=0)(ray, model))
dro_dri = batched_dro_dri(Rays)

# # Method 2
batched_run_model_array_of_rays = jax.vmap(jax.jacobian(run_model_to_end), in_axes=(0, None))
dro_dri = batched_run_model_array_of_rays(Rays, model)


# Method 3
dro_dri = jax.jacobian(run_model_to_end, argnums=0)(Ray(z=0.0, matrix=input_matrix, amplitude=1., pathlength=0., wavelength=1.), model)


In [5]:

# batched_run_model = jax.vmap(lambda ray: run_model_iter(ray, model))
# all_rays = batched_run_model(Rays)

# # # Partial derivatives of ray output with respect to ray input
# batched_dro_dri = jax.vmap(lambda ray: jax.jacobian(run_model_to_end, argnums=0)(ray, model))
# dro_dri = batched_dro_dri(Rays)

# ABCD = dro_dri.matrix.matrix

# # Extract ABCD matrices
# A = ABCD[:, :2, :2]
# B = ABCD[:, :2, 2:4]
# C = ABCD[:, 2:4, :2]
# D = ABCD[:, 2:4, 2:4]



In [None]:
focal_length = model[0].focal_length
expected_C = -1 / focal_length * jnp.eye(2)
expected_C = jnp.repeat(expected_C[None, :, :], C.shape[0], axis=0)

# Verify that the 2x2 matric C is equal to -1/f on each diagonal element
assert jnp.allclose(C, expected_C), "C does not match -1/f * I"
print("Test passed: C is equal to -1/f * I.")

NameError: name 'C' is not defined

In [8]:
# # Partial derivatives of ray output with respect to model parameters
batched_dro_dri = jax.jit(jax.vmap(lambda r: jax.jacobian(run_model_to_end, argnums=1)(create_ray(r), model)))
dro_dmodel = batched_dro_dri(rays_input)

print(dro_dmodel)


[Lens(z=Array([[-1.000000e+00, -1.000000e+00,  4.656613e-10,  0.000000e+00,
         0.000000e+00],
       [-1.000000e+00,  1.000000e+00,  4.656613e-10,  0.000000e+00,
         0.000000e+00],
       [ 1.000000e+00,  1.000000e+00,  0.000000e+00,  0.000000e+00,
         0.000000e+00],
       [-1.000000e+00,  1.000000e+00,  4.656613e-10,  0.000000e+00,
         0.000000e+00]], dtype=float32, weak_type=True), focal_length=Array([[-4.999999, -4.999999, -9.999999, -9.999999, -0.      ],
       [-4.999999,  4.999999, -9.999999,  9.999999,  0.      ],
       [ 4.999999,  4.999999,  9.999999,  9.999999,  0.      ],
       [-4.999999,  4.999999, -9.999999,  9.999999,  0.      ]],      dtype=float32, weak_type=True)), DoubleDeflector(z=Array([[-2.0000005e-01,  0.0000000e+00,  9.3132280e-11,  0.0000000e+00,
         0.0000000e+00],
       [-2.0000005e-01,  0.0000000e+00,  9.3132280e-11,  0.0000000e+00,
         0.0000000e+00],
       [-1.9999999e-01,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,