Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions autoconf/jax_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,40 @@
""""
JAX 64-bit precision has been automatically enabled for you (JAX_ENABLE_X64=True),
as double precision is required for most scientific computing applications.
To enable 64 precision as default in JAX, set the environment variable

To enable 64 precision as default in JAX, set the environment variable
JAX_ENABLE_X64=true before running your script.
"""
)


def register_pytree_node_class(cls):
"""Opt-in JAX pytree class registration that defers the JAX import.

The previous eager registration in ``autofit.mapper.prior_model.prior_model``
forced ``jax.tree_util`` to load whenever ``import autofit`` ran. To keep
JAX an optional dependency, library code now exposes ``tree_flatten`` /
``tree_unflatten`` methods but does NOT register the class itself; callers
that want JAX integration call this helper explicitly (typically via
``autofit.jax.enable_pytrees()``).

No-ops if JAX is not installed.
"""
try:
from jax.tree_util import register_pytree_node_class as _r
except ImportError:
return cls
return _r(cls)


def register_pytree_node(nodetype, flatten_func, unflatten_func):
"""Opt-in JAX pytree registration for an externally-defined class.

Lazy counterpart to :func:`register_pytree_node_class` for the case where
the class cannot be decorated directly. No-ops if JAX is not installed.
"""
try:
from jax.tree_util import register_pytree_node as _r
except ImportError:
return None
return _r(nodetype, flatten_func, unflatten_func)
Loading