diff --git a/autofit/mapper/model.py b/autofit/mapper/model.py index 010dcac4a..d24ef82bc 100644 --- a/autofit/mapper/model.py +++ b/autofit/mapper/model.py @@ -83,9 +83,28 @@ def __init__(self, label=None, id_=None): self._frozen_cache = dict() super().__init__(label=label, id_=id_) + @classmethod + def _cached_property_names(cls) -> frozenset: + """ + Return the names of every ``cached_property``-style descriptor + declared anywhere in ``cls``'s MRO. + + Used by the ``__dict__``-iteration sites in this module and in + ``autofit/mapper/prior_model/`` to exclude cached descriptor values + from instance construction, ``ModelInstance.dict``, pickling, and + downstream JAX pytree flattening. See PyAutoFit#1300 for the + diagnosed leak this defends against. + """ + from autoconf.tools.decorators import cached_property_names + + return cached_property_names(cls) + def __getstate__(self): + excluded = type(self)._cached_property_names() return { - key: value for key, value in self.__dict__.items() if key != "_frozen_cache" + key: value + for key, value in self.__dict__.items() + if key != "_frozen_cache" and key not in excluded } def __setstate__(self, state): @@ -446,11 +465,13 @@ def __hash__(self): @property def dict(self): + excluded = type(self)._cached_property_names() return { key: value for key, value in self.__dict__.items() if key not in ("id", "component_number", "item_number") and not (isinstance(key, str) and key.startswith("_")) + and key not in excluded } def tree_flatten(self) -> Tuple[List, Tuple]: diff --git a/autofit/mapper/model_object.py b/autofit/mapper/model_object.py index 641c04e77..3316c7324 100644 --- a/autofit/mapper/model_object.py +++ b/autofit/mapper/model_object.py @@ -330,9 +330,20 @@ def dict(self) -> dict: @property def _dict(self): + # Pick up any cached_property descriptors declared on the class so + # their cached values don't propagate via `Collection.items()` (which + # delegates here) or any other downstream consumer. The lookup is + # gated on hasattr because ModelObject is the base for the whole + # mapper module: a few non-AbstractModel descendants do not carry the + # ``_cached_property_names`` classmethod. + try: + excluded = type(self)._cached_property_names() + except AttributeError: + excluded = frozenset() return { key: value for key, value in self.__dict__.items() if key not in ("component_number", "item_number", "id", "cls", "label") and not key.startswith("_") + and key not in excluded } diff --git a/autofit/mapper/prior_model/abstract.py b/autofit/mapper/prior_model/abstract.py index ba9d0ea99..663d9fc69 100644 --- a/autofit/mapper/prior_model/abstract.py +++ b/autofit/mapper/prior_model/abstract.py @@ -1275,12 +1275,17 @@ def from_instance( def items(self): """Return (name, value) pairs for all public, non-internal attributes. - Excludes private attributes (prefixed with ``_``), ``cls``, and ``id``. + Excludes private attributes (prefixed with ``_``), ``cls``, ``id``, + and any ``cached_property``-style descriptors declared on the class + (see ``AbstractModel._cached_property_names``). """ + excluded = type(self)._cached_property_names() return [ (key, value) for key, value in self.__dict__.items() - if not key.startswith("_") and key not in ("cls", "id") + if not key.startswith("_") + and key not in ("cls", "id") + and key not in excluded ] @property diff --git a/autofit/mapper/prior_model/collection.py b/autofit/mapper/prior_model/collection.py index b06d8878b..5198dfe57 100644 --- a/autofit/mapper/prior_model/collection.py +++ b/autofit/mapper/prior_model/collection.py @@ -286,8 +286,9 @@ def _instance_for_arguments( A list of instances constructed from the list of prior models. """ result = ModelInstance() + excluded = type(self)._cached_property_names() for key, value in self.__dict__.items(): - if key.startswith("_"): + if key.startswith("_") or key in excluded: continue if isinstance(value, AbstractPriorModel): value = value.instance_for_arguments( diff --git a/autofit/mapper/prior_model/prior_model.py b/autofit/mapper/prior_model/prior_model.py index c0153e8ae..e758164e2 100644 --- a/autofit/mapper/prior_model/prior_model.py +++ b/autofit/mapper/prior_model/prior_model.py @@ -494,12 +494,14 @@ def _instance_for_arguments( else: result = self.cls(**constructor_arguments) + excluded = type(self)._cached_property_names() for key, value in self.__dict__.items(): if ( not hasattr(result, key) and not isinstance(value, Prior) and not key == "cls" and not key.startswith("_") + and key not in excluded ): if isinstance(value, Model): value = value.instance_for_arguments( diff --git a/test_autofit/mapper/test_parameterization.py b/test_autofit/mapper/test_parameterization.py index 416c07313..a8e845125 100644 --- a/test_autofit/mapper/test_parameterization.py +++ b/test_autofit/mapper/test_parameterization.py @@ -1,3 +1,4 @@ +import functools import itertools import pytest @@ -176,6 +177,92 @@ def test_parameterization_cache_does_not_leak_into_instance(): assert not isinstance(child, str) +def test_cached_property_names_classmethod_walks_mro(): + """The ``_cached_property_names`` classmethod on AbstractModel exposes the + autoconf ``cached_property_names`` MRO walker. It must pick up + descriptors declared on any ancestor and memoise the result on the class.""" + + import functools + + import autofit as af + + # Build a synthetic subclass with a cached_property to verify the walker + # finds it. We use af.Collection because both AbstractPriorModel and + # ModelInstance inherit from AbstractModel. + class SyntheticCollection(af.Collection): + @functools.cached_property + def synthetic_value(self): + return "a synthetic cached string" + + names = SyntheticCollection._cached_property_names() + assert "synthetic_value" in names + + # Result is memoised on the synthetic class. + assert "__cached_property_names_cache__" in SyntheticCollection.__dict__ + + # Plain af.Collection (no synthetic_value) has its own cache. + base_names = af.Collection._cached_property_names() + assert "synthetic_value" not in base_names + + +class _GuardedCollection(af.Collection): + """Module-level subclass used by + ``test_cached_property_excluded_from_all_dict_walks`` — must live at + module scope so ``pickle.dumps`` can locate the class on round-trip.""" + + @functools.cached_property + def derived(self): + return "leaky-string" + + +def test_cached_property_excluded_from_all_dict_walks(): + """Regression: a future ``@functools.cached_property`` declared anywhere + in the model class hierarchy must not surface through any of: + ``Collection._instance_for_arguments`` (via ``instance.__dict__``), + ``ModelInstance.dict``, ``ModelInstance.tree_flatten()``, + ``AbstractModel.items()``, ``ModelObject._dict``, or pickling via + ``__getstate__``. + + Covers the class of bug PyAutoFit#1300 fixed for ``parameterization``; + this test will fail if a maintainer reintroduces an un-prefixed + cached_property on the model hierarchy without the + ``_cached_property_names`` defense applied at every site.""" + + import pickle + + model = _GuardedCollection(gaussian=af.Model(af.ex.Gaussian)) + + # Trigger the cache. After this, model.__dict__["derived"] = "leaky-string". + _ = model.derived + assert model.__dict__.get("derived") == "leaky-string" + + instance = model.instance_from_prior_medians() + + # Site 1+4: Collection._instance_for_arguments + ModelInstance.dict + assert "derived" not in instance.__dict__ + assert "derived" not in instance.dict + + # Site 4 also feeds tree_flatten — no string leaves. + leaves = instance.dict.values() + for leaf in leaves: + assert not isinstance(leaf, str) + + # Site 3: AbstractModel.items() on the model itself. + assert all(key != "derived" for key, _ in model.items()) + + # Site 5: __getstate__ drops the cached value from pickles. + state = model.__getstate__() + assert "derived" not in state + + # Round-trip via pickle: the unpickled model re-computes the cached value, + # rather than carrying the pickled string on the wire. + blob = pickle.dumps(model) + revived = pickle.loads(blob) + assert "derived" not in revived.__dict__ + # Touching it recomputes. + assert revived.derived == "leaky-string" + + def test_integer_attributes(): model = af.Model(af.ex.Gaussian)