Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ The `xp` parameter pattern controls the backend:

### JAX and the `jax.jit` boundary

Autoarray types (`Array2D`, `ArrayIrregular`, `VectorYX2DIrregular`, etc.) are **not registered as JAX pytrees**. They can be constructed inside a JIT trace, but **cannot be returned** as the output of a `jax.jit`-compiled function.
Two patterns coexist for crossing the JIT boundary:

Functions intended to be called directly inside `jax.jit` must guard autoarray wrapping with `if xp is np:`:
**Pattern 1: `if xp is np:` guard (raw `jax.Array` return).** Functions intended to be called directly inside `jax.jit` as the outermost op — where no wrapper is needed on the JAX path — guard their autoarray wrapping:

```python
def convergence_2d_via_hessian_from(self, grid, xp=np):
Expand All @@ -134,7 +134,9 @@ def convergence_2d_via_hessian_from(self, grid, xp=np):
return convergence # jax: raw jax.Array
```

Functions that are only called as intermediate steps (e.g. `deflections_yx_2d_from`) do not need this guard — they are consumed by downstream Python before the JIT boundary.
All `LensCalc` hessian-derived methods use this pattern. Intermediate helpers (e.g. `deflections_yx_2d_from`) don't need the guard — they're consumed by downstream Python before the JIT boundary.

**Pattern 2: pytree-registered wrapper return.** Functions that must return a real autoarray wrapper (or a structured object built from them) opt in to JAX pytree registration. `AbstractNDArray` auto-registers its subclass with `jax.tree_util` the first time an instance is built with `xp=jnp` (via `autoarray.abstract_ndarray._register_as_pytree`). Higher-level types (`FitImaging`, `Tracer`, `DatasetModel`) use `autoarray.abstract_ndarray.register_instance_pytree(cls, no_flatten=...)`, which flattens `__dict__` and carries `no_flatten` names through `aux_data` for per-analysis constants (dataset, settings, cosmology). `AnalysisImaging._register_fit_imaging_pytrees` wires these up when `use_jax=True`, so `jax.jit(analysis.fit_from)(instance)` returns a real `FitImaging` with `jax.Array` leaves.

### `LensCalc` (autogalaxy)

Expand Down
25 changes: 25 additions & 0 deletions autolens/imaging/model/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def fit_from(
The fit of the plane to the imaging dataset, which includes the log likelihood.
"""

if self._use_jax:
self._register_fit_imaging_pytrees()

tracer = self.tracer_via_instance_from(
instance=instance,
)
Expand All @@ -125,3 +128,25 @@ def fit_from(
xp=self._xp
)

@staticmethod
def _register_fit_imaging_pytrees() -> None:
"""Register every type reachable from a ``FitImaging`` return value
so ``jax.jit(fit_from)`` can flatten its output.

``dataset``, ``adapt_images`` and ``settings`` are constants per
analysis — ride as aux so JAX does not recurse into them. Everything
else (``tracer``, ``dataset_model`` and the autoarray wrappers they
carry) is dynamic per fit.
"""
from autoarray.abstract_ndarray import register_instance_pytree
from autoarray.dataset.dataset_model import DatasetModel
from autolens.lens.tracer import Tracer

register_instance_pytree(
FitImaging,
no_flatten=("dataset", "adapt_images", "settings"),
)
register_instance_pytree(DatasetModel)
# ``cosmology`` is a fixed physical constant per fit; ride as aux.
register_instance_pytree(Tracer, no_flatten=("cosmology",))

Loading