Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 89 additions & 38 deletions autofit/jax/pytrees.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

_ENABLED = False
_REGISTERED_INSTANCE_CLASSES: set = set()
_CLASS_FIELD_CLASSIFIERS: dict = {}
_CLASS_CONSTRUCTOR_ARGS: dict = {}


def enable_pytrees() -> bool:
Expand Down Expand Up @@ -71,14 +73,26 @@ def register_model(model) -> bool:
if not enable_pytrees():
return False

from autofit.mapper.prior.abstract import Prior
from autofit.mapper.prior_model.prior_model import Model
from autofit.mapper.prior_model.collection import Collection
from autofit.mapper.prior_model.abstract import AbstractPriorModel

def _walk(node):
if isinstance(node, Model):
cls = node.cls
classifier = _CLASS_FIELD_CLASSIFIERS.setdefault(cls, {})
for name, value in node.items():
is_dynamic = isinstance(value, (Prior, AbstractPriorModel))
# setdefault: earliest classification wins. Different models
# sharing the same cls (e.g. lens vs source Galaxy) may
# declare different attribute sets; we accumulate them all.
classifier.setdefault(name, is_dynamic)
_CLASS_CONSTRUCTOR_ARGS.setdefault(
cls, tuple(node.constructor_argument_names)
)
if cls not in _REGISTERED_INSTANCE_CLASSES:
flatten, unflatten = _build_instance_pytree_funcs(node)
flatten, unflatten = _build_instance_pytree_funcs(cls)
try:
register_pytree_node(cls, flatten, unflatten)
except ValueError:
Expand All @@ -95,55 +109,92 @@ def _walk(node):
return True


def _build_instance_pytree_funcs(model):
"""Build flatten/unflatten functions for instances of ``model.cls``.

Constants from the original model definition (e.g. ``Galaxy(redshift=0.5)``)
are placed in the JAX ``aux_data`` so they remain concrete Python values
inside a ``jax.jit`` trace. Only prior-derived constructor arguments are
placed in ``children`` (and therefore become JAX tracers).

This is critical for code that uses constants for control flow — e.g.
``sorted(galaxies, key=lambda g: g.redshift)`` in ``Tracer`` — which would
otherwise raise ``TracerBoolConversionError`` under JIT.
def _build_instance_pytree_funcs(cls):
"""Build flatten/unflatten functions for any instance of ``cls``.

At flatten time we iterate the instance's own public attributes rather
than a list captured from a single model at registration time. This is
required because classes like ``Galaxy`` use ``**kwargs`` in ``__init__``,
so different ``af.Model(Galaxy, ...)`` instances (e.g. a lens galaxy vs a
source galaxy) produce instances with completely different attribute sets
that nonetheless share the same ``cls`` — and JAX allows only one
registration per class.

Each attribute is classified as:

* **Dynamic** (prior-derived): the corresponding ``Model`` attribute was
a ``Prior`` or ``AbstractPriorModel``. Resolved to concrete numbers
(or nested instances) per sampled point, so it becomes a JAX child leaf
and gets traced under ``jax.jit``.
* **Constant**: everything else — a fixed ``redshift=0.5``, or a concrete
non-prior kwarg like ``Galaxy(pixelization=<Pixelization>)``. Goes into
``aux_data`` so it stays as the original Python object inside a trace.
Required for control flow that reads constants
(``sorted(..., key=lambda g: g.redshift)``) and for ``isinstance``
dispatch on concrete kwargs (``isinstance(obj, Pixelization)``).

Classification is read from the shared ``_CLASS_FIELD_CLASSIFIERS`` dict,
which is updated by every ``register_model`` call. Attributes unknown to
the classifier (never declared on any walked model) default to constant —
safer than tracing an unknown object.
"""
constructor_args = list(model.constructor_argument_names)
constant_arg_names = [
name for name in constructor_args if name in dict(model.direct_instance_tuples)
]
constant_values = {
name: dict(model.direct_instance_tuples)[name] for name in constant_arg_names
}
dynamic_arg_names = [
name for name in constructor_args if name not in constant_arg_names
]
constructor_args = _CLASS_CONSTRUCTOR_ARGS.get(cls, ())
constructor_arg_set = set(constructor_args)

def _partition(instance):
classifier = _CLASS_FIELD_CLASSIFIERS.get(cls, {})
ctor_dyn: list = []
ctor_const: list = []
attr_dyn: list = []
attr_const: list = []
for name, value in vars(instance).items():
if name.startswith("_") or name in ("cls", "id"):
continue
is_dynamic = classifier.get(name, False)
in_ctor = name in constructor_arg_set
if in_ctor and is_dynamic:
ctor_dyn.append((name, value))
elif in_ctor:
ctor_const.append((name, value))
elif is_dynamic:
attr_dyn.append((name, value))
else:
attr_const.append((name, value))
return ctor_dyn, ctor_const, attr_dyn, attr_const

def flatten(instance):
attribute_names = [
name
for name in model.direct_argument_names
if hasattr(instance, name) and name not in constructor_args
]
ctor_dyn, ctor_const, attr_dyn, attr_const = _partition(instance)
children = (
[getattr(instance, name) for name in dynamic_arg_names],
[getattr(instance, name) for name in attribute_names],
[v for _, v in ctor_dyn],
[v for _, v in attr_dyn],
)
aux = (
tuple(dynamic_arg_names),
tuple(constant_arg_names),
tuple(constant_values[n] for n in constant_arg_names),
tuple(attribute_names),
tuple(n for n, _ in ctor_dyn),
tuple(n for n, _ in ctor_const),
tuple(v for _, v in ctor_const),
tuple(n for n, _ in attr_dyn),
tuple(n for n, _ in attr_const),
tuple(v for _, v in attr_const),
)
return children, aux

def unflatten(aux, children):
dyn_names, const_names, const_vals, attr_names = aux
dyn_vals, attr_vals = children
(
dyn_names,
const_names,
const_vals,
attr_dyn_names,
attr_const_names,
attr_const_vals,
) = aux
dyn_vals, attr_dyn_vals = children
kwargs = dict(zip(dyn_names, dyn_vals))
kwargs.update(zip(const_names, const_vals))
ordered = [kwargs[name] for name in constructor_args]
instance = model.cls(*ordered)
for name, value in zip(attr_names, attr_vals):
ordered = [kwargs[name] for name in constructor_args if name in kwargs]
instance = cls(*ordered)
for name, value in zip(attr_dyn_names, attr_dyn_vals):
setattr(instance, name, value)
for name, value in zip(attr_const_names, attr_const_vals):
setattr(instance, name, value)
return instance

Expand Down
42 changes: 42 additions & 0 deletions test_autofit/jax/test_enable_pytrees.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,48 @@ def use_redshift_for_control_flow(inst):
assert float(result) == pytest.approx(2.0 * instance.scale)


def test_register_model_keeps_kwarg_constants_static():
"""Constant ``**kwargs`` attributes must stay in aux_data, not children.

``Galaxy.__init__(self, redshift, **kwargs)`` stores every kwarg via
``setattr``. A concrete object passed as a kwarg (e.g. a ``Pixelization``)
is an instance attribute but NOT a constructor argument, so the old
flatten logic routed it to ``children`` and it became a JAX tracer.
Downstream ``isinstance(x, Pixelization)`` checks then returned False.
This test exercises the exact pattern.
"""
class Marker:
pass

class KwargHolder:
def __init__(self, redshift, **kwargs):
self.redshift = redshift
for k, v in kwargs.items():
setattr(self, k, v)

marker = Marker()
model = af.Model(
KwargHolder,
redshift=0.5,
marker=marker,
scale=af.GaussianPrior(mean=1.0, sigma=1.0),
)
register_model(model)
instance = model.instance_from_prior_medians()
assert instance.marker is marker

@jax.jit
def use_marker_isinstance(inst):
# isinstance on a tracer would return False; this only works if
# `marker` is kept concrete via aux_data.
if isinstance(inst.marker, Marker):
return inst.scale * 2
return inst.scale

result = use_marker_isinstance(instance)
assert float(result) == pytest.approx(2.0 * instance.scale)


def test_enable_pytrees_idempotent():
assert enable_pytrees() is True
assert enable_pytrees() is True
Expand Down
Loading