From a39b5a9b39912143b79503f728db86952681a357 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Thu, 2 Apr 2026 19:06:47 +0100 Subject: [PATCH 1/2] add expanded unit tests for model mapping API 81 new tests covering collection composition, shared priors, vector mapping, tree navigation, assertions, subsetting, freezing, serialization and edge cases. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../mapper/test_model_mapping_expanded.py | 631 ++++++++++++++++++ 1 file changed, 631 insertions(+) create mode 100644 test_autofit/mapper/test_model_mapping_expanded.py diff --git a/test_autofit/mapper/test_model_mapping_expanded.py b/test_autofit/mapper/test_model_mapping_expanded.py new file mode 100644 index 000000000..8be6fd534 --- /dev/null +++ b/test_autofit/mapper/test_model_mapping_expanded.py @@ -0,0 +1,631 @@ +""" +Expanded tests for the model mapping API, covering gaps identified in: +- Collection composition and instance creation +- Shared (linked) priors across model types +- Direct use of instance_for_arguments with argument dicts +- Model tree navigation (object_for_path, path_for_prior, name_for_prior) +- Edge cases (empty models, deeply nested models, single-parameter models) +- Model subsetting (with_paths, without_paths) +- Freezing behavior +- Assertion checking +- from_instance round-trips +- mapper_from_prior_arguments and mapper_from_partial_prior_arguments +""" +import copy + +import numpy as np +import pytest + +import autofit as af +from autofit import exc +from autofit.mapper.prior.abstract import Prior + + +# --------------------------------------------------------------------------- +# Collection: composition, nesting, instance creation, iteration +# --------------------------------------------------------------------------- +class TestCollectionComposition: + def test_collection_from_dict(self): + model = af.Collection( + one=af.Model(af.m.MockClassx2), + two=af.Model(af.m.MockClassx2), + ) + assert model.prior_count == 4 + + def test_collection_from_list(self): + model = af.Collection([af.m.MockClassx2, af.m.MockClassx2]) + assert model.prior_count == 4 + + def test_collection_from_generator(self): + model = af.Collection(af.Model(af.m.MockClassx2) for _ in range(3)) + assert model.prior_count == 6 + + def test_nested_collection(self): + inner = af.Collection(a=af.m.MockClassx2) + outer = af.Collection(inner=inner, extra=af.m.MockClassx2) + assert outer.prior_count == 4 + + def test_deeply_nested_collection(self): + model = af.Collection( + level1=af.Collection( + level2=af.Collection( + leaf=af.m.MockClassx2, + ) + ) + ) + assert model.prior_count == 2 + + def test_collection_instance_attribute_access(self): + model = af.Collection(gaussian=af.m.MockClassx2, exp=af.m.MockClassx2) + instance = model.instance_from_vector([1.0, 2.0, 3.0, 4.0]) + assert instance.gaussian.one == 1.0 + assert instance.gaussian.two == 2.0 + assert instance.exp.one == 3.0 + assert instance.exp.two == 4.0 + + def test_collection_instance_index_access(self): + model = af.Collection([af.m.MockClassx2, af.m.MockClassx2]) + instance = model.instance_from_vector([1.0, 2.0, 3.0, 4.0]) + assert instance[0].one == 1.0 + assert instance[1].one == 3.0 + + def test_collection_len(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + assert len(model) == 2 + + def test_collection_contains(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + assert "a" in model + assert "c" not in model + + def test_collection_items(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + keys = [k for k, v in model.items()] + assert "a" in keys + assert "b" in keys + + def test_collection_getitem_string(self): + model = af.Collection(a=af.m.MockClassx2) + assert isinstance(model["a"], af.Model) + + def test_collection_append(self): + model = af.Collection() + model.append(af.m.MockClassx2) + model.append(af.m.MockClassx2) + assert model.prior_count == 4 + + def test_collection_mixed_model_and_fixed(self): + """Collection with one free model and one fixed instance.""" + model = af.Collection( + free=af.Model(af.m.MockClassx2), + ) + assert model.prior_count == 2 + + def test_empty_collection(self): + model = af.Collection() + assert model.prior_count == 0 + + +# --------------------------------------------------------------------------- +# Shared (linked) priors +# --------------------------------------------------------------------------- +class TestSharedPriors: + def test_link_within_model(self): + model = af.Model(af.m.MockClassx2) + model.one = model.two + assert model.prior_count == 1 + instance = model.instance_from_vector([5.0]) + assert instance.one == instance.two == 5.0 + + def test_link_across_collection_children(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + model.a.one = model.b.one # Link a.one to b.one + assert model.prior_count == 3 # 4 - 1 shared + + def test_linked_priors_same_value(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + model.a.one = model.b.one + instance = model.instance_from_vector([10.0, 20.0, 30.0]) + assert instance.a.one == instance.b.one + + def test_link_reduces_unique_prior_count(self): + model = af.Model(af.m.MockClassx2) + original_count = len(model.unique_prior_tuples) + model.one = model.two + assert len(model.unique_prior_tuples) == original_count - 1 + + def test_linked_prior_identity(self): + model = af.Model(af.m.MockClassx2) + model.one = model.two + assert model.one is model.two + + +# --------------------------------------------------------------------------- +# instance_for_arguments (direct argument dict usage) +# --------------------------------------------------------------------------- +class TestInstanceForArguments: + def test_model_instance_for_arguments(self): + model = af.Model(af.m.MockClassx2) + args = {model.one: 10.0, model.two: 20.0} + instance = model.instance_for_arguments(args) + assert instance.one == 10.0 + assert instance.two == 20.0 + + def test_collection_instance_for_arguments(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + args = {} + for name, prior in model.prior_tuples_ordered_by_id: + args[prior] = 1.0 + instance = model.instance_for_arguments(args) + assert instance.a.one == 1.0 + assert instance.b.two == 1.0 + + def test_shared_prior_in_arguments(self): + """When priors are linked, only one entry is needed in the arguments dict.""" + model = af.Model(af.m.MockClassx2) + model.one = model.two + shared_prior = model.one + args = {shared_prior: 42.0} + instance = model.instance_for_arguments(args) + assert instance.one == 42.0 + assert instance.two == 42.0 + + def test_missing_prior_raises(self): + model = af.Model(af.m.MockClassx2) + args = {model.one: 10.0} # missing model.two + with pytest.raises(KeyError): + model.instance_for_arguments(args) + + +# --------------------------------------------------------------------------- +# Vector and unit vector mapping +# --------------------------------------------------------------------------- +class TestVectorMapping: + def test_instance_from_vector_basic(self): + model = af.Model(af.m.MockClassx2) + instance = model.instance_from_vector([3.0, 4.0]) + assert instance.one == 3.0 + assert instance.two == 4.0 + + def test_vector_length_mismatch_raises(self): + model = af.Model(af.m.MockClassx2) + with pytest.raises(AssertionError): + model.instance_from_vector([1.0]) + + def test_unit_vector_length_mismatch_raises(self): + model = af.Model(af.m.MockClassx2) + with pytest.raises(AssertionError): + model.instance_from_unit_vector([0.5]) + + def test_vector_from_unit_vector(self): + model = af.Model(af.m.MockClassx2) + model.one = af.UniformPrior(lower_limit=0.0, upper_limit=10.0) + model.two = af.UniformPrior(lower_limit=0.0, upper_limit=10.0) + physical = model.vector_from_unit_vector([0.0, 1.0]) + assert physical[0] == pytest.approx(0.0, abs=1e-6) + assert physical[1] == pytest.approx(10.0, abs=1e-6) + + def test_instance_from_prior_medians(self): + model = af.Model(af.m.MockClassx2) + model.one = af.UniformPrior(lower_limit=0.0, upper_limit=100.0) + model.two = af.UniformPrior(lower_limit=0.0, upper_limit=100.0) + instance = model.instance_from_prior_medians() + assert instance.one == pytest.approx(50.0) + assert instance.two == pytest.approx(50.0) + + def test_random_instance(self): + model = af.Model(af.m.MockClassx2) + model.one = af.UniformPrior(lower_limit=0.0, upper_limit=1.0) + model.two = af.UniformPrior(lower_limit=0.0, upper_limit=1.0) + instance = model.random_instance() + assert 0.0 <= instance.one <= 1.0 + assert 0.0 <= instance.two <= 1.0 + + +# --------------------------------------------------------------------------- +# Model tree navigation +# --------------------------------------------------------------------------- +class TestModelTreeNavigation: + def test_object_for_path_child_model(self): + model = af.Collection(g=af.Model(af.m.MockClassx2)) + child = model.object_for_path(("g",)) + assert isinstance(child, af.Model) + + def test_object_for_path_prior(self): + model = af.Collection(g=af.Model(af.m.MockClassx2)) + prior = model.object_for_path(("g", "one")) + assert isinstance(prior, Prior) + + def test_paths_matches_prior_count(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + assert len(model.paths) == model.prior_count + + def test_path_for_prior(self): + model = af.Collection(g=af.Model(af.m.MockClassx2)) + prior = model.g.one + path = model.path_for_prior(prior) + assert path == ("g", "one") + + def test_name_for_prior(self): + model = af.Collection(g=af.Model(af.m.MockClassx2)) + prior = model.g.one + name = model.name_for_prior(prior) + assert name == "g_one" + + def test_path_instance_tuples_for_class(self): + model = af.Collection(g=af.Model(af.m.MockClassx2)) + tuples = model.path_instance_tuples_for_class(Prior) + paths = [t[0] for t in tuples] + assert ("g", "one") in paths + assert ("g", "two") in paths + + def test_deeply_nested_path(self): + inner_model = af.Model(af.m.MockClassx2) + inner_collection = af.Collection(leaf=inner_model) + outer = af.Collection(branch=inner_collection) + + prior = outer.branch.leaf.one + path = outer.path_for_prior(prior) + assert path == ("branch", "leaf", "one") + + def test_direct_vs_recursive_prior_tuples(self): + model = af.Collection(a=af.m.MockClassx2) + assert len(model.direct_prior_tuples) == 0 # Collection has no direct priors + assert len(model.prior_tuples) == 2 # But has 2 recursive priors + + def test_direct_prior_model_tuples(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + assert len(model.direct_prior_model_tuples) == 2 + + +# --------------------------------------------------------------------------- +# instance_from_path_arguments and instance_from_prior_name_arguments +# --------------------------------------------------------------------------- +class TestPathAndNameArguments: + def test_instance_from_path_arguments(self): + model = af.Collection(g=af.m.MockClassx2) + instance = model.instance_from_path_arguments( + {("g", "one"): 10.0, ("g", "two"): 20.0} + ) + assert instance.g.one == 10.0 + assert instance.g.two == 20.0 + + def test_instance_from_prior_name_arguments(self): + model = af.Collection(g=af.m.MockClassx2) + instance = model.instance_from_prior_name_arguments( + {"g_one": 10.0, "g_two": 20.0} + ) + assert instance.g.one == 10.0 + assert instance.g.two == 20.0 + + +# --------------------------------------------------------------------------- +# Assertions +# --------------------------------------------------------------------------- +class TestAssertions: + def test_assertion_passes(self): + model = af.Model(af.m.MockClassx2) + model.add_assertion(model.one > model.two) + # one=10 > two=5 should pass + instance = model.instance_from_vector([10.0, 5.0]) + assert instance.one == 10.0 + + def test_assertion_fails(self): + model = af.Model(af.m.MockClassx2) + model.add_assertion(model.one > model.two) + with pytest.raises(exc.FitException): + model.instance_from_vector([1.0, 10.0]) + + def test_ignore_assertions(self): + model = af.Model(af.m.MockClassx2) + model.add_assertion(model.one > model.two) + # Should not raise even though assertion fails + instance = model.instance_from_vector([1.0, 10.0], ignore_assertions=True) + assert instance.one == 1.0 + + def test_multiple_assertions(self): + model = af.Model(af.m.MockClassx4) + model.add_assertion(model.one > model.two) + model.add_assertion(model.three > model.four) + # Both pass + instance = model.instance_from_vector([10.0, 5.0, 10.0, 5.0]) + assert instance.one == 10.0 + # First fails + with pytest.raises(exc.FitException): + model.instance_from_vector([1.0, 10.0, 10.0, 5.0]) + + def test_true_assertion_ignored(self): + """Adding True as an assertion should be a no-op.""" + model = af.Model(af.m.MockClassx2) + model.add_assertion(True) + assert len(model.assertions) == 0 + + +# --------------------------------------------------------------------------- +# Model subsetting (with_paths, without_paths) +# --------------------------------------------------------------------------- +class TestModelSubsetting: + def test_with_paths_single_child(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + subset = model.with_paths([("a",)]) + assert subset.prior_count == 2 + + def test_without_paths_single_child(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + subset = model.without_paths([("a",)]) + assert subset.prior_count == 2 + + def test_with_paths_specific_prior(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + subset = model.with_paths([("a", "one")]) + assert subset.prior_count == 1 + + def test_with_prefix(self): + model = af.Collection(ab_one=af.m.MockClassx2, cd_two=af.m.MockClassx2) + subset = model.with_prefix("ab") + assert subset.prior_count == 2 + + +# --------------------------------------------------------------------------- +# Freezing behavior +# --------------------------------------------------------------------------- +class TestFreezing: + def test_freeze_prevents_modification(self): + model = af.Model(af.m.MockClassx2) + model.freeze() + with pytest.raises(AssertionError): + model.one = af.UniformPrior(lower_limit=0.0, upper_limit=1.0) + + def test_unfreeze_allows_modification(self): + model = af.Model(af.m.MockClassx2) + model.freeze() + model.unfreeze() + model.one = af.UniformPrior(lower_limit=0.0, upper_limit=1.0) + assert isinstance(model.one, af.UniformPrior) + + def test_frozen_model_still_creates_instances(self): + model = af.Model(af.m.MockClassx2) + model.freeze() + instance = model.instance_from_vector([1.0, 2.0]) + assert instance.one == 1.0 + + def test_freeze_propagates_to_children(self): + model = af.Collection(a=af.m.MockClassx2) + model.freeze() + with pytest.raises(AssertionError): + model.a.one = 1.0 + + def test_cached_results_consistent(self): + model = af.Model(af.m.MockClassx2) + model.freeze() + result1 = model.prior_tuples_ordered_by_id + result2 = model.prior_tuples_ordered_by_id + assert result1 == result2 + + +# --------------------------------------------------------------------------- +# mapper_from_prior_arguments and related +# --------------------------------------------------------------------------- +class TestMapperFromPriorArguments: + def test_replace_all_priors(self): + model = af.Model(af.m.MockClassx2) + new_one = af.GaussianPrior(mean=0.0, sigma=1.0) + new_two = af.GaussianPrior(mean=5.0, sigma=2.0) + new_model = model.mapper_from_prior_arguments( + {model.one: new_one, model.two: new_two} + ) + assert new_model.prior_count == 2 + assert isinstance(new_model.one, af.GaussianPrior) + + def test_partial_replacement(self): + model = af.Model(af.m.MockClassx2) + new_one = af.GaussianPrior(mean=0.0, sigma=1.0) + new_model = model.mapper_from_partial_prior_arguments( + {model.one: new_one} + ) + assert new_model.prior_count == 2 + assert isinstance(new_model.one, af.GaussianPrior) + # two should retain its original prior type + assert new_model.two is not None + + def test_fix_via_mapper_from_prior_arguments(self): + """Replacing a prior with a float effectively fixes that parameter.""" + model = af.Model(af.m.MockClassx2) + new_model = model.mapper_from_prior_arguments( + {model.one: 5.0, model.two: model.two} + ) + assert new_model.prior_count == 1 + + def test_with_limits(self): + model = af.Model(af.m.MockClassx2) + model.one = af.UniformPrior(lower_limit=0.0, upper_limit=100.0) + model.two = af.UniformPrior(lower_limit=0.0, upper_limit=100.0) + new_model = model.with_limits([(10.0, 20.0), (30.0, 40.0)]) + assert new_model.prior_count == 2 + + +# --------------------------------------------------------------------------- +# from_instance round trips +# --------------------------------------------------------------------------- +class TestFromInstance: + def test_from_simple_instance(self): + instance = af.m.MockClassx2(1.0, 2.0) + model = af.AbstractPriorModel.from_instance(instance) + assert model.prior_count == 0 + + def test_from_instance_as_model(self): + instance = af.m.MockClassx2(1.0, 2.0) + model = af.AbstractPriorModel.from_instance(instance) + free_model = model.as_model() + assert free_model.prior_count == 2 + + def test_from_instance_with_model_classes(self): + instance = af.m.MockClassx2(1.0, 2.0) + model = af.AbstractPriorModel.from_instance( + instance, model_classes=(af.m.MockClassx2,) + ) + assert model.prior_count == 2 + + def test_from_list_instance(self): + instance_list = [af.m.MockClassx2(1.0, 2.0), af.m.MockClassx2(3.0, 4.0)] + model = af.AbstractPriorModel.from_instance(instance_list) + assert model.prior_count == 0 + + def test_from_dict_instance(self): + instance_dict = { + "one": af.m.MockClassx2(1.0, 2.0), + "two": af.m.MockClassx2(3.0, 4.0), + } + model = af.AbstractPriorModel.from_instance(instance_dict) + assert model.prior_count == 0 + + +# --------------------------------------------------------------------------- +# Fixing parameters and Constant values +# --------------------------------------------------------------------------- +class TestFixedParameters: + def test_fix_reduces_prior_count(self): + model = af.Model(af.m.MockClassx2) + model.one = 5.0 + assert model.prior_count == 1 + + def test_fixed_value_in_instance(self): + model = af.Model(af.m.MockClassx2) + model.one = 5.0 + instance = model.instance_from_vector([10.0]) + assert instance.one == 5.0 + assert instance.two == 10.0 + + def test_fix_all_parameters(self): + model = af.Model(af.m.MockClassx2) + model.one = 5.0 + model.two = 10.0 + assert model.prior_count == 0 + instance = model.instance_from_vector([]) + assert instance.one == 5.0 + assert instance.two == 10.0 + + +# --------------------------------------------------------------------------- +# take_attributes (prior passing) +# --------------------------------------------------------------------------- +class TestTakeAttributes: + def test_take_from_instance(self): + model = af.Model(af.m.MockClassx2) + source = af.m.MockClassx2(10.0, 20.0) + model.take_attributes(source) + assert model.prior_count == 0 + + def test_take_from_model(self): + """Taking attributes from another model copies priors.""" + source_model = af.Model(af.m.MockClassx2) + source_model.one = af.GaussianPrior(mean=5.0, sigma=1.0) + source_model.two = af.GaussianPrior(mean=10.0, sigma=2.0) + + target_model = af.Model(af.m.MockClassx2) + target_model.take_attributes(source_model) + assert isinstance(target_model.one, af.GaussianPrior) + + +# --------------------------------------------------------------------------- +# Serialization (dict / from_dict) +# --------------------------------------------------------------------------- +class TestSerialization: + def test_model_dict_roundtrip(self): + model = af.Model(af.m.MockClassx2) + d = model.dict() + loaded = af.AbstractPriorModel.from_dict(d) + assert loaded.prior_count == model.prior_count + + def test_collection_dict_roundtrip(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + d = model.dict() + loaded = af.AbstractPriorModel.from_dict(d) + assert loaded.prior_count == model.prior_count + + def test_fixed_parameter_survives_roundtrip(self): + model = af.Model(af.m.MockClassx2) + model.one = 5.0 + d = model.dict() + loaded = af.AbstractPriorModel.from_dict(d) + assert loaded.prior_count == 1 + + def test_linked_prior_survives_roundtrip(self): + model = af.Model(af.m.MockClassx2) + model.one = model.two + assert model.prior_count == 1 + d = model.dict() + loaded = af.AbstractPriorModel.from_dict(d) + assert loaded.prior_count == 1 + + +# --------------------------------------------------------------------------- +# Log prior computation +# --------------------------------------------------------------------------- +class TestLogPrior: + def test_log_prior_within_bounds(self): + model = af.Model(af.m.MockClassx2) + model.one = af.UniformPrior(lower_limit=0.0, upper_limit=10.0) + model.two = af.UniformPrior(lower_limit=0.0, upper_limit=10.0) + log_priors = model.log_prior_list_from_vector([5.0, 5.0]) + assert all(np.isfinite(lp) for lp in log_priors) + + def test_log_prior_outside_bounds(self): + model = af.Model(af.m.MockClassx2) + model.one = af.UniformPrior(lower_limit=0.0, upper_limit=10.0) + model.two = af.UniformPrior(lower_limit=0.0, upper_limit=10.0) + log_priors = model.log_prior_list_from_vector([15.0, 5.0]) + # Out-of-bounds value should have a lower (or zero) log prior than in-bounds + assert log_priors[0] <= log_priors[1] + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- +class TestEdgeCases: + def test_single_parameter_model(self): + """A model with a single free parameter using explicit prior.""" + model = af.Model(af.m.MockClassx2) + model.two = 5.0 # Fix one parameter + assert model.prior_count == 1 + instance = model.instance_from_vector([42.0]) + assert instance.one == 42.0 + assert instance.two == 5.0 + + def test_model_copy_preserves_priors(self): + model = af.Model(af.m.MockClassx2) + copied = model.copy() + assert copied.prior_count == model.prior_count + # Priors are independent copies (different objects) + assert copied.one is not model.one + + def test_model_copy_linked_priors_independent(self): + """Copying a model with linked priors preserves the link in the copy.""" + model = af.Model(af.m.MockClassx2) + model.one = model.two + assert model.prior_count == 1 + copied = model.copy() + assert copied.prior_count == 1 + # The copy's internal link is preserved + assert copied.one is copied.two + + def test_prior_ordering_is_deterministic(self): + """prior_tuples_ordered_by_id should be stable across calls.""" + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + order1 = [(n, p.id) for n, p in model.prior_tuples_ordered_by_id] + order2 = [(n, p.id) for n, p in model.prior_tuples_ordered_by_id] + assert order1 == order2 + + def test_prior_count_equals_total_free_parameters(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx4) + assert model.prior_count == model.total_free_parameters + + def test_has_model(self): + model = af.Collection(a=af.Model(af.m.MockClassx2)) + assert model.has_model(af.m.MockClassx2) + assert not model.has_model(af.m.MockClassx4) + + def test_has_instance(self): + model = af.Model(af.m.MockClassx2) + assert model.has_instance(Prior) + assert not model.has_instance(af.m.MockClassx4) From 2d91219a922a8ec9cb75dd7c6728c9c1c8b4dc8b Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Fri, 3 Apr 2026 09:30:25 +0100 Subject: [PATCH 2/2] add docstrings to mapper module public methods and properties Coverage improved from 63% to 96% across prior/abstract.py, prior/uniform.py, prior/gaussian.py, prior_model/abstract.py, prior_model/collection.py, and prior_model/array.py. Uses numpy-style docstrings matching existing conventions. Co-Authored-By: Claude Opus 4.6 (1M context) --- autofit/mapper/prior/abstract.py | 70 +++++++++++++ autofit/mapper/prior/gaussian.py | 7 ++ autofit/mapper/prior/uniform.py | 29 ++++++ autofit/mapper/prior_model/abstract.py | 122 +++++++++++++++++++++++ autofit/mapper/prior_model/array.py | 1 + autofit/mapper/prior_model/collection.py | 67 +++++++++++++ 6 files changed, 296 insertions(+) diff --git a/autofit/mapper/prior/abstract.py b/autofit/mapper/prior/abstract.py index 29380bd36..90fe2b6e3 100644 --- a/autofit/mapper/prior/abstract.py +++ b/autofit/mapper/prior/abstract.py @@ -61,6 +61,18 @@ def unit_value_for(self, physical_value: float) -> float: return self.message.cdf(physical_value) def with_message(self, message): + """Return a copy of this prior with a different message (distribution). + + Parameters + ---------- + message + The new message object defining the prior's distribution. + + Returns + ------- + Prior + A copy of this prior using the new message. + """ new = copy(self) new.message = message return new @@ -88,6 +100,23 @@ def factor(self): @staticmethod def for_class_and_attribute_name(cls, attribute_name): + """Create a prior from the configuration for a given class and attribute. + + Looks up the prior type and parameters in the prior config files + for the specified class and attribute name. + + Parameters + ---------- + cls + The model class whose config is looked up. + attribute_name + The name of the attribute on that class. + + Returns + ------- + Prior + A prior instance constructed from the config entry. + """ prior_dict = conf.instance.prior_config.for_class_and_suffix_path( cls, [attribute_name] ) @@ -129,10 +158,31 @@ def instance_for_arguments( arguments, ignore_assertions=False, ): + """Look up this prior's value in an arguments dictionary. + + Parameters + ---------- + arguments + A dictionary mapping Prior objects to physical values. + ignore_assertions + Unused for priors (present for interface compatibility). + """ _ = ignore_assertions return arguments[self] def project(self, samples, weights): + """Project this prior given samples and log weights from a search. + + Returns a copy of this prior whose message has been updated to + reflect the posterior information from the samples. + + Parameters + ---------- + samples + Array of sample values for this parameter. + weights + Log weights for each sample. + """ result = copy(self) result.message = self.message.project( samples=samples, @@ -170,6 +220,11 @@ def __str__(self): @property @abstractmethod def parameter_string(self) -> str: + """A human-readable string summarizing this prior's parameters. + + Subclasses must implement this to return a description such as + ``"mean = 0.0, sigma = 1.0"`` or ``"lower_limit = 0.0, upper_limit = 1.0"``. + """ pass def __float__(self): @@ -254,7 +309,22 @@ def name_of_class(cls) -> str: @property def limits(self) -> Tuple[float, float]: + """The (lower, upper) bounds of this prior. + + Returns (-inf, inf) by default. Subclasses with finite bounds + (e.g. UniformPrior) override this. + """ return (float("-inf"), float("inf")) def gaussian_prior_model_for_arguments(self, arguments): + """Look up this prior in an arguments dict and return the mapped value. + + Used during prior replacement workflows where each prior is mapped + to a new prior or fixed value via an arguments dictionary. + + Parameters + ---------- + arguments + A dictionary mapping Prior objects to their replacement values. + """ return arguments[self] diff --git a/autofit/mapper/prior/gaussian.py b/autofit/mapper/prior/gaussian.py index fe74b6016..7d6039bba 100644 --- a/autofit/mapper/prior/gaussian.py +++ b/autofit/mapper/prior/gaussian.py @@ -57,6 +57,13 @@ def __init__( ) def tree_flatten(self): + """Flatten this prior into a JAX-compatible PyTree representation. + + Returns + ------- + tuple + A (children, aux_data) pair where children are (mean, sigma, id). + """ return (self.mean, self.sigma, self.id), () @classmethod diff --git a/autofit/mapper/prior/uniform.py b/autofit/mapper/prior/uniform.py index af04f23d5..0bc2b97e5 100644 --- a/autofit/mapper/prior/uniform.py +++ b/autofit/mapper/prior/uniform.py @@ -64,10 +64,18 @@ def __init__( ) def tree_flatten(self): + """Flatten this prior into a JAX-compatible PyTree representation. + + Returns + ------- + tuple + A (children, aux_data) pair where children are (lower_limit, upper_limit, id). + """ return (self.lower_limit, self.upper_limit, self.id), () @property def width(self): + """The width of the uniform distribution (upper_limit - lower_limit).""" return self.upper_limit - self.lower_limit def with_limits( @@ -75,12 +83,31 @@ def with_limits( lower_limit: float, upper_limit: float, ) -> "Prior": + """Create a new UniformPrior with different bounds. + + Parameters + ---------- + lower_limit + The new lower bound. + upper_limit + The new upper bound. + """ return UniformPrior( lower_limit=lower_limit, upper_limit=upper_limit, ) def logpdf(self, x): + """Compute the log probability density at x. + + Adjusts boundary values by epsilon to avoid evaluating exactly at + the distribution edges where the PDF is undefined. + + Parameters + ---------- + x + The value at which to evaluate the log PDF. + """ # TODO: handle x as a numpy array if x == self.lower_limit: x += epsilon @@ -102,6 +129,7 @@ def dict(self) -> dict: @property def parameter_string(self) -> str: + """A human-readable string summarizing the prior's lower and upper limits.""" return f"lower_limit = {self.lower_limit}, upper_limit = {self.upper_limit}" def value_for(self, unit: float) -> float: @@ -142,4 +170,5 @@ def log_prior_from_value(self, value, xp=np): @property def limits(self) -> Tuple[float, float]: + """The (lower_limit, upper_limit) bounds of this uniform prior.""" return self.lower_limit, self.upper_limit \ No newline at end of file diff --git a/autofit/mapper/prior_model/abstract.py b/autofit/mapper/prior_model/abstract.py index f6b7f3939..bce4b0d3d 100644 --- a/autofit/mapper/prior_model/abstract.py +++ b/autofit/mapper/prior_model/abstract.py @@ -41,6 +41,20 @@ class Limits: @staticmethod def for_class_and_attributes_name(cls, attribute_name): + """Look up the (lower, upper) limits for a class attribute from config. + + Parameters + ---------- + cls + The model class. + attribute_name + The name of the attribute on that class. + + Returns + ------- + tuple + A (lower, upper) pair of limit values. + """ limit_dict = conf.instance.prior_config.for_class_and_suffix_path( cls, [attribute_name, "limits"] ) @@ -165,6 +179,11 @@ def __init__(self, label=None): @property def assertions(self): + """The list of assertion constraints attached to this model. + + Assertions are checked when creating instances; a failed assertion + raises FitException, causing the non-linear search to resample. + """ return self._assertions @assertions.setter @@ -441,11 +460,28 @@ def add_assertion(self, assertion, name=""): @property def name(self): + """The class name of this prior model (e.g. ``"Model"`` or ``"Collection"``).""" return self.__class__.__name__ # noinspection PyUnusedLocal @staticmethod def from_object(t, *args, **kwargs): + """Convert an arbitrary object into an appropriate prior model representation. + + - Classes become ``Model`` instances. + - Lists and dicts become ``Collection`` instances. + - Floats become ``Constant`` instances. + - Existing prior models and other objects are returned as-is. + + Parameters + ---------- + t + A class, list, dict, float, or existing prior model. + + Returns + ------- + An AbstractPriorModel, Constant, or the original object. + """ if inspect.isclass(t): from .prior_model import Model @@ -606,6 +642,7 @@ def prior_tuples_ordered_by_id(self): @property def priors_ordered_by_id(self): + """Unique priors sorted by their id, defining the canonical parameter ordering.""" return [prior for _, prior in self.prior_tuples_ordered_by_id] def vector_from_unit_vector(self, unit_vector): @@ -836,6 +873,16 @@ def is_only_model(self, cls) -> bool: return len(cls_models) > 0 and len(cls_models) == len(other_models) def replacing(self, arguments): + """Return a new model with some priors replaced. + + This is a convenience alias for ``mapper_from_partial_prior_arguments``. + Priors not present in the arguments dict are kept unchanged. + + Parameters + ---------- + arguments : dict + A dictionary mapping existing Prior objects to new priors or fixed values. + """ return self.mapper_from_partial_prior_arguments(arguments) @classmethod @@ -1211,6 +1258,10 @@ def from_instance( return result def items(self): + """Return (name, value) pairs for all public, non-internal attributes. + + Excludes private attributes (prefixed with ``_``), ``cls``, and ``id``. + """ return [ (key, value) for key, value in self.__dict__.items() @@ -1225,6 +1276,7 @@ def direct_prior_tuples(self): @property @cast_collection(InstanceNameValue) def direct_instance_tuples(self): + """(name, value) tuples for direct float and Constant attributes.""" return self.direct_tuples_with_type(float) + self.direct_tuples_with_type( Constant ) @@ -1232,16 +1284,19 @@ def direct_instance_tuples(self): @property @cast_collection(PriorModelNameValue) def prior_model_tuples(self): + """(name, prior_model) tuples for direct child AbstractPriorModel attributes.""" return self.direct_tuples_with_type(AbstractPriorModel) @property @cast_collection(PriorModelNameValue) def direct_prior_model_tuples(self): + """(name, prior_model) tuples for immediate child prior models (non-recursive).""" return self.direct_tuples_with_type(AbstractPriorModel) @property @cast_collection(PriorModelNameValue) def direct_tuple_priors(self): + """(name, tuple_prior) tuples for direct TuplePrior attributes.""" return self.direct_tuples_with_type(TuplePrior) @property @@ -1267,6 +1322,7 @@ def direct_prior_tuples(self): @property @cast_collection(DeferredNameValue) def direct_deferred_tuples(self): + """(name, deferred_argument) tuples for direct DeferredArgument attributes.""" return self.direct_tuples_with_type(DeferredArgument) @property @@ -1298,6 +1354,11 @@ def instance_tuples(self): @property def prior_class_dict(self): + """Map each prior to the class it will produce when instantiated. + + Direct priors on this model map to ``self.cls``. Child prior models + contribute their own mappings recursively. + """ from autofit.mapper.prior_model.annotation import AnnotationPriorModel d = {prior[1]: self.cls for prior in self.prior_tuples} @@ -1467,6 +1528,7 @@ def total_free_parameters(self) -> int: @property def priors(self): + """A list of all Prior objects in this model (may contain duplicates for shared priors).""" return [prior_tuple.prior for prior_tuple in self.prior_tuples] @property @@ -1474,9 +1536,41 @@ def _prior_id_map(self): return {prior.id: prior for prior in self.priors} def prior_with_id(self, prior_id): + """Retrieve a prior by its unique integer id. + + Parameters + ---------- + prior_id : int + The id of the prior to find. + + Returns + ------- + Prior + The prior with the matching id. + + Raises + ------ + KeyError + If no prior with the given id exists in this model. + """ return self._prior_id_map[prior_id] def name_for_prior(self, prior): + """Get the underscore-separated name for a prior in this model. + + Searches child prior models recursively. Returns None if the prior + is not found. + + Parameters + ---------- + prior : Prior + The prior to find. + + Returns + ------- + str or None + The name path joined by underscores, e.g. ``"galaxy_centre"``. + """ for prior_model_name, prior_model in self.direct_prior_model_tuples: prior_name = prior_model.name_for_prior(prior) if prior_name is not None: @@ -1525,6 +1619,11 @@ def copy_with_fixed_priors(self, instance, excluded_classes=tuple()): @property def path_priors_tuples(self) -> List[Tuple[Path, Prior]]: + """All (path, prior) tuples in this model, sorted by prior id. + + Unlike ``unique_path_prior_tuples``, this includes duplicate entries + when a prior appears at multiple paths (shared priors). + """ path_priors_tuples = self.path_instance_tuples_for_class(Prior) return sorted(path_priors_tuples, key=lambda item: item[1].id) @@ -1546,6 +1645,10 @@ def paths_formatted(self) -> List[Path]: @property def composition(self): + """A list of dot-separated path strings for each prior, ordered by prior id. + + For example: ``["galaxy.centre", "galaxy.normalization", "galaxy.sigma"]``. + """ return [".".join(path) for path in self.paths] def sort_priors_alphabetically(self, priors: Iterable[Prior]) -> List[Prior]: @@ -1617,14 +1720,20 @@ def all_paths_for_prior(self, prior: Prior) -> Generator[Path, None, None]: @property def path_float_tuples(self): + """(path, float) tuples for all fixed float values, excluding Prior objects.""" return self.path_instance_tuples_for_class(float, ignore_class=Prior) @property def unique_prior_paths(self): + """Paths to each unique prior (deduplicated for shared priors), ordered by id.""" return [item[0] for item in self.unique_path_prior_tuples] @property def unique_path_prior_tuples(self): + """(path, prior) tuples deduplicated by prior identity, ordered by id. + + When a prior is shared across multiple paths, only one path is kept. + """ unique = {item[1]: item for item in self.path_priors_tuples}.values() return sorted(unique, key=lambda item: item[1].id) @@ -1645,6 +1754,18 @@ def prior_prior_model_dict(self): } def log_prior_list_from(self, parameter_lists: List[List]) -> List: + """Compute the total log prior for each parameter vector in a list. + + Parameters + ---------- + parameter_lists + A list of physical parameter vectors. + + Returns + ------- + list + The summed log prior for each vector. + """ return [ sum(self.log_prior_list_from_vector(vector=vector)) for vector in parameter_lists @@ -1809,6 +1930,7 @@ def model_component_and_parameter_names(self) -> List[str]: @property def joined_paths(self) -> List[str]: + """Dot-joined path strings for each unique prior, ordered by id.""" prior_paths = self.unique_prior_paths return [".".join(path) for path in prior_paths] diff --git a/autofit/mapper/prior_model/array.py b/autofit/mapper/prior_model/array.py index 317680818..e2d1f8afb 100644 --- a/autofit/mapper/prior_model/array.py +++ b/autofit/mapper/prior_model/array.py @@ -197,6 +197,7 @@ def tree_unflatten(cls, aux_data, children): @property def prior_class_dict(self): + """Map each prior to the class it produces (np.ndarray for direct priors).""" return { **{ prior: cls diff --git a/autofit/mapper/prior_model/collection.py b/autofit/mapper/prior_model/collection.py index 13073c10e..ed6f19beb 100644 --- a/autofit/mapper/prior_model/collection.py +++ b/autofit/mapper/prior_model/collection.py @@ -31,11 +31,28 @@ def name_for_prior(self, prior: Prior) -> str: return name def tree_flatten(self): + """Flatten this collection into a JAX-compatible PyTree representation. + + Returns + ------- + tuple + A (children, aux_data) pair where children are the values and + aux_data are the corresponding keys. + """ keys, values = zip(*self.items()) return values, keys @classmethod def tree_unflatten(cls, aux_data, children): + """Reconstruct a Collection from a flattened PyTree. + + Parameters + ---------- + aux_data + The keys of the collection items. + children + The values of the collection items. + """ instance = cls() for key, value in zip(aux_data, children): @@ -46,6 +63,14 @@ def __contains__(self, item): return item in self._dict or item in self._dict.values() def __getitem__(self, item): + """Retrieve an item by string key or integer index. + + Parameters + ---------- + item : str or int + A string key for dict-style access, or an integer index + for positional access into the values list. + """ if isinstance(item, str): return self._dict[item] return self.values[item] @@ -64,9 +89,11 @@ def __repr__(self): @property def values(self): + """The model components in this collection as a list.""" return list(self._dict.values()) def items(self): + """The (key, model_component) pairs in this collection.""" return self._dict.items() def with_prefix(self, prefix: str): @@ -79,6 +106,11 @@ def with_prefix(self, prefix: str): ) def as_model(self): + """Convert all prior models in this collection to Model instances. + + Returns a new Collection where each AbstractPriorModel child has + been converted via its own as_model() method. + """ return Collection( { key: value.as_model() @@ -162,6 +194,13 @@ def __init__( @assert_not_frozen def add_dict_items(self, item_dict): + """Add all entries from a dictionary, converting values to prior models. + + Parameters + ---------- + item_dict + A dictionary mapping string keys to classes, instances, or prior models. + """ for key, value in item_dict.items(): if isinstance(key, tuple): key = ".".join(key) @@ -179,11 +218,20 @@ def __eq__(self, other): @assert_not_frozen def append(self, item): + """Append an item to the collection with an auto-incremented numeric key. + + The item is converted to an AbstractPriorModel if it is not already one. + """ setattr(self, str(self.item_number), AbstractPriorModel.from_object(item)) self.item_number += 1 @assert_not_frozen def __setitem__(self, key, value): + """Set an item by key, converting the value to a prior model. + + Preserves the id of any existing item at the same key so that + prior identity is maintained across replacements. + """ obj = AbstractPriorModel.from_object(value) try: obj.id = getattr(self, str(key)).id @@ -193,6 +241,12 @@ def __setitem__(self, key, value): @assert_not_frozen def __setattr__(self, key, value): + """Set an attribute, automatically converting values to prior models. + + Private attributes (starting with ``_``) are set directly. All other + values are wrapped via ``AbstractPriorModel.from_object`` so that + plain classes become ``Model`` instances and floats become fixed values. + """ if key.startswith("_"): super().__setattr__(key, value) else: @@ -202,6 +256,14 @@ def __setattr__(self, key, value): pass def remove(self, item): + """Remove an item from the collection by value equality. + + Parameters + ---------- + item + The item to remove. All entries whose value equals this item + are deleted. + """ for key, value in self.__dict__.copy().items(): if value == item: del self.__dict__[key] @@ -271,6 +333,11 @@ def gaussian_prior_model_for_arguments(self, arguments): @property def prior_class_dict(self): + """Map each prior to the class it will produce when instantiated. + + For child prior models, delegates to their own prior_class_dict. + Direct priors on the collection itself map to ModelInstance. + """ return { **{ prior: cls