From 33ef84b1f052daa7adedbf0127268b9e2c7fd5f2 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Fri, 29 May 2026 08:29:47 +0100 Subject: [PATCH] fix(jax): structural defense against cached_property pytree/dict leaks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR #1300 fixed a specific leak where AbstractPriorModel.parameterization (a `@functools.cached_property` added in commit 4564ae9a1) leaked its cached string into every ModelInstance via Collection._instance_for_arguments. That broke 38 JAX jit(fit_from) calls and the autofit_workspace overview_1 smoke (clusters C1+C4). The minimal fix renamed the cache key to `_parameterization_cache` so the existing `_`-prefix filter at each `__dict__` iterator skipped it. The structural problem remained: every walker uses an opt-out filter (blacklist + underscore prefix), so any future cached_property declared on a model class silently reproduces the same class of bug. This PR closes the class: - New classmethod `AbstractModel._cached_property_names(cls)` delegates to `autoconf.tools.decorators.cached_property_names` (PyAutoConf #111), returning a frozenset of every functools.cached_property and autoconf CachedProperty descriptor name in the MRO. - Extend the filter at every `__dict__` iteration site to union the pre-existing exclusion with this frozenset: autofit/mapper/model.py (__getstate__, ModelInstance.dict) autofit/mapper/model_object.py (ModelObject._dict — feeds Collection.items) autofit/mapper/prior_model/abstract.py (AbstractModel.items) autofit/mapper/prior_model/collection.py (Collection._instance_for_arguments) autofit/mapper/prior_model/prior_model.py (Model._instance_for_arguments) Identifier-hash stability verified: the unique_identifier walker at `autofit/mapper/identifier.py` does NOT call any of these 6 sites — it walks `__dict__` independently with its own `_`-prefix filter. Three representative model shapes (simple Collection, nested Collection, Model with tuple arg) all produce byte-identical identifier hashes pre- and post-defense: simple: f7f19073a8fb19b3d11231fb6eef7e3b ✓ nested: 04e1328c84a1e4c3a81a9d3544dd19f5 ✓ with_tuple: 36084d2c3fec27e0b7aa504add0bd898 ✓ Tests: - `test_cached_property_names_classmethod_walks_mro`: confirms the classmethod surfaces MRO-declared descriptors and memoises per-class. - `test_cached_property_excluded_from_all_dict_walks`: ships a synthetic GuardedCollection with a cached_property returning a string; asserts the value never appears in instance.__dict__, instance.dict, model.items(), tree_flatten() leaves, __getstate__, or pickle round-trip. - 1415/1415 PyAutoFit tests pass (1413 prior + 2 new). Depends on: PyAutoLabs/PyAutoConf#111. Co-Authored-By: Claude Opus 4.7 (1M context) --- autofit/mapper/model.py | 23 +++++- autofit/mapper/model_object.py | 11 +++ autofit/mapper/prior_model/abstract.py | 9 +- autofit/mapper/prior_model/collection.py | 3 +- autofit/mapper/prior_model/prior_model.py | 2 + test_autofit/mapper/test_parameterization.py | 87 ++++++++++++++++++++ 6 files changed, 131 insertions(+), 4 deletions(-) diff --git a/autofit/mapper/model.py b/autofit/mapper/model.py index 010dcac4a..d24ef82bc 100644 --- a/autofit/mapper/model.py +++ b/autofit/mapper/model.py @@ -83,9 +83,28 @@ def __init__(self, label=None, id_=None): self._frozen_cache = dict() super().__init__(label=label, id_=id_) + @classmethod + def _cached_property_names(cls) -> frozenset: + """ + Return the names of every ``cached_property``-style descriptor + declared anywhere in ``cls``'s MRO. + + Used by the ``__dict__``-iteration sites in this module and in + ``autofit/mapper/prior_model/`` to exclude cached descriptor values + from instance construction, ``ModelInstance.dict``, pickling, and + downstream JAX pytree flattening. See PyAutoFit#1300 for the + diagnosed leak this defends against. + """ + from autoconf.tools.decorators import cached_property_names + + return cached_property_names(cls) + def __getstate__(self): + excluded = type(self)._cached_property_names() return { - key: value for key, value in self.__dict__.items() if key != "_frozen_cache" + key: value + for key, value in self.__dict__.items() + if key != "_frozen_cache" and key not in excluded } def __setstate__(self, state): @@ -446,11 +465,13 @@ def __hash__(self): @property def dict(self): + excluded = type(self)._cached_property_names() return { key: value for key, value in self.__dict__.items() if key not in ("id", "component_number", "item_number") and not (isinstance(key, str) and key.startswith("_")) + and key not in excluded } def tree_flatten(self) -> Tuple[List, Tuple]: diff --git a/autofit/mapper/model_object.py b/autofit/mapper/model_object.py index 641c04e77..3316c7324 100644 --- a/autofit/mapper/model_object.py +++ b/autofit/mapper/model_object.py @@ -330,9 +330,20 @@ def dict(self) -> dict: @property def _dict(self): + # Pick up any cached_property descriptors declared on the class so + # their cached values don't propagate via `Collection.items()` (which + # delegates here) or any other downstream consumer. The lookup is + # gated on hasattr because ModelObject is the base for the whole + # mapper module: a few non-AbstractModel descendants do not carry the + # ``_cached_property_names`` classmethod. + try: + excluded = type(self)._cached_property_names() + except AttributeError: + excluded = frozenset() return { key: value for key, value in self.__dict__.items() if key not in ("component_number", "item_number", "id", "cls", "label") and not key.startswith("_") + and key not in excluded } diff --git a/autofit/mapper/prior_model/abstract.py b/autofit/mapper/prior_model/abstract.py index ba9d0ea99..663d9fc69 100644 --- a/autofit/mapper/prior_model/abstract.py +++ b/autofit/mapper/prior_model/abstract.py @@ -1275,12 +1275,17 @@ def from_instance( def items(self): """Return (name, value) pairs for all public, non-internal attributes. - Excludes private attributes (prefixed with ``_``), ``cls``, and ``id``. + Excludes private attributes (prefixed with ``_``), ``cls``, ``id``, + and any ``cached_property``-style descriptors declared on the class + (see ``AbstractModel._cached_property_names``). """ + excluded = type(self)._cached_property_names() return [ (key, value) for key, value in self.__dict__.items() - if not key.startswith("_") and key not in ("cls", "id") + if not key.startswith("_") + and key not in ("cls", "id") + and key not in excluded ] @property diff --git a/autofit/mapper/prior_model/collection.py b/autofit/mapper/prior_model/collection.py index b06d8878b..5198dfe57 100644 --- a/autofit/mapper/prior_model/collection.py +++ b/autofit/mapper/prior_model/collection.py @@ -286,8 +286,9 @@ def _instance_for_arguments( A list of instances constructed from the list of prior models. """ result = ModelInstance() + excluded = type(self)._cached_property_names() for key, value in self.__dict__.items(): - if key.startswith("_"): + if key.startswith("_") or key in excluded: continue if isinstance(value, AbstractPriorModel): value = value.instance_for_arguments( diff --git a/autofit/mapper/prior_model/prior_model.py b/autofit/mapper/prior_model/prior_model.py index c0153e8ae..e758164e2 100644 --- a/autofit/mapper/prior_model/prior_model.py +++ b/autofit/mapper/prior_model/prior_model.py @@ -494,12 +494,14 @@ def _instance_for_arguments( else: result = self.cls(**constructor_arguments) + excluded = type(self)._cached_property_names() for key, value in self.__dict__.items(): if ( not hasattr(result, key) and not isinstance(value, Prior) and not key == "cls" and not key.startswith("_") + and key not in excluded ): if isinstance(value, Model): value = value.instance_for_arguments( diff --git a/test_autofit/mapper/test_parameterization.py b/test_autofit/mapper/test_parameterization.py index 416c07313..a8e845125 100644 --- a/test_autofit/mapper/test_parameterization.py +++ b/test_autofit/mapper/test_parameterization.py @@ -1,3 +1,4 @@ +import functools import itertools import pytest @@ -176,6 +177,92 @@ def test_parameterization_cache_does_not_leak_into_instance(): assert not isinstance(child, str) +def test_cached_property_names_classmethod_walks_mro(): + """The ``_cached_property_names`` classmethod on AbstractModel exposes the + autoconf ``cached_property_names`` MRO walker. It must pick up + descriptors declared on any ancestor and memoise the result on the class.""" + + import functools + + import autofit as af + + # Build a synthetic subclass with a cached_property to verify the walker + # finds it. We use af.Collection because both AbstractPriorModel and + # ModelInstance inherit from AbstractModel. + class SyntheticCollection(af.Collection): + @functools.cached_property + def synthetic_value(self): + return "a synthetic cached string" + + names = SyntheticCollection._cached_property_names() + assert "synthetic_value" in names + + # Result is memoised on the synthetic class. + assert "__cached_property_names_cache__" in SyntheticCollection.__dict__ + + # Plain af.Collection (no synthetic_value) has its own cache. + base_names = af.Collection._cached_property_names() + assert "synthetic_value" not in base_names + + +class _GuardedCollection(af.Collection): + """Module-level subclass used by + ``test_cached_property_excluded_from_all_dict_walks`` — must live at + module scope so ``pickle.dumps`` can locate the class on round-trip.""" + + @functools.cached_property + def derived(self): + return "leaky-string" + + +def test_cached_property_excluded_from_all_dict_walks(): + """Regression: a future ``@functools.cached_property`` declared anywhere + in the model class hierarchy must not surface through any of: + ``Collection._instance_for_arguments`` (via ``instance.__dict__``), + ``ModelInstance.dict``, ``ModelInstance.tree_flatten()``, + ``AbstractModel.items()``, ``ModelObject._dict``, or pickling via + ``__getstate__``. + + Covers the class of bug PyAutoFit#1300 fixed for ``parameterization``; + this test will fail if a maintainer reintroduces an un-prefixed + cached_property on the model hierarchy without the + ``_cached_property_names`` defense applied at every site.""" + + import pickle + + model = _GuardedCollection(gaussian=af.Model(af.ex.Gaussian)) + + # Trigger the cache. After this, model.__dict__["derived"] = "leaky-string". + _ = model.derived + assert model.__dict__.get("derived") == "leaky-string" + + instance = model.instance_from_prior_medians() + + # Site 1+4: Collection._instance_for_arguments + ModelInstance.dict + assert "derived" not in instance.__dict__ + assert "derived" not in instance.dict + + # Site 4 also feeds tree_flatten — no string leaves. + leaves = instance.dict.values() + for leaf in leaves: + assert not isinstance(leaf, str) + + # Site 3: AbstractModel.items() on the model itself. + assert all(key != "derived" for key, _ in model.items()) + + # Site 5: __getstate__ drops the cached value from pickles. + state = model.__getstate__() + assert "derived" not in state + + # Round-trip via pickle: the unpickled model re-computes the cached value, + # rather than carrying the pickled string on the wire. + blob = pickle.dumps(model) + revived = pickle.loads(blob) + assert "derived" not in revived.__dict__ + # Touching it recomputes. + assert revived.derived == "leaky-string" + + def test_integer_attributes(): model = af.Model(af.ex.Gaussian)