Skip to content

fix(jax): route concrete kwarg attrs to aux_data in pytree flatten#1221

Merged
Jammy2211 merged 1 commit intomainfrom
feature/pixelization-pytree-migration
Apr 17, 2026
Merged

fix(jax): route concrete kwarg attrs to aux_data in pytree flatten#1221
Jammy2211 merged 1 commit intomainfrom
feature/pixelization-pytree-migration

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

Fixes two bugs in autofit.jax._build_instance_pytree_funcs that surfaced when migrating jax_profiling/imaging/pixelization.py to pass a pytree ModelInstance into JIT/vmap directly (follow-up to PyAutoLabs/autolens_workspace_developer#10, the same work that produced PyAutoFit#1220 for MGE):

  1. Concrete non-Prior kwargs became JAX tracers. A Galaxy(pixelization=<concrete Pixelization>) attribute is stored directly on model.__dict__ but is not typed as Prior/AbstractPriorModel. The old classification path (direct_argument_names / direct_instance_tuples) never saw it, so it was routed into children instead of aux_data. Under jax.jit this turned the pixelization into a tracer, causing isinstance(x, aa.Pixelization) inside FitImaging to silently return False and fit.inversion to resolve to None.
  2. **kwargs classes shared a broken flatten. Galaxy.__init__(self, redshift, **kwargs) produces instances with different attribute sets per-model (lens: bulge/mass/shear; source: pixelization). JAX registers one flatten per class, so the first model's captured name list blew up with AttributeError: 'Galaxy' object has no attribute 'bulge' when flattening the source instance.

API Changes

None — internal changes only. enable_pytrees() and register_model(model) signatures and semantics are unchanged; these are pure correctness fixes to the internal flatten/unflatten closures registered with jax.tree_util.

Test Plan

  • New regression test test_register_model_keeps_kwarg_constants_static — exercises the exact Galaxy(pixelization=...) pattern: a concrete kwarg must stay identity-equal to the original object under JIT and pass isinstance.
  • Existing test_register_model_keeps_constants_static (Galaxy-redshift-style constant) still passes.
  • Full PyAutoFit suite green: pytest test_autofit/ -x → 1224 passed.
  • End-to-end smoke: jax_profiling/imaging/pixelization.py runs cleanly, both correctness assertions pass (step-by-step log_evidence matches reference within rtol=1e-4; vmap batch of 3 matches single-JIT result).
Full API Changes (for automation & release notes)

Removed

None.

Added

None.

Renamed

None.

Changed Signature

None.

Changed Behaviour

  • Internal only: autofit.jax._build_instance_pytree_funcs is now instance-driven. flatten iterates vars(instance) at call time against a shared _CLASS_FIELD_CLASSIFIERS[cls] dict populated across all register_model walks (rather than closing over a fixed name list captured from a single model at registration time). Attributes unknown to the classifier default to aux_data (safer than tracing an unknown object).

Migration

No migration required. Callers of autofit.jax.enable_pytrees / autofit.jax.register_model see only a behavioural fix — patterns that previously silently produced broken JAX traces now work correctly.

🤖 Generated with Claude Code

`_build_instance_pytree_funcs` previously classified attributes from a
single model at registration time and captured a fixed name list in the
flatten closure. Two problems:

1. `model.direct_argument_names` / `direct_instance_tuples` miss concrete
   non-Prior kwargs (e.g. `af.Model(Galaxy, pixelization=<Pixelization>)`),
   which are stored directly on `model.__dict__` but not typed as Prior.
   These slipped into `children` and became JAX tracers under JIT, so
   `isinstance(x, Pixelization)` inside `FitImaging` returned False and
   `fit.inversion` silently resolved to None.

2. Classes using `**kwargs` (like Galaxy) produce instances with different
   attribute sets per model (lens: bulge/mass/shear; source: pixelization)
   while sharing one `cls`. JAX registers flatten per-class, so the first
   model's captured names fail `getattr` on later instances.

Fix: maintain a shared `_CLASS_FIELD_CLASSIFIERS[cls]` populated across
all `register_model` walks. `flatten` now iterates `vars(instance)` at
call time and consults the classifier, defaulting unknown attrs to
constant (aux_data) — safer than tracing an unknown object.

Regression test `test_register_model_keeps_kwarg_constants_static`
exercises the Pixelization-style pattern: a concrete kwarg must remain
identity-equal to the original object under JIT and pass `isinstance`.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@Jammy2211 Jammy2211 added the pending-release Pending next release build label Apr 17, 2026
@Jammy2211
Copy link
Copy Markdown
Collaborator Author

@Jammy2211 Jammy2211 merged commit c917349 into main Apr 17, 2026
2 checks passed
@Jammy2211 Jammy2211 deleted the feature/pixelization-pytree-migration branch April 17, 2026 09:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pending-release Pending next release build

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant