From 19b7f1e20cbc1892064781e8b2b4b180038351e2 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 15 Aug 2025 15:07:46 +0100 Subject: [PATCH 01/33] Reinstate check after renaming caused bugs in sampling --- tests/test_integration/test_two_normal_example.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index 5283fb1..1957a61 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -85,18 +85,18 @@ def lagrangian( subkeys = jax.random.split(rng_key, predictive_model.num_samples) l_mult = parameter_values["_l_mult"] - def _x_sampler(pv: dict[str, npt.ArrayLike], key: jax.Array) -> float: + def _ux_sampler(pv: dict[str, npt.ArrayLike], key: jax.Array) -> float: return predictive_model(key, **pv)["UX"] + def _x_sampler(pv: dict[str, npt.ArrayLike], key: jax.Array) -> float: + return predictive_model(key, **pv)["X"] + def _ce(pv, subkeys): return ( ce_prefactor * jax.vmap(_x_sampler, in_axes=(None, 0))(pv, subkeys).mean() ) - def _ux_sampler(pv: dict[str, npt.ArrayLike], key: jax.Array) -> float: - return predictive_model(key, **pv)["X"] - def _constraint(pv, subkeys): return ( jnp.abs( @@ -156,6 +156,11 @@ def objective(params, predictive, key, ce_prefactor=ce_prefactor): assert converged, f"Did not converge, final objective value: {objective_value}" + # The lagrangian is independent of nu_x, thus it should not have changed value. + assert jnp.isclose(params["cov2"], nu_x_starting_value), ( + "nu_x has changed significantly from the starting value." + ) + # Confirm that we found a minimiser that does satisfy the inequality constraints. assert params["_l_mult"] > 0.0, ( f"Converged, but not to a minimiser (lagrange multiplier = {params['_l_mult']})" From a5f55bca04b6e4a1aeac821e7d6cce99af05ecc9 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 15 Aug 2025 14:56:37 +0100 Subject: [PATCH 02/33] Remove old CausalProblem class --- src/causalprog/causal_problem.py | 319 ------------------ tests/test_causal_problem/test_callables.py | 212 ------------ .../test_graph_and_param.py | 42 --- 3 files changed, 573 deletions(-) delete mode 100644 src/causalprog/causal_problem.py delete mode 100644 tests/test_causal_problem/test_callables.py delete mode 100644 tests/test_causal_problem/test_graph_and_param.py diff --git a/src/causalprog/causal_problem.py b/src/causalprog/causal_problem.py deleted file mode 100644 index e1232cc..0000000 --- a/src/causalprog/causal_problem.py +++ /dev/null @@ -1,319 +0,0 @@ -"""Extension of the Graph class providing features for solving causal problems.""" - -from collections.abc import Callable -from inspect import signature -from typing import Literal, TypeAlias - -import jax -import jax.numpy as jnp - -from causalprog._abc.labelled import Labelled -from causalprog.graph import Graph, Node - -CausalEstimand: TypeAlias = Callable[..., float] -Constraints: TypeAlias = Callable[..., float] - - -def raises(exception: Exception) -> Callable[[], float]: - """Create a callable that raises ``exception`` when called.""" - - def _inner() -> float: - raise exception - - return _inner - - -class CausalProblem(Labelled): - """ - Container class for handling a causal problem. - - A causal problem - requires an underlying ``Graph`` to describe the relationships between the random - variables and parameters, plus a causal estimand and list of (data) constraints. - Structural constraints are handled by imposing restrictions on forms of the random - variables and constraints directly. - - A ``CausalProblem`` instance brings together these components, providing a container - for a causal problem that can be given inputs like empirical data, a solver - tolerance, etc, and will provide (estimates of) the bounds for the causal estimand. - - - The ``.graph`` attribute stores the underlying ``Graph`` object. - - The ``.causal_estimand`` method evaluates the causal estimand, given values for - the parameters. - - The ``.constraints`` method evaluates the (vector-valued) constraints, given - values for the parameters. - - The user must specify each of the above before a ``CausalProblem`` can be solved. - The primary way for this to be done is to construct or load the corresponding - ``Graph``, and provide it by setting the ``CausalProblem.graph`` attribute directly. - Then, `set_causal_estimand` and `set_constraints` can be used to provide the causal - estimand and constraints functions, in terms of the random variables. The - ``CausalProblem`` instance will handle turning them into functions of the parameter - values under the hood. Initial parameter values (for the purposes of solving) can be - provided to the solver method directly or set beforehand via ``set_parameters``. It - should never be necessary for the user to interact with, or provide, a vector of - parameters (as this is taken care of under the hood). - """ - - _graph: Graph | None - _sigma: CausalEstimand - _sigma_mapping: dict[str, Node] - _constraints: Constraints - _constraints_mapping: dict[str, Node] - _parameter_values: dict[str, float] - - @property - def graph(self) -> Graph: - """Graph defining the structure of the `CausalProblem`.""" - if self._graph is None: - msg = f"No graph set for {self.label}." - raise ValueError(msg) - return self._graph - - @graph.setter - def graph(self, new_graph: Graph) -> None: - if not isinstance(new_graph, Graph): - msg = f"{self.label}.graph must be a Graph instance." - raise TypeError(msg) - self._graph = new_graph - - @property - def parameter_values(self) -> dict[str, float]: - """Dictionary mapping parameter labels to their (current) values.""" - return self._parameter_vector_to_dict(self.parameter_vector) - - @property - def parameter_vector(self) -> jax.Array: - """Returns the (current) vector of parameter values.""" - return jnp.array( - tuple( - self._parameter_values[node.label] - if node.label in self._parameter_values - else float("NaN") - for node in self.graph.parameter_nodes - ), - ndmin=1, - ) - - def __init__( - self, - graph: Graph | None = None, - *, - label: str = "CausalProblem", - ) -> None: - """Set up a new CausalProblem.""" - super().__init__(label=label) - - self._parameter_values = {} - self._graph = graph - - # Callables cannot be evaluated until they are explicitly set - self.set_causal_estimand( - raises(NotImplementedError(f"Causal estimand not set for {self.label}.")) - ) - self.set_constraints( - raises(NotImplementedError(f"Constraints not set for {self.label}.")) - ) - - def _parameter_vector_to_dict( - self, parameter_vector: jax.Array - ) -> dict[str, float]: - """ - Convert a parameter vector to a dictionary mapping labels to parameter values. - - Convention is that a vector of parameter values contains values in the same - order as self.graph.parameter_nodes. - """ - # Avoid recomputing the parameter node tuple every time. - pn = self.graph.parameter_nodes - return {pn[i].label: value for i, value in enumerate(parameter_vector)} - - def _eval_callable( - self, which: Literal["sigma", "constraints"], at: jax.Array - ) -> jax.Array: - """ - Evaluate a callable method of this instance. - - This is an abstraction method for when the causal estimand or constraints - functions need to be evaluated. In each case, the process is the same: - - - Update the values of the parameter nodes. - - Call the underlying function composed with its mapping of RVs to Nodes. - - The method is abstracted here so that any changes to the process are reflected - in both methods automatically. - """ - self._set_parameters_via_vector(at) - if "parameter_values" in signature(getattr(self, f"_{which}")).parameters: - return getattr(self, f"_{which}")( - **getattr(self, f"_{which}_mapping"), - parameter_values=self._parameter_values, - ) - return getattr(self, f"_{which}")(**getattr(self, f"_{which}_mapping")) - - def _set_callable( - self, - which: Literal["sigma", "constraints"], - *, - fn: CausalEstimand | Constraints, - rvs_to_nodes: dict[str, str] | None = None, - graph_argument: str | None = None, - ) -> None: - """ - Abstracted method for setting the Causal Estimand and/or Constraints functions. - - The functionality for setting up these two methods of an instance are identical, - save for the attributes which need to be updated. As such, we can refactor the - common functionality into a single, hidden, method and provide a friendlier - access point for users to employ. - """ - fn_attr = f"_{which}" - map_attr = f"_{which}_mapping" - debug_name = "constraints" if which == "constraints" else "causal estimand" - - setattr(self, fn_attr, fn) - setattr(self, map_attr, {}) - - if rvs_to_nodes is None: - rvs_to_nodes = {} - fn_args = signature(fn).parameters - - for rv_name, node_label in rvs_to_nodes.items(): - if rv_name not in fn_args: - msg = f"{rv_name} is not an argument to provided {debug_name} function." - raise ValueError(msg) - getattr(self, map_attr)[rv_name] = self.graph.get_node(node_label) - - # Any unaccounted-for RV arguments to sigma are assumed to match - # the label of the corresponding node. - args_not_used = set(fn_args) - set(getattr(self, map_attr)) - - ## Temporary hack to ensure that we can use expectation(graph, X) syntax. - if graph_argument: - getattr(self, map_attr)[graph_argument] = self.graph - args_not_used -= {graph_argument} - ## END HACK - - for arg in args_not_used: - if arg != "parameter_values": - getattr(self, map_attr)[arg] = self.graph.get_node(arg) - - def _set_parameters_via_vector(self, parameter_vector: jax.Array | None) -> None: - """ - Shorthand to set parameter node values from a parameter vector. - - No intended for frontend use - primary use will be internal when running - optimisation methods over the CausalProblem, when we need to treat the - parameters as a vector or array of function inputs. - """ - self.set_parameter_values(**self._parameter_vector_to_dict(parameter_vector)) - - def set_parameter_values(self, **parameter_values: float | None) -> None: - """ - Set (initial) parameter values for this CausalProblem. - - See ``Graph.set_parameters`` for input details. - """ - for parameter, value in parameter_values.items(): - if value is None: - if parameter in self._parameter_values: - del self._parameter_values[parameter] - else: - self._parameter_values[parameter] = value - - def set_causal_estimand( - self, - sigma: CausalEstimand, - rvs_to_nodes: dict[str, str] | None = None, - graph_argument: str | None = None, - ) -> None: - """ - Set the Causal Estimand for this problem. - - `sigma` should be a callable object that defines the Causal Estimand of - interest, in terms of the random variables of to the problem. The - random variables are in turn represented by `Node`s, with this association being - recorded in the `rv_to_nodes` dictionary. - - The `causal_estimand` method of the instance will be usable once this method - completes. - - Args: - sigma (CausalEstimand): Callable object that evaluates the causal estimand - of interest for this `CausalProblem`, in terms of the random variables, - which are the arguments to this callable. `sigma`s with additional - arguments are not currently supported. - rvs_to_nodes (dict[str, str]): Mapping of random variable (argument) names - of `sigma` to the labels of the corresponding `Node`s representing the - random variables. Argument names that match their corresponding `Node` - label can be omitted. - graph_argument (str): Argument to `sigma` that should be replaced with - `self.graph`. This argument is only temporary, as we are currently - limited to the syntax `expectation(Graph, Node)` rather than just - `expectation(Node)`. It will be removed in the future when methods like - `expectation` can be called solely on `Node` objects. - - """ - self._set_callable( - "sigma", fn=sigma, rvs_to_nodes=rvs_to_nodes, graph_argument=graph_argument - ) - - def set_constraints( - self, - constraints: CausalEstimand, - rvs_to_nodes: dict[str, str] | None = None, - graph_argument: str | None = None, - ) -> None: - """ - Set the Constraints for this problem. - - ``constraints`` should be a callable object that defines the Data Constraints of - interest, in terms of the random variables of to the problem. The - random variables are in turn represented by `Node`s, with this association being - recorded in the `rv_to_nodes` dictionary. - - The `constraints` method of the instance will be usable once this method - completes. - - Args: - constraints (Constraints): Callable object that evaluates the constraints - of interest for this `CausalProblem`, in terms of the random variables, - which are the arguments to this callable. ``constraints``s with - additional arguments are not currently supported. - rvs_to_nodes (dict[str, str]): Mapping of random variable (argument) names - of `sigma` to the labels of the corresponding `Node`s representing the - random variables. Argument names that match their corresponding `Node` - label can be omitted. - graph_argument (str): Argument to `sigma` that should be replaced with - `self.graph`. This argument is only temporary, as we are currently - limited to the syntax `expectation(Graph, Node)` rather than just - `expectation(Node)`. It will be removed in the future when methods like - `expectation` can be called solely on `Node` objects. - - """ - self._set_callable( - "constraints", - fn=constraints, - rvs_to_nodes=rvs_to_nodes, - graph_argument=graph_argument, - ) - - def causal_estimand(self, p: jax.Array) -> float: - """ - Evaluate the Causal Estimand at parameter vector `p`. - - Args: - p (jax.Array): Vector of parameter values to evaluate at. - - """ - return self._eval_callable("sigma", p) - - def constraints(self, p: jax.Array) -> jax.Array: - """ - Evaluate the Constraints at parameter vector `p`. - - Args: - p (jax.Array): Vector of parameter values to evaluate at. - - """ - return self._eval_callable("constraints", p) diff --git a/tests/test_causal_problem/test_callables.py b/tests/test_causal_problem/test_callables.py deleted file mode 100644 index f70d2f2..0000000 --- a/tests/test_causal_problem/test_callables.py +++ /dev/null @@ -1,212 +0,0 @@ -from collections.abc import Callable -from typing import Literal - -import jax -import jax.numpy as jnp -import pytest - -from causalprog.algorithms import expectation, standard_deviation -from causalprog.causal_problem import CausalProblem -from causalprog.graph import Graph, Node - - -@pytest.fixture -def n_samples_for_estimands() -> int: - return 1000 - - -@pytest.fixture -def expectation_fixture( - rng_key: jax.Array, - n_samples_for_estimands: int, -) -> Callable[[Graph, Node, dict[str, float] | None], float]: - def _inner(g: Graph, x: Node, parameter_values=None) -> float: - return expectation( - g, - x.label, - samples=n_samples_for_estimands, - rng_key=rng_key, - parameter_values=parameter_values, - ) - - return _inner - - -@pytest.fixture -def std_fixture( - rng_key: jax.Array, - n_samples_for_estimands: int, -) -> Callable[[Graph, Node, dict[str, float] | None], float]: - def _inner(g: Graph, x: Node, parameter_values=None) -> float: - return ( - standard_deviation( - g, - x.label, - samples=n_samples_for_estimands, - rng_key=rng_key, - parameter_values=parameter_values, - ) - ** 2 - ) - - return _inner - - -@pytest.fixture -def vector_fixture( - rng_key: jax.Array, - n_samples_for_estimands: int, -) -> Callable[[Graph, Node, Node, dict[str, float] | None], jax.Array]: - """vector_fixture(g, x1, x2) = [mean of x1, std of x2].""" - - def _inner(g: Graph, x1: Node, x2: Node, parameter_values=None) -> jax.Array: - return jnp.array( - [ - expectation( - g, - x1.label, - samples=n_samples_for_estimands, - rng_key=rng_key, - parameter_values=parameter_values, - ), - standard_deviation( - g, - x2.label, - samples=n_samples_for_estimands, - rng_key=rng_key, - parameter_values=parameter_values, - ) - ** 2, - ] - ) - - return _inner - - -@pytest.fixture(params=["causal_estimand", "constraints"]) -def which(request: pytest.FixtureRequest) -> Literal["causal_estimand", "constraints"]: - """For tests applicable to both the causal_estimand and constraints methods.""" - return request.param - - -@pytest.mark.parametrize( - ("initial_param_values", "args_to_setter", "expected", "atol"), - [ - pytest.param( - {"mean": 1.0, "cov2": 1.0}, - { - "fn": "expectation_fixture", - "rvs_to_nodes": {"x": "mean"}, - "graph_argument": "g", - }, - 1.0, - 1.0e-12, - id="mean", - ), - pytest.param( - {"mean": 1.0, "cov2": 1.0}, - { - "fn": "expectation_fixture", - "rvs_to_nodes": {"x": "cov2"}, - "graph_argument": "g", - }, - 1.0, - 1.0e-12, - id="cov2", - ), - pytest.param( - {"mean": 0.0, "cov2": 1.0}, - { - "fn": "expectation_fixture", - "rvs_to_nodes": {"x": "X"}, - "graph_argument": "g", - }, - 0.0, - 3.0e-2, - id="E[x], infer association", - ), - pytest.param( - {"mean": 0.0, "cov2": 1.0}, - { - "fn": "std_fixture", - "rvs_to_nodes": {"x": "X"}, - "graph_argument": "g", - }, - # UX has fixed std 1, and cov2 will be set to 1. - 1.0**2 + 1.0**2, - 3.0e-1, - id="Var[y]", - ), - pytest.param( - {"mean": 0.0, "cov2": 1.0}, - { - "fn": "vector_fixture", - "rvs_to_nodes": {"x1": "UX", "x2": "X"}, - "graph_argument": "g", - }, - # As per the previous test cases - jnp.array([0.0, 1.0**2 + 1.0**2]), - jnp.array([3.0e-2, 2.0e-1]), - id="E[x], Var[y]", - ), - ], -) -def test_callables( - two_normal_graph: Callable[..., Graph], - which: Literal["causal_estimand", "constraints"], - initial_param_values: dict[str, float], - args_to_setter: dict[str, Callable[..., float] | dict[str, str] | str], - expected: float | jax.Array, - atol: float, - request: pytest.FixtureRequest, - raises_context, -) -> None: - """ - Test the set_{causal_estimand, constraints} and .{casual_estimand, constraints} - evaluation method. - - Test works by: - - Set the parameter values using the initial_param_values. - - Call the setter method using the given arguments. - - Evaluate the method that should have been set at the current parameter_vector, - which should evaluate the corresponding function at the current values of the - parameter vector, which will be the initial values just set. - - Check the result (lies within a given tolerance). - - In theory, there is no difference between the causal estimand and constraints when - it comes to this test - the constraints may be vector-valued but there is nothing - preventing the ``causal_estimand`` (programmatically) from being vector-valued - either. - """ - # Parametrised fixtures edit-in-place objects - args_to_setter = dict(args_to_setter) - if isinstance(args_to_setter["fn"], str): - args_to_setter["fn"] = request.getfixturevalue(args_to_setter["fn"]) - - if which == "constraints": - args_to_setter["constraints"] = args_to_setter.pop("fn") - else: - args_to_setter["sigma"] = args_to_setter.pop("fn") - - expected = jnp.array(expected, ndmin=1) - - # Test properly begins. - graph = two_normal_graph(cov=1.0) - cp = CausalProblem(graph) - - method = getattr(cp, which) - setter_method = getattr(cp, f"set_{which}") - - # Before setting the causal estimand, it should throw an error if called. - with raises_context( - NotImplementedError( - f"{which.replace('_', ' ').capitalize()} not set for CausalProblem." - ) - ): - method(cp.parameter_vector) - - cp.set_parameter_values(**initial_param_values) - setter_method(**args_to_setter) - result = jnp.array(method(cp.parameter_vector), ndmin=1) - - assert jnp.allclose(result, expected, atol=atol) diff --git a/tests/test_causal_problem/test_graph_and_param.py b/tests/test_causal_problem/test_graph_and_param.py deleted file mode 100644 index 96ea924..0000000 --- a/tests/test_causal_problem/test_graph_and_param.py +++ /dev/null @@ -1,42 +0,0 @@ -from collections.abc import Callable - -import jax.numpy as jnp - -from causalprog.causal_problem import CausalProblem -from causalprog.graph import Graph - - -def test_graph_and_parameter_interactions( - two_normal_graph: Callable[..., Graph], - raises_context, -) -> None: - cp = CausalProblem(label="TestCP") - - # Without a graph, we can't do anything - with raises_context(ValueError("No graph set for TestCP")): - cp.graph # noqa: B018 - with raises_context(ValueError("No graph set for TestCP")): - cp.parameter_values # noqa: B018 - - # Cannot set graph to non-graph value - with raises_context(TypeError("TestCP.graph must be a Graph instance")): - cp.graph = 1.0 - - # Provide an actual graph value - cp.graph = two_normal_graph(cov=1.0) - - # We should now be able to fetch parameter values, but they are all unset. - assert jnp.all(jnp.isnan(cp.parameter_vector)) - assert cp.parameter_vector.shape == (len(cp.graph.parameter_nodes),) - assert all(jnp.isnan(value) for value in cp.parameter_values.values()) - assert set(cp.parameter_values.keys()) == {"mean", "cov2"} - - # Users should only ever need to set parameter values via their names. - cp.set_parameter_values(mean=1.0, cov2=2.0) - assert cp.parameter_values == {"mean": 1.0, "cov2": 2.0} - # We don't know which way round the internal parameter vector is being stored, - # but that doesn't matter. We do know that it should contain the values 1 & 2 - # in some order though. - assert jnp.allclose(cp.parameter_vector, jnp.array([1.0, 2.0])) or jnp.allclose( - cp.parameter_vector, jnp.array([2.0, 1.0]) - ) From de904bff028424455729701a2a878aa803d9dad2 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 15 Aug 2025 14:57:22 +0100 Subject: [PATCH 03/33] Add barebones classes to be populated later --- src/causalprog/causal_problem/__init__.py | 0 .../causal_problem/causal_estimand.py | 45 ++++++++++++++ .../causal_problem/causal_problem.py | 58 +++++++++++++++++++ 3 files changed, 103 insertions(+) create mode 100644 src/causalprog/causal_problem/__init__.py create mode 100644 src/causalprog/causal_problem/causal_estimand.py create mode 100644 src/causalprog/causal_problem/causal_problem.py diff --git a/src/causalprog/causal_problem/__init__.py b/src/causalprog/causal_problem/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/causal_estimand.py new file mode 100644 index 0000000..05effd1 --- /dev/null +++ b/src/causalprog/causal_problem/causal_estimand.py @@ -0,0 +1,45 @@ +"""C.""" + +from collections.abc import Callable +from typing import Any, Concatenate, TypeAlias + +import numpy.typing as npt + +Model: TypeAlias = Callable[..., Any] +EffectHandler: TypeAlias = Callable[Concatenate[Model, ...], Model] +ModelMask: TypeAlias = tuple[EffectHandler, dict] + + +class _CPComponent: + """""" + + do_with_samples: Callable[..., npt.ArrayLike] + effect_handlers: tuple[ModelMask, ...] + + @property + def requires_model_adaption(self) -> bool: + """Return True if effect handlers need to be applied to model.""" + return len(self.effect_handlers) > 0 + + def __init__( + self, + *effect_handlers: ModelMask, + do_with_samples: Callable[..., npt.ArrayLike], + ): + self.effect_handlers = tuple(effect_handlers) + self.do_with_samples = do_with_samples + + def apply_effects(self, model: Model) -> Model: + """Apply any necessary effect handlers prior to evaluating.""" + adapted_model = model + for handler, handler_options in self.effect_handlers: + adapted_model = handler(adapted_model, **handler_options) + return adapted_model + + +class CausalEstimand(_CPComponent): + """""" + + +class Constraint(_CPComponent): + """""" diff --git a/src/causalprog/causal_problem/causal_problem.py b/src/causalprog/causal_problem/causal_problem.py new file mode 100644 index 0000000..ae44f7e --- /dev/null +++ b/src/causalprog/causal_problem/causal_problem.py @@ -0,0 +1,58 @@ +"""C.""" + +from collections.abc import Callable + +import jax +import numpy.typing as npt +from numpyro.infer import Predictive + +from causalprog.causal_problem.causal_estimand import CausalEstimand, Constraint, Model + + +def sample_model( + model: Predictive, rng_key: jax.Array, parameter_values: dict[str, npt.ArrayLike] +) -> dict[str, npt.ArrayLike]: + return jax.vmap(lambda pv, key: model(key, **pv), in_axes=(None, 0))( + parameter_values, + jax.random.split(rng_key, model.num_samples), + ) + + +class CausalProblem: + causal_estimand: CausalEstimand + constraints: list[Constraint] + + def __init__( + self, + *constraints: Constraint, + causal_estimand: CausalEstimand, + ): + self.causal_estimand = causal_estimand + self.constraints = list(constraints) + + def lagrangian( + self, n_samples: int = 1000 + ) -> Callable[[dict[str, npt.ArrayLike], Model, jax.Array], npt.ArrayLike]: + """Assemble the Lagrangian.""" + + def _inner( + parameter_values: dict[str, npt.ArrayLike], model: Model, rng_key: jax.Array + ) -> npt.ArrayLike: + # In general, we will need to check which of our CE/CONs require masking, + # and do multiple predictive models to account for this... + # We can always pre-build the predictive models too, so we should replace + # the "model" input with something that can map the right predictive models + # to the CE/CONS that need them. + predictive_model = Predictive(model=model, num_samples=n_samples) + all_samples = sample_model(predictive_model, rng_key, parameter_values) + + value = self.causal_estimand.do_with_samples(**all_samples) + # CLEANER IF THE LAGRANGE MULTIPLIERS COULD BE A SECOND FUNCTION ARG, + # as right now they have to be inside the parameter dict... + value += sum( + parameter_values[f"_l_mult{i}"] * c.do_with_samples(**all_samples) + for i, c in enumerate(self.constraints) + ) + return value + + return _inner From 32ccf0b38573037791f13ab8fead1bbe6532f725 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 15 Aug 2025 15:02:05 +0100 Subject: [PATCH 04/33] Update two normal example test to use new infrastructure --- .../test_two_normal_example.py | 82 +++++++------------ 1 file changed, 29 insertions(+), 53 deletions(-) diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index 1957a61..b5ebcf9 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -1,12 +1,13 @@ +import sys from collections.abc import Callable import jax import jax.numpy as jnp -import numpy.typing as npt import optax import pytest -from numpyro.infer import Predictive +from causalprog.causal_problem.causal_estimand import CausalEstimand, Constraint +from causalprog.causal_problem.causal_problem import CausalProblem from causalprog.graph import Graph @@ -71,55 +72,28 @@ def test_two_normal_example( (again with $+\mu_{ux}$ in the maximisation case). In both cases, $\mathcal{L}$ is minimised at $(\mu_{ux}^{*}, \nu_x, 1)$. """ - # Setup the optimisation problem from the graph - g = two_normal_graph(cov=1.0) - predictive_model = Predictive(g.model, num_samples=n_samples) - - def lagrangian( - parameter_values: dict[str, npt.ArrayLike], - predictive_model: Predictive, - rng_key: jax.Array, - *, - ce_prefactor: float, - ): - subkeys = jax.random.split(rng_key, predictive_model.num_samples) - l_mult = parameter_values["_l_mult"] - - def _ux_sampler(pv: dict[str, npt.ArrayLike], key: jax.Array) -> float: - return predictive_model(key, **pv)["UX"] - - def _x_sampler(pv: dict[str, npt.ArrayLike], key: jax.Array) -> float: - return predictive_model(key, **pv)["X"] - - def _ce(pv, subkeys): - return ( - ce_prefactor - * jax.vmap(_x_sampler, in_axes=(None, 0))(pv, subkeys).mean() - ) - - def _constraint(pv, subkeys): - return ( - jnp.abs( - jax.vmap(_ux_sampler, in_axes=(None, 0))(pv, subkeys).mean() - - phi_observed - ) - - epsilon - ) - - return _ce(parameter_values, subkeys) + l_mult * _constraint( - parameter_values, subkeys - ) - # In both cases, the Lagrange multiplier has the value 1.0 at the minimum. lambda_sol = 1.0 ce_prefactor = 1.0 if not is_solving_max else -1.0 mu_x_sol = phi_observed - ce_prefactor * epsilon + # Setup the optimisation problem from the graph + g = two_normal_graph(cov=1.0) + ce = CausalEstimand(do_with_samples=lambda **pv: ce_prefactor * pv["X"].mean()) + con = Constraint( + do_with_samples=lambda **pv: jnp.abs(pv["UX"].mean() - phi_observed) - epsilon + ) + cp = CausalProblem( + con, + causal_estimand=ce, + ) + lagrangian = cp.lagrangian(n_samples=n_samples) + # We'll be seeking stationary points of the Lagrangian, using the # naive approach of minimising the norm of its gradient. We will need to # ensure we "converge" to a minimum value suitably close to 0. - def objective(params, predictive, key, ce_prefactor=ce_prefactor): - v = jax.grad(lagrangian)(params, predictive, key, ce_prefactor=ce_prefactor) + def objective(params, predictive, key): + v = jax.grad(lagrangian)(params, predictive, key) return sum(value**2 for value in v.values()) # Choose a starting guess that is at the optimal solution, in the hopes that @@ -130,7 +104,7 @@ def objective(params, predictive, key, ce_prefactor=ce_prefactor): params = { "mean": mu_x_sol, "cov2": nu_x_starting_value, - "_l_mult": lambda_sol, + "_l_mult0": lambda_sol, } # Setup SGD optimiser optimiser = optax.adam(adams_learning_rate) @@ -139,17 +113,18 @@ def objective(params, predictive, key, ce_prefactor=ce_prefactor): converged = False for _ in range(maxiter): # Actual iteration loop - grads = jax.jacobian(objective)( - params, predictive_model, rng_key, ce_prefactor=ce_prefactor - ) + grads = jax.jacobian(objective)(params, g.model, rng_key) updates, opt_state = optimiser.update(grads, opt_state) params = optax.apply_updates(params, updates) # Convergence "check" and progress update - objective_value = objective( - params, predictive_model, rng_key, ce_prefactor=ce_prefactor + objective_value = objective(params, g.model, rng_key) + sys.stdout.write( + f"{_}, F_val={objective_value:.4e}, " + f"mu_ux={params['mean']:.4e}, " + f"nu_x={params['cov2']:.4e}, " + f"lambda={params['_l_mult0']:.4e}\n" ) - if jnp.abs(objective_value) <= minimisation_tolerance: converged = True break @@ -162,12 +137,13 @@ def objective(params, predictive, key, ce_prefactor=ce_prefactor): ) # Confirm that we found a minimiser that does satisfy the inequality constraints. - assert params["_l_mult"] > 0.0, ( - f"Converged, but not to a minimiser (lagrange multiplier = {params['_l_mult']})" + assert params["_l_mult0"] > 0.0, ( + "Converged, but not to a minimiser " + f"(lagrange multiplier = {params['_l_mult0']})" ) # Give a generous error margin in mu_ux and the Lagrange multiplier, # given SGD is being used on MC-integral functions. rtol = jnp.sqrt(1.0 / n_samples) assert jnp.isclose(params["mean"], mu_x_sol, rtol=rtol) - assert jnp.isclose(params["_l_mult"], lagrange_mult_sol, atol=rtol) + assert jnp.isclose(params["_l_mult0"], lagrange_mult_sol, atol=rtol) From c8fbaa5ad5764a94b9e1971fa25ae4ffd044e196 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 15 Aug 2025 15:28:42 +0100 Subject: [PATCH 05/33] Rework so that the lagrangian can be passed model parameters and the multiplier values as separate args --- .../causal_problem/causal_problem.py | 13 ++++++-- .../test_two_normal_example.py | 31 ++++++++++--------- 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/src/causalprog/causal_problem/causal_problem.py b/src/causalprog/causal_problem/causal_problem.py index ae44f7e..08b1ca9 100644 --- a/src/causalprog/causal_problem/causal_problem.py +++ b/src/causalprog/causal_problem/causal_problem.py @@ -32,11 +32,16 @@ def __init__( def lagrangian( self, n_samples: int = 1000 - ) -> Callable[[dict[str, npt.ArrayLike], Model, jax.Array], npt.ArrayLike]: + ) -> Callable[ + [dict[str, npt.ArrayLike], npt.ArrayLike, Model, jax.Array], npt.ArrayLike + ]: """Assemble the Lagrangian.""" def _inner( - parameter_values: dict[str, npt.ArrayLike], model: Model, rng_key: jax.Array + parameter_values: dict[str, npt.ArrayLike], + l_mult: jax.Array, + model: Model, + rng_key: jax.Array, ) -> npt.ArrayLike: # In general, we will need to check which of our CE/CONs require masking, # and do multiple predictive models to account for this... @@ -49,8 +54,10 @@ def _inner( value = self.causal_estimand.do_with_samples(**all_samples) # CLEANER IF THE LAGRANGE MULTIPLIERS COULD BE A SECOND FUNCTION ARG, # as right now they have to be inside the parameter dict... + # Cleaner if we could somehow build a vector-valued function of the + # constraints and then take a dot product, but this works for now value += sum( - parameter_values[f"_l_mult{i}"] * c.do_with_samples(**all_samples) + l_mult[i] * c.do_with_samples(**all_samples) for i, c in enumerate(self.constraints) ) return value diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index b5ebcf9..f13fe81 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -72,8 +72,6 @@ def test_two_normal_example( (again with $+\mu_{ux}$ in the maximisation case). In both cases, $\mathcal{L}$ is minimised at $(\mu_{ux}^{*}, \nu_x, 1)$. """ - # In both cases, the Lagrange multiplier has the value 1.0 at the minimum. - lambda_sol = 1.0 ce_prefactor = 1.0 if not is_solving_max else -1.0 mu_x_sol = phi_observed - ce_prefactor * epsilon @@ -92,9 +90,9 @@ def test_two_normal_example( # We'll be seeking stationary points of the Lagrangian, using the # naive approach of minimising the norm of its gradient. We will need to # ensure we "converge" to a minimum value suitably close to 0. - def objective(params, predictive, key): - v = jax.grad(lagrangian)(params, predictive, key) - return sum(value**2 for value in v.values()) + def objective(params, l_mult, predictive, key): + v = jax.grad(lagrangian, argnums=(0, 1))(params, l_mult, predictive, key) + return sum(value**2 for value in v[0].values()) + (v[1] ** 2).sum() # Choose a starting guess that is at the optimal solution, in the hopes that # SGD converges quickly. We almost certainly will not have this luxury in general. @@ -104,26 +102,30 @@ def objective(params, predictive, key): params = { "mean": mu_x_sol, "cov2": nu_x_starting_value, - "_l_mult0": lambda_sol, } + l_mult = jnp.atleast_1d(lagrange_mult_sol) + # Setup SGD optimiser optimiser = optax.adam(adams_learning_rate) - opt_state = optimiser.init(params) + opt_state = optimiser.init((params, l_mult)) + # Run optimisation loop on gradient of the Lagrangian converged = False for _ in range(maxiter): # Actual iteration loop - grads = jax.jacobian(objective)(params, g.model, rng_key) + grads = jax.jacobian(objective, argnums=(0, 1))( + params, l_mult, g.model, rng_key + ) updates, opt_state = optimiser.update(grads, opt_state) - params = optax.apply_updates(params, updates) + params, l_mult = optax.apply_updates((params, l_mult), updates) # Convergence "check" and progress update - objective_value = objective(params, g.model, rng_key) + objective_value = objective(params, l_mult, g.model, rng_key) sys.stdout.write( f"{_}, F_val={objective_value:.4e}, " f"mu_ux={params['mean']:.4e}, " f"nu_x={params['cov2']:.4e}, " - f"lambda={params['_l_mult0']:.4e}\n" + f"lambda={l_mult[0]:.4e}\n" ) if jnp.abs(objective_value) <= minimisation_tolerance: converged = True @@ -137,13 +139,12 @@ def objective(params, predictive, key): ) # Confirm that we found a minimiser that does satisfy the inequality constraints. - assert params["_l_mult0"] > 0.0, ( - "Converged, but not to a minimiser " - f"(lagrange multiplier = {params['_l_mult0']})" + assert jnp.all(l_mult > 0.0), ( + f"Converged, but not to a minimiser (lagrange multiplier = {l_mult})" ) # Give a generous error margin in mu_ux and the Lagrange multiplier, # given SGD is being used on MC-integral functions. rtol = jnp.sqrt(1.0 / n_samples) assert jnp.isclose(params["mean"], mu_x_sol, rtol=rtol) - assert jnp.isclose(params["_l_mult0"], lagrange_mult_sol, atol=rtol) + assert jnp.allclose(l_mult, lagrange_mult_sol, atol=rtol) From 87d128fba284bd3e13a66d3ad51d4f9b58533bce Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 15 Aug 2025 15:39:34 +0100 Subject: [PATCH 06/33] ruffing --- .../causal_problem/causal_estimand.py | 8 +-- .../causal_problem/causal_problem.py | 52 +++++++++++++++++-- 2 files changed, 52 insertions(+), 8 deletions(-) diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/causal_estimand.py index 05effd1..8ff50a4 100644 --- a/src/causalprog/causal_problem/causal_estimand.py +++ b/src/causalprog/causal_problem/causal_estimand.py @@ -11,7 +11,7 @@ class _CPComponent: - """""" + """Base.""" do_with_samples: Callable[..., npt.ArrayLike] effect_handlers: tuple[ModelMask, ...] @@ -25,7 +25,7 @@ def __init__( self, *effect_handlers: ModelMask, do_with_samples: Callable[..., npt.ArrayLike], - ): + ) -> None: self.effect_handlers = tuple(effect_handlers) self.do_with_samples = do_with_samples @@ -38,8 +38,8 @@ def apply_effects(self, model: Model) -> Model: class CausalEstimand(_CPComponent): - """""" + """A Causal Estimand.""" class Constraint(_CPComponent): - """""" + """A constraint function.""" diff --git a/src/causalprog/causal_problem/causal_problem.py b/src/causalprog/causal_problem/causal_problem.py index 08b1ca9..10fdd50 100644 --- a/src/causalprog/causal_problem/causal_problem.py +++ b/src/causalprog/causal_problem/causal_problem.py @@ -12,6 +12,21 @@ def sample_model( model: Predictive, rng_key: jax.Array, parameter_values: dict[str, npt.ArrayLike] ) -> dict[str, npt.ArrayLike]: + """ + Draw samples from the predictive model. + + TODO: Move function to somewhere more appropriate. + + Args: + model: Predictive model to draw samples from. + rng_key: PRNG Key to use in pseudorandom number generation. + parameter_values: Model parameter values to substitute. + + Returns: + `dict` of samples, with RV labels as keys and sample values (`jax.Array`s) as + values. + + """ return jax.vmap(lambda pv, key: model(key, **pv), in_axes=(None, 0))( parameter_values, jax.random.split(rng_key, model.num_samples), @@ -19,6 +34,8 @@ def sample_model( class CausalProblem: + """Defines a causal problem.""" + causal_estimand: CausalEstimand constraints: list[Constraint] @@ -26,16 +43,41 @@ def __init__( self, *constraints: Constraint, causal_estimand: CausalEstimand, - ): + ) -> None: + """Create a new causal problem.""" self.causal_estimand = causal_estimand self.constraints = list(constraints) def lagrangian( - self, n_samples: int = 1000 + self, n_samples: int = 1000, *, maximum_problem: bool = False ) -> Callable[ [dict[str, npt.ArrayLike], npt.ArrayLike, Model, jax.Array], npt.ArrayLike ]: - """Assemble the Lagrangian.""" + """ + Return a function that evaluates the Lagrangian of this `CausalProblem`. + + Following the + [KKT theorem](https://en.wikipedia.org/wiki/Karush%E2%80%93Kuhn%E2%80%93Tucker_conditions), + given the causal estimand and the constraints we can assemble a Lagrangian and + seek its stationary points, to in turn identify minimisers of the constrained + optimisation problem that we started with. + + The Lagrangian returned is a mathematical function of its first two arguments. + The first argument is the same dictionary of parameters that is passed to models + like `Graph.model`, and is the values the parameters (represented by the + `ParameterNode`s) are taking. The second argument is a 1D vector of Lagrange + multipliers, whose length is equal to the number of constraints. + + The remaining two arguments of the Lagrangian are the underlying model that it + should use to generate and draw samples of the RVs from, and the PRNGKey that + should be used in this generation. + + Note that our current implementation assumes there are no equality constraints + being imposed (in which case, we would need a 3-argument Lagrangian function). + + TODO: Can we store g.model in the class, and have this still work??? + """ + maximisation_prefactor = -1.0 if maximum_problem else 1.0 def _inner( parameter_values: dict[str, npt.ArrayLike], @@ -51,7 +93,9 @@ def _inner( predictive_model = Predictive(model=model, num_samples=n_samples) all_samples = sample_model(predictive_model, rng_key, parameter_values) - value = self.causal_estimand.do_with_samples(**all_samples) + value = maximisation_prefactor * self.causal_estimand.do_with_samples( + **all_samples + ) # CLEANER IF THE LAGRANGE MULTIPLIERS COULD BE A SECOND FUNCTION ARG, # as right now they have to be inside the parameter dict... # Cleaner if we could somehow build a vector-valued function of the From 7c171c401333b0842c25ab5e6ba20dac0681a285 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 15 Aug 2025 15:47:46 +0100 Subject: [PATCH 07/33] Refactor out g.model argument from the Lagrangian call --- .../causal_problem/causal_problem.py | 34 +++++++++++++------ .../test_two_normal_example.py | 16 ++++----- 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/src/causalprog/causal_problem/causal_problem.py b/src/causalprog/causal_problem/causal_problem.py index 10fdd50..e41824b 100644 --- a/src/causalprog/causal_problem/causal_problem.py +++ b/src/causalprog/causal_problem/causal_problem.py @@ -6,7 +6,8 @@ import numpy.typing as npt from numpyro.infer import Predictive -from causalprog.causal_problem.causal_estimand import CausalEstimand, Constraint, Model +from causalprog.causal_problem.causal_estimand import CausalEstimand, Constraint +from causalprog.graph import Graph def sample_model( @@ -36,23 +37,26 @@ def sample_model( class CausalProblem: """Defines a causal problem.""" + # NB: A CausalProblem could just BE a graph!!!! + # But separation of concerns and all... + _underlying_graph: Graph causal_estimand: CausalEstimand constraints: list[Constraint] def __init__( self, + graph: Graph, *constraints: Constraint, causal_estimand: CausalEstimand, ) -> None: """Create a new causal problem.""" + self._underlying_graph = graph self.causal_estimand = causal_estimand self.constraints = list(constraints) def lagrangian( self, n_samples: int = 1000, *, maximum_problem: bool = False - ) -> Callable[ - [dict[str, npt.ArrayLike], npt.ArrayLike, Model, jax.Array], npt.ArrayLike - ]: + ) -> Callable[[dict[str, npt.ArrayLike], npt.ArrayLike, jax.Array], npt.ArrayLike]: """ Return a function that evaluates the Lagrangian of this `CausalProblem`. @@ -68,21 +72,29 @@ def lagrangian( `ParameterNode`s) are taking. The second argument is a 1D vector of Lagrange multipliers, whose length is equal to the number of constraints. - The remaining two arguments of the Lagrangian are the underlying model that it - should use to generate and draw samples of the RVs from, and the PRNGKey that - should be used in this generation. + The remaining argument of the Lagrangian is the PRNGKey that should be used + when drawing samples. Note that our current implementation assumes there are no equality constraints being imposed (in which case, we would need a 3-argument Lagrangian function). - TODO: Can we store g.model in the class, and have this still work??? + Args: + n_samples: The number of random samples to be drawn when estimating the + value of functions of the RVs. + maximum_problem: If passed as `True`, assemble the Lagrangian for the + maximisation problem. Otherwise assemble that for the minimisation + problem (default behaviour). + + Returns: + The Lagrangian, as a function of the model parameters, Lagrange multipliers, + and PRNG key. + """ maximisation_prefactor = -1.0 if maximum_problem else 1.0 def _inner( parameter_values: dict[str, npt.ArrayLike], l_mult: jax.Array, - model: Model, rng_key: jax.Array, ) -> npt.ArrayLike: # In general, we will need to check which of our CE/CONs require masking, @@ -90,7 +102,9 @@ def _inner( # We can always pre-build the predictive models too, so we should replace # the "model" input with something that can map the right predictive models # to the CE/CONS that need them. - predictive_model = Predictive(model=model, num_samples=n_samples) + predictive_model = Predictive( + model=self._underlying_graph.model, num_samples=n_samples + ) all_samples = sample_model(predictive_model, rng_key, parameter_values) value = maximisation_prefactor * self.causal_estimand.do_with_samples( diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index f13fe81..bbae1b8 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -76,22 +76,22 @@ def test_two_normal_example( mu_x_sol = phi_observed - ce_prefactor * epsilon # Setup the optimisation problem from the graph - g = two_normal_graph(cov=1.0) - ce = CausalEstimand(do_with_samples=lambda **pv: ce_prefactor * pv["X"].mean()) + ce = CausalEstimand(do_with_samples=lambda **pv: pv["X"].mean()) con = Constraint( do_with_samples=lambda **pv: jnp.abs(pv["UX"].mean() - phi_observed) - epsilon ) cp = CausalProblem( + two_normal_graph(cov=1.0), con, causal_estimand=ce, ) - lagrangian = cp.lagrangian(n_samples=n_samples) + lagrangian = cp.lagrangian(n_samples=n_samples, maximum_problem=is_solving_max) # We'll be seeking stationary points of the Lagrangian, using the # naive approach of minimising the norm of its gradient. We will need to # ensure we "converge" to a minimum value suitably close to 0. - def objective(params, l_mult, predictive, key): - v = jax.grad(lagrangian, argnums=(0, 1))(params, l_mult, predictive, key) + def objective(params, l_mult, key): + v = jax.grad(lagrangian, argnums=(0, 1))(params, l_mult, key) return sum(value**2 for value in v[0].values()) + (v[1] ** 2).sum() # Choose a starting guess that is at the optimal solution, in the hopes that @@ -113,14 +113,12 @@ def objective(params, l_mult, predictive, key): converged = False for _ in range(maxiter): # Actual iteration loop - grads = jax.jacobian(objective, argnums=(0, 1))( - params, l_mult, g.model, rng_key - ) + grads = jax.jacobian(objective, argnums=(0, 1))(params, l_mult, rng_key) updates, opt_state = optimiser.update(grads, opt_state) params, l_mult = optax.apply_updates((params, l_mult), updates) # Convergence "check" and progress update - objective_value = objective(params, l_mult, g.model, rng_key) + objective_value = objective(params, l_mult, rng_key) sys.stdout.write( f"{_}, F_val={objective_value:.4e}, " f"mu_ux={params['mean']:.4e}, " From 33721938ca41b928cfebcc2d991765c7598a5454 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 15 Aug 2025 15:50:36 +0100 Subject: [PATCH 08/33] Make TODOs obvious so I don't forget to do them --- src/causalprog/causal_problem/causal_estimand.py | 3 +++ src/causalprog/causal_problem/causal_problem.py | 5 ++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/causal_estimand.py index 8ff50a4..02179a0 100644 --- a/src/causalprog/causal_problem/causal_estimand.py +++ b/src/causalprog/causal_problem/causal_estimand.py @@ -9,6 +9,9 @@ EffectHandler: TypeAlias = Callable[Concatenate[Model, ...], Model] ModelMask: TypeAlias = tuple[EffectHandler, dict] +# TODO: Docstrings, split into multiple files, +# auto-assembly of constraints wrt the tolerance? That kind of thing. + class _CPComponent: """Base.""" diff --git a/src/causalprog/causal_problem/causal_problem.py b/src/causalprog/causal_problem/causal_problem.py index e41824b..7677126 100644 --- a/src/causalprog/causal_problem/causal_problem.py +++ b/src/causalprog/causal_problem/causal_problem.py @@ -102,6 +102,7 @@ def _inner( # We can always pre-build the predictive models too, so we should replace # the "model" input with something that can map the right predictive models # to the CE/CONS that need them. + # TODO: Address pre-handlers that may apply from CEs/CONstraints predictive_model = Predictive( model=self._underlying_graph.model, num_samples=n_samples ) @@ -110,9 +111,7 @@ def _inner( value = maximisation_prefactor * self.causal_estimand.do_with_samples( **all_samples ) - # CLEANER IF THE LAGRANGE MULTIPLIERS COULD BE A SECOND FUNCTION ARG, - # as right now they have to be inside the parameter dict... - # Cleaner if we could somehow build a vector-valued function of the + # TODO: Cleaner if we could somehow build a vector-valued function of the # constraints and then take a dot product, but this works for now value += sum( l_mult[i] * c.do_with_samples(**all_samples) From 87fc5a2ec81c035a0cdf68c76d8c1dfaa7e69871 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Mon, 18 Aug 2025 09:39:27 +0100 Subject: [PATCH 09/33] Add docstrings and more TODOs --- src/causalprog/causal_problem/__init__.py | 1 + .../causal_problem/causal_estimand.py | 55 ++++++++++++++++++- .../causal_problem/causal_problem.py | 2 +- 3 files changed, 54 insertions(+), 4 deletions(-) diff --git a/src/causalprog/causal_problem/__init__.py b/src/causalprog/causal_problem/__init__.py index e69de29..3d45fe4 100644 --- a/src/causalprog/causal_problem/__init__.py +++ b/src/causalprog/causal_problem/__init__.py @@ -0,0 +1 @@ +"""Classes for defining causal problems.""" diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/causal_estimand.py index 02179a0..e7a72fa 100644 --- a/src/causalprog/causal_problem/causal_estimand.py +++ b/src/causalprog/causal_problem/causal_estimand.py @@ -14,7 +14,22 @@ class _CPComponent: - """Base.""" + """ + Base class for components of a Causal Problem. + + A _CPComponent has an attached method that it can apply to samples + (`do_with_samples`), which will be passed sample values of the RVs + during solution of a Causal Problem and used to evaluate the causal + estimand or constraint the instance represents. + + It also has a sequence of effect handlers that need to be applied + to the sampling model before samples can be drawn to evaluate this + component. For example, if a component requires conditioning on the + value of a RV, the `condition` handler needs to be applied to the + underlying model, before generating samples to pass to the + `do_with_sample` method. `effect_handlers` will be applied to the model + in the order they are given. + """ do_with_samples: Callable[..., npt.ArrayLike] effect_handlers: tuple[ModelMask, ...] @@ -40,9 +55,43 @@ def apply_effects(self, model: Model) -> Model: return adapted_model +# TODO: Turn CausalEstimand and Constraint into callables that evaluate do_with_samples + + class CausalEstimand(_CPComponent): - """A Causal Estimand.""" + """ + A Causal Estimand. + + The causal estimand is the function that we want to minimise (and maximise) + as part of a causal problem. It should be a scalar-valued function of the + random variables appearing in a graph. + """ class Constraint(_CPComponent): - """A constraint function.""" + r""" + A Constraint that forms part of a causal problem. + + Constraints of a causal problem are derived properties of RVs for which we + have observed data. The causal estimand is minimised (or maximised) subject + to the predicted values of the constraints being close to their observed + values in the data. + + Adding a constraint $g(\theta)$ to a causal problem (where $\theta$ are the + parameters of the causal problem) essentially imposes an additional + constraint on the minimisation problem; + + $$ g(\theta) - g_{\text{data}} \leq \epsilon, $$ + + where $g_{\text{data}}$ is the observed data value for the quantity $g$, + and $\epsilon$ is some tolerance. + """ + + # TODO: Should explain that Constraint needs more inputs and slightly different + # interpretation of the `do_with_samples` object. + # Inputs: + # - include epsilon as an input (allows constraints to have different tolerances) + # - `do_with_samples` should just be $g(\theta)$. Then have the instance build the + # full constraint that will need to be called in the Lagrangian. + # - $g$ still needs to be scalar valued? Allow a wrapper function to be applied in + # the event $g$ is vector-valued. diff --git a/src/causalprog/causal_problem/causal_problem.py b/src/causalprog/causal_problem/causal_problem.py index 7677126..7b3780f 100644 --- a/src/causalprog/causal_problem/causal_problem.py +++ b/src/causalprog/causal_problem/causal_problem.py @@ -72,7 +72,7 @@ def lagrangian( `ParameterNode`s) are taking. The second argument is a 1D vector of Lagrange multipliers, whose length is equal to the number of constraints. - The remaining argument of the Lagrangian is the PRNGKey that should be used + The remaining argument of the Lagrangian is the PRNG Key that should be used when drawing samples. Note that our current implementation assumes there are no equality constraints From d202568b2ff58a05b50e01092f76e10d034b0b11 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Mon, 18 Aug 2025 10:13:55 +0100 Subject: [PATCH 10/33] Todo resolution and addition --- src/causalprog/causal_problem/causal_estimand.py | 3 --- src/causalprog/causal_problem/causal_problem.py | 8 ++++---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/causal_estimand.py index e7a72fa..bb9635c 100644 --- a/src/causalprog/causal_problem/causal_estimand.py +++ b/src/causalprog/causal_problem/causal_estimand.py @@ -9,9 +9,6 @@ EffectHandler: TypeAlias = Callable[Concatenate[Model, ...], Model] ModelMask: TypeAlias = tuple[EffectHandler, dict] -# TODO: Docstrings, split into multiple files, -# auto-assembly of constraints wrt the tolerance? That kind of thing. - class _CPComponent: """ diff --git a/src/causalprog/causal_problem/causal_problem.py b/src/causalprog/causal_problem/causal_problem.py index 7b3780f..ffd8c16 100644 --- a/src/causalprog/causal_problem/causal_problem.py +++ b/src/causalprog/causal_problem/causal_problem.py @@ -10,14 +10,13 @@ from causalprog.graph import Graph +# TODO: Move somewhere more appropriate def sample_model( model: Predictive, rng_key: jax.Array, parameter_values: dict[str, npt.ArrayLike] ) -> dict[str, npt.ArrayLike]: """ Draw samples from the predictive model. - TODO: Move function to somewhere more appropriate. - Args: model: Predictive model to draw samples from. rng_key: PRNG Key to use in pseudorandom number generation. @@ -37,8 +36,6 @@ def sample_model( class CausalProblem: """Defines a causal problem.""" - # NB: A CausalProblem could just BE a graph!!!! - # But separation of concerns and all... _underlying_graph: Graph causal_estimand: CausalEstimand constraints: list[Constraint] @@ -108,6 +105,9 @@ def _inner( ) all_samples = sample_model(predictive_model, rng_key, parameter_values) + # TODO: would be cleaner if causal_estimand (and constraint) were just + # directly callable. This would also let us hide do_with_samples to avoid + # runtime edits... value = maximisation_prefactor * self.causal_estimand.do_with_samples( **all_samples ) From 7f3231945bed604cfb80b18c6be4ee8d74a1f4d5 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Mon, 18 Aug 2025 10:46:16 +0100 Subject: [PATCH 11/33] Make _CPConstraint callable --- src/causalprog/causal_problem/causal_estimand.py | 13 +++++++++++++ src/causalprog/causal_problem/causal_problem.py | 7 ++----- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/causal_estimand.py index bb9635c..acdbc54 100644 --- a/src/causalprog/causal_problem/causal_estimand.py +++ b/src/causalprog/causal_problem/causal_estimand.py @@ -36,6 +36,19 @@ def requires_model_adaption(self) -> bool: """Return True if effect handlers need to be applied to model.""" return len(self.effect_handlers) > 0 + def __call__(self, samples: dict[str, npt.ArrayLike]) -> npt.ArrayLike: + """ + Evaluate the estimand or constraint, given sample values. + + Args: + samples: Mapping of RV (node) labels to samples of that RV. + + Returns: + Value of the estimand or constraint, given the samples. + + """ + return self.do_with_samples(**samples) + def __init__( self, *effect_handlers: ModelMask, diff --git a/src/causalprog/causal_problem/causal_problem.py b/src/causalprog/causal_problem/causal_problem.py index ffd8c16..0664ced 100644 --- a/src/causalprog/causal_problem/causal_problem.py +++ b/src/causalprog/causal_problem/causal_problem.py @@ -108,14 +108,11 @@ def _inner( # TODO: would be cleaner if causal_estimand (and constraint) were just # directly callable. This would also let us hide do_with_samples to avoid # runtime edits... - value = maximisation_prefactor * self.causal_estimand.do_with_samples( - **all_samples - ) + value = maximisation_prefactor * self.causal_estimand(all_samples) # TODO: Cleaner if we could somehow build a vector-valued function of the # constraints and then take a dot product, but this works for now value += sum( - l_mult[i] * c.do_with_samples(**all_samples) - for i, c in enumerate(self.constraints) + l_mult[i] * c(all_samples) for i, c in enumerate(self.constraints) ) return value From 3189f2d4123d155741e20c33151d93946ce97d0a Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Mon, 18 Aug 2025 10:47:36 +0100 Subject: [PATCH 12/33] Hide _CPComponent attributes that we don't expect to change --- .../causal_problem/causal_estimand.py | 17 +++++++---------- src/causalprog/causal_problem/causal_problem.py | 3 --- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/causal_estimand.py index acdbc54..708235f 100644 --- a/src/causalprog/causal_problem/causal_estimand.py +++ b/src/causalprog/causal_problem/causal_estimand.py @@ -28,13 +28,13 @@ class _CPComponent: in the order they are given. """ - do_with_samples: Callable[..., npt.ArrayLike] - effect_handlers: tuple[ModelMask, ...] + _do_with_samples: Callable[..., npt.ArrayLike] + _effect_handlers: tuple[ModelMask, ...] @property def requires_model_adaption(self) -> bool: """Return True if effect handlers need to be applied to model.""" - return len(self.effect_handlers) > 0 + return len(self._effect_handlers) > 0 def __call__(self, samples: dict[str, npt.ArrayLike]) -> npt.ArrayLike: """ @@ -47,27 +47,24 @@ def __call__(self, samples: dict[str, npt.ArrayLike]) -> npt.ArrayLike: Value of the estimand or constraint, given the samples. """ - return self.do_with_samples(**samples) + return self._do_with_samples(**samples) def __init__( self, *effect_handlers: ModelMask, do_with_samples: Callable[..., npt.ArrayLike], ) -> None: - self.effect_handlers = tuple(effect_handlers) - self.do_with_samples = do_with_samples + self._effect_handlers = tuple(effect_handlers) + self._do_with_samples = do_with_samples def apply_effects(self, model: Model) -> Model: """Apply any necessary effect handlers prior to evaluating.""" adapted_model = model - for handler, handler_options in self.effect_handlers: + for handler, handler_options in self._effect_handlers: adapted_model = handler(adapted_model, **handler_options) return adapted_model -# TODO: Turn CausalEstimand and Constraint into callables that evaluate do_with_samples - - class CausalEstimand(_CPComponent): """ A Causal Estimand. diff --git a/src/causalprog/causal_problem/causal_problem.py b/src/causalprog/causal_problem/causal_problem.py index 0664ced..c42cdc8 100644 --- a/src/causalprog/causal_problem/causal_problem.py +++ b/src/causalprog/causal_problem/causal_problem.py @@ -105,9 +105,6 @@ def _inner( ) all_samples = sample_model(predictive_model, rng_key, parameter_values) - # TODO: would be cleaner if causal_estimand (and constraint) were just - # directly callable. This would also let us hide do_with_samples to avoid - # runtime edits... value = maximisation_prefactor * self.causal_estimand(all_samples) # TODO: Cleaner if we could somehow build a vector-valued function of the # constraints and then take a dot product, but this works for now From c2dab0cdd399c5a269c4740d42828e6e05781b79 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Mon, 18 Aug 2025 10:59:03 +0100 Subject: [PATCH 13/33] Test __call__ for _CPComponents --- .../causal_problem/causal_estimand.py | 1 + tests/test_causal_problem/test_cpcomponent.py | 52 +++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 tests/test_causal_problem/test_cpcomponent.py diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/causal_estimand.py index 708235f..c199203 100644 --- a/src/causalprog/causal_problem/causal_estimand.py +++ b/src/causalprog/causal_problem/causal_estimand.py @@ -102,3 +102,4 @@ class Constraint(_CPComponent): # full constraint that will need to be called in the Lagrangian. # - $g$ still needs to be scalar valued? Allow a wrapper function to be applied in # the event $g$ is vector-valued. + # If we do this, will also need to override __call__... diff --git a/tests/test_causal_problem/test_cpcomponent.py b/tests/test_causal_problem/test_cpcomponent.py new file mode 100644 index 0000000..b833ba3 --- /dev/null +++ b/tests/test_causal_problem/test_cpcomponent.py @@ -0,0 +1,52 @@ +from collections.abc import Callable + +import jax.numpy as jnp +import numpy.typing as npt +import pytest + +from causalprog.causal_problem.causal_estimand import _CPComponent + + +@pytest.mark.parametrize( + ("expression", "samples", "expect_error"), + [ + pytest.param( + lambda **pv: jnp.atleast_1d(0.0), {}, None, id="Constant expression" + ), + pytest.param( + lambda **pv: jnp.atleast_1d(0.0), + {"not_needed": jnp.atleast_1d(0.0)}, + None, + id="Un-needed samples", + ), + pytest.param( + lambda **pv: pv["a"], + {"a": jnp.atleast_1d(1.0)}, + None, + id="All needed samples given", + ), + pytest.param( + lambda **pv: pv["b"], + {"a": jnp.atleast_1d(1.0)}, + KeyError("b"), + id="Missing sample", + ), + ], +) +def test_call( + expression: Callable, + samples: dict[str, npt.ArrayLike], + expect_error: Exception | None, + raises_context, +) -> None: + """Check that _CPComponent correctly calls its _do_with_samples attribute.""" + + component = _CPComponent(do_with_samples=expression) + + assert callable(component) + + if expect_error: + with raises_context(expect_error): + component(samples) + else: + assert jnp.allclose(component(samples), expression(**samples)) From 648ff9d1e58988f09b88c938abd1d85560ee4d8f Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Mon, 18 Aug 2025 11:19:43 +0100 Subject: [PATCH 14/33] Add note about __call__ in docstring --- src/causalprog/causal_problem/causal_estimand.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/causal_estimand.py index c199203..3a9546f 100644 --- a/src/causalprog/causal_problem/causal_estimand.py +++ b/src/causalprog/causal_problem/causal_estimand.py @@ -26,6 +26,10 @@ class _CPComponent: underlying model, before generating samples to pass to the `do_with_sample` method. `effect_handlers` will be applied to the model in the order they are given. + + `_CPComponent`s are callable, taking the sample values (as a mapping + from strings to arrays) as arguments and returning the value of the + component. """ _do_with_samples: Callable[..., npt.ArrayLike] From 77eb893f72daf73d7e34925a233ea263e9220a6a Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Mon, 18 Aug 2025 15:04:31 +0100 Subject: [PATCH 15/33] Fix bug in how handlers are applied --- src/causalprog/causal_problem/causal_estimand.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/causal_estimand.py index 5821696..460c776 100644 --- a/src/causalprog/causal_problem/causal_estimand.py +++ b/src/causalprog/causal_problem/causal_estimand.py @@ -61,7 +61,7 @@ def apply_effects(self, model: Model) -> Model: """Apply any necessary effect handlers prior to evaluating.""" adapted_model = model for handler, handler_options in self._effect_handlers: - adapted_model = handler(adapted_model, **handler_options) + adapted_model = handler(adapted_model, handler_options) return adapted_model From cb7d92382c3fda51604fd46860dfd25dcc95ac92 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Mon, 18 Aug 2025 15:04:40 +0100 Subject: [PATCH 16/33] Write tests for features --- tests/test_causal_problem/test_cpcomponent.py | 134 +++++++++++++++++- 1 file changed, 133 insertions(+), 1 deletion(-) diff --git a/tests/test_causal_problem/test_cpcomponent.py b/tests/test_causal_problem/test_cpcomponent.py index b833ba3..a331d10 100644 --- a/tests/test_causal_problem/test_cpcomponent.py +++ b/tests/test_causal_problem/test_cpcomponent.py @@ -3,8 +3,10 @@ import jax.numpy as jnp import numpy.typing as npt import pytest +from numpyro.handlers import condition, do -from causalprog.causal_problem.causal_estimand import _CPComponent +from causalprog.causal_problem.causal_estimand import Model, ModelMask, _CPComponent +from causalprog.graph import Graph @pytest.mark.parametrize( @@ -50,3 +52,133 @@ def test_call( component(samples) else: assert jnp.allclose(component(samples), expression(**samples)) + + +@pytest.fixture +def conditioned_on_x_1( + two_normal_graph: Callable[..., Graph], +) -> Callable[..., Callable[..., None]]: + """ + Only intended for use in test_apply_handlers. + + Builds the model expected when we condition on X=1. + """ + + def _inner(**two_normal_graph_options: float) -> Callable[..., None]: + return condition( + two_normal_graph(**two_normal_graph_options).model, + {"X": jnp.atleast_1d(1.0)}, + ) + + return _inner + + +@pytest.fixture +def double_condition( + two_normal_graph: Callable[..., Graph], +) -> Callable[..., Callable[..., None]]: + """ + Only intended for use in test_apply_handlers. + + Builds the model expected when we condition on UX=-10, then again on + UX=10 (which should override the first action). + """ + + def _inner(**two_normal_graph_options: float) -> Callable[..., None]: + return condition( + condition( + two_normal_graph(**two_normal_graph_options).model, + {"UX": jnp.atleast_1d(-10.0)}, + ), + {"UX": jnp.atleast_1d(10.0)}, + ) + + return _inner + + +@pytest.fixture +def condition_then_do( + two_normal_graph: Callable[..., Graph], +) -> Callable[..., Callable[..., None]]: + """ + Only intended for use in test_apply_handlers. + + Builds the model expected when we first condition on UX=0, and then + apply do(X = 10). When sampling, we should still draw samples from + X as per a N(UX, 1.0). + """ + + def _inner(**two_normal_graph_options: float) -> Callable[..., None]: + return do( + condition( + two_normal_graph(**two_normal_graph_options).model, + {"UX": jnp.atleast_1d(0.0)}, + ), + {"X": jnp.atleast_1d(10.0)}, + ) + + return _inner + + +@pytest.mark.parametrize( + ("handlers", "expected_model"), + [ + pytest.param( + ((condition, {"X": jnp.atleast_1d(1.0)}),), + "conditioned_on_x_1", + id="Condition X to 1", + ), + # Should condition on UX=-10, then OVERRIDE this with UX=10. + pytest.param( + ( + (condition, {"UX": jnp.atleast_1d(-10.0)}), + (condition, {"UX": jnp.atleast_1d(10.0)}), + ), + "double_condition", + id="Condition twice on same variable", + ), + # Condition UX=0, but then do X=10. + # Should still observe samples of X given by N(0, 1). + pytest.param( + ( + (condition, {"UX": jnp.atleast_1d(0.0)}), + (do, {"X": jnp.atleast_1d(10.0)}), + ), + "condition_then_do", + id="Condition then do", + ), + ], +) +def test_apply_handlers( + handlers: tuple[ModelMask], + expected_model: Model, + two_normal_graph: Callable[..., Graph], + request: pytest.FixtureRequest, + assert_samples_are_identical, + run_default_nuts_mcmc, + two_normal_graph_params: dict[str, float] | None = None, + do_with_samples: Callable[..., npt.ArrayLike] = lambda **pv: pv["X"].mean(), +) -> None: + """ + Test that model handlers are correctly applied to graphs. + + Note that the order of the handlers is important, as it dictates + which effects are applied first. + """ + if two_normal_graph_params is None: + two_normal_graph_params = {"mean": 0.0, "cov": 1.0, "cov2": 1.0} + if isinstance(expected_model, str): + expected_model = request.getfixturevalue(expected_model)( + **two_normal_graph_params + ) + + g = two_normal_graph(**two_normal_graph_params) + + cp = _CPComponent(*handlers, do_with_samples=do_with_samples) + + handled_model = cp.apply_effects(g.model) + + handled_mcmc = run_default_nuts_mcmc(handled_model) + expected_mcmc = run_default_nuts_mcmc(expected_model) + + assert_samples_are_identical(handled_mcmc, expected_mcmc) From 05a3d5b3b42e9b35342d11065c173a02ec019b86 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Tue, 19 Aug 2025 09:33:42 +0100 Subject: [PATCH 17/33] Edit Constraint so it is created in pieces --- .../causal_problem/causal_estimand.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/causal_estimand.py index 460c776..14a6913 100644 --- a/src/causalprog/causal_problem/causal_estimand.py +++ b/src/causalprog/causal_problem/causal_estimand.py @@ -3,6 +3,7 @@ from collections.abc import Callable from typing import Any, Concatenate, TypeAlias +import jax.numpy as jnp import numpy.typing as npt Model: TypeAlias = Callable[..., Any] @@ -94,6 +95,53 @@ class Constraint(_CPComponent): and $\epsilon$ is some tolerance. """ + data: npt.ArrayLike + tolerance: npt.ArrayLike + _outer_norm: Callable[npt.ArrayLike, float] + + def __init__( + self, + *effect_handlers: ModelMask, + model_quantity: Callable[..., npt.ArrayLike], + outer_norm: Callable[npt.ArrayLike, float] | None = None, + data: npt.ArrayLike = 0.0, + ): + super().__init__(*effect_handlers, do_with_samples=model_quantity) + + if outer_norm is not None: + self._outer_norm = outer_norm + else: + self._outer_norm = jnp.linalg.vector_norm + + self.data = data + + def __call__(self, samples: dict[str, npt.ArrayLike]) -> npt.ArrayLike: + r""" + Evaluate the constraint, given RV samples. + + Constraints are evaluated as + + $$ \mathrm{norm}\left( g(\theta) - g_{\mathrm{data}} \right) - \epsilon $$ + + where; + - $\mathrm{norm}$ is the outer norm of the constraint (`self._outer_norm`), + - $g(\theta)$ is the model quantity involved in the constraint + (`self._do_with_samples`), + - $\epsilon$ is the tolerance in the data (`self.tolerance`), + - $g_{\mathrm{data}}$ is the observed data (`self.data`). + + Args: + samples: Mapping of RV (node) labels to drawn samples. + + Returns: + Value of the constraint. + + """ + return ( + self._outer_norm(self._do_with_samples(**samples) - self.data) + - self.tolerance + ) + # TODO: (https://github.com/UCL/causalprog/issues/89) # Should explain that Constraint needs more inputs and slightly different # interpretation of the `do_with_samples` object. From f87b2fa2856a102323f2ba1517266c85e3238e95 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Tue, 19 Aug 2025 10:50:02 +0100 Subject: [PATCH 18/33] Rework Constraint.__init__ and docstring to match new format --- .../causal_problem/causal_estimand.py | 53 +++++++++++++------ 1 file changed, 38 insertions(+), 15 deletions(-) diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/causal_estimand.py index 14a6913..cc37c25 100644 --- a/src/causalprog/causal_problem/causal_estimand.py +++ b/src/causalprog/causal_problem/causal_estimand.py @@ -97,15 +97,49 @@ class Constraint(_CPComponent): data: npt.ArrayLike tolerance: npt.ArrayLike - _outer_norm: Callable[npt.ArrayLike, float] + _outer_norm: Callable[[npt.ArrayLike], float] def __init__( self, *effect_handlers: ModelMask, model_quantity: Callable[..., npt.ArrayLike], - outer_norm: Callable[npt.ArrayLike, float] | None = None, + outer_norm: Callable[[npt.ArrayLike], float] | None = None, data: npt.ArrayLike = 0.0, - ): + ) -> None: + r""" + Create a new constraint. + + Constraints have the form + + $$ c(\theta) := + \mathrm{norm}\left( g(\theta) + - g_{\mathrm{data}} \right) + - \epsilon $$ + + where; + - $\mathrm{norm}$ is the outer norm of the constraint (`outer_norm`), + - $g(\theta)$ is the model quantity involved in the constraint + (`model_quantity`), + - $g_{\mathrm{data}}$ is the observed data (`data`), + - $\epsilon$ is the tolerance in the data (`tolerance`). + + In a causal problem, each constraint appears as the condition $c(\theta)\leq 0$ + in the minimisation / maximisation (hence the inclusion of the $-\epsilon$ + term within $c(\theta)$ itself). + + $g$ should be a (possibly vector-valued) function that acts on (a subset of) + samples from the random variables of the causal problem. It must accept + variable keyword-arguments only, and should access the samples for each random + variable by indexing via the RV names (node labels). It should return the + model quantity as computed from the samples, that $g_{\mathrm{data}}$ observed. + + $g_{\mathrm{data}}$ should be a fixed value whose shape is broadcast-able with + the return shape of $g$. It defaults to $0$ if not explicitly set. + + $\mathrm{norm}$ should be a suitable norm to take on the difference between the + model quantity as predicted by the samples ($g$) and the observed data + ($g_{\mathrm{data}}$). It must return a scalar value. The default is the 2-norm. + """ super().__init__(*effect_handlers, do_with_samples=model_quantity) if outer_norm is not None: @@ -116,20 +150,9 @@ def __init__( self.data = data def __call__(self, samples: dict[str, npt.ArrayLike]) -> npt.ArrayLike: - r""" + """ Evaluate the constraint, given RV samples. - Constraints are evaluated as - - $$ \mathrm{norm}\left( g(\theta) - g_{\mathrm{data}} \right) - \epsilon $$ - - where; - - $\mathrm{norm}$ is the outer norm of the constraint (`self._outer_norm`), - - $g(\theta)$ is the model quantity involved in the constraint - (`self._do_with_samples`), - - $\epsilon$ is the tolerance in the data (`self.tolerance`), - - $g_{\mathrm{data}}$ is the observed data (`self.data`). - Args: samples: Mapping of RV (node) labels to drawn samples. From de2f6235f3b8fbaf41db848e1524f9da192972ba Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Tue, 19 Aug 2025 11:08:10 +0100 Subject: [PATCH 19/33] Update two_normal_example integration test --- src/causalprog/causal_problem/causal_estimand.py | 2 ++ tests/test_integration/test_two_normal_example.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/causal_estimand.py index cc37c25..5d20fa5 100644 --- a/src/causalprog/causal_problem/causal_estimand.py +++ b/src/causalprog/causal_problem/causal_estimand.py @@ -105,6 +105,7 @@ def __init__( model_quantity: Callable[..., npt.ArrayLike], outer_norm: Callable[[npt.ArrayLike], float] | None = None, data: npt.ArrayLike = 0.0, + tolerance: float = 1.0e-6, ) -> None: r""" Create a new constraint. @@ -148,6 +149,7 @@ def __init__( self._outer_norm = jnp.linalg.vector_norm self.data = data + self.tolerance = tolerance def __call__(self, samples: dict[str, npt.ArrayLike]) -> npt.ArrayLike: """ diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index bbae1b8..381915d 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -78,7 +78,9 @@ def test_two_normal_example( # Setup the optimisation problem from the graph ce = CausalEstimand(do_with_samples=lambda **pv: pv["X"].mean()) con = Constraint( - do_with_samples=lambda **pv: jnp.abs(pv["UX"].mean() - phi_observed) - epsilon + model_quantity=lambda **pv: pv["UX"].mean(), + data=phi_observed, + tolerance=epsilon, ) cp = CausalProblem( two_normal_graph(cov=1.0), From 28f07727f6718fc1410b4ce9c7ad4de900e65b28 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Tue, 19 Aug 2025 11:11:57 +0100 Subject: [PATCH 20/33] Remove todo note --- src/causalprog/causal_problem/causal_estimand.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/causal_estimand.py index 5d20fa5..239b927 100644 --- a/src/causalprog/causal_problem/causal_estimand.py +++ b/src/causalprog/causal_problem/causal_estimand.py @@ -166,14 +166,3 @@ def __call__(self, samples: dict[str, npt.ArrayLike]) -> npt.ArrayLike: self._outer_norm(self._do_with_samples(**samples) - self.data) - self.tolerance ) - - # TODO: (https://github.com/UCL/causalprog/issues/89) - # Should explain that Constraint needs more inputs and slightly different - # interpretation of the `do_with_samples` object. - # Inputs: - # - include epsilon as an input (allows constraints to have different tolerances) - # - `do_with_samples` should just be $g(\theta)$. Then have the instance build the - # full constraint that will need to be called in the Lagrangian. - # - $g$ still needs to be scalar valued? Allow a wrapper function to be applied in - # the event $g$ is vector-valued. - # If we do this, will also need to override __call__... From 3366e94d879c54d8d9f65222a31b60a3bdd163ec Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Tue, 19 Aug 2025 12:33:16 +0100 Subject: [PATCH 21/33] Create wrapper class to make handlers easier. __eq__ placeholder for now --- .../causal_problem/causal_estimand.py | 44 ++++++++++++++++--- .../causal_problem/causal_problem.py | 10 +++++ tests/test_causal_problem/test_cpcomponent.py | 8 +++- 3 files changed, 54 insertions(+), 8 deletions(-) diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/causal_estimand.py index 460c776..c4a66f1 100644 --- a/src/causalprog/causal_problem/causal_estimand.py +++ b/src/causalprog/causal_problem/causal_estimand.py @@ -1,13 +1,42 @@ """Classes for defining causal estimands and constraints of causal problems.""" from collections.abc import Callable +from dataclasses import dataclass, field from typing import Any, Concatenate, TypeAlias import numpy.typing as npt Model: TypeAlias = Callable[..., Any] EffectHandler: TypeAlias = Callable[Concatenate[Model, ...], Model] -ModelMask: TypeAlias = tuple[EffectHandler, dict] + + +@dataclass +class HandlerToApply: + """ """ + + handler: EffectHandler + options: dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_pair(cls, pair: tuple[EffectHandler, dict]) -> "HandlerToApply": + """TODO: make pair just any lenght-2 container, and auto-identify which itme is the options and which item is the callable""" + return cls(handler=pair[0], options=pair[1]) + + def __post_init__(self) -> None: + if not callable(self.handler): + msg = f"{self.handler} is not callable!" + raise TypeError(msg) + if not isinstance(self.options, dict): + msg = f"{self.options} should be keyword-argument mapping." + raise TypeError(msg) + + def __eq__(self, other: object) -> bool: + """ """ + return ( + isinstance(other, HandlerToApply) + and self.handler is other.handler + and self.options == other.options + ) class _CPComponent: @@ -29,7 +58,7 @@ class _CPComponent: """ do_with_samples: Callable[..., npt.ArrayLike] - effect_handlers: tuple[ModelMask, ...] + effect_handlers: tuple[HandlerToApply, ...] @property def requires_model_adaption(self) -> bool: @@ -51,17 +80,20 @@ def __call__(self, samples: dict[str, npt.ArrayLike]) -> npt.ArrayLike: def __init__( self, - *effect_handlers: ModelMask, + *effect_handlers: HandlerToApply | tuple[EffectHandler, dict[str, Any]], do_with_samples: Callable[..., npt.ArrayLike], ) -> None: - self._effect_handlers = tuple(effect_handlers) + self._effect_handlers = tuple( + h if isinstance(h, HandlerToApply) else HandlerToApply.from_pair(h) + for h in effect_handlers + ) self._do_with_samples = do_with_samples def apply_effects(self, model: Model) -> Model: """Apply any necessary effect handlers prior to evaluating.""" adapted_model = model - for handler, handler_options in self._effect_handlers: - adapted_model = handler(adapted_model, handler_options) + for handler in self._effect_handlers: + adapted_model = handler.handler(adapted_model, handler.options) return adapted_model diff --git a/src/causalprog/causal_problem/causal_problem.py b/src/causalprog/causal_problem/causal_problem.py index b0fd20e..c08587c 100644 --- a/src/causalprog/causal_problem/causal_problem.py +++ b/src/causalprog/causal_problem/causal_problem.py @@ -89,6 +89,16 @@ def lagrangian( """ maximisation_prefactor = -1.0 if maximum_problem else 1.0 + # Build association between self.constraints and the model-samples that each + # one needs to use. + # self._associate_constraints_to_models() + # + # self._models : store of Predictives that have had the apply_effects done to + # them + # self._constraint_index_to_model_index : maps index of item in self.constraints + # to corresponding model in self._models + # These could just be objects too, don't need to be class attributes. + def _inner( parameter_values: dict[str, npt.ArrayLike], l_mult: jax.Array, diff --git a/tests/test_causal_problem/test_cpcomponent.py b/tests/test_causal_problem/test_cpcomponent.py index a331d10..1fe91fe 100644 --- a/tests/test_causal_problem/test_cpcomponent.py +++ b/tests/test_causal_problem/test_cpcomponent.py @@ -5,7 +5,11 @@ import pytest from numpyro.handlers import condition, do -from causalprog.causal_problem.causal_estimand import Model, ModelMask, _CPComponent +from causalprog.causal_problem.causal_estimand import ( + HandlerToApply, + Model, + _CPComponent, +) from causalprog.graph import Graph @@ -150,7 +154,7 @@ def _inner(**two_normal_graph_options: float) -> Callable[..., None]: ], ) def test_apply_handlers( - handlers: tuple[ModelMask], + handlers: tuple[HandlerToApply], expected_model: Model, two_normal_graph: Callable[..., Graph], request: pytest.FixtureRequest, From fe7361460982bd7ed86519bb4cfc6917beccc266 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Tue, 19 Aug 2025 13:43:32 +0100 Subject: [PATCH 22/33] Tidy eq docstring --- .../causal_problem/causal_estimand.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/causal_estimand.py index c4a66f1..db6e170 100644 --- a/src/causalprog/causal_problem/causal_estimand.py +++ b/src/causalprog/causal_problem/causal_estimand.py @@ -12,14 +12,18 @@ @dataclass class HandlerToApply: - """ """ + """Specifies a handler than needs to be applied to a model at runtime.""" handler: EffectHandler options: dict[str, Any] = field(default_factory=dict) @classmethod def from_pair(cls, pair: tuple[EffectHandler, dict]) -> "HandlerToApply": - """TODO: make pair just any lenght-2 container, and auto-identify which itme is the options and which item is the callable""" + """ + TODO: make pair just any lenght-2 container. + + and auto-identify which time is the options and which item is the callable + """ return cls(handler=pair[0], options=pair[1]) def __post_init__(self) -> None: @@ -31,7 +35,14 @@ def __post_init__(self) -> None: raise TypeError(msg) def __eq__(self, other: object) -> bool: - """ """ + """ + Equality operation. + + `HandlerToApply`s are considered equal if they use the same handler function and + provide the same options to this function. + + Comparison to other types returns `False`. + """ return ( isinstance(other, HandlerToApply) and self.handler is other.handler From 1e396e228bad42e373bac5f96420c7795c194429 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Tue, 19 Aug 2025 13:50:53 +0100 Subject: [PATCH 23/33] Write necessary comparison methods --- .../causal_problem/causal_estimand.py | 20 +++++++++++++++++++ .../test_compare_handlers.py | 3 +++ 2 files changed, 23 insertions(+) create mode 100644 tests/test_causal_problem/test_compare_handlers.py diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/causal_estimand.py index db6e170..cd64a10 100644 --- a/src/causalprog/causal_problem/causal_estimand.py +++ b/src/causalprog/causal_problem/causal_estimand.py @@ -107,6 +107,26 @@ def apply_effects(self, model: Model) -> Model: adapted_model = handler.handler(adapted_model, handler.options) return adapted_model + def can_use_same_model_as(self, other: "_CPComponent") -> bool: + """ + Determine if two components use the same (predictive) model. + + Two components rely on the same model if they apply the same handlers + to the model, which occurs if and only if `self.effect_handlers` and + `other.effect_handlers` contain identical entries, in the same order. + """ + if (not isinstance(other, _CPComponent)) or ( + len(self.effect_handlers) != len(other.effect_handlers) + ): + return False + + return all( + my_handler == their_handler + for my_handler, their_handler in zip( + self.effect_handlers, other.effect_handlers, strict=True + ) + ) + class CausalEstimand(_CPComponent): """ diff --git a/tests/test_causal_problem/test_compare_handlers.py b/tests/test_causal_problem/test_compare_handlers.py new file mode 100644 index 0000000..667e29a --- /dev/null +++ b/tests/test_causal_problem/test_compare_handlers.py @@ -0,0 +1,3 @@ +# Write test for _CPComponent.can_use_same_model + +# Write tests for HandlerToApply class From 440873b6c05d05b08ff8b3822de952e83dc4b327 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 20 Aug 2025 09:07:54 +0100 Subject: [PATCH 24/33] Write model association method and implement in Lagrangian --- .../causal_problem/causal_problem.py | 118 +++++++++++++++--- 1 file changed, 98 insertions(+), 20 deletions(-) diff --git a/src/causalprog/causal_problem/causal_problem.py b/src/causalprog/causal_problem/causal_problem.py index c08587c..b73e1b7 100644 --- a/src/causalprog/causal_problem/causal_problem.py +++ b/src/causalprog/causal_problem/causal_problem.py @@ -6,7 +6,11 @@ import numpy.typing as npt from numpyro.infer import Predictive -from causalprog.causal_problem.causal_estimand import CausalEstimand, Constraint +from causalprog.causal_problem.causal_estimand import ( + CausalEstimand, + Constraint, + _CPComponent, +) from causalprog.graph import Graph @@ -40,6 +44,11 @@ class CausalProblem: causal_estimand: CausalEstimand constraints: list[Constraint] + @property + def _ordered_components(self) -> list[_CPComponent]: + """Internal ordering for components of the causal problem.""" + return [*self.constraints, self.causal_estimand] + def __init__( self, graph: Graph, @@ -51,6 +60,83 @@ def __init__( self.causal_estimand = causal_estimand self.constraints = list(constraints) + def _associate_models_to_components( + self, n_samples: int + ) -> tuple[list[Predictive], list[int]]: + """ + Create models to be used by components of the problem. + + Depending on how many constraints (and the causal estimand) require effect + handlers to wrap `self._underlying_graph.model`, we will need to create several + predictive models to sample from. However, we also want to minimise the number + of such models we have to make, in order to minimise the time we spend + actually computing samples. + + As such, in this method we determine: + - How many models we will need to build, by grouping the constraints and the + causal estimand by the handlers they use. + - Build these models, returning them in a list called `models`. + - Build another list that maps the index of components in + `self._ordered_components` to the index of the model in `models` that they + use. The causal estimand is by convention the component at index -1 of this + returned list. + + Args: + n_samples: Value to be passed to `numpyro.Predictive`'s `num_samples` + argument for each of the models that are constructed from the underlying + graph. + + Returns: + list[Predictive]: List of Predictive models, whose elements contain all the + models needed by the components. + list[int]: Mapping of component indexes (as per `self_ordered_components`) + to the index of the model in the first return argument that the + component uses. + + """ + models: list[Predictive] = [] + grouped_component_indexes: list[list[int]] = [] + for index, component in enumerate(self._ordered_components): + # Determine if this constraint uses the same handlers as those of any of + # the other sets. + belongs_to_existing_group = False + for group in grouped_component_indexes: + # Pull any element from the group to compare models to. + # Items in a group are known to have the same model, so we can just + # pull out the first one. + group_element = self._ordered_components[group[0]] + # Check if the current constraint can also use this model. + if component.can_use_same_model_as(group_element): + group.append(index) + belongs_to_existing_group = True + break + + # If the component does not fit into any existing group, create a new + # group for it. And add the model corresponding to the group to the + # list of models. + if not belongs_to_existing_group: + grouped_component_indexes.append([index]) + + models.append( + Predictive( + component.apply_effects(self._underlying_graph.model), + num_samples=n_samples, + ) + ) + + # Now "invert" the grouping, creating a mapping that maps the index of a + # component to the (index of the) model it uses. + component_index_to_model_index = [] + for index in range(len(self._ordered_components)): + for group_index, group in enumerate(grouped_component_indexes): + if index in group: + component_index_to_model_index.append(group_index) + break + # All indexes should belong to at least one group (worst case scenario, + # their own individual group). Thus, it is safe to do the above to create + # the mapping from component index -> model (group) index. + return models, component_index_to_model_index + def lagrangian( self, n_samples: int = 1000, *, maximum_problem: bool = False ) -> Callable[[dict[str, npt.ArrayLike], npt.ArrayLike, jax.Array], npt.ArrayLike]: @@ -90,35 +176,27 @@ def lagrangian( maximisation_prefactor = -1.0 if maximum_problem else 1.0 # Build association between self.constraints and the model-samples that each - # one needs to use. - # self._associate_constraints_to_models() - # - # self._models : store of Predictives that have had the apply_effects done to - # them - # self._constraint_index_to_model_index : maps index of item in self.constraints - # to corresponding model in self._models - # These could just be objects too, don't need to be class attributes. + # one needs to use. We do this here, since once it is constructed, it is + # FIXED, and doesn't need to be done each time we call the Lagrangian. + models, component_to_index_mapping = self._associate_models_to_components( + n_samples + ) def _inner( parameter_values: dict[str, npt.ArrayLike], l_mult: jax.Array, rng_key: jax.Array, ) -> npt.ArrayLike: - # In general, we will need to check which of our CE/CONs require masking, - # and do multiple predictive models to account for this... - # We can always pre-build the predictive models too, so we should replace - # the "model" input with something that can map the right predictive models - # to the CE/CONS that need them. - # TODO: https://github.com/UCL/causalprog/issues/90 - predictive_model = Predictive( - model=self._underlying_graph.model, num_samples=n_samples + # Draw samples from all models + all_samples = tuple( + sample_model(model, rng_key, parameter_values) for model in models ) - all_samples = sample_model(predictive_model, rng_key, parameter_values) - value = maximisation_prefactor * self.causal_estimand(all_samples) + value = maximisation_prefactor * self.causal_estimand(all_samples[-1]) # TODO: https://github.com/UCL/causalprog/issues/87 value += sum( - l_mult[i] * c(all_samples) for i, c in enumerate(self.constraints) + l_mult[i] * c(all_samples[component_to_index_mapping[i]]) + for i, c in enumerate(self.constraints) ) return value From 12d2b71fd5219d95a4b379f832fb26f700fd2a20 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 20 Aug 2025 09:27:41 +0100 Subject: [PATCH 25/33] Docstirngs and breakout HandlerToApply class to submodule --- .../causal_problem/causal_estimand.py | 46 +---------- src/causalprog/causal_problem/handlers.py | 82 +++++++++++++++++++ 2 files changed, 84 insertions(+), 44 deletions(-) create mode 100644 src/causalprog/causal_problem/handlers.py diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/causal_estimand.py index cd64a10..80b37d2 100644 --- a/src/causalprog/causal_problem/causal_estimand.py +++ b/src/causalprog/causal_problem/causal_estimand.py @@ -1,53 +1,11 @@ """Classes for defining causal estimands and constraints of causal problems.""" from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any, Concatenate, TypeAlias +from typing import Any import numpy.typing as npt -Model: TypeAlias = Callable[..., Any] -EffectHandler: TypeAlias = Callable[Concatenate[Model, ...], Model] - - -@dataclass -class HandlerToApply: - """Specifies a handler than needs to be applied to a model at runtime.""" - - handler: EffectHandler - options: dict[str, Any] = field(default_factory=dict) - - @classmethod - def from_pair(cls, pair: tuple[EffectHandler, dict]) -> "HandlerToApply": - """ - TODO: make pair just any lenght-2 container. - - and auto-identify which time is the options and which item is the callable - """ - return cls(handler=pair[0], options=pair[1]) - - def __post_init__(self) -> None: - if not callable(self.handler): - msg = f"{self.handler} is not callable!" - raise TypeError(msg) - if not isinstance(self.options, dict): - msg = f"{self.options} should be keyword-argument mapping." - raise TypeError(msg) - - def __eq__(self, other: object) -> bool: - """ - Equality operation. - - `HandlerToApply`s are considered equal if they use the same handler function and - provide the same options to this function. - - Comparison to other types returns `False`. - """ - return ( - isinstance(other, HandlerToApply) - and self.handler is other.handler - and self.options == other.options - ) +from causalprog.causal_problem.handlers import EffectHandler, HandlerToApply, Model class _CPComponent: diff --git a/src/causalprog/causal_problem/handlers.py b/src/causalprog/causal_problem/handlers.py new file mode 100644 index 0000000..41ce206 --- /dev/null +++ b/src/causalprog/causal_problem/handlers.py @@ -0,0 +1,82 @@ +"""Container class for specifying effect handlers that need to be applied at runtime.""" + +from collections.abc import Callable, Sequence +from dataclasses import dataclass, field +from typing import Any, Concatenate, TypeAlias + +Model: TypeAlias = Callable[..., None] +EffectHandler: TypeAlias = Callable[Concatenate[Model, ...], Model] + + +@dataclass +class HandlerToApply: + """Specifies a handler that needs to be applied to a model at runtime.""" + + handler: EffectHandler + options: dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_pair(cls, pair: Sequence) -> "HandlerToApply": + """ + Create an instance from an effect handler and its options. + + The two objects should be passed in as the elements of a container of length + 2. They can be passed in any order; + - One element must be a dictionary, which will be interpreted as the `options` + for the effect handler. + - The other element must be callable, and will be interpreted as the `handler` + itself. + + Args: + pair: Container of two elements, one being the effect handler callable and + the other being the options to pass to it (as a dictionary). + + Returns: + Class instance corresponding to the effect handler and options passed. + + """ + if len(pair) != 2: # noqa: PLR2004 + msg = ( + f"{cls.__name__} can only be constructed from a container of 2 elements" + ) + raise ValueError(msg) + + # __post_init__ will catch cases when the incorrect types for one or both items + # is passed, so we can just naively if-else here. + if callable(pair[0]): + handler: EffectHandler = pair[0] + options: dict = pair[1] + else: + handler = pair[1] + options = pair[0] + + return cls(handler=handler, options=options) + + def __post_init__(self) -> None: + """ + Validate set attributes. + + - The handler is a callable object. + - The options have been passed as a dictionary of keyword-value pairs. + """ + if not callable(self.handler): + msg = f"{self.handler} is not callable!" + raise TypeError(msg) + if not isinstance(self.options, dict): + msg = f"{self.options} should be keyword-argument mapping." + raise TypeError(msg) + + def __eq__(self, other: object) -> bool: + """ + Equality operation. + + `HandlerToApply`s are considered equal if they use the same handler function and + provide the same options to this function. + + Comparison to other types returns `False`. + """ + return ( + isinstance(other, HandlerToApply) + and self.handler is other.handler + and self.options == other.options + ) From cc3d2a05ff872d45f966982eb31c390aba206263 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 20 Aug 2025 09:28:05 +0100 Subject: [PATCH 26/33] Docstirngs and breakout HandlerToApply class to submodule --- src/causalprog/causal_problem/handlers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/causalprog/causal_problem/handlers.py b/src/causalprog/causal_problem/handlers.py index 41ce206..4903c5d 100644 --- a/src/causalprog/causal_problem/handlers.py +++ b/src/causalprog/causal_problem/handlers.py @@ -43,9 +43,11 @@ def from_pair(cls, pair: Sequence) -> "HandlerToApply": # __post_init__ will catch cases when the incorrect types for one or both items # is passed, so we can just naively if-else here. + handler: EffectHandler + options: dict if callable(pair[0]): - handler: EffectHandler = pair[0] - options: dict = pair[1] + handler = pair[0] + options = pair[1] else: handler = pair[1] options = pair[0] From 65ab7815b0bfc4d3edf96302decd6d1d24703f4c Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 20 Aug 2025 09:54:10 +0100 Subject: [PATCH 27/33] Tests for HandlerToApply class --- .../causal_problem/causal_problem.py | 1 + src/causalprog/causal_problem/handlers.py | 7 +- .../test_compare_handlers.py | 3 - tests/test_causal_problem/test_handlers.py | 100 ++++++++++++++++++ 4 files changed, 106 insertions(+), 5 deletions(-) delete mode 100644 tests/test_causal_problem/test_compare_handlers.py create mode 100644 tests/test_causal_problem/test_handlers.py diff --git a/src/causalprog/causal_problem/causal_problem.py b/src/causalprog/causal_problem/causal_problem.py index b73e1b7..fe80ac2 100644 --- a/src/causalprog/causal_problem/causal_problem.py +++ b/src/causalprog/causal_problem/causal_problem.py @@ -6,6 +6,7 @@ import numpy.typing as npt from numpyro.infer import Predictive +# TODO: Rename module to components from causalprog.causal_problem.causal_estimand import ( CausalEstimand, Constraint, diff --git a/src/causalprog/causal_problem/handlers.py b/src/causalprog/causal_problem/handlers.py index 4903c5d..f9958ff 100644 --- a/src/causalprog/causal_problem/handlers.py +++ b/src/causalprog/causal_problem/handlers.py @@ -62,10 +62,13 @@ def __post_init__(self) -> None: - The options have been passed as a dictionary of keyword-value pairs. """ if not callable(self.handler): - msg = f"{self.handler} is not callable!" + msg = f"{type(self.handler).__name__} is not callable." raise TypeError(msg) if not isinstance(self.options, dict): - msg = f"{self.options} should be keyword-argument mapping." + msg = ( + "Options should be dictionary mapping option arguments to values " + f"(got {type(self.options).__name__})." + ) raise TypeError(msg) def __eq__(self, other: object) -> bool: diff --git a/tests/test_causal_problem/test_compare_handlers.py b/tests/test_causal_problem/test_compare_handlers.py deleted file mode 100644 index 667e29a..0000000 --- a/tests/test_causal_problem/test_compare_handlers.py +++ /dev/null @@ -1,3 +0,0 @@ -# Write test for _CPComponent.can_use_same_model - -# Write tests for HandlerToApply class diff --git a/tests/test_causal_problem/test_handlers.py b/tests/test_causal_problem/test_handlers.py new file mode 100644 index 0000000..8f428f4 --- /dev/null +++ b/tests/test_causal_problem/test_handlers.py @@ -0,0 +1,100 @@ +import pytest + +from causalprog.causal_problem.handlers import EffectHandler, HandlerToApply + + +def placeholder_callable() -> EffectHandler: + """Stand-in for an effect handler.""" + return lambda model, **kwargs: (lambda **pv: model(**kwargs)) + + +@pytest.mark.parametrize( + ( + "args", + "expected_error", + "use_classmethod", + ), + [ + pytest.param( + (placeholder_callable, {}), None, False, id="Standard construction" + ), + pytest.param((placeholder_callable, {}), None, True, id="Via from_pair"), + pytest.param( + ({}, placeholder_callable), None, True, id="Via from_pair (out of order)" + ), + pytest.param( + (placeholder_callable, []), + TypeError( + "Options should be dictionary mapping option arguments to values " + "(got list)." + ), + False, + id="Wrong options type", + ), + pytest.param( + (1.0, {}), + TypeError("float is not callable."), + False, + id="Handler is not callable", + ), + pytest.param( + (0, 0, 0), + ValueError( + "HandlerToApply can only be constructed from a container of 2 elements" + ), + True, + id="Tuple too long", + ), + ], +) +def test_handlertoapply_creation( + args: tuple[EffectHandler | dict, ...], + expected_error: Exception | None, + raises_context, + *, + use_classmethod: bool, +): + if isinstance(expected_error, Exception): + if use_classmethod: + with raises_context(expected_error): + HandlerToApply.from_pair(args) + else: + with raises_context(expected_error): + HandlerToApply(*args) + else: + handler = ( + HandlerToApply.from_pair(args) if use_classmethod else HandlerToApply(*args) + ) + + assert isinstance(handler.options, dict) + assert callable(handler.handler) + + +@pytest.mark.parametrize( + ("left", "right", "expected_result"), + [ + pytest.param( + HandlerToApply(placeholder_callable, {"option1": 1.0}), + HandlerToApply(placeholder_callable, {"option1": 1.0}), + True, + id="Identical", + ), + pytest.param( + HandlerToApply(lambda **pv: None, {"option1": 1.0}), + HandlerToApply(lambda **pv: None, {"option1": 1.0}), + False, + id="callables compared using IS", + ), + pytest.param( + HandlerToApply(placeholder_callable, {"option1": 1.0}), + HandlerToApply(placeholder_callable, {"option2": 1.0}), + False, + id="Options must match", + ), + ], +) +def test_handlertoapply_equality( + left: object, right: object, *, expected_result: bool +) -> None: + assert (left == right) == expected_result + assert (left == right) == (right == left) From 9335477b0b3a79b76e104cb1d0941bdff65b3d7e Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 20 Aug 2025 10:18:03 +0100 Subject: [PATCH 28/33] Tests for can_use_same_model --- .../causal_problem/causal_estimand.py | 6 +- .../test_cpcomponent_same_model.py | 107 ++++++++++++++++++ tests/test_causal_problem/test_handlers.py | 6 + 3 files changed, 116 insertions(+), 3 deletions(-) create mode 100644 tests/test_causal_problem/test_cpcomponent_same_model.py diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/causal_estimand.py index 80b37d2..8f4b1d1 100644 --- a/src/causalprog/causal_problem/causal_estimand.py +++ b/src/causalprog/causal_problem/causal_estimand.py @@ -32,7 +32,7 @@ class _CPComponent: @property def requires_model_adaption(self) -> bool: """Return True if effect handlers need to be applied to model.""" - return len(self._effect_handlers) > 0 + return len(self.effect_handlers) > 0 def __call__(self, samples: dict[str, npt.ArrayLike]) -> npt.ArrayLike: """ @@ -52,7 +52,7 @@ def __init__( *effect_handlers: HandlerToApply | tuple[EffectHandler, dict[str, Any]], do_with_samples: Callable[..., npt.ArrayLike], ) -> None: - self._effect_handlers = tuple( + self.effect_handlers = tuple( h if isinstance(h, HandlerToApply) else HandlerToApply.from_pair(h) for h in effect_handlers ) @@ -61,7 +61,7 @@ def __init__( def apply_effects(self, model: Model) -> Model: """Apply any necessary effect handlers prior to evaluating.""" adapted_model = model - for handler in self._effect_handlers: + for handler in self.effect_handlers: adapted_model = handler.handler(adapted_model, handler.options) return adapted_model diff --git a/tests/test_causal_problem/test_cpcomponent_same_model.py b/tests/test_causal_problem/test_cpcomponent_same_model.py new file mode 100644 index 0000000..e1c0af5 --- /dev/null +++ b/tests/test_causal_problem/test_cpcomponent_same_model.py @@ -0,0 +1,107 @@ +"""Tests for _CPComponent.can_use_same_model. + +- This method should be symmetric in its arguments. +- Returns False when one of the arguments is not a _CPComponent instance. +- CausalEstimands and Constraints should still be able to share models. +- Models can be shared IFF the same handlers, in the same order, are applied. +""" + +import pytest + +from causalprog.causal_problem.causal_estimand import ( + CausalEstimand, + Constraint, + _CPComponent, +) +from causalprog.causal_problem.handlers import HandlerToApply + + +# HandlerToApply compares the handler argument with IS, so we need to instantiate here. +def handler_a(**pv) -> None: + return + + +def handler_b(**pv) -> None: + return + + +@pytest.mark.parametrize( + ("component_1", "component_2", "expected_result"), + [ + pytest.param( + _CPComponent( + HandlerToApply(handler_a, {}), do_with_samples=lambda **pv: None + ), + _CPComponent( + HandlerToApply(handler_a, {}), do_with_samples=lambda **pv: None + ), + True, + id="Same model as self", + ), + pytest.param( + _CPComponent(do_with_samples=lambda **pv: None), + _CPComponent(do_with_samples=lambda **pv: None), + True, + id="No effect handlers case is handled", + ), + pytest.param( + _CPComponent( + HandlerToApply(handler_a, {}), do_with_samples=lambda **pv: 1.0 + ), + _CPComponent( + HandlerToApply(handler_a, {}), do_with_samples=lambda **pv: 2.0 + ), + True, + id="_do_with_samples does not affect model compatibility", + ), + pytest.param( + CausalEstimand( + HandlerToApply(handler_a, {}), do_with_samples=lambda **pv: None + ), + Constraint( + HandlerToApply(handler_a, {}), do_with_samples=lambda **pv: None + ), + True, + id="CausalEstimand and Constraints can share models", + ), + pytest.param( + _CPComponent( + HandlerToApply(handler_a, {"option": "a"}), + do_with_samples=lambda **pv: None, + ), + _CPComponent( + HandlerToApply(handler_a, {"option": "b"}), + do_with_samples=lambda **pv: None, + ), + False, + id="Different handlers deny same model", + ), + pytest.param( + _CPComponent( + HandlerToApply(handler_a, {"option": "a"}), + HandlerToApply(handler_a, {"option": "b"}), + do_with_samples=lambda **pv: None, + ), + _CPComponent( + HandlerToApply(handler_a, {"option": "b"}), + HandlerToApply(handler_a, {"option": "a"}), + do_with_samples=lambda **pv: None, + ), + False, + id="Different handler order denies same model", + ), + pytest.param( + _CPComponent(do_with_samples=lambda **pv: None), + 1.0, + False, + id="Compare to non-_CPComponent", + ), + ], +) +def test_can_use_same_model( + component_1: _CPComponent, component_2: _CPComponent, *, expected_result: bool +) -> None: + if isinstance(component_1, _CPComponent): + assert component_1.can_use_same_model_as(component_2) == expected_result + if isinstance(component_2, _CPComponent): + assert component_2.can_use_same_model_as(component_1) == expected_result diff --git a/tests/test_causal_problem/test_handlers.py b/tests/test_causal_problem/test_handlers.py index 8f428f4..399a63c 100644 --- a/tests/test_causal_problem/test_handlers.py +++ b/tests/test_causal_problem/test_handlers.py @@ -91,6 +91,12 @@ def test_handlertoapply_creation( False, id="Options must match", ), + pytest.param( + HandlerToApply(placeholder_callable, {"option1": 1.0}), + 1.0, + False, + id="Comparison to different object class", + ), ], ) def test_handlertoapply_equality( From 8a4483e79434ed16c72e0b20f7c371a55d0ff487 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 20 Aug 2025 11:09:07 +0100 Subject: [PATCH 29/33] Tests for associating models to components of the CP --- .../test_associate_models_to_components.py | 112 ++++++++++++++++++ 1 file changed, 112 insertions(+) create mode 100644 tests/test_causal_problem/test_associate_models_to_components.py diff --git a/tests/test_causal_problem/test_associate_models_to_components.py b/tests/test_causal_problem/test_associate_models_to_components.py new file mode 100644 index 0000000..0aa76a7 --- /dev/null +++ b/tests/test_causal_problem/test_associate_models_to_components.py @@ -0,0 +1,112 @@ +"""Test the association of models to the components of the CausalProblem. + +One model is needed fr each unique combination of handlers that the Constraints and +CausalEstimand possess. We can mimic this behaviour by defining a single handler, and +then passing different numbers of copies of this handler in to our Constraints and +CausalEstimand. Different numbers of handlers force different models, and thus we should +end up with one model for each unique number of copies that we use. + +Components are examined in the `_ordered_components` order, which goes through the +`constraints` list first, and then the CausalEstimand at the end. As such, models are +also "created" in this order. +""" + +from collections.abc import Callable + +import pytest + +from causalprog.causal_problem.causal_estimand import CausalEstimand, Constraint +from causalprog.causal_problem.causal_problem import CausalProblem +from causalprog.causal_problem.handlers import HandlerToApply +from causalprog.graph import Graph, ParameterNode + + +@pytest.fixture +def underlying_graph() -> Graph: + """The underlying graph is not actually important for checking model association, + so we just return a single node graph. + """ + g = Graph(label="Placeholder") + g.add_node(ParameterNode(label="p")) + return g + + +def placeholder_handler_fn(*args, **kwargs) -> None: + """IS comparison means this function does need to be statically defined. + + args[0] is the model input, so we just effectively return the same model. + """ + return args[0] + + +@pytest.fixture +def placeholder_handler() -> Callable[[], HandlerToApply]: + """Creates a HandlerToApply instance. + + We will use copies of the returned instance to "trick" the CasualProblem + class into creating additional models due to different numbers of handlers + being applied to its components. + """ + + def _inner() -> HandlerToApply: + return HandlerToApply(placeholder_handler_fn, {}) + + return _inner + + +@pytest.mark.parametrize( + ( + "handlers_to_give_to_constraints", + "handlers_to_give_estimand", + "expected_components_to_models_mapping", + ), + [ + # There are no constraints, so the CausalEstimand uses the only model created. + pytest.param([], 0, [0], id="No constraints"), + # All constraints and the causal estimand use the same model. + pytest.param([0] * 3, 0, [0] * 4, id="Same model used by all"), + # 1st constraint: Model w/ 1 handler created, taking model index 0. + # 2nd constraint: re-uses model index 0 (with 1 handler). + # CE: Model w/ 0 handlers created, taking model index 1. + pytest.param([1] * 2, 0, [0, 0, 1], id="CausalEstimand is always last model"), + # 1st constraint: Model w/ 1 handler created, taking model index 0. + # 2nd constraint: Model w/ 0 handlers created, taking model index 1. + # 3rd constraint: re-uses model index 0 (with 1 handler). + # 4th constraint: Model w/ 2 handlers created, taking model index 2. + # CE: re-uses model index 0 (with 1 handler). + pytest.param( + [1, 0, 1, 2], 1, [0, 1, 0, 2, 0], id="Models created in particular order" + ), + ], +) +def test_associate_models_to_components( + handlers_to_give_to_constraints: list[int], + handlers_to_give_estimand: int, + expected_components_to_models_mapping: list[int], + placeholder_handler: Callable[[], HandlerToApply], + underlying_graph: Graph, + n_samples: int = 1, +) -> None: + # The number of models is the number of 'unique numbers of handlers' given to the + # constraints and causal estimand. + expected_number_of_models = len( + set(handlers_to_give_to_constraints).union({handlers_to_give_estimand}) + ) + + constraints = [ + Constraint( + *(placeholder_handler() for _ in range(copies)), + do_with_samples=lambda **pv: None, + ) + for copies in handlers_to_give_to_constraints + ] + ce = CausalEstimand( + *(placeholder_handler() for _ in range(handlers_to_give_estimand)), + do_with_samples=lambda **pv: None, + ) + + cp = CausalProblem(underlying_graph, *constraints, causal_estimand=ce) + models, association = cp._associate_models_to_components(n_samples) # noqa: SLF001 + + assert len(models) == expected_number_of_models + assert association == expected_components_to_models_mapping From 821072e2d5ddf2efe21559cf0ff933984a1d7051 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 20 Aug 2025 11:26:07 +0100 Subject: [PATCH 30/33] Reorganise classes now that structure is somewhat rigid --- ...{causal_estimand.py => _base_component.py} | 43 +----------------- .../causal_problem/causal_problem.py | 5 +-- src/causalprog/causal_problem/components.py | 44 +++++++++++++++++++ .../test_associate_models_to_components.py | 2 +- tests/test_causal_problem/test_cpcomponent.py | 7 +-- .../test_cpcomponent_same_model.py | 2 +- .../test_two_normal_example.py | 2 +- 7 files changed, 52 insertions(+), 53 deletions(-) rename src/causalprog/causal_problem/{causal_estimand.py => _base_component.py} (64%) create mode 100644 src/causalprog/causal_problem/components.py diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/_base_component.py similarity index 64% rename from src/causalprog/causal_problem/causal_estimand.py rename to src/causalprog/causal_problem/_base_component.py index 8f4b1d1..4be1812 100644 --- a/src/causalprog/causal_problem/causal_estimand.py +++ b/src/causalprog/causal_problem/_base_component.py @@ -1,4 +1,4 @@ -"""Classes for defining causal estimands and constraints of causal problems.""" +"""Base class for components of causal problems.""" from collections.abc import Callable from typing import Any @@ -84,44 +84,3 @@ def can_use_same_model_as(self, other: "_CPComponent") -> bool: self.effect_handlers, other.effect_handlers, strict=True ) ) - - -class CausalEstimand(_CPComponent): - """ - A Causal Estimand. - - The causal estimand is the function that we want to minimise (and maximise) - as part of a causal problem. It should be a scalar-valued function of the - random variables appearing in a graph. - """ - - -class Constraint(_CPComponent): - r""" - A Constraint that forms part of a causal problem. - - Constraints of a causal problem are derived properties of RVs for which we - have observed data. The causal estimand is minimised (or maximised) subject - to the predicted values of the constraints being close to their observed - values in the data. - - Adding a constraint $g(\theta)$ to a causal problem (where $\theta$ are the - parameters of the causal problem) essentially imposes an additional - constraint on the minimisation problem; - - $$ g(\theta) - g_{\text{data}} \leq \epsilon, $$ - - where $g_{\text{data}}$ is the observed data value for the quantity $g$, - and $\epsilon$ is some tolerance. - """ - - # TODO: (https://github.com/UCL/causalprog/issues/89) - # Should explain that Constraint needs more inputs and slightly different - # interpretation of the `do_with_samples` object. - # Inputs: - # - include epsilon as an input (allows constraints to have different tolerances) - # - `do_with_samples` should just be $g(\theta)$. Then have the instance build the - # full constraint that will need to be called in the Lagrangian. - # - $g$ still needs to be scalar valued? Allow a wrapper function to be applied in - # the event $g$ is vector-valued. - # If we do this, will also need to override __call__... diff --git a/src/causalprog/causal_problem/causal_problem.py b/src/causalprog/causal_problem/causal_problem.py index fe80ac2..dc5e0e7 100644 --- a/src/causalprog/causal_problem/causal_problem.py +++ b/src/causalprog/causal_problem/causal_problem.py @@ -6,11 +6,10 @@ import numpy.typing as npt from numpyro.infer import Predictive -# TODO: Rename module to components -from causalprog.causal_problem.causal_estimand import ( +from causalprog.causal_problem._base_component import _CPComponent +from causalprog.causal_problem.components import ( CausalEstimand, Constraint, - _CPComponent, ) from causalprog.graph import Graph diff --git a/src/causalprog/causal_problem/components.py b/src/causalprog/causal_problem/components.py new file mode 100644 index 0000000..11a5d5c --- /dev/null +++ b/src/causalprog/causal_problem/components.py @@ -0,0 +1,44 @@ +"""Classes for defining causal estimands and constraints of causal problems.""" + +from causalprog.causal_problem._base_component import _CPComponent + + +class CausalEstimand(_CPComponent): + """ + A Causal Estimand. + + The causal estimand is the function that we want to minimise (and maximise) + as part of a causal problem. It should be a scalar-valued function of the + random variables appearing in a graph. + """ + + +class Constraint(_CPComponent): + r""" + A Constraint that forms part of a causal problem. + + Constraints of a causal problem are derived properties of RVs for which we + have observed data. The causal estimand is minimised (or maximised) subject + to the predicted values of the constraints being close to their observed + values in the data. + + Adding a constraint $g(\theta)$ to a causal problem (where $\theta$ are the + parameters of the causal problem) essentially imposes an additional + constraint on the minimisation problem; + + $$ g(\theta) - g_{\text{data}} \leq \epsilon, $$ + + where $g_{\text{data}}$ is the observed data value for the quantity $g$, + and $\epsilon$ is some tolerance. + """ + + # TODO: (https://github.com/UCL/causalprog/issues/89) + # Should explain that Constraint needs more inputs and slightly different + # interpretation of the `do_with_samples` object. + # Inputs: + # - include epsilon as an input (allows constraints to have different tolerances) + # - `do_with_samples` should just be $g(\theta)$. Then have the instance build the + # full constraint that will need to be called in the Lagrangian. + # - $g$ still needs to be scalar valued? Allow a wrapper function to be applied in + # the event $g$ is vector-valued. + # If we do this, will also need to override __call__... diff --git a/tests/test_causal_problem/test_associate_models_to_components.py b/tests/test_causal_problem/test_associate_models_to_components.py index 0aa76a7..448097a 100644 --- a/tests/test_causal_problem/test_associate_models_to_components.py +++ b/tests/test_causal_problem/test_associate_models_to_components.py @@ -15,8 +15,8 @@ import pytest -from causalprog.causal_problem.causal_estimand import CausalEstimand, Constraint from causalprog.causal_problem.causal_problem import CausalProblem +from causalprog.causal_problem.components import CausalEstimand, Constraint from causalprog.causal_problem.handlers import HandlerToApply from causalprog.graph import Graph, ParameterNode diff --git a/tests/test_causal_problem/test_cpcomponent.py b/tests/test_causal_problem/test_cpcomponent.py index 1fe91fe..df70e29 100644 --- a/tests/test_causal_problem/test_cpcomponent.py +++ b/tests/test_causal_problem/test_cpcomponent.py @@ -5,11 +5,8 @@ import pytest from numpyro.handlers import condition, do -from causalprog.causal_problem.causal_estimand import ( - HandlerToApply, - Model, - _CPComponent, -) +from causalprog.causal_problem.components import _CPComponent +from causalprog.causal_problem.handlers import HandlerToApply, Model from causalprog.graph import Graph diff --git a/tests/test_causal_problem/test_cpcomponent_same_model.py b/tests/test_causal_problem/test_cpcomponent_same_model.py index e1c0af5..810ee48 100644 --- a/tests/test_causal_problem/test_cpcomponent_same_model.py +++ b/tests/test_causal_problem/test_cpcomponent_same_model.py @@ -8,7 +8,7 @@ import pytest -from causalprog.causal_problem.causal_estimand import ( +from causalprog.causal_problem.components import ( CausalEstimand, Constraint, _CPComponent, diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index bbae1b8..9fbb0c3 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -6,8 +6,8 @@ import optax import pytest -from causalprog.causal_problem.causal_estimand import CausalEstimand, Constraint from causalprog.causal_problem.causal_problem import CausalProblem +from causalprog.causal_problem.components import CausalEstimand, Constraint from causalprog.graph import Graph From 2c359eddaec5a9761805646ceed00623360e47ad Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 20 Aug 2025 11:29:12 +0100 Subject: [PATCH 31/33] Add module-level import for user-facing functions --- src/causalprog/causal_problem/__init__.py | 6 ++++++ .../test_associate_models_to_components.py | 9 ++++++--- tests/test_causal_problem/test_cpcomponent.py | 4 ++-- tests/test_causal_problem/test_cpcomponent_same_model.py | 6 +++--- tests/test_causal_problem/test_handlers.py | 3 ++- 5 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/causalprog/causal_problem/__init__.py b/src/causalprog/causal_problem/__init__.py index 3d45fe4..62cdc3d 100644 --- a/src/causalprog/causal_problem/__init__.py +++ b/src/causalprog/causal_problem/__init__.py @@ -1 +1,7 @@ """Classes for defining causal problems.""" + +from .causal_problem import CausalProblem +from .components import CausalEstimand, Constraint +from .handlers import HandlerToApply + +__all__ = ("CausalEstimand", "CausalProblem", "Constraint", "HandlerToApply") diff --git a/tests/test_causal_problem/test_associate_models_to_components.py b/tests/test_causal_problem/test_associate_models_to_components.py index 448097a..28901eb 100644 --- a/tests/test_causal_problem/test_associate_models_to_components.py +++ b/tests/test_causal_problem/test_associate_models_to_components.py @@ -15,9 +15,12 @@ import pytest -from causalprog.causal_problem.causal_problem import CausalProblem -from causalprog.causal_problem.components import CausalEstimand, Constraint -from causalprog.causal_problem.handlers import HandlerToApply +from causalprog.causal_problem import ( + CausalEstimand, + CausalProblem, + Constraint, + HandlerToApply, +) from causalprog.graph import Graph, ParameterNode diff --git a/tests/test_causal_problem/test_cpcomponent.py b/tests/test_causal_problem/test_cpcomponent.py index df70e29..d5d46e6 100644 --- a/tests/test_causal_problem/test_cpcomponent.py +++ b/tests/test_causal_problem/test_cpcomponent.py @@ -5,8 +5,8 @@ import pytest from numpyro.handlers import condition, do -from causalprog.causal_problem.components import _CPComponent -from causalprog.causal_problem.handlers import HandlerToApply, Model +from causalprog.causal_problem import HandlerToApply +from causalprog.causal_problem._base_component import Model, _CPComponent from causalprog.graph import Graph diff --git a/tests/test_causal_problem/test_cpcomponent_same_model.py b/tests/test_causal_problem/test_cpcomponent_same_model.py index 810ee48..dc64e70 100644 --- a/tests/test_causal_problem/test_cpcomponent_same_model.py +++ b/tests/test_causal_problem/test_cpcomponent_same_model.py @@ -8,12 +8,12 @@ import pytest -from causalprog.causal_problem.components import ( +from causalprog.causal_problem import ( CausalEstimand, Constraint, - _CPComponent, + HandlerToApply, ) -from causalprog.causal_problem.handlers import HandlerToApply +from causalprog.causal_problem._base_component import _CPComponent # HandlerToApply compares the handler argument with IS, so we need to instantiate here. diff --git a/tests/test_causal_problem/test_handlers.py b/tests/test_causal_problem/test_handlers.py index 399a63c..3801201 100644 --- a/tests/test_causal_problem/test_handlers.py +++ b/tests/test_causal_problem/test_handlers.py @@ -1,6 +1,7 @@ import pytest -from causalprog.causal_problem.handlers import EffectHandler, HandlerToApply +from causalprog.causal_problem import HandlerToApply +from causalprog.causal_problem._base_component import EffectHandler def placeholder_callable() -> EffectHandler: From e7c7b268964d283e417d2a175405ea24d8f6403f Mon Sep 17 00:00:00 2001 From: Will Graham <32364977+willGraham01@users.noreply.github.com> Date: Wed, 3 Sep 2025 09:30:06 +0100 Subject: [PATCH 32/33] Update src/causalprog/causal_problem/causal_problem.py Co-authored-by: Matthew Scroggs --- src/causalprog/causal_problem/causal_problem.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/causalprog/causal_problem/causal_problem.py b/src/causalprog/causal_problem/causal_problem.py index dc5e0e7..2b5785a 100644 --- a/src/causalprog/causal_problem/causal_problem.py +++ b/src/causalprog/causal_problem/causal_problem.py @@ -177,7 +177,7 @@ def lagrangian( # Build association between self.constraints and the model-samples that each # one needs to use. We do this here, since once it is constructed, it is - # FIXED, and doesn't need to be done each time we call the Lagrangian. + # fixed, and doesn't need to be done each time we call the Lagrangian. models, component_to_index_mapping = self._associate_models_to_components( n_samples ) From 5036d810ab5e80574cbe1f5459154e513c1159ca Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 3 Sep 2025 09:34:24 +0100 Subject: [PATCH 33/33] Reinstate if not else fix --- src/causalprog/causal_problem/components.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/causalprog/causal_problem/components.py b/src/causalprog/causal_problem/components.py index a5b790a..2b9900c 100644 --- a/src/causalprog/causal_problem/components.py +++ b/src/causalprog/causal_problem/components.py @@ -90,10 +90,10 @@ def __init__( """ super().__init__(*effect_handlers, do_with_samples=model_quantity) - if outer_norm is not None: - self._outer_norm = outer_norm - else: + if outer_norm is None: self._outer_norm = jnp.linalg.vector_norm + else: + self._outer_norm = outer_norm self.data = data self.tolerance = tolerance