Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion autofit/mapper/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,28 @@ def __init__(self, label=None, id_=None):
self._frozen_cache = dict()
super().__init__(label=label, id_=id_)

@classmethod
def _cached_property_names(cls) -> frozenset:
"""
Return the names of every ``cached_property``-style descriptor
declared anywhere in ``cls``'s MRO.

Used by the ``__dict__``-iteration sites in this module and in
``autofit/mapper/prior_model/`` to exclude cached descriptor values
from instance construction, ``ModelInstance.dict``, pickling, and
downstream JAX pytree flattening. See PyAutoFit#1300 for the
diagnosed leak this defends against.
"""
from autoconf.tools.decorators import cached_property_names

return cached_property_names(cls)

def __getstate__(self):
excluded = type(self)._cached_property_names()
return {
key: value for key, value in self.__dict__.items() if key != "_frozen_cache"
key: value
for key, value in self.__dict__.items()
if key != "_frozen_cache" and key not in excluded
}

def __setstate__(self, state):
Expand Down Expand Up @@ -446,11 +465,13 @@ def __hash__(self):

@property
def dict(self):
excluded = type(self)._cached_property_names()
return {
key: value
for key, value in self.__dict__.items()
if key not in ("id", "component_number", "item_number")
and not (isinstance(key, str) and key.startswith("_"))
and key not in excluded
}

def tree_flatten(self) -> Tuple[List, Tuple]:
Expand Down
11 changes: 11 additions & 0 deletions autofit/mapper/model_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,20 @@ def dict(self) -> dict:

@property
def _dict(self):
# Pick up any cached_property descriptors declared on the class so
# their cached values don't propagate via `Collection.items()` (which
# delegates here) or any other downstream consumer. The lookup is
# gated on hasattr because ModelObject is the base for the whole
# mapper module: a few non-AbstractModel descendants do not carry the
# ``_cached_property_names`` classmethod.
try:
excluded = type(self)._cached_property_names()
except AttributeError:
excluded = frozenset()
return {
key: value
for key, value in self.__dict__.items()
if key not in ("component_number", "item_number", "id", "cls", "label")
and not key.startswith("_")
and key not in excluded
}
9 changes: 7 additions & 2 deletions autofit/mapper/prior_model/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,12 +1275,17 @@ def from_instance(
def items(self):
"""Return (name, value) pairs for all public, non-internal attributes.

Excludes private attributes (prefixed with ``_``), ``cls``, and ``id``.
Excludes private attributes (prefixed with ``_``), ``cls``, ``id``,
and any ``cached_property``-style descriptors declared on the class
(see ``AbstractModel._cached_property_names``).
"""
excluded = type(self)._cached_property_names()
return [
(key, value)
for key, value in self.__dict__.items()
if not key.startswith("_") and key not in ("cls", "id")
if not key.startswith("_")
and key not in ("cls", "id")
and key not in excluded
]

@property
Expand Down
3 changes: 2 additions & 1 deletion autofit/mapper/prior_model/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,9 @@ def _instance_for_arguments(
A list of instances constructed from the list of prior models.
"""
result = ModelInstance()
excluded = type(self)._cached_property_names()
for key, value in self.__dict__.items():
if key.startswith("_"):
if key.startswith("_") or key in excluded:
continue
if isinstance(value, AbstractPriorModel):
value = value.instance_for_arguments(
Expand Down
2 changes: 2 additions & 0 deletions autofit/mapper/prior_model/prior_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,12 +494,14 @@ def _instance_for_arguments(
else:
result = self.cls(**constructor_arguments)

excluded = type(self)._cached_property_names()
for key, value in self.__dict__.items():
if (
not hasattr(result, key)
and not isinstance(value, Prior)
and not key == "cls"
and not key.startswith("_")
and key not in excluded
):
if isinstance(value, Model):
value = value.instance_for_arguments(
Expand Down
87 changes: 87 additions & 0 deletions test_autofit/mapper/test_parameterization.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import itertools

import pytest
Expand Down Expand Up @@ -176,6 +177,92 @@ def test_parameterization_cache_does_not_leak_into_instance():
assert not isinstance(child, str)


def test_cached_property_names_classmethod_walks_mro():
"""The ``_cached_property_names`` classmethod on AbstractModel exposes the
autoconf ``cached_property_names`` MRO walker. It must pick up
descriptors declared on any ancestor and memoise the result on the class."""

import functools

import autofit as af

# Build a synthetic subclass with a cached_property to verify the walker
# finds it. We use af.Collection because both AbstractPriorModel and
# ModelInstance inherit from AbstractModel.
class SyntheticCollection(af.Collection):
@functools.cached_property
def synthetic_value(self):
return "a synthetic cached string"

names = SyntheticCollection._cached_property_names()
assert "synthetic_value" in names

# Result is memoised on the synthetic class.
assert "__cached_property_names_cache__" in SyntheticCollection.__dict__

# Plain af.Collection (no synthetic_value) has its own cache.
base_names = af.Collection._cached_property_names()
assert "synthetic_value" not in base_names


class _GuardedCollection(af.Collection):
"""Module-level subclass used by
``test_cached_property_excluded_from_all_dict_walks`` — must live at
module scope so ``pickle.dumps`` can locate the class on round-trip."""

@functools.cached_property
def derived(self):
return "leaky-string"


def test_cached_property_excluded_from_all_dict_walks():
"""Regression: a future ``@functools.cached_property`` declared anywhere
in the model class hierarchy must not surface through any of:
``Collection._instance_for_arguments`` (via ``instance.__dict__``),
``ModelInstance.dict``, ``ModelInstance.tree_flatten()``,
``AbstractModel.items()``, ``ModelObject._dict``, or pickling via
``__getstate__``.

Covers the class of bug PyAutoFit#1300 fixed for ``parameterization``;
this test will fail if a maintainer reintroduces an un-prefixed
cached_property on the model hierarchy without the
``_cached_property_names`` defense applied at every site."""

import pickle

model = _GuardedCollection(gaussian=af.Model(af.ex.Gaussian))

# Trigger the cache. After this, model.__dict__["derived"] = "leaky-string".
_ = model.derived
assert model.__dict__.get("derived") == "leaky-string"

instance = model.instance_from_prior_medians()

# Site 1+4: Collection._instance_for_arguments + ModelInstance.dict
assert "derived" not in instance.__dict__
assert "derived" not in instance.dict

# Site 4 also feeds tree_flatten — no string leaves.
leaves = instance.dict.values()
for leaf in leaves:
assert not isinstance(leaf, str)

# Site 3: AbstractModel.items() on the model itself.
assert all(key != "derived" for key, _ in model.items())

# Site 5: __getstate__ drops the cached value from pickles.
state = model.__getstate__()
assert "derived" not in state

# Round-trip via pickle: the unpickled model re-computes the cached value,
# rather than carrying the pickled string on the wire.
blob = pickle.dumps(model)
revived = pickle.loads(blob)
assert "derived" not in revived.__dict__
# Touching it recomputes.
assert revived.derived == "leaky-string"


def test_integer_attributes():
model = af.Model(af.ex.Gaussian)

Expand Down
Loading