In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".1"

from jax import numpy as jnp
import jax
import temgymbasic.jax_components as comp
from temgymbasic.jax_run import run_to_end_model_wrapper
from temgymbasic.jax_ray import Ray, propagate
from temgymbasic.jax_utils import get_pytree_idx_from_model
%matplotlib widget

In [5]:
descan_error = [0.0, 0.0, 0.0, 0.0]

key = jax.random.PRNGKey(0)
complex_image = jax.random.normal(key, (1, 1)) + 1j * jax.random.normal(key, (1, 1))

defocus = 0.01

model = [comp.Sample(z=defocus, complex_image=complex_image, pixel_size=0.001),
         comp.Descanner(z=defocus, scan_position=(0.0, 0.0), descan_error=descan_error),
         comp.Detector(z=1.0, pixel_size=0.01, shape=(256, 256))]

n_rays = 1
n_rays_dim = int(jnp.sqrt(n_rays))

input_amplitude = 1.0
input_pathlength = 0.0
input_wavelength = 1.0

# Prepare multiple input rays (each row is a separate ray)
r1mx = jnp.linspace(0.0, 0.0, n_rays_dim) #x off set
r1my = jnp.linspace(0.0, 0.0, n_rays_dim) #y off set
r1m = jnp.stack(jnp.meshgrid(r1mx, r1my), axis=-1).reshape(2, -1)

theta1mx = jnp.linspace(0.0, 0.0, n_rays_dim) #dx off set
theta1my = jnp.linspace(0.0, 0.0, n_rays_dim) #dy off set
theta1m = jnp.stack(jnp.meshgrid(theta1mx, theta1my), axis=-1).reshape(2, -1)

#Create ray input
crossover_z = jnp.zeros((n_rays_dim))
sample_z = jnp.ones((n_rays_dim)) * defocus
# input_matrix = jnp.array([0., 0., 0., 0., 1.])
input_matrix = jnp.vstack([
    r1m[0:1, :],
    r1m[1:2, :],
    theta1m[0:1, :],
    theta1m[1:2, :],
    jnp.ones((n_rays)).reshape(n_rays)
]).T

input_amplitude = jnp.ones(n_rays_dim)
input_pathlength = jnp.zeros(n_rays_dim)
input_wavelength = jnp.ones(n_rays_dim)
input_blocked = jnp.zeros(n_rays_dim, dtype=float)

Rays = Ray(z=crossover_z, 
           matrix=input_matrix, 
           amplitude=input_amplitude, 
           pathlength=input_pathlength, 
           wavelength=input_wavelength,
           blocked=input_blocked)

# batched_run_model = jax.vmap(run_to_end, in_axes=(0, None))
# jac_batched_run_model = jax.vmap(jax.jacobian(run_to_end), in_axes=(0, None))

# batched_run_model(Rays, model)
# print(jac_batched_run_model(Rays, model))

# Usage:
# Suppose you want to differentiate with respect to the 'scan_position' of model component 1
target = {1: 'descan_error'}
indices, model_flat, unravel_fn = get_pytree_idx_from_model(model, target)

# Here, indices['scan_position'] is the slice you want.
# Now you can compute the Jacobian with respect to that slice:
jac = jax.jacobian(run_to_end_model_wrapper, argnums=indices['descan_error'])(model_flat, unravel_fn, Rays)



TypeError: 'jaxlib.xla_extension.pytree.PyTreeDef' object is not iterable

In [4]:
Ray.__dataclass_fields__


{'matrix': Field(name='matrix',type=<class 'jax.Array'>,default=<dataclasses._MISSING_TYPE object at 0x725841661490>,default_factory=<dataclasses._MISSING_TYPE object at 0x725841661490>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD),
 'z': Field(name='z',type=<class 'float'>,default=<dataclasses._MISSING_TYPE object at 0x725841661490>,default_factory=<dataclasses._MISSING_TYPE object at 0x725841661490>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD),
 'amplitude': Field(name='amplitude',type=<class 'float'>,default=<dataclasses._MISSING_TYPE object at 0x725841661490>,default_factory=<dataclasses._MISSING_TYPE object at 0x725841661490>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD),
 'pathlength': Field(name='pathlength',type=<class 'float'>,default=<dataclasses._MISSING_TYPE object at 0x725841661490>,default_factory=<data

In [3]:
print(jac.shape)

(1, 5, 65550)


In [4]:
from jax.tree_util import tree_structure
from jax.tree_util import tree_flatten, tree_unflatten
print(tree_structure(Rays))

tree_flatten(Rays)

PyTreeDef(CustomNode(Ray[()], [*, *, *, *, *, *]))


([Array([[0., 0., 0., 0., 1.]], dtype=float32),
  Array([0.], dtype=float32),
  Array([1.], dtype=float32),
  Array([0.], dtype=float32),
  Array([1.], dtype=float32),
  Array([0.], dtype=float32)],
 PyTreeDef(CustomNode(Ray[()], [*, *, *, *, *, *])))

In [5]:
for coordinate_transform in coordinate_transforms:
    print(coordinate_transform)


NameError: name 'coordinate_transforms' is not defined