In [1]:
import jax
import jax_dataclasses as jdc
import os
import jax.numpy as jnp

jax.config.update("jax_platform_name", "cpu")
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_MEM_LIMIT_MB"] = "400"

Option 1:

Ray dataclass with all variables explicitly named.

Advantages: 

1. Obtain an easy to use partial derivative structure - i.e ray_jac.x.dx = $\frac{\partial x}{\partial dx}$

2. Can pass a complete ray object into the step function of each component.

Disadvantages: 

1. Jacobian call gives a somewhat verbose looking stucture, not a clean matrix - as a result need a custom jacobian matrix function to form the jacobian matrix. 

2. To only compute specific gradients with respect to ray parameters, need to write a wrapper function outisde the model. 



In [2]:
def custom_jacobian_matrix(ray_jac):
    return jnp.array(
        [
            [ray_jac.x.x, ray_jac.x.y, ray_jac.x.dx, ray_jac.x.dy, ray_jac.x._one],
            [ray_jac.y.x, ray_jac.y.y, ray_jac.y.dx, ray_jac.y.dy, ray_jac.y._one],
            [ray_jac.dx.x, ray_jac.dx.y, ray_jac.dx.dx, ray_jac.dx.dy, ray_jac.dx._one],
            [ray_jac.dy.x, ray_jac.dy.y, ray_jac.dy.dx, ray_jac.dy.dy, ray_jac.dy._one],
            [
                ray_jac._one.x,
                ray_jac._one.y,
                ray_jac._one.dx,
                ray_jac._one.dy,
                ray_jac._one._one,
            ],
        ]
    )


@jdc.pytree_dataclass
class Ray:
    x: float
    y: float
    dx: float
    dy: float
    z: float
    pathlength: float
    _one: float = 1.0


@jdc.pytree_dataclass
class Lens:
    f: float
    z: float

    def __call__(self, ray: Ray) -> Ray:
        z = self.z
        f = self.f
        dx = ray.dx - ray.x / f
        dy = ray.dy - ray.y / f
        one = ray._one * 1.0

        return Ray(
            x=ray.x, y=ray.y, dx=dx, dy=dy, _one=one, pathlength=ray.pathlength, z=z
        )


lens = Lens(f=0.25, z=0.0)
ray = Ray(x=0.1, y=0.2, dx=0.3, dy=0.4, z=0.0, pathlength=0.6, _one=1.0)
jac = jax.jacobian(lens)(ray)
jacobian_matrix = custom_jacobian_matrix(jac)
print(jacobian_matrix)

# vmapping over a ray works too:
rays = Ray(
    x=jnp.array([0.1, 0.2]),
    y=jnp.array([0.2, 0.3]),
    dx=jnp.array([0.3, 0.4]),
    dy=jnp.array([0.4, 0.5]),
    z=jnp.array([0.0, 0.0]),
    pathlength=jnp.array([0.6, 0.7]),
    _one=jnp.array([1.0, 1.0]),
)
jacs = jax.vmap(jax.jacobian(lens))(rays)

# Returns a simple vector of jacobian values which is convenient.
print(jacs.x.x)

[[ 1.  0.  0.  0.  0.]
 [ 0.  1.  0.  0.  0.]
 [-4.  0.  1.  0.  0.]
 [ 0. -4.  0.  1.  0.]
 [ 0.  0.  0.  0.  1.]]
[1. 1.]


Option 2: 

Ray tuple, with explicit unravelling of ray variables into each step function:

Advantages:

1. Can easily choose what what parameters to calculate gradients with respect to (argnums parameter of jacobian) 

Disadvantages:

1. Loses the explicit naming structure of the dataclass - no more ray.x.dx - only ray.x[1] for the variable index. 

2. Very rigid input of an unravelled ray into each step function of component.

In [23]:
from dataclasses import astuple


def custom_jacobian_matrix(ray_jac):
    return jnp.array(
        [
            [ray_jac.x[X], ray_jac.x[Y], ray_jac.x[DX], ray_jac.x[DY], ray_jac.x[-1]],
            [ray_jac.y[X], ray_jac.y[Y], ray_jac.y[DX], ray_jac.y[DY], ray_jac.y[-1]],
            [
                ray_jac.dx[X],
                ray_jac.dx[Y],
                ray_jac.dx[DX],
                ray_jac.dx[DY],
                ray_jac.dx[-1],
            ],
            [
                ray_jac.dy[X],
                ray_jac.dy[Y],
                ray_jac.dy[DX],
                ray_jac.dy[DY],
                ray_jac.dy[-1],
            ],
            [
                ray_jac._one[X],
                ray_jac._one[Y],
                ray_jac._one[DX],
                ray_jac._one[DY],
                ray_jac._one[-1],
            ],
        ]
    )


# Enums:
X, Y, DX, DY, Z, PATHLENGTH, _ONE = 0, 1, 2, 3, 4, 5, 6


@jdc.pytree_dataclass
class Ray:
    x: float
    y: float
    dx: float
    dy: float
    z: float
    pathlength: float
    _one: float = 1.0


@jdc.pytree_dataclass
class Lens:
    f: float
    z: float

    def __call__(self, x, y, dx, dy, z, pathlength, _one) -> Ray:
        z = self.z
        f = self.f
        dx = dx - x / f
        dy = dy - y / f
        one = _one * 1.0

        return Ray(x=x, y=y, dx=dx, dy=dy, _one=one, pathlength=pathlength, z=z)


lens = Lens(f=0.25, z=0.0)
ray = Ray(x=0.1, y=0.2, dx=0.3, dy=0.4, z=0.0, pathlength=0.6, _one=1.0)

# Now we can easily and explicitly choose gradients, but still need to form manually the final structure of the ray.
ray = astuple(ray)
jac = jax.jacobian(lens, argnums=[X, Y, DX, DY, _ONE])(
    *ray
)  # NOTE THE STAR TO UNRWRAP THE RAY.


jacobian_matrix = custom_jacobian_matrix(jac)
print(jacobian_matrix)

# vmapping over a ray works too:
rays = Ray(
    x=jnp.array([0.1, 0.2]),
    y=jnp.array([0.2, 0.3]),
    dx=jnp.array([0.3, 0.4]),
    dy=jnp.array([0.4, 0.5]),
    z=jnp.array([0.0, 0.0]),
    pathlength=jnp.array([0.6, 0.7]),
    _one=jnp.array([1.0, 1.0]),
)

rays = astuple(rays)
jacs = jax.vmap(jax.jacobian(lens, argnums=[X, Y, DX, DY, _ONE]))(*rays)

# The one issue is that jacobian array shape row indices are not the same as the Original ray ENUMS
print(jacs.x[X])

[[ 1.  0.  0.  0.  0.]
 [ 0.  1.  0.  0.  0.]
 [-4.  0.  1.  0.  0.]
 [ 0. -4.  0.  1.  0.]
 [ 0.  0.  0.  0.  1.]]
[1. 1.]


Option 3:

David' orignal solution: Ray dataclass with explicit matrix structure, and auxillary parameters z and optical path length.

Advantages:

1. Matrix variable of [x, y, dx, dy, _one] gives natural vector for calculating jacobians of optical components - ray.matrix.matrix gives 5x5 jacobian of optical component. 


Disadvantages: 

1. Again not easy to take gradients with repect to specific parameters of the ray - easiest way is to write a wrapper. 

2. Mixture of matrix and floats on dataclass perhaps obtuse. 

In [20]:
@jdc.pytree_dataclass
class Ray:
    matrix: jnp.ndarray
    z: float
    pathlength: float

    @property
    def x(self):
        return self.matrix[..., 0]

    @property
    def y(self):
        return self.matrix[..., 1]

    @property
    def dx(self):
        return self.matrix[..., 2]

    @property
    def dy(self):
        return self.matrix[..., 3]

    @property
    def _one(self):
        return self.matrix[..., 4]


@jdc.pytree_dataclass
class Lens:
    f: float
    z: float

    def __call__(self, ray: Ray) -> Ray:
        z = self.z
        f = self.f
        dx = ray.dx - ray.x / f
        dy = ray.dy - ray.y / f
        one = ray._one * 1.0

        matrix = jnp.array([ray.x, ray.y, dx, dy, one])
        return Ray(matrix=matrix, pathlength=ray.pathlength, z=z)


lens = Lens(f=0.25, z=0.0)
matrix = jnp.array([0.1, 0.2, 0.3, 0.4, 1.0])
ray = Ray(matrix=matrix, pathlength=0.6, z=0.0)
jac = jax.jacobian(lens)(ray)

# Returns a jacobian matrix of the ray:
print(jac.matrix.matrix)

[[ 1.  0.  0.  0.  0.]
 [ 0.  1.  0.  0.  0.]
 [-4.  0.  1.  0.  0.]
 [ 0. -4.  0.  1.  0.]
 [ 0.  0.  0.  0.  1.]]
