Skip to content

fix(jax): keep parameterization cache off ModelInstance + auto-register pytrees#1300

Merged
Jammy2211 merged 1 commit into
mainfrom
feature/parameterization-cache-fix
May 28, 2026
Merged

fix(jax): keep parameterization cache off ModelInstance + auto-register pytrees#1300
Jammy2211 merged 1 commit into
mainfrom
feature/parameterization-cache-fix

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

Two coupled fixes restoring the JAX jit(fit_from) path that broke when commit 4564ae9a1 made AbstractPriorModel.parameterization a functools.cached_property. Surfaced as 40 failing scripts across autogalaxy_workspace_test, autolens_workspace_test, and autofit_workspace in the 2026-05-28 release-prep triage.

  • Fix 1 — parameterization cache: store the cached string under the underscore-prefixed key _parameterization_cache so Collection._instance_for_arguments (collection.py:289) and ModelInstance.dict (model.py:451) — both of which iterate __dict__ and skip only _-prefixed keys — filter it out. Preserves the 2.7s → 0.05s perf win from 4564ae9a1.
  • Fix 2 — auto-register pytrees: call enable_pytrees() + register_model(self.model) from Fitness.__init__ whenever analysis._use_jax=True. Both helpers are idempotent, so workspaces that still call them explicitly keep working. New JAX-enabled workspaces don't need the boilerplate.

Root cause

cached_property writes to self.__dict__["parameterization"]. After any model.info access (every script does this), Collection._instance_for_arguments copies the string onto every ModelInstance. The string then either:

  • surfaces as a non-array JAX pytree leaf → TypeError: Error interpreting argument to <jit>... fit_from ... ModelInstance ... static_argnums in the 38 jax_likelihood_functions/* scripts (Cluster C1 of the triage), or
  • gets yielded by for x in instanceAttributeError: 'str' object has no attribute 'model_data_from' in overview/overview_1_the_basics.py (Cluster C4).

Test plan

  • pytest test_autofit — 1413/1413 pass + new test_parameterization_cache_does_not_leak_into_instance regression (numpy-only per feedback_no_jax_in_unit_tests).
  • python autofit_workspace/scripts/overview/overview_1_the_basics.py (C4 reproducer) — runs to completion.
  • python autolens_workspace_test/scripts/jax_likelihood_functions/imaging/rectangular.py (C1 reproducer) — prints PASS: jit(fit_from) round-trip matches NumPy scalar.. Notably worked without any explicit enable_pytrees() + register_model() in the script, confirming the Fitness.__init__ auto-call landed correctly.
  • Reviewer to sanity-check that the _parameterization_cache key is filtered by every __dict__-iterator in autofit/mapper/ (a follow-up issue will harden this with a _cached_property_names(cls) classmethod — see below).

Follow-up

A systemic audit during planning found the class of bug is structural — 4 __dict__ iterators in autofit/mapper/ plus autoarray/abstract_ndarray.py:108-119 all use opt-out filters that default to leak. The first future @cached_property added to AbstractPriorModel/Model/Collection/ModelInstance (or to a JAX-pytree-registered Fit class with a non-array return) reintroduces the bug. Plus AbstractModel.__getstate__ (model.py:86) currently pickles the cached parameterization string into the SQLAlchemy DB.

A separate feat(jax): structural defense against cached_property pytree leaks issue/PR will add a _cached_property_names(cls) classmethod applied as an extra exclusion at every leak site, so this entire class of bug is closed permanently in one shot.

🤖 Generated with Claude Code

…er pytrees

Two coupled fixes restoring the JAX `jit(fit_from)` path that broke
when commit 4564ae9 made `AbstractPriorModel.parameterization` a
`functools.cached_property`.

`cached_property` writes to `self.__dict__["parameterization"]`. After
any `model.info` access, `Collection._instance_for_arguments` (which
iterates `__dict__` and skips only underscore-prefixed keys) propagates
the cached string onto every `ModelInstance`. The string then surfaces
as a non-array JAX pytree leaf (autogalaxy_workspace_test +
autolens_workspace_test `jax_likelihood_functions/*` — 38 scripts) and
makes `for x in instance:` yield strings instead of profiles
(autofit_workspace `overview/overview_1_the_basics.py`).

Fix 1: store the cache under the underscore-prefixed key
`_parameterization_cache` so both `Collection._instance_for_arguments`
and `ModelInstance.dict` filter it out. Preserves the 2.7s → 0.05s
perf win from 4564ae9.

Fix 2: auto-call `enable_pytrees() + register_model(self.model)` from
`Fitness.__init__` whenever `analysis._use_jax=True`. Both helpers are
idempotent, so workspaces that still call them explicitly keep
working. New JAX-enabled workspaces don't need the boilerplate.

Verified locally:
- 1413/1413 PyAutoFit unit tests pass + new
  `test_parameterization_cache_does_not_leak_into_instance` regression
- `autofit_workspace/scripts/overview/overview_1_the_basics.py` runs
  to completion (cluster C4 reproducer)
- `autolens_workspace_test/scripts/jax_likelihood_functions/imaging/rectangular.py`
  prints "PASS: jit(fit_from) round-trip matches NumPy scalar"
  (cluster C1 reproducer)

Follow-up: a structural defense across the four `__dict__`-iterators
in `autofit/mapper/` plus `autoarray/abstract_ndarray.py` will ship as
a separate PR — a `_cached_property_names(cls)` classmethod applied
as an extra filter at every leak site so the next future
`@cached_property` on a model class cannot reintroduce this bug.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Jammy2211 Jammy2211 merged commit 2a8f5c7 into main May 28, 2026
7 checks passed
@Jammy2211 Jammy2211 deleted the feature/parameterization-cache-fix branch May 28, 2026 18:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant