Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion autofit/graphical/declarative/factor/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,17 @@ def __getattr__(self, item):
return getattr(self.prior_model, item)

def name_for_variable(self, variable):
path = ".".join(self.prior_model.path_for_prior(variable))
path_iter = self.prior_model.path_for_prior(variable)
if path_iter is None:
# Variable is no longer in the factor's prior model — typically
# because it was replaced by a fixed scalar after the factor was
# registered in the graph (the constant-element pattern of
# ``Array.__setitem__``). The factor still appears in the graph's
# variable→factor map, but it does not reference this variable
# anymore, so info-only callers (graph.info, results text) get
# ``None`` and are expected to skip it.
return None
path = ".".join(path_iter)
return f"{self.name}.{path}"

def visualize(
Expand Down Expand Up @@ -236,3 +246,22 @@ def set_cavity_dist(self, cavity_dist):
cavity Gaussian summary.
"""
self.analysis._cavity_mean_field = cavity_dist

def set_model_approx(self, model_approx):
"""
Store the full ``EPMeanField`` on the wrapped ``Analysis``.

Called by :func:`factor_step` before ``set_cavity_dist`` on
each EP iteration. ``set_cavity_dist`` only exposes the cavity
``MeanField`` over this factor's *own* variables; some
hierarchical use cases additionally need to inspect the
per-factor messages of *sibling* factors (e.g. a "global"
Analysis that reads each upstream local fit's posterior on a
variable that the global model itself does not formally vary).

The default implementation just attaches ``model_approx`` to
the Analysis as ``_mean_field``. Subclasses (e.g. workspace
factors that freeze a subset of priors at sibling-fit means
between iterations) can override to add more behaviour.
"""
self.analysis._mean_field = model_approx
12 changes: 7 additions & 5 deletions autofit/graphical/declarative/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,13 @@ def _related_factor_names(
factor.distribution_model.name
)
else:
names.add(
factor.name_for_variable(
variable
)
)
name = factor.name_for_variable(variable)
if name is None:
# Factor no longer references this variable (typically
# because the variable was fixed to a scalar after graph
# construction). Skip it for info purposes.
continue
names.add(name)

return ", ".join(sorted(names))

Expand Down
21 changes: 17 additions & 4 deletions autofit/graphical/expectation_propagation/optimiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,21 @@ def delta(self, factor, model_approx):
)


def factor_step(factor_approx, optimiser):
def factor_step(factor_approx, optimiser, model_approx=None):
factor = factor_approx.factor
factor_logger = logging.getLogger(factor.name)
factor_logger.debug("Optimising...")
# Mean-field opt-in: factors that implement ``set_model_approx``
# (e.g. ``EPAnalysisFactor``) get the *full* ``EPMeanField`` first.
# This is needed by hierarchical factors that have to inspect
# sibling factors' messages on variables the current factor's
# cavity no longer contains — e.g. variables that were dropped from
# this factor's prior model via the constant-element pattern of
# ``Array.__setitem__``. Default factors lack the method, so the
# call is a no-op for them. ``model_approx`` is None when called
# from a context that does not propagate it (older call sites).
if hasattr(factor, "set_model_approx") and model_approx is not None:
factor.set_model_approx(model_approx)
# Cavity-message opt-in: factors that implement ``set_cavity_dist``
# (e.g. ``EPAnalysisFactor``) receive the current cavity distribution
# before optimisation so their Analysis can read per-variable cavity
Expand Down Expand Up @@ -277,8 +288,8 @@ def _log_factor(self, factor: Factor):
except exc.HistoryException as e:
factor_logger.exception(e)

def factor_step(self, factor_approx, optimiser):
return factor_step(factor_approx, optimiser)
def factor_step(self, factor_approx, optimiser, model_approx=None):
return factor_step(factor_approx, optimiser, model_approx=model_approx)

def run(
self,
Expand Down Expand Up @@ -323,7 +334,9 @@ def run(
_should_output = should_output()
for factor, optimiser in self.factor_optimisers.items():
factor_approx = model_approx.factor_approximation(factor)
new_model_dist, status = self.factor_step(factor_approx, optimiser)
new_model_dist, status = self.factor_step(
factor_approx, optimiser, model_approx=model_approx,
)
model_approx, status = self.updater.update_model_approx(
new_model_dist, factor_approx, model_approx, status
)
Expand Down
11 changes: 10 additions & 1 deletion autofit/mapper/prior_model/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,15 @@ def gaussian_prior_model_for_arguments(self, arguments: Dict[Prior, Prior]):
"""
new_array = Array(self.shape)
for index in self.indices:
new_array[index] = self[index].gaussian_prior_model_for_arguments(arguments)
element = self[index]
try:
new_array[index] = element.gaussian_prior_model_for_arguments(arguments)
except AttributeError:
# Element is a fixed scalar (float / int / np.ndarray) — the
# documented constant-element pattern of ``Array.__setitem__``.
# Mirrors the try/except already used in
# ``_instance_for_arguments`` so fixed elements pass through
# unchanged when the model is rebuilt from posterior samples.
new_array[index] = element

return new_array
26 changes: 26 additions & 0 deletions test_autofit/mapper/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,32 @@ def test_1d_array_modify_prior(array_1d):
assert (array_1d.instance_from_prior_medians() == np.array([1.0, 0.0])).all()


def test_gaussian_prior_model_for_arguments_with_fixed_element(array):
"""
``Array.gaussian_prior_model_for_arguments`` is invoked by
``AbstractSearch.optimise`` to build a posterior ``GaussianPrior``
model from search arguments. Fixed scalar elements (set via
``arr[i, j] = float``) must pass through unchanged — they have no
prior to update from posterior samples. Mirrors the try/except
already in ``_instance_for_arguments``.
"""
array[0, 0] = 1.5
array[1, 1] = 2.5

arguments = {
prior: af.GaussianPrior(mean=10.0, sigma=0.1)
for prior in array.priors
}
new_array = array.gaussian_prior_model_for_arguments(arguments)

assert new_array.prior_count == 2
instance = new_array.instance_from_prior_medians()
assert instance[0, 0] == 1.5
assert instance[1, 1] == 2.5
assert instance[0, 1] == 10.0
assert instance[1, 0] == 10.0


def test_tree_flatten(array):
children, aux = array.tree_flatten()
assert len(children) == 4
Expand Down
Loading