Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions autofit/mapper/prior_model/abstract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import copy
import functools
import inspect
import json
import logging
Expand Down Expand Up @@ -1860,12 +1859,26 @@ def order_no(self) -> str:
]
return ":".join(values)

@functools.cached_property
@property
def parameterization(self) -> str:
"""
Describes the path to each of the PriorModels, its class
and its number of free parameters
"""
and its number of free parameters.

Cached on first access in ``self.__dict__`` under the
``_`` -prefixed key ``_parameterization_cache`` so that
``Collection._instance_for_arguments`` and
``ModelInstance.dict`` (which iterate ``__dict__`` and filter
underscore-prefixed keys) do not propagate the cached string
onto the constructed ``ModelInstance``. A plain
``functools.cached_property`` writes to ``__dict__[name]``
without a leading underscore, which would leak the string as
a non-array JAX pytree leaf and break ``jax.jit(fit_from)``.
"""
cached = self.__dict__.get("_parameterization_cache")
if cached is not None:
return cached

from .prior_model import Model

formatter = TextFormatter(line_length=info_whitespace())
Expand Down Expand Up @@ -1900,6 +1913,7 @@ def parameterization(self) -> str:
for group in find_groups(paths, limit=0):
formatter.add(*group)

self.__dict__["_parameterization_cache"] = formatter.text
return formatter.text

@property
Expand Down
6 changes: 6 additions & 0 deletions autofit/non_linear/fitness.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ def __init__(
self.use_jax_vmap = use_jax_vmap
self.use_jax_jit = use_jax_jit

if getattr(self.analysis, "_use_jax", False):
from autofit.jax.pytrees import enable_pytrees, register_model

enable_pytrees()
register_model(self.model)

self._call = self.call

if self.use_jax_vmap:
Expand Down
35 changes: 35 additions & 0 deletions test_autofit/mapper/test_parameterization.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,41 @@ def test_tuple_instance_model_info(self, mapper):
assert len(info.split("\n")) == len(mapper.info.split("\n"))


def test_parameterization_cache_does_not_leak_into_instance():
"""Regression: ``parameterization`` is cached in
``self.__dict__["_parameterization_cache"]`` so that
``Collection._instance_for_arguments`` and ``ModelInstance.dict``
(which skip underscore-prefixed keys) do not propagate the cached
string onto the constructed instance. A plain
``functools.cached_property`` would write to ``__dict__["parameterization"]``
without an underscore, leaking the string into ``ModelInstance.dict``
and downstream JAX pytree flattening — see commit 4564ae9a1."""

model = af.Collection(gaussian=af.Model(af.ex.Gaussian))

# Touch model.info → exercises the same propagation path that every
# workspace script hits at construction time.
_ = model.info
_ = model.parameterization # second access uses the cache

# The cache must live behind an underscore key on the model.
assert "_parameterization_cache" in model.__dict__
assert "parameterization" not in model.__dict__

instance = model.instance_from_prior_medians()

# Neither the cached key nor the public name may appear on the
# constructed instance.
assert "parameterization" not in instance.__dict__
assert "_parameterization_cache" not in instance.__dict__
assert "parameterization" not in instance.dict
assert "_parameterization_cache" not in instance.dict

# The instance must yield only model components when iterated.
for child in instance:
assert not isinstance(child, str)


def test_integer_attributes():
model = af.Model(af.ex.Gaussian)

Expand Down
Loading