diff --git a/autoconf/jax_wrapper.py b/autoconf/jax_wrapper.py index 15b2214..2fc3b22 100644 --- a/autoconf/jax_wrapper.py +++ b/autoconf/jax_wrapper.py @@ -47,8 +47,40 @@ """" JAX 64-bit precision has been automatically enabled for you (JAX_ENABLE_X64=True), as double precision is required for most scientific computing applications. - - To enable 64 precision as default in JAX, set the environment variable + + To enable 64 precision as default in JAX, set the environment variable JAX_ENABLE_X64=true before running your script. """ ) + + +def register_pytree_node_class(cls): + """Opt-in JAX pytree class registration that defers the JAX import. + + The previous eager registration in ``autofit.mapper.prior_model.prior_model`` + forced ``jax.tree_util`` to load whenever ``import autofit`` ran. To keep + JAX an optional dependency, library code now exposes ``tree_flatten`` / + ``tree_unflatten`` methods but does NOT register the class itself; callers + that want JAX integration call this helper explicitly (typically via + ``autofit.jax.enable_pytrees()``). + + No-ops if JAX is not installed. + """ + try: + from jax.tree_util import register_pytree_node_class as _r + except ImportError: + return cls + return _r(cls) + + +def register_pytree_node(nodetype, flatten_func, unflatten_func): + """Opt-in JAX pytree registration for an externally-defined class. + + Lazy counterpart to :func:`register_pytree_node_class` for the case where + the class cannot be decorated directly. No-ops if JAX is not installed. + """ + try: + from jax.tree_util import register_pytree_node as _r + except ImportError: + return None + return _r(nodetype, flatten_func, unflatten_func)