fix(jax): keep parameterization cache off ModelInstance + auto-register pytrees#1300
Merged
Merged
Conversation
…er pytrees Two coupled fixes restoring the JAX `jit(fit_from)` path that broke when commit 4564ae9 made `AbstractPriorModel.parameterization` a `functools.cached_property`. `cached_property` writes to `self.__dict__["parameterization"]`. After any `model.info` access, `Collection._instance_for_arguments` (which iterates `__dict__` and skips only underscore-prefixed keys) propagates the cached string onto every `ModelInstance`. The string then surfaces as a non-array JAX pytree leaf (autogalaxy_workspace_test + autolens_workspace_test `jax_likelihood_functions/*` — 38 scripts) and makes `for x in instance:` yield strings instead of profiles (autofit_workspace `overview/overview_1_the_basics.py`). Fix 1: store the cache under the underscore-prefixed key `_parameterization_cache` so both `Collection._instance_for_arguments` and `ModelInstance.dict` filter it out. Preserves the 2.7s → 0.05s perf win from 4564ae9. Fix 2: auto-call `enable_pytrees() + register_model(self.model)` from `Fitness.__init__` whenever `analysis._use_jax=True`. Both helpers are idempotent, so workspaces that still call them explicitly keep working. New JAX-enabled workspaces don't need the boilerplate. Verified locally: - 1413/1413 PyAutoFit unit tests pass + new `test_parameterization_cache_does_not_leak_into_instance` regression - `autofit_workspace/scripts/overview/overview_1_the_basics.py` runs to completion (cluster C4 reproducer) - `autolens_workspace_test/scripts/jax_likelihood_functions/imaging/rectangular.py` prints "PASS: jit(fit_from) round-trip matches NumPy scalar" (cluster C1 reproducer) Follow-up: a structural defense across the four `__dict__`-iterators in `autofit/mapper/` plus `autoarray/abstract_ndarray.py` will ship as a separate PR — a `_cached_property_names(cls)` classmethod applied as an extra filter at every leak site so the next future `@cached_property` on a model class cannot reintroduce this bug. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Two coupled fixes restoring the JAX
jit(fit_from)path that broke when commit4564ae9a1madeAbstractPriorModel.parameterizationafunctools.cached_property. Surfaced as 40 failing scripts acrossautogalaxy_workspace_test,autolens_workspace_test, andautofit_workspacein the 2026-05-28 release-prep triage._parameterization_cachesoCollection._instance_for_arguments(collection.py:289) andModelInstance.dict(model.py:451) — both of which iterate__dict__and skip only_-prefixed keys — filter it out. Preserves the 2.7s → 0.05s perf win from4564ae9a1.enable_pytrees() + register_model(self.model)fromFitness.__init__wheneveranalysis._use_jax=True. Both helpers are idempotent, so workspaces that still call them explicitly keep working. New JAX-enabled workspaces don't need the boilerplate.Root cause
cached_propertywrites toself.__dict__["parameterization"]. After anymodel.infoaccess (every script does this),Collection._instance_for_argumentscopies the string onto everyModelInstance. The string then either:TypeError: Error interpreting argument to <jit>... fit_from ... ModelInstance ... static_argnumsin the 38jax_likelihood_functions/*scripts (Cluster C1 of the triage), orfor x in instance→AttributeError: 'str' object has no attribute 'model_data_from'inoverview/overview_1_the_basics.py(Cluster C4).Test plan
pytest test_autofit— 1413/1413 pass + newtest_parameterization_cache_does_not_leak_into_instanceregression (numpy-only perfeedback_no_jax_in_unit_tests).python autofit_workspace/scripts/overview/overview_1_the_basics.py(C4 reproducer) — runs to completion.python autolens_workspace_test/scripts/jax_likelihood_functions/imaging/rectangular.py(C1 reproducer) — printsPASS: jit(fit_from) round-trip matches NumPy scalar.. Notably worked without any explicitenable_pytrees() + register_model()in the script, confirming theFitness.__init__auto-call landed correctly._parameterization_cachekey is filtered by every__dict__-iterator inautofit/mapper/(a follow-up issue will harden this with a_cached_property_names(cls)classmethod — see below).Follow-up
A systemic audit during planning found the class of bug is structural — 4
__dict__iterators inautofit/mapper/plusautoarray/abstract_ndarray.py:108-119all use opt-out filters that default to leak. The first future@cached_propertyadded toAbstractPriorModel/Model/Collection/ModelInstance(or to a JAX-pytree-registered Fit class with a non-array return) reintroduces the bug. PlusAbstractModel.__getstate__(model.py:86) currently pickles the cachedparameterizationstring into the SQLAlchemy DB.A separate
feat(jax): structural defense against cached_property pytree leaksissue/PR will add a_cached_property_names(cls)classmethod applied as an extra exclusion at every leak site, so this entire class of bug is closed permanently in one shot.🤖 Generated with Claude Code