From 19b7f1e20cbc1892064781e8b2b4b180038351e2 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 15 Aug 2025 15:07:46 +0100 Subject: [PATCH 01/12] Reinstate check after renaming caused bugs in sampling --- tests/test_integration/test_two_normal_example.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index 5283fb1..1957a61 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -85,18 +85,18 @@ def lagrangian( subkeys = jax.random.split(rng_key, predictive_model.num_samples) l_mult = parameter_values["_l_mult"] - def _x_sampler(pv: dict[str, npt.ArrayLike], key: jax.Array) -> float: + def _ux_sampler(pv: dict[str, npt.ArrayLike], key: jax.Array) -> float: return predictive_model(key, **pv)["UX"] + def _x_sampler(pv: dict[str, npt.ArrayLike], key: jax.Array) -> float: + return predictive_model(key, **pv)["X"] + def _ce(pv, subkeys): return ( ce_prefactor * jax.vmap(_x_sampler, in_axes=(None, 0))(pv, subkeys).mean() ) - def _ux_sampler(pv: dict[str, npt.ArrayLike], key: jax.Array) -> float: - return predictive_model(key, **pv)["X"] - def _constraint(pv, subkeys): return ( jnp.abs( @@ -156,6 +156,11 @@ def objective(params, predictive, key, ce_prefactor=ce_prefactor): assert converged, f"Did not converge, final objective value: {objective_value}" + # The lagrangian is independent of nu_x, thus it should not have changed value. + assert jnp.isclose(params["cov2"], nu_x_starting_value), ( + "nu_x has changed significantly from the starting value." + ) + # Confirm that we found a minimiser that does satisfy the inequality constraints. assert params["_l_mult"] > 0.0, ( f"Converged, but not to a minimiser (lagrange multiplier = {params['_l_mult']})" From a5f55bca04b6e4a1aeac821e7d6cce99af05ecc9 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 15 Aug 2025 14:56:37 +0100 Subject: [PATCH 02/12] Remove old CausalProblem class --- src/causalprog/causal_problem.py | 319 ------------------ tests/test_causal_problem/test_callables.py | 212 ------------ .../test_graph_and_param.py | 42 --- 3 files changed, 573 deletions(-) delete mode 100644 src/causalprog/causal_problem.py delete mode 100644 tests/test_causal_problem/test_callables.py delete mode 100644 tests/test_causal_problem/test_graph_and_param.py diff --git a/src/causalprog/causal_problem.py b/src/causalprog/causal_problem.py deleted file mode 100644 index e1232cc..0000000 --- a/src/causalprog/causal_problem.py +++ /dev/null @@ -1,319 +0,0 @@ -"""Extension of the Graph class providing features for solving causal problems.""" - -from collections.abc import Callable -from inspect import signature -from typing import Literal, TypeAlias - -import jax -import jax.numpy as jnp - -from causalprog._abc.labelled import Labelled -from causalprog.graph import Graph, Node - -CausalEstimand: TypeAlias = Callable[..., float] -Constraints: TypeAlias = Callable[..., float] - - -def raises(exception: Exception) -> Callable[[], float]: - """Create a callable that raises ``exception`` when called.""" - - def _inner() -> float: - raise exception - - return _inner - - -class CausalProblem(Labelled): - """ - Container class for handling a causal problem. - - A causal problem - requires an underlying ``Graph`` to describe the relationships between the random - variables and parameters, plus a causal estimand and list of (data) constraints. - Structural constraints are handled by imposing restrictions on forms of the random - variables and constraints directly. - - A ``CausalProblem`` instance brings together these components, providing a container - for a causal problem that can be given inputs like empirical data, a solver - tolerance, etc, and will provide (estimates of) the bounds for the causal estimand. - - - The ``.graph`` attribute stores the underlying ``Graph`` object. - - The ``.causal_estimand`` method evaluates the causal estimand, given values for - the parameters. - - The ``.constraints`` method evaluates the (vector-valued) constraints, given - values for the parameters. - - The user must specify each of the above before a ``CausalProblem`` can be solved. - The primary way for this to be done is to construct or load the corresponding - ``Graph``, and provide it by setting the ``CausalProblem.graph`` attribute directly. - Then, `set_causal_estimand` and `set_constraints` can be used to provide the causal - estimand and constraints functions, in terms of the random variables. The - ``CausalProblem`` instance will handle turning them into functions of the parameter - values under the hood. Initial parameter values (for the purposes of solving) can be - provided to the solver method directly or set beforehand via ``set_parameters``. It - should never be necessary for the user to interact with, or provide, a vector of - parameters (as this is taken care of under the hood). - """ - - _graph: Graph | None - _sigma: CausalEstimand - _sigma_mapping: dict[str, Node] - _constraints: Constraints - _constraints_mapping: dict[str, Node] - _parameter_values: dict[str, float] - - @property - def graph(self) -> Graph: - """Graph defining the structure of the `CausalProblem`.""" - if self._graph is None: - msg = f"No graph set for {self.label}." - raise ValueError(msg) - return self._graph - - @graph.setter - def graph(self, new_graph: Graph) -> None: - if not isinstance(new_graph, Graph): - msg = f"{self.label}.graph must be a Graph instance." - raise TypeError(msg) - self._graph = new_graph - - @property - def parameter_values(self) -> dict[str, float]: - """Dictionary mapping parameter labels to their (current) values.""" - return self._parameter_vector_to_dict(self.parameter_vector) - - @property - def parameter_vector(self) -> jax.Array: - """Returns the (current) vector of parameter values.""" - return jnp.array( - tuple( - self._parameter_values[node.label] - if node.label in self._parameter_values - else float("NaN") - for node in self.graph.parameter_nodes - ), - ndmin=1, - ) - - def __init__( - self, - graph: Graph | None = None, - *, - label: str = "CausalProblem", - ) -> None: - """Set up a new CausalProblem.""" - super().__init__(label=label) - - self._parameter_values = {} - self._graph = graph - - # Callables cannot be evaluated until they are explicitly set - self.set_causal_estimand( - raises(NotImplementedError(f"Causal estimand not set for {self.label}.")) - ) - self.set_constraints( - raises(NotImplementedError(f"Constraints not set for {self.label}.")) - ) - - def _parameter_vector_to_dict( - self, parameter_vector: jax.Array - ) -> dict[str, float]: - """ - Convert a parameter vector to a dictionary mapping labels to parameter values. - - Convention is that a vector of parameter values contains values in the same - order as self.graph.parameter_nodes. - """ - # Avoid recomputing the parameter node tuple every time. - pn = self.graph.parameter_nodes - return {pn[i].label: value for i, value in enumerate(parameter_vector)} - - def _eval_callable( - self, which: Literal["sigma", "constraints"], at: jax.Array - ) -> jax.Array: - """ - Evaluate a callable method of this instance. - - This is an abstraction method for when the causal estimand or constraints - functions need to be evaluated. In each case, the process is the same: - - - Update the values of the parameter nodes. - - Call the underlying function composed with its mapping of RVs to Nodes. - - The method is abstracted here so that any changes to the process are reflected - in both methods automatically. - """ - self._set_parameters_via_vector(at) - if "parameter_values" in signature(getattr(self, f"_{which}")).parameters: - return getattr(self, f"_{which}")( - **getattr(self, f"_{which}_mapping"), - parameter_values=self._parameter_values, - ) - return getattr(self, f"_{which}")(**getattr(self, f"_{which}_mapping")) - - def _set_callable( - self, - which: Literal["sigma", "constraints"], - *, - fn: CausalEstimand | Constraints, - rvs_to_nodes: dict[str, str] | None = None, - graph_argument: str | None = None, - ) -> None: - """ - Abstracted method for setting the Causal Estimand and/or Constraints functions. - - The functionality for setting up these two methods of an instance are identical, - save for the attributes which need to be updated. As such, we can refactor the - common functionality into a single, hidden, method and provide a friendlier - access point for users to employ. - """ - fn_attr = f"_{which}" - map_attr = f"_{which}_mapping" - debug_name = "constraints" if which == "constraints" else "causal estimand" - - setattr(self, fn_attr, fn) - setattr(self, map_attr, {}) - - if rvs_to_nodes is None: - rvs_to_nodes = {} - fn_args = signature(fn).parameters - - for rv_name, node_label in rvs_to_nodes.items(): - if rv_name not in fn_args: - msg = f"{rv_name} is not an argument to provided {debug_name} function." - raise ValueError(msg) - getattr(self, map_attr)[rv_name] = self.graph.get_node(node_label) - - # Any unaccounted-for RV arguments to sigma are assumed to match - # the label of the corresponding node. - args_not_used = set(fn_args) - set(getattr(self, map_attr)) - - ## Temporary hack to ensure that we can use expectation(graph, X) syntax. - if graph_argument: - getattr(self, map_attr)[graph_argument] = self.graph - args_not_used -= {graph_argument} - ## END HACK - - for arg in args_not_used: - if arg != "parameter_values": - getattr(self, map_attr)[arg] = self.graph.get_node(arg) - - def _set_parameters_via_vector(self, parameter_vector: jax.Array | None) -> None: - """ - Shorthand to set parameter node values from a parameter vector. - - No intended for frontend use - primary use will be internal when running - optimisation methods over the CausalProblem, when we need to treat the - parameters as a vector or array of function inputs. - """ - self.set_parameter_values(**self._parameter_vector_to_dict(parameter_vector)) - - def set_parameter_values(self, **parameter_values: float | None) -> None: - """ - Set (initial) parameter values for this CausalProblem. - - See ``Graph.set_parameters`` for input details. - """ - for parameter, value in parameter_values.items(): - if value is None: - if parameter in self._parameter_values: - del self._parameter_values[parameter] - else: - self._parameter_values[parameter] = value - - def set_causal_estimand( - self, - sigma: CausalEstimand, - rvs_to_nodes: dict[str, str] | None = None, - graph_argument: str | None = None, - ) -> None: - """ - Set the Causal Estimand for this problem. - - `sigma` should be a callable object that defines the Causal Estimand of - interest, in terms of the random variables of to the problem. The - random variables are in turn represented by `Node`s, with this association being - recorded in the `rv_to_nodes` dictionary. - - The `causal_estimand` method of the instance will be usable once this method - completes. - - Args: - sigma (CausalEstimand): Callable object that evaluates the causal estimand - of interest for this `CausalProblem`, in terms of the random variables, - which are the arguments to this callable. `sigma`s with additional - arguments are not currently supported. - rvs_to_nodes (dict[str, str]): Mapping of random variable (argument) names - of `sigma` to the labels of the corresponding `Node`s representing the - random variables. Argument names that match their corresponding `Node` - label can be omitted. - graph_argument (str): Argument to `sigma` that should be replaced with - `self.graph`. This argument is only temporary, as we are currently - limited to the syntax `expectation(Graph, Node)` rather than just - `expectation(Node)`. It will be removed in the future when methods like - `expectation` can be called solely on `Node` objects. - - """ - self._set_callable( - "sigma", fn=sigma, rvs_to_nodes=rvs_to_nodes, graph_argument=graph_argument - ) - - def set_constraints( - self, - constraints: CausalEstimand, - rvs_to_nodes: dict[str, str] | None = None, - graph_argument: str | None = None, - ) -> None: - """ - Set the Constraints for this problem. - - ``constraints`` should be a callable object that defines the Data Constraints of - interest, in terms of the random variables of to the problem. The - random variables are in turn represented by `Node`s, with this association being - recorded in the `rv_to_nodes` dictionary. - - The `constraints` method of the instance will be usable once this method - completes. - - Args: - constraints (Constraints): Callable object that evaluates the constraints - of interest for this `CausalProblem`, in terms of the random variables, - which are the arguments to this callable. ``constraints``s with - additional arguments are not currently supported. - rvs_to_nodes (dict[str, str]): Mapping of random variable (argument) names - of `sigma` to the labels of the corresponding `Node`s representing the - random variables. Argument names that match their corresponding `Node` - label can be omitted. - graph_argument (str): Argument to `sigma` that should be replaced with - `self.graph`. This argument is only temporary, as we are currently - limited to the syntax `expectation(Graph, Node)` rather than just - `expectation(Node)`. It will be removed in the future when methods like - `expectation` can be called solely on `Node` objects. - - """ - self._set_callable( - "constraints", - fn=constraints, - rvs_to_nodes=rvs_to_nodes, - graph_argument=graph_argument, - ) - - def causal_estimand(self, p: jax.Array) -> float: - """ - Evaluate the Causal Estimand at parameter vector `p`. - - Args: - p (jax.Array): Vector of parameter values to evaluate at. - - """ - return self._eval_callable("sigma", p) - - def constraints(self, p: jax.Array) -> jax.Array: - """ - Evaluate the Constraints at parameter vector `p`. - - Args: - p (jax.Array): Vector of parameter values to evaluate at. - - """ - return self._eval_callable("constraints", p) diff --git a/tests/test_causal_problem/test_callables.py b/tests/test_causal_problem/test_callables.py deleted file mode 100644 index f70d2f2..0000000 --- a/tests/test_causal_problem/test_callables.py +++ /dev/null @@ -1,212 +0,0 @@ -from collections.abc import Callable -from typing import Literal - -import jax -import jax.numpy as jnp -import pytest - -from causalprog.algorithms import expectation, standard_deviation -from causalprog.causal_problem import CausalProblem -from causalprog.graph import Graph, Node - - -@pytest.fixture -def n_samples_for_estimands() -> int: - return 1000 - - -@pytest.fixture -def expectation_fixture( - rng_key: jax.Array, - n_samples_for_estimands: int, -) -> Callable[[Graph, Node, dict[str, float] | None], float]: - def _inner(g: Graph, x: Node, parameter_values=None) -> float: - return expectation( - g, - x.label, - samples=n_samples_for_estimands, - rng_key=rng_key, - parameter_values=parameter_values, - ) - - return _inner - - -@pytest.fixture -def std_fixture( - rng_key: jax.Array, - n_samples_for_estimands: int, -) -> Callable[[Graph, Node, dict[str, float] | None], float]: - def _inner(g: Graph, x: Node, parameter_values=None) -> float: - return ( - standard_deviation( - g, - x.label, - samples=n_samples_for_estimands, - rng_key=rng_key, - parameter_values=parameter_values, - ) - ** 2 - ) - - return _inner - - -@pytest.fixture -def vector_fixture( - rng_key: jax.Array, - n_samples_for_estimands: int, -) -> Callable[[Graph, Node, Node, dict[str, float] | None], jax.Array]: - """vector_fixture(g, x1, x2) = [mean of x1, std of x2].""" - - def _inner(g: Graph, x1: Node, x2: Node, parameter_values=None) -> jax.Array: - return jnp.array( - [ - expectation( - g, - x1.label, - samples=n_samples_for_estimands, - rng_key=rng_key, - parameter_values=parameter_values, - ), - standard_deviation( - g, - x2.label, - samples=n_samples_for_estimands, - rng_key=rng_key, - parameter_values=parameter_values, - ) - ** 2, - ] - ) - - return _inner - - -@pytest.fixture(params=["causal_estimand", "constraints"]) -def which(request: pytest.FixtureRequest) -> Literal["causal_estimand", "constraints"]: - """For tests applicable to both the causal_estimand and constraints methods.""" - return request.param - - -@pytest.mark.parametrize( - ("initial_param_values", "args_to_setter", "expected", "atol"), - [ - pytest.param( - {"mean": 1.0, "cov2": 1.0}, - { - "fn": "expectation_fixture", - "rvs_to_nodes": {"x": "mean"}, - "graph_argument": "g", - }, - 1.0, - 1.0e-12, - id="mean", - ), - pytest.param( - {"mean": 1.0, "cov2": 1.0}, - { - "fn": "expectation_fixture", - "rvs_to_nodes": {"x": "cov2"}, - "graph_argument": "g", - }, - 1.0, - 1.0e-12, - id="cov2", - ), - pytest.param( - {"mean": 0.0, "cov2": 1.0}, - { - "fn": "expectation_fixture", - "rvs_to_nodes": {"x": "X"}, - "graph_argument": "g", - }, - 0.0, - 3.0e-2, - id="E[x], infer association", - ), - pytest.param( - {"mean": 0.0, "cov2": 1.0}, - { - "fn": "std_fixture", - "rvs_to_nodes": {"x": "X"}, - "graph_argument": "g", - }, - # UX has fixed std 1, and cov2 will be set to 1. - 1.0**2 + 1.0**2, - 3.0e-1, - id="Var[y]", - ), - pytest.param( - {"mean": 0.0, "cov2": 1.0}, - { - "fn": "vector_fixture", - "rvs_to_nodes": {"x1": "UX", "x2": "X"}, - "graph_argument": "g", - }, - # As per the previous test cases - jnp.array([0.0, 1.0**2 + 1.0**2]), - jnp.array([3.0e-2, 2.0e-1]), - id="E[x], Var[y]", - ), - ], -) -def test_callables( - two_normal_graph: Callable[..., Graph], - which: Literal["causal_estimand", "constraints"], - initial_param_values: dict[str, float], - args_to_setter: dict[str, Callable[..., float] | dict[str, str] | str], - expected: float | jax.Array, - atol: float, - request: pytest.FixtureRequest, - raises_context, -) -> None: - """ - Test the set_{causal_estimand, constraints} and .{casual_estimand, constraints} - evaluation method. - - Test works by: - - Set the parameter values using the initial_param_values. - - Call the setter method using the given arguments. - - Evaluate the method that should have been set at the current parameter_vector, - which should evaluate the corresponding function at the current values of the - parameter vector, which will be the initial values just set. - - Check the result (lies within a given tolerance). - - In theory, there is no difference between the causal estimand and constraints when - it comes to this test - the constraints may be vector-valued but there is nothing - preventing the ``causal_estimand`` (programmatically) from being vector-valued - either. - """ - # Parametrised fixtures edit-in-place objects - args_to_setter = dict(args_to_setter) - if isinstance(args_to_setter["fn"], str): - args_to_setter["fn"] = request.getfixturevalue(args_to_setter["fn"]) - - if which == "constraints": - args_to_setter["constraints"] = args_to_setter.pop("fn") - else: - args_to_setter["sigma"] = args_to_setter.pop("fn") - - expected = jnp.array(expected, ndmin=1) - - # Test properly begins. - graph = two_normal_graph(cov=1.0) - cp = CausalProblem(graph) - - method = getattr(cp, which) - setter_method = getattr(cp, f"set_{which}") - - # Before setting the causal estimand, it should throw an error if called. - with raises_context( - NotImplementedError( - f"{which.replace('_', ' ').capitalize()} not set for CausalProblem." - ) - ): - method(cp.parameter_vector) - - cp.set_parameter_values(**initial_param_values) - setter_method(**args_to_setter) - result = jnp.array(method(cp.parameter_vector), ndmin=1) - - assert jnp.allclose(result, expected, atol=atol) diff --git a/tests/test_causal_problem/test_graph_and_param.py b/tests/test_causal_problem/test_graph_and_param.py deleted file mode 100644 index 96ea924..0000000 --- a/tests/test_causal_problem/test_graph_and_param.py +++ /dev/null @@ -1,42 +0,0 @@ -from collections.abc import Callable - -import jax.numpy as jnp - -from causalprog.causal_problem import CausalProblem -from causalprog.graph import Graph - - -def test_graph_and_parameter_interactions( - two_normal_graph: Callable[..., Graph], - raises_context, -) -> None: - cp = CausalProblem(label="TestCP") - - # Without a graph, we can't do anything - with raises_context(ValueError("No graph set for TestCP")): - cp.graph # noqa: B018 - with raises_context(ValueError("No graph set for TestCP")): - cp.parameter_values # noqa: B018 - - # Cannot set graph to non-graph value - with raises_context(TypeError("TestCP.graph must be a Graph instance")): - cp.graph = 1.0 - - # Provide an actual graph value - cp.graph = two_normal_graph(cov=1.0) - - # We should now be able to fetch parameter values, but they are all unset. - assert jnp.all(jnp.isnan(cp.parameter_vector)) - assert cp.parameter_vector.shape == (len(cp.graph.parameter_nodes),) - assert all(jnp.isnan(value) for value in cp.parameter_values.values()) - assert set(cp.parameter_values.keys()) == {"mean", "cov2"} - - # Users should only ever need to set parameter values via their names. - cp.set_parameter_values(mean=1.0, cov2=2.0) - assert cp.parameter_values == {"mean": 1.0, "cov2": 2.0} - # We don't know which way round the internal parameter vector is being stored, - # but that doesn't matter. We do know that it should contain the values 1 & 2 - # in some order though. - assert jnp.allclose(cp.parameter_vector, jnp.array([1.0, 2.0])) or jnp.allclose( - cp.parameter_vector, jnp.array([2.0, 1.0]) - ) From de904bff028424455729701a2a878aa803d9dad2 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 15 Aug 2025 14:57:22 +0100 Subject: [PATCH 03/12] Add barebones classes to be populated later --- src/causalprog/causal_problem/__init__.py | 0 .../causal_problem/causal_estimand.py | 45 ++++++++++++++ .../causal_problem/causal_problem.py | 58 +++++++++++++++++++ 3 files changed, 103 insertions(+) create mode 100644 src/causalprog/causal_problem/__init__.py create mode 100644 src/causalprog/causal_problem/causal_estimand.py create mode 100644 src/causalprog/causal_problem/causal_problem.py diff --git a/src/causalprog/causal_problem/__init__.py b/src/causalprog/causal_problem/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/causal_estimand.py new file mode 100644 index 0000000..05effd1 --- /dev/null +++ b/src/causalprog/causal_problem/causal_estimand.py @@ -0,0 +1,45 @@ +"""C.""" + +from collections.abc import Callable +from typing import Any, Concatenate, TypeAlias + +import numpy.typing as npt + +Model: TypeAlias = Callable[..., Any] +EffectHandler: TypeAlias = Callable[Concatenate[Model, ...], Model] +ModelMask: TypeAlias = tuple[EffectHandler, dict] + + +class _CPComponent: + """""" + + do_with_samples: Callable[..., npt.ArrayLike] + effect_handlers: tuple[ModelMask, ...] + + @property + def requires_model_adaption(self) -> bool: + """Return True if effect handlers need to be applied to model.""" + return len(self.effect_handlers) > 0 + + def __init__( + self, + *effect_handlers: ModelMask, + do_with_samples: Callable[..., npt.ArrayLike], + ): + self.effect_handlers = tuple(effect_handlers) + self.do_with_samples = do_with_samples + + def apply_effects(self, model: Model) -> Model: + """Apply any necessary effect handlers prior to evaluating.""" + adapted_model = model + for handler, handler_options in self.effect_handlers: + adapted_model = handler(adapted_model, **handler_options) + return adapted_model + + +class CausalEstimand(_CPComponent): + """""" + + +class Constraint(_CPComponent): + """""" diff --git a/src/causalprog/causal_problem/causal_problem.py b/src/causalprog/causal_problem/causal_problem.py new file mode 100644 index 0000000..ae44f7e --- /dev/null +++ b/src/causalprog/causal_problem/causal_problem.py @@ -0,0 +1,58 @@ +"""C.""" + +from collections.abc import Callable + +import jax +import numpy.typing as npt +from numpyro.infer import Predictive + +from causalprog.causal_problem.causal_estimand import CausalEstimand, Constraint, Model + + +def sample_model( + model: Predictive, rng_key: jax.Array, parameter_values: dict[str, npt.ArrayLike] +) -> dict[str, npt.ArrayLike]: + return jax.vmap(lambda pv, key: model(key, **pv), in_axes=(None, 0))( + parameter_values, + jax.random.split(rng_key, model.num_samples), + ) + + +class CausalProblem: + causal_estimand: CausalEstimand + constraints: list[Constraint] + + def __init__( + self, + *constraints: Constraint, + causal_estimand: CausalEstimand, + ): + self.causal_estimand = causal_estimand + self.constraints = list(constraints) + + def lagrangian( + self, n_samples: int = 1000 + ) -> Callable[[dict[str, npt.ArrayLike], Model, jax.Array], npt.ArrayLike]: + """Assemble the Lagrangian.""" + + def _inner( + parameter_values: dict[str, npt.ArrayLike], model: Model, rng_key: jax.Array + ) -> npt.ArrayLike: + # In general, we will need to check which of our CE/CONs require masking, + # and do multiple predictive models to account for this... + # We can always pre-build the predictive models too, so we should replace + # the "model" input with something that can map the right predictive models + # to the CE/CONS that need them. + predictive_model = Predictive(model=model, num_samples=n_samples) + all_samples = sample_model(predictive_model, rng_key, parameter_values) + + value = self.causal_estimand.do_with_samples(**all_samples) + # CLEANER IF THE LAGRANGE MULTIPLIERS COULD BE A SECOND FUNCTION ARG, + # as right now they have to be inside the parameter dict... + value += sum( + parameter_values[f"_l_mult{i}"] * c.do_with_samples(**all_samples) + for i, c in enumerate(self.constraints) + ) + return value + + return _inner From 32ccf0b38573037791f13ab8fead1bbe6532f725 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 15 Aug 2025 15:02:05 +0100 Subject: [PATCH 04/12] Update two normal example test to use new infrastructure --- .../test_two_normal_example.py | 82 +++++++------------ 1 file changed, 29 insertions(+), 53 deletions(-) diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index 1957a61..b5ebcf9 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -1,12 +1,13 @@ +import sys from collections.abc import Callable import jax import jax.numpy as jnp -import numpy.typing as npt import optax import pytest -from numpyro.infer import Predictive +from causalprog.causal_problem.causal_estimand import CausalEstimand, Constraint +from causalprog.causal_problem.causal_problem import CausalProblem from causalprog.graph import Graph @@ -71,55 +72,28 @@ def test_two_normal_example( (again with $+\mu_{ux}$ in the maximisation case). In both cases, $\mathcal{L}$ is minimised at $(\mu_{ux}^{*}, \nu_x, 1)$. """ - # Setup the optimisation problem from the graph - g = two_normal_graph(cov=1.0) - predictive_model = Predictive(g.model, num_samples=n_samples) - - def lagrangian( - parameter_values: dict[str, npt.ArrayLike], - predictive_model: Predictive, - rng_key: jax.Array, - *, - ce_prefactor: float, - ): - subkeys = jax.random.split(rng_key, predictive_model.num_samples) - l_mult = parameter_values["_l_mult"] - - def _ux_sampler(pv: dict[str, npt.ArrayLike], key: jax.Array) -> float: - return predictive_model(key, **pv)["UX"] - - def _x_sampler(pv: dict[str, npt.ArrayLike], key: jax.Array) -> float: - return predictive_model(key, **pv)["X"] - - def _ce(pv, subkeys): - return ( - ce_prefactor - * jax.vmap(_x_sampler, in_axes=(None, 0))(pv, subkeys).mean() - ) - - def _constraint(pv, subkeys): - return ( - jnp.abs( - jax.vmap(_ux_sampler, in_axes=(None, 0))(pv, subkeys).mean() - - phi_observed - ) - - epsilon - ) - - return _ce(parameter_values, subkeys) + l_mult * _constraint( - parameter_values, subkeys - ) - # In both cases, the Lagrange multiplier has the value 1.0 at the minimum. lambda_sol = 1.0 ce_prefactor = 1.0 if not is_solving_max else -1.0 mu_x_sol = phi_observed - ce_prefactor * epsilon + # Setup the optimisation problem from the graph + g = two_normal_graph(cov=1.0) + ce = CausalEstimand(do_with_samples=lambda **pv: ce_prefactor * pv["X"].mean()) + con = Constraint( + do_with_samples=lambda **pv: jnp.abs(pv["UX"].mean() - phi_observed) - epsilon + ) + cp = CausalProblem( + con, + causal_estimand=ce, + ) + lagrangian = cp.lagrangian(n_samples=n_samples) + # We'll be seeking stationary points of the Lagrangian, using the # naive approach of minimising the norm of its gradient. We will need to # ensure we "converge" to a minimum value suitably close to 0. - def objective(params, predictive, key, ce_prefactor=ce_prefactor): - v = jax.grad(lagrangian)(params, predictive, key, ce_prefactor=ce_prefactor) + def objective(params, predictive, key): + v = jax.grad(lagrangian)(params, predictive, key) return sum(value**2 for value in v.values()) # Choose a starting guess that is at the optimal solution, in the hopes that @@ -130,7 +104,7 @@ def objective(params, predictive, key, ce_prefactor=ce_prefactor): params = { "mean": mu_x_sol, "cov2": nu_x_starting_value, - "_l_mult": lambda_sol, + "_l_mult0": lambda_sol, } # Setup SGD optimiser optimiser = optax.adam(adams_learning_rate) @@ -139,17 +113,18 @@ def objective(params, predictive, key, ce_prefactor=ce_prefactor): converged = False for _ in range(maxiter): # Actual iteration loop - grads = jax.jacobian(objective)( - params, predictive_model, rng_key, ce_prefactor=ce_prefactor - ) + grads = jax.jacobian(objective)(params, g.model, rng_key) updates, opt_state = optimiser.update(grads, opt_state) params = optax.apply_updates(params, updates) # Convergence "check" and progress update - objective_value = objective( - params, predictive_model, rng_key, ce_prefactor=ce_prefactor + objective_value = objective(params, g.model, rng_key) + sys.stdout.write( + f"{_}, F_val={objective_value:.4e}, " + f"mu_ux={params['mean']:.4e}, " + f"nu_x={params['cov2']:.4e}, " + f"lambda={params['_l_mult0']:.4e}\n" ) - if jnp.abs(objective_value) <= minimisation_tolerance: converged = True break @@ -162,12 +137,13 @@ def objective(params, predictive, key, ce_prefactor=ce_prefactor): ) # Confirm that we found a minimiser that does satisfy the inequality constraints. - assert params["_l_mult"] > 0.0, ( - f"Converged, but not to a minimiser (lagrange multiplier = {params['_l_mult']})" + assert params["_l_mult0"] > 0.0, ( + "Converged, but not to a minimiser " + f"(lagrange multiplier = {params['_l_mult0']})" ) # Give a generous error margin in mu_ux and the Lagrange multiplier, # given SGD is being used on MC-integral functions. rtol = jnp.sqrt(1.0 / n_samples) assert jnp.isclose(params["mean"], mu_x_sol, rtol=rtol) - assert jnp.isclose(params["_l_mult"], lagrange_mult_sol, atol=rtol) + assert jnp.isclose(params["_l_mult0"], lagrange_mult_sol, atol=rtol) From c8fbaa5ad5764a94b9e1971fa25ae4ffd044e196 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 15 Aug 2025 15:28:42 +0100 Subject: [PATCH 05/12] Rework so that the lagrangian can be passed model parameters and the multiplier values as separate args --- .../causal_problem/causal_problem.py | 13 ++++++-- .../test_two_normal_example.py | 31 ++++++++++--------- 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/src/causalprog/causal_problem/causal_problem.py b/src/causalprog/causal_problem/causal_problem.py index ae44f7e..08b1ca9 100644 --- a/src/causalprog/causal_problem/causal_problem.py +++ b/src/causalprog/causal_problem/causal_problem.py @@ -32,11 +32,16 @@ def __init__( def lagrangian( self, n_samples: int = 1000 - ) -> Callable[[dict[str, npt.ArrayLike], Model, jax.Array], npt.ArrayLike]: + ) -> Callable[ + [dict[str, npt.ArrayLike], npt.ArrayLike, Model, jax.Array], npt.ArrayLike + ]: """Assemble the Lagrangian.""" def _inner( - parameter_values: dict[str, npt.ArrayLike], model: Model, rng_key: jax.Array + parameter_values: dict[str, npt.ArrayLike], + l_mult: jax.Array, + model: Model, + rng_key: jax.Array, ) -> npt.ArrayLike: # In general, we will need to check which of our CE/CONs require masking, # and do multiple predictive models to account for this... @@ -49,8 +54,10 @@ def _inner( value = self.causal_estimand.do_with_samples(**all_samples) # CLEANER IF THE LAGRANGE MULTIPLIERS COULD BE A SECOND FUNCTION ARG, # as right now they have to be inside the parameter dict... + # Cleaner if we could somehow build a vector-valued function of the + # constraints and then take a dot product, but this works for now value += sum( - parameter_values[f"_l_mult{i}"] * c.do_with_samples(**all_samples) + l_mult[i] * c.do_with_samples(**all_samples) for i, c in enumerate(self.constraints) ) return value diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index b5ebcf9..f13fe81 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -72,8 +72,6 @@ def test_two_normal_example( (again with $+\mu_{ux}$ in the maximisation case). In both cases, $\mathcal{L}$ is minimised at $(\mu_{ux}^{*}, \nu_x, 1)$. """ - # In both cases, the Lagrange multiplier has the value 1.0 at the minimum. - lambda_sol = 1.0 ce_prefactor = 1.0 if not is_solving_max else -1.0 mu_x_sol = phi_observed - ce_prefactor * epsilon @@ -92,9 +90,9 @@ def test_two_normal_example( # We'll be seeking stationary points of the Lagrangian, using the # naive approach of minimising the norm of its gradient. We will need to # ensure we "converge" to a minimum value suitably close to 0. - def objective(params, predictive, key): - v = jax.grad(lagrangian)(params, predictive, key) - return sum(value**2 for value in v.values()) + def objective(params, l_mult, predictive, key): + v = jax.grad(lagrangian, argnums=(0, 1))(params, l_mult, predictive, key) + return sum(value**2 for value in v[0].values()) + (v[1] ** 2).sum() # Choose a starting guess that is at the optimal solution, in the hopes that # SGD converges quickly. We almost certainly will not have this luxury in general. @@ -104,26 +102,30 @@ def objective(params, predictive, key): params = { "mean": mu_x_sol, "cov2": nu_x_starting_value, - "_l_mult0": lambda_sol, } + l_mult = jnp.atleast_1d(lagrange_mult_sol) + # Setup SGD optimiser optimiser = optax.adam(adams_learning_rate) - opt_state = optimiser.init(params) + opt_state = optimiser.init((params, l_mult)) + # Run optimisation loop on gradient of the Lagrangian converged = False for _ in range(maxiter): # Actual iteration loop - grads = jax.jacobian(objective)(params, g.model, rng_key) + grads = jax.jacobian(objective, argnums=(0, 1))( + params, l_mult, g.model, rng_key + ) updates, opt_state = optimiser.update(grads, opt_state) - params = optax.apply_updates(params, updates) + params, l_mult = optax.apply_updates((params, l_mult), updates) # Convergence "check" and progress update - objective_value = objective(params, g.model, rng_key) + objective_value = objective(params, l_mult, g.model, rng_key) sys.stdout.write( f"{_}, F_val={objective_value:.4e}, " f"mu_ux={params['mean']:.4e}, " f"nu_x={params['cov2']:.4e}, " - f"lambda={params['_l_mult0']:.4e}\n" + f"lambda={l_mult[0]:.4e}\n" ) if jnp.abs(objective_value) <= minimisation_tolerance: converged = True @@ -137,13 +139,12 @@ def objective(params, predictive, key): ) # Confirm that we found a minimiser that does satisfy the inequality constraints. - assert params["_l_mult0"] > 0.0, ( - "Converged, but not to a minimiser " - f"(lagrange multiplier = {params['_l_mult0']})" + assert jnp.all(l_mult > 0.0), ( + f"Converged, but not to a minimiser (lagrange multiplier = {l_mult})" ) # Give a generous error margin in mu_ux and the Lagrange multiplier, # given SGD is being used on MC-integral functions. rtol = jnp.sqrt(1.0 / n_samples) assert jnp.isclose(params["mean"], mu_x_sol, rtol=rtol) - assert jnp.isclose(params["_l_mult0"], lagrange_mult_sol, atol=rtol) + assert jnp.allclose(l_mult, lagrange_mult_sol, atol=rtol) From 87d128fba284bd3e13a66d3ad51d4f9b58533bce Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 15 Aug 2025 15:39:34 +0100 Subject: [PATCH 06/12] ruffing --- .../causal_problem/causal_estimand.py | 8 +-- .../causal_problem/causal_problem.py | 52 +++++++++++++++++-- 2 files changed, 52 insertions(+), 8 deletions(-) diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/causal_estimand.py index 05effd1..8ff50a4 100644 --- a/src/causalprog/causal_problem/causal_estimand.py +++ b/src/causalprog/causal_problem/causal_estimand.py @@ -11,7 +11,7 @@ class _CPComponent: - """""" + """Base.""" do_with_samples: Callable[..., npt.ArrayLike] effect_handlers: tuple[ModelMask, ...] @@ -25,7 +25,7 @@ def __init__( self, *effect_handlers: ModelMask, do_with_samples: Callable[..., npt.ArrayLike], - ): + ) -> None: self.effect_handlers = tuple(effect_handlers) self.do_with_samples = do_with_samples @@ -38,8 +38,8 @@ def apply_effects(self, model: Model) -> Model: class CausalEstimand(_CPComponent): - """""" + """A Causal Estimand.""" class Constraint(_CPComponent): - """""" + """A constraint function.""" diff --git a/src/causalprog/causal_problem/causal_problem.py b/src/causalprog/causal_problem/causal_problem.py index 08b1ca9..10fdd50 100644 --- a/src/causalprog/causal_problem/causal_problem.py +++ b/src/causalprog/causal_problem/causal_problem.py @@ -12,6 +12,21 @@ def sample_model( model: Predictive, rng_key: jax.Array, parameter_values: dict[str, npt.ArrayLike] ) -> dict[str, npt.ArrayLike]: + """ + Draw samples from the predictive model. + + TODO: Move function to somewhere more appropriate. + + Args: + model: Predictive model to draw samples from. + rng_key: PRNG Key to use in pseudorandom number generation. + parameter_values: Model parameter values to substitute. + + Returns: + `dict` of samples, with RV labels as keys and sample values (`jax.Array`s) as + values. + + """ return jax.vmap(lambda pv, key: model(key, **pv), in_axes=(None, 0))( parameter_values, jax.random.split(rng_key, model.num_samples), @@ -19,6 +34,8 @@ def sample_model( class CausalProblem: + """Defines a causal problem.""" + causal_estimand: CausalEstimand constraints: list[Constraint] @@ -26,16 +43,41 @@ def __init__( self, *constraints: Constraint, causal_estimand: CausalEstimand, - ): + ) -> None: + """Create a new causal problem.""" self.causal_estimand = causal_estimand self.constraints = list(constraints) def lagrangian( - self, n_samples: int = 1000 + self, n_samples: int = 1000, *, maximum_problem: bool = False ) -> Callable[ [dict[str, npt.ArrayLike], npt.ArrayLike, Model, jax.Array], npt.ArrayLike ]: - """Assemble the Lagrangian.""" + """ + Return a function that evaluates the Lagrangian of this `CausalProblem`. + + Following the + [KKT theorem](https://en.wikipedia.org/wiki/Karush%E2%80%93Kuhn%E2%80%93Tucker_conditions), + given the causal estimand and the constraints we can assemble a Lagrangian and + seek its stationary points, to in turn identify minimisers of the constrained + optimisation problem that we started with. + + The Lagrangian returned is a mathematical function of its first two arguments. + The first argument is the same dictionary of parameters that is passed to models + like `Graph.model`, and is the values the parameters (represented by the + `ParameterNode`s) are taking. The second argument is a 1D vector of Lagrange + multipliers, whose length is equal to the number of constraints. + + The remaining two arguments of the Lagrangian are the underlying model that it + should use to generate and draw samples of the RVs from, and the PRNGKey that + should be used in this generation. + + Note that our current implementation assumes there are no equality constraints + being imposed (in which case, we would need a 3-argument Lagrangian function). + + TODO: Can we store g.model in the class, and have this still work??? + """ + maximisation_prefactor = -1.0 if maximum_problem else 1.0 def _inner( parameter_values: dict[str, npt.ArrayLike], @@ -51,7 +93,9 @@ def _inner( predictive_model = Predictive(model=model, num_samples=n_samples) all_samples = sample_model(predictive_model, rng_key, parameter_values) - value = self.causal_estimand.do_with_samples(**all_samples) + value = maximisation_prefactor * self.causal_estimand.do_with_samples( + **all_samples + ) # CLEANER IF THE LAGRANGE MULTIPLIERS COULD BE A SECOND FUNCTION ARG, # as right now they have to be inside the parameter dict... # Cleaner if we could somehow build a vector-valued function of the From 7c171c401333b0842c25ab5e6ba20dac0681a285 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 15 Aug 2025 15:47:46 +0100 Subject: [PATCH 07/12] Refactor out g.model argument from the Lagrangian call --- .../causal_problem/causal_problem.py | 34 +++++++++++++------ .../test_two_normal_example.py | 16 ++++----- 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/src/causalprog/causal_problem/causal_problem.py b/src/causalprog/causal_problem/causal_problem.py index 10fdd50..e41824b 100644 --- a/src/causalprog/causal_problem/causal_problem.py +++ b/src/causalprog/causal_problem/causal_problem.py @@ -6,7 +6,8 @@ import numpy.typing as npt from numpyro.infer import Predictive -from causalprog.causal_problem.causal_estimand import CausalEstimand, Constraint, Model +from causalprog.causal_problem.causal_estimand import CausalEstimand, Constraint +from causalprog.graph import Graph def sample_model( @@ -36,23 +37,26 @@ def sample_model( class CausalProblem: """Defines a causal problem.""" + # NB: A CausalProblem could just BE a graph!!!! + # But separation of concerns and all... + _underlying_graph: Graph causal_estimand: CausalEstimand constraints: list[Constraint] def __init__( self, + graph: Graph, *constraints: Constraint, causal_estimand: CausalEstimand, ) -> None: """Create a new causal problem.""" + self._underlying_graph = graph self.causal_estimand = causal_estimand self.constraints = list(constraints) def lagrangian( self, n_samples: int = 1000, *, maximum_problem: bool = False - ) -> Callable[ - [dict[str, npt.ArrayLike], npt.ArrayLike, Model, jax.Array], npt.ArrayLike - ]: + ) -> Callable[[dict[str, npt.ArrayLike], npt.ArrayLike, jax.Array], npt.ArrayLike]: """ Return a function that evaluates the Lagrangian of this `CausalProblem`. @@ -68,21 +72,29 @@ def lagrangian( `ParameterNode`s) are taking. The second argument is a 1D vector of Lagrange multipliers, whose length is equal to the number of constraints. - The remaining two arguments of the Lagrangian are the underlying model that it - should use to generate and draw samples of the RVs from, and the PRNGKey that - should be used in this generation. + The remaining argument of the Lagrangian is the PRNGKey that should be used + when drawing samples. Note that our current implementation assumes there are no equality constraints being imposed (in which case, we would need a 3-argument Lagrangian function). - TODO: Can we store g.model in the class, and have this still work??? + Args: + n_samples: The number of random samples to be drawn when estimating the + value of functions of the RVs. + maximum_problem: If passed as `True`, assemble the Lagrangian for the + maximisation problem. Otherwise assemble that for the minimisation + problem (default behaviour). + + Returns: + The Lagrangian, as a function of the model parameters, Lagrange multipliers, + and PRNG key. + """ maximisation_prefactor = -1.0 if maximum_problem else 1.0 def _inner( parameter_values: dict[str, npt.ArrayLike], l_mult: jax.Array, - model: Model, rng_key: jax.Array, ) -> npt.ArrayLike: # In general, we will need to check which of our CE/CONs require masking, @@ -90,7 +102,9 @@ def _inner( # We can always pre-build the predictive models too, so we should replace # the "model" input with something that can map the right predictive models # to the CE/CONS that need them. - predictive_model = Predictive(model=model, num_samples=n_samples) + predictive_model = Predictive( + model=self._underlying_graph.model, num_samples=n_samples + ) all_samples = sample_model(predictive_model, rng_key, parameter_values) value = maximisation_prefactor * self.causal_estimand.do_with_samples( diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index f13fe81..bbae1b8 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -76,22 +76,22 @@ def test_two_normal_example( mu_x_sol = phi_observed - ce_prefactor * epsilon # Setup the optimisation problem from the graph - g = two_normal_graph(cov=1.0) - ce = CausalEstimand(do_with_samples=lambda **pv: ce_prefactor * pv["X"].mean()) + ce = CausalEstimand(do_with_samples=lambda **pv: pv["X"].mean()) con = Constraint( do_with_samples=lambda **pv: jnp.abs(pv["UX"].mean() - phi_observed) - epsilon ) cp = CausalProblem( + two_normal_graph(cov=1.0), con, causal_estimand=ce, ) - lagrangian = cp.lagrangian(n_samples=n_samples) + lagrangian = cp.lagrangian(n_samples=n_samples, maximum_problem=is_solving_max) # We'll be seeking stationary points of the Lagrangian, using the # naive approach of minimising the norm of its gradient. We will need to # ensure we "converge" to a minimum value suitably close to 0. - def objective(params, l_mult, predictive, key): - v = jax.grad(lagrangian, argnums=(0, 1))(params, l_mult, predictive, key) + def objective(params, l_mult, key): + v = jax.grad(lagrangian, argnums=(0, 1))(params, l_mult, key) return sum(value**2 for value in v[0].values()) + (v[1] ** 2).sum() # Choose a starting guess that is at the optimal solution, in the hopes that @@ -113,14 +113,12 @@ def objective(params, l_mult, predictive, key): converged = False for _ in range(maxiter): # Actual iteration loop - grads = jax.jacobian(objective, argnums=(0, 1))( - params, l_mult, g.model, rng_key - ) + grads = jax.jacobian(objective, argnums=(0, 1))(params, l_mult, rng_key) updates, opt_state = optimiser.update(grads, opt_state) params, l_mult = optax.apply_updates((params, l_mult), updates) # Convergence "check" and progress update - objective_value = objective(params, l_mult, g.model, rng_key) + objective_value = objective(params, l_mult, rng_key) sys.stdout.write( f"{_}, F_val={objective_value:.4e}, " f"mu_ux={params['mean']:.4e}, " From 33721938ca41b928cfebcc2d991765c7598a5454 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 15 Aug 2025 15:50:36 +0100 Subject: [PATCH 08/12] Make TODOs obvious so I don't forget to do them --- src/causalprog/causal_problem/causal_estimand.py | 3 +++ src/causalprog/causal_problem/causal_problem.py | 5 ++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/causal_estimand.py index 8ff50a4..02179a0 100644 --- a/src/causalprog/causal_problem/causal_estimand.py +++ b/src/causalprog/causal_problem/causal_estimand.py @@ -9,6 +9,9 @@ EffectHandler: TypeAlias = Callable[Concatenate[Model, ...], Model] ModelMask: TypeAlias = tuple[EffectHandler, dict] +# TODO: Docstrings, split into multiple files, +# auto-assembly of constraints wrt the tolerance? That kind of thing. + class _CPComponent: """Base.""" diff --git a/src/causalprog/causal_problem/causal_problem.py b/src/causalprog/causal_problem/causal_problem.py index e41824b..7677126 100644 --- a/src/causalprog/causal_problem/causal_problem.py +++ b/src/causalprog/causal_problem/causal_problem.py @@ -102,6 +102,7 @@ def _inner( # We can always pre-build the predictive models too, so we should replace # the "model" input with something that can map the right predictive models # to the CE/CONS that need them. + # TODO: Address pre-handlers that may apply from CEs/CONstraints predictive_model = Predictive( model=self._underlying_graph.model, num_samples=n_samples ) @@ -110,9 +111,7 @@ def _inner( value = maximisation_prefactor * self.causal_estimand.do_with_samples( **all_samples ) - # CLEANER IF THE LAGRANGE MULTIPLIERS COULD BE A SECOND FUNCTION ARG, - # as right now they have to be inside the parameter dict... - # Cleaner if we could somehow build a vector-valued function of the + # TODO: Cleaner if we could somehow build a vector-valued function of the # constraints and then take a dot product, but this works for now value += sum( l_mult[i] * c.do_with_samples(**all_samples) From 87fc5a2ec81c035a0cdf68c76d8c1dfaa7e69871 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Mon, 18 Aug 2025 09:39:27 +0100 Subject: [PATCH 09/12] Add docstrings and more TODOs --- src/causalprog/causal_problem/__init__.py | 1 + .../causal_problem/causal_estimand.py | 55 ++++++++++++++++++- .../causal_problem/causal_problem.py | 2 +- 3 files changed, 54 insertions(+), 4 deletions(-) diff --git a/src/causalprog/causal_problem/__init__.py b/src/causalprog/causal_problem/__init__.py index e69de29..3d45fe4 100644 --- a/src/causalprog/causal_problem/__init__.py +++ b/src/causalprog/causal_problem/__init__.py @@ -0,0 +1 @@ +"""Classes for defining causal problems.""" diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/causal_estimand.py index 02179a0..e7a72fa 100644 --- a/src/causalprog/causal_problem/causal_estimand.py +++ b/src/causalprog/causal_problem/causal_estimand.py @@ -14,7 +14,22 @@ class _CPComponent: - """Base.""" + """ + Base class for components of a Causal Problem. + + A _CPComponent has an attached method that it can apply to samples + (`do_with_samples`), which will be passed sample values of the RVs + during solution of a Causal Problem and used to evaluate the causal + estimand or constraint the instance represents. + + It also has a sequence of effect handlers that need to be applied + to the sampling model before samples can be drawn to evaluate this + component. For example, if a component requires conditioning on the + value of a RV, the `condition` handler needs to be applied to the + underlying model, before generating samples to pass to the + `do_with_sample` method. `effect_handlers` will be applied to the model + in the order they are given. + """ do_with_samples: Callable[..., npt.ArrayLike] effect_handlers: tuple[ModelMask, ...] @@ -40,9 +55,43 @@ def apply_effects(self, model: Model) -> Model: return adapted_model +# TODO: Turn CausalEstimand and Constraint into callables that evaluate do_with_samples + + class CausalEstimand(_CPComponent): - """A Causal Estimand.""" + """ + A Causal Estimand. + + The causal estimand is the function that we want to minimise (and maximise) + as part of a causal problem. It should be a scalar-valued function of the + random variables appearing in a graph. + """ class Constraint(_CPComponent): - """A constraint function.""" + r""" + A Constraint that forms part of a causal problem. + + Constraints of a causal problem are derived properties of RVs for which we + have observed data. The causal estimand is minimised (or maximised) subject + to the predicted values of the constraints being close to their observed + values in the data. + + Adding a constraint $g(\theta)$ to a causal problem (where $\theta$ are the + parameters of the causal problem) essentially imposes an additional + constraint on the minimisation problem; + + $$ g(\theta) - g_{\text{data}} \leq \epsilon, $$ + + where $g_{\text{data}}$ is the observed data value for the quantity $g$, + and $\epsilon$ is some tolerance. + """ + + # TODO: Should explain that Constraint needs more inputs and slightly different + # interpretation of the `do_with_samples` object. + # Inputs: + # - include epsilon as an input (allows constraints to have different tolerances) + # - `do_with_samples` should just be $g(\theta)$. Then have the instance build the + # full constraint that will need to be called in the Lagrangian. + # - $g$ still needs to be scalar valued? Allow a wrapper function to be applied in + # the event $g$ is vector-valued. diff --git a/src/causalprog/causal_problem/causal_problem.py b/src/causalprog/causal_problem/causal_problem.py index 7677126..7b3780f 100644 --- a/src/causalprog/causal_problem/causal_problem.py +++ b/src/causalprog/causal_problem/causal_problem.py @@ -72,7 +72,7 @@ def lagrangian( `ParameterNode`s) are taking. The second argument is a 1D vector of Lagrange multipliers, whose length is equal to the number of constraints. - The remaining argument of the Lagrangian is the PRNGKey that should be used + The remaining argument of the Lagrangian is the PRNG Key that should be used when drawing samples. Note that our current implementation assumes there are no equality constraints From d202568b2ff58a05b50e01092f76e10d034b0b11 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Mon, 18 Aug 2025 10:13:55 +0100 Subject: [PATCH 10/12] Todo resolution and addition --- src/causalprog/causal_problem/causal_estimand.py | 3 --- src/causalprog/causal_problem/causal_problem.py | 8 ++++---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/causal_estimand.py index e7a72fa..bb9635c 100644 --- a/src/causalprog/causal_problem/causal_estimand.py +++ b/src/causalprog/causal_problem/causal_estimand.py @@ -9,9 +9,6 @@ EffectHandler: TypeAlias = Callable[Concatenate[Model, ...], Model] ModelMask: TypeAlias = tuple[EffectHandler, dict] -# TODO: Docstrings, split into multiple files, -# auto-assembly of constraints wrt the tolerance? That kind of thing. - class _CPComponent: """ diff --git a/src/causalprog/causal_problem/causal_problem.py b/src/causalprog/causal_problem/causal_problem.py index 7b3780f..ffd8c16 100644 --- a/src/causalprog/causal_problem/causal_problem.py +++ b/src/causalprog/causal_problem/causal_problem.py @@ -10,14 +10,13 @@ from causalprog.graph import Graph +# TODO: Move somewhere more appropriate def sample_model( model: Predictive, rng_key: jax.Array, parameter_values: dict[str, npt.ArrayLike] ) -> dict[str, npt.ArrayLike]: """ Draw samples from the predictive model. - TODO: Move function to somewhere more appropriate. - Args: model: Predictive model to draw samples from. rng_key: PRNG Key to use in pseudorandom number generation. @@ -37,8 +36,6 @@ def sample_model( class CausalProblem: """Defines a causal problem.""" - # NB: A CausalProblem could just BE a graph!!!! - # But separation of concerns and all... _underlying_graph: Graph causal_estimand: CausalEstimand constraints: list[Constraint] @@ -108,6 +105,9 @@ def _inner( ) all_samples = sample_model(predictive_model, rng_key, parameter_values) + # TODO: would be cleaner if causal_estimand (and constraint) were just + # directly callable. This would also let us hide do_with_samples to avoid + # runtime edits... value = maximisation_prefactor * self.causal_estimand.do_with_samples( **all_samples ) From 618aa357504125f53e40fece61d54af6ed653fce Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Mon, 18 Aug 2025 11:24:53 +0100 Subject: [PATCH 11/12] Fix module-level placeholder docstrings --- src/causalprog/causal_problem/causal_estimand.py | 2 +- src/causalprog/causal_problem/causal_problem.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/causal_estimand.py index bb9635c..d916c7f 100644 --- a/src/causalprog/causal_problem/causal_estimand.py +++ b/src/causalprog/causal_problem/causal_estimand.py @@ -1,4 +1,4 @@ -"""C.""" +"""Classes for defining causal estimands and constraints of causal problems.""" from collections.abc import Callable from typing import Any, Concatenate, TypeAlias diff --git a/src/causalprog/causal_problem/causal_problem.py b/src/causalprog/causal_problem/causal_problem.py index ffd8c16..0ac6714 100644 --- a/src/causalprog/causal_problem/causal_problem.py +++ b/src/causalprog/causal_problem/causal_problem.py @@ -1,4 +1,4 @@ -"""C.""" +"""Classes for representing causal problems.""" from collections.abc import Callable From 3876bb64927e9ff209bd03f28a28ce71f58316d1 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Mon, 18 Aug 2025 11:45:47 +0100 Subject: [PATCH 12/12] Add issue tags to TODOs --- src/causalprog/causal_problem/causal_estimand.py | 6 ++---- src/causalprog/causal_problem/causal_problem.py | 11 ++++------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/causal_estimand.py index d916c7f..7166ec0 100644 --- a/src/causalprog/causal_problem/causal_estimand.py +++ b/src/causalprog/causal_problem/causal_estimand.py @@ -52,9 +52,6 @@ def apply_effects(self, model: Model) -> Model: return adapted_model -# TODO: Turn CausalEstimand and Constraint into callables that evaluate do_with_samples - - class CausalEstimand(_CPComponent): """ A Causal Estimand. @@ -84,7 +81,8 @@ class Constraint(_CPComponent): and $\epsilon$ is some tolerance. """ - # TODO: Should explain that Constraint needs more inputs and slightly different + # TODO: (https://github.com/UCL/causalprog/issues/89) + # Should explain that Constraint needs more inputs and slightly different # interpretation of the `do_with_samples` object. # Inputs: # - include epsilon as an input (allows constraints to have different tolerances) diff --git a/src/causalprog/causal_problem/causal_problem.py b/src/causalprog/causal_problem/causal_problem.py index 0ac6714..1feb441 100644 --- a/src/causalprog/causal_problem/causal_problem.py +++ b/src/causalprog/causal_problem/causal_problem.py @@ -10,7 +10,7 @@ from causalprog.graph import Graph -# TODO: Move somewhere more appropriate +# TODO: https://github.com/UCL/causalprog/issues/88 def sample_model( model: Predictive, rng_key: jax.Array, parameter_values: dict[str, npt.ArrayLike] ) -> dict[str, npt.ArrayLike]: @@ -99,20 +99,17 @@ def _inner( # We can always pre-build the predictive models too, so we should replace # the "model" input with something that can map the right predictive models # to the CE/CONS that need them. - # TODO: Address pre-handlers that may apply from CEs/CONstraints + # TODO: https://github.com/UCL/causalprog/issues/90 predictive_model = Predictive( model=self._underlying_graph.model, num_samples=n_samples ) all_samples = sample_model(predictive_model, rng_key, parameter_values) - # TODO: would be cleaner if causal_estimand (and constraint) were just - # directly callable. This would also let us hide do_with_samples to avoid - # runtime edits... + # TODO: https://github.com/UCL/causalprog/issues/86 value = maximisation_prefactor * self.causal_estimand.do_with_samples( **all_samples ) - # TODO: Cleaner if we could somehow build a vector-valued function of the - # constraints and then take a dot product, but this works for now + # TODO: https://github.com/UCL/causalprog/issues/87 value += sum( l_mult[i] * c.do_with_samples(**all_samples) for i, c in enumerate(self.constraints)