diff --git a/src/causalprog/causal_problem.py b/src/causalprog/causal_problem.py deleted file mode 100644 index e1232cc..0000000 --- a/src/causalprog/causal_problem.py +++ /dev/null @@ -1,319 +0,0 @@ -"""Extension of the Graph class providing features for solving causal problems.""" - -from collections.abc import Callable -from inspect import signature -from typing import Literal, TypeAlias - -import jax -import jax.numpy as jnp - -from causalprog._abc.labelled import Labelled -from causalprog.graph import Graph, Node - -CausalEstimand: TypeAlias = Callable[..., float] -Constraints: TypeAlias = Callable[..., float] - - -def raises(exception: Exception) -> Callable[[], float]: - """Create a callable that raises ``exception`` when called.""" - - def _inner() -> float: - raise exception - - return _inner - - -class CausalProblem(Labelled): - """ - Container class for handling a causal problem. - - A causal problem - requires an underlying ``Graph`` to describe the relationships between the random - variables and parameters, plus a causal estimand and list of (data) constraints. - Structural constraints are handled by imposing restrictions on forms of the random - variables and constraints directly. - - A ``CausalProblem`` instance brings together these components, providing a container - for a causal problem that can be given inputs like empirical data, a solver - tolerance, etc, and will provide (estimates of) the bounds for the causal estimand. - - - The ``.graph`` attribute stores the underlying ``Graph`` object. - - The ``.causal_estimand`` method evaluates the causal estimand, given values for - the parameters. - - The ``.constraints`` method evaluates the (vector-valued) constraints, given - values for the parameters. - - The user must specify each of the above before a ``CausalProblem`` can be solved. - The primary way for this to be done is to construct or load the corresponding - ``Graph``, and provide it by setting the ``CausalProblem.graph`` attribute directly. - Then, `set_causal_estimand` and `set_constraints` can be used to provide the causal - estimand and constraints functions, in terms of the random variables. The - ``CausalProblem`` instance will handle turning them into functions of the parameter - values under the hood. Initial parameter values (for the purposes of solving) can be - provided to the solver method directly or set beforehand via ``set_parameters``. It - should never be necessary for the user to interact with, or provide, a vector of - parameters (as this is taken care of under the hood). - """ - - _graph: Graph | None - _sigma: CausalEstimand - _sigma_mapping: dict[str, Node] - _constraints: Constraints - _constraints_mapping: dict[str, Node] - _parameter_values: dict[str, float] - - @property - def graph(self) -> Graph: - """Graph defining the structure of the `CausalProblem`.""" - if self._graph is None: - msg = f"No graph set for {self.label}." - raise ValueError(msg) - return self._graph - - @graph.setter - def graph(self, new_graph: Graph) -> None: - if not isinstance(new_graph, Graph): - msg = f"{self.label}.graph must be a Graph instance." - raise TypeError(msg) - self._graph = new_graph - - @property - def parameter_values(self) -> dict[str, float]: - """Dictionary mapping parameter labels to their (current) values.""" - return self._parameter_vector_to_dict(self.parameter_vector) - - @property - def parameter_vector(self) -> jax.Array: - """Returns the (current) vector of parameter values.""" - return jnp.array( - tuple( - self._parameter_values[node.label] - if node.label in self._parameter_values - else float("NaN") - for node in self.graph.parameter_nodes - ), - ndmin=1, - ) - - def __init__( - self, - graph: Graph | None = None, - *, - label: str = "CausalProblem", - ) -> None: - """Set up a new CausalProblem.""" - super().__init__(label=label) - - self._parameter_values = {} - self._graph = graph - - # Callables cannot be evaluated until they are explicitly set - self.set_causal_estimand( - raises(NotImplementedError(f"Causal estimand not set for {self.label}.")) - ) - self.set_constraints( - raises(NotImplementedError(f"Constraints not set for {self.label}.")) - ) - - def _parameter_vector_to_dict( - self, parameter_vector: jax.Array - ) -> dict[str, float]: - """ - Convert a parameter vector to a dictionary mapping labels to parameter values. - - Convention is that a vector of parameter values contains values in the same - order as self.graph.parameter_nodes. - """ - # Avoid recomputing the parameter node tuple every time. - pn = self.graph.parameter_nodes - return {pn[i].label: value for i, value in enumerate(parameter_vector)} - - def _eval_callable( - self, which: Literal["sigma", "constraints"], at: jax.Array - ) -> jax.Array: - """ - Evaluate a callable method of this instance. - - This is an abstraction method for when the causal estimand or constraints - functions need to be evaluated. In each case, the process is the same: - - - Update the values of the parameter nodes. - - Call the underlying function composed with its mapping of RVs to Nodes. - - The method is abstracted here so that any changes to the process are reflected - in both methods automatically. - """ - self._set_parameters_via_vector(at) - if "parameter_values" in signature(getattr(self, f"_{which}")).parameters: - return getattr(self, f"_{which}")( - **getattr(self, f"_{which}_mapping"), - parameter_values=self._parameter_values, - ) - return getattr(self, f"_{which}")(**getattr(self, f"_{which}_mapping")) - - def _set_callable( - self, - which: Literal["sigma", "constraints"], - *, - fn: CausalEstimand | Constraints, - rvs_to_nodes: dict[str, str] | None = None, - graph_argument: str | None = None, - ) -> None: - """ - Abstracted method for setting the Causal Estimand and/or Constraints functions. - - The functionality for setting up these two methods of an instance are identical, - save for the attributes which need to be updated. As such, we can refactor the - common functionality into a single, hidden, method and provide a friendlier - access point for users to employ. - """ - fn_attr = f"_{which}" - map_attr = f"_{which}_mapping" - debug_name = "constraints" if which == "constraints" else "causal estimand" - - setattr(self, fn_attr, fn) - setattr(self, map_attr, {}) - - if rvs_to_nodes is None: - rvs_to_nodes = {} - fn_args = signature(fn).parameters - - for rv_name, node_label in rvs_to_nodes.items(): - if rv_name not in fn_args: - msg = f"{rv_name} is not an argument to provided {debug_name} function." - raise ValueError(msg) - getattr(self, map_attr)[rv_name] = self.graph.get_node(node_label) - - # Any unaccounted-for RV arguments to sigma are assumed to match - # the label of the corresponding node. - args_not_used = set(fn_args) - set(getattr(self, map_attr)) - - ## Temporary hack to ensure that we can use expectation(graph, X) syntax. - if graph_argument: - getattr(self, map_attr)[graph_argument] = self.graph - args_not_used -= {graph_argument} - ## END HACK - - for arg in args_not_used: - if arg != "parameter_values": - getattr(self, map_attr)[arg] = self.graph.get_node(arg) - - def _set_parameters_via_vector(self, parameter_vector: jax.Array | None) -> None: - """ - Shorthand to set parameter node values from a parameter vector. - - No intended for frontend use - primary use will be internal when running - optimisation methods over the CausalProblem, when we need to treat the - parameters as a vector or array of function inputs. - """ - self.set_parameter_values(**self._parameter_vector_to_dict(parameter_vector)) - - def set_parameter_values(self, **parameter_values: float | None) -> None: - """ - Set (initial) parameter values for this CausalProblem. - - See ``Graph.set_parameters`` for input details. - """ - for parameter, value in parameter_values.items(): - if value is None: - if parameter in self._parameter_values: - del self._parameter_values[parameter] - else: - self._parameter_values[parameter] = value - - def set_causal_estimand( - self, - sigma: CausalEstimand, - rvs_to_nodes: dict[str, str] | None = None, - graph_argument: str | None = None, - ) -> None: - """ - Set the Causal Estimand for this problem. - - `sigma` should be a callable object that defines the Causal Estimand of - interest, in terms of the random variables of to the problem. The - random variables are in turn represented by `Node`s, with this association being - recorded in the `rv_to_nodes` dictionary. - - The `causal_estimand` method of the instance will be usable once this method - completes. - - Args: - sigma (CausalEstimand): Callable object that evaluates the causal estimand - of interest for this `CausalProblem`, in terms of the random variables, - which are the arguments to this callable. `sigma`s with additional - arguments are not currently supported. - rvs_to_nodes (dict[str, str]): Mapping of random variable (argument) names - of `sigma` to the labels of the corresponding `Node`s representing the - random variables. Argument names that match their corresponding `Node` - label can be omitted. - graph_argument (str): Argument to `sigma` that should be replaced with - `self.graph`. This argument is only temporary, as we are currently - limited to the syntax `expectation(Graph, Node)` rather than just - `expectation(Node)`. It will be removed in the future when methods like - `expectation` can be called solely on `Node` objects. - - """ - self._set_callable( - "sigma", fn=sigma, rvs_to_nodes=rvs_to_nodes, graph_argument=graph_argument - ) - - def set_constraints( - self, - constraints: CausalEstimand, - rvs_to_nodes: dict[str, str] | None = None, - graph_argument: str | None = None, - ) -> None: - """ - Set the Constraints for this problem. - - ``constraints`` should be a callable object that defines the Data Constraints of - interest, in terms of the random variables of to the problem. The - random variables are in turn represented by `Node`s, with this association being - recorded in the `rv_to_nodes` dictionary. - - The `constraints` method of the instance will be usable once this method - completes. - - Args: - constraints (Constraints): Callable object that evaluates the constraints - of interest for this `CausalProblem`, in terms of the random variables, - which are the arguments to this callable. ``constraints``s with - additional arguments are not currently supported. - rvs_to_nodes (dict[str, str]): Mapping of random variable (argument) names - of `sigma` to the labels of the corresponding `Node`s representing the - random variables. Argument names that match their corresponding `Node` - label can be omitted. - graph_argument (str): Argument to `sigma` that should be replaced with - `self.graph`. This argument is only temporary, as we are currently - limited to the syntax `expectation(Graph, Node)` rather than just - `expectation(Node)`. It will be removed in the future when methods like - `expectation` can be called solely on `Node` objects. - - """ - self._set_callable( - "constraints", - fn=constraints, - rvs_to_nodes=rvs_to_nodes, - graph_argument=graph_argument, - ) - - def causal_estimand(self, p: jax.Array) -> float: - """ - Evaluate the Causal Estimand at parameter vector `p`. - - Args: - p (jax.Array): Vector of parameter values to evaluate at. - - """ - return self._eval_callable("sigma", p) - - def constraints(self, p: jax.Array) -> jax.Array: - """ - Evaluate the Constraints at parameter vector `p`. - - Args: - p (jax.Array): Vector of parameter values to evaluate at. - - """ - return self._eval_callable("constraints", p) diff --git a/src/causalprog/causal_problem/__init__.py b/src/causalprog/causal_problem/__init__.py new file mode 100644 index 0000000..3d45fe4 --- /dev/null +++ b/src/causalprog/causal_problem/__init__.py @@ -0,0 +1 @@ +"""Classes for defining causal problems.""" diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/causal_estimand.py new file mode 100644 index 0000000..7166ec0 --- /dev/null +++ b/src/causalprog/causal_problem/causal_estimand.py @@ -0,0 +1,92 @@ +"""Classes for defining causal estimands and constraints of causal problems.""" + +from collections.abc import Callable +from typing import Any, Concatenate, TypeAlias + +import numpy.typing as npt + +Model: TypeAlias = Callable[..., Any] +EffectHandler: TypeAlias = Callable[Concatenate[Model, ...], Model] +ModelMask: TypeAlias = tuple[EffectHandler, dict] + + +class _CPComponent: + """ + Base class for components of a Causal Problem. + + A _CPComponent has an attached method that it can apply to samples + (`do_with_samples`), which will be passed sample values of the RVs + during solution of a Causal Problem and used to evaluate the causal + estimand or constraint the instance represents. + + It also has a sequence of effect handlers that need to be applied + to the sampling model before samples can be drawn to evaluate this + component. For example, if a component requires conditioning on the + value of a RV, the `condition` handler needs to be applied to the + underlying model, before generating samples to pass to the + `do_with_sample` method. `effect_handlers` will be applied to the model + in the order they are given. + """ + + do_with_samples: Callable[..., npt.ArrayLike] + effect_handlers: tuple[ModelMask, ...] + + @property + def requires_model_adaption(self) -> bool: + """Return True if effect handlers need to be applied to model.""" + return len(self.effect_handlers) > 0 + + def __init__( + self, + *effect_handlers: ModelMask, + do_with_samples: Callable[..., npt.ArrayLike], + ) -> None: + self.effect_handlers = tuple(effect_handlers) + self.do_with_samples = do_with_samples + + def apply_effects(self, model: Model) -> Model: + """Apply any necessary effect handlers prior to evaluating.""" + adapted_model = model + for handler, handler_options in self.effect_handlers: + adapted_model = handler(adapted_model, **handler_options) + return adapted_model + + +class CausalEstimand(_CPComponent): + """ + A Causal Estimand. + + The causal estimand is the function that we want to minimise (and maximise) + as part of a causal problem. It should be a scalar-valued function of the + random variables appearing in a graph. + """ + + +class Constraint(_CPComponent): + r""" + A Constraint that forms part of a causal problem. + + Constraints of a causal problem are derived properties of RVs for which we + have observed data. The causal estimand is minimised (or maximised) subject + to the predicted values of the constraints being close to their observed + values in the data. + + Adding a constraint $g(\theta)$ to a causal problem (where $\theta$ are the + parameters of the causal problem) essentially imposes an additional + constraint on the minimisation problem; + + $$ g(\theta) - g_{\text{data}} \leq \epsilon, $$ + + where $g_{\text{data}}$ is the observed data value for the quantity $g$, + and $\epsilon$ is some tolerance. + """ + + # TODO: (https://github.com/UCL/causalprog/issues/89) + # Should explain that Constraint needs more inputs and slightly different + # interpretation of the `do_with_samples` object. + # Inputs: + # - include epsilon as an input (allows constraints to have different tolerances) + # - `do_with_samples` should just be $g(\theta)$. Then have the instance build the + # full constraint that will need to be called in the Lagrangian. + # - $g$ still needs to be scalar valued? Allow a wrapper function to be applied in + # the event $g$ is vector-valued. diff --git a/src/causalprog/causal_problem/causal_problem.py b/src/causalprog/causal_problem/causal_problem.py new file mode 100644 index 0000000..1feb441 --- /dev/null +++ b/src/causalprog/causal_problem/causal_problem.py @@ -0,0 +1,119 @@ +"""Classes for representing causal problems.""" + +from collections.abc import Callable + +import jax +import numpy.typing as npt +from numpyro.infer import Predictive + +from causalprog.causal_problem.causal_estimand import CausalEstimand, Constraint +from causalprog.graph import Graph + + +# TODO: https://github.com/UCL/causalprog/issues/88 +def sample_model( + model: Predictive, rng_key: jax.Array, parameter_values: dict[str, npt.ArrayLike] +) -> dict[str, npt.ArrayLike]: + """ + Draw samples from the predictive model. + + Args: + model: Predictive model to draw samples from. + rng_key: PRNG Key to use in pseudorandom number generation. + parameter_values: Model parameter values to substitute. + + Returns: + `dict` of samples, with RV labels as keys and sample values (`jax.Array`s) as + values. + + """ + return jax.vmap(lambda pv, key: model(key, **pv), in_axes=(None, 0))( + parameter_values, + jax.random.split(rng_key, model.num_samples), + ) + + +class CausalProblem: + """Defines a causal problem.""" + + _underlying_graph: Graph + causal_estimand: CausalEstimand + constraints: list[Constraint] + + def __init__( + self, + graph: Graph, + *constraints: Constraint, + causal_estimand: CausalEstimand, + ) -> None: + """Create a new causal problem.""" + self._underlying_graph = graph + self.causal_estimand = causal_estimand + self.constraints = list(constraints) + + def lagrangian( + self, n_samples: int = 1000, *, maximum_problem: bool = False + ) -> Callable[[dict[str, npt.ArrayLike], npt.ArrayLike, jax.Array], npt.ArrayLike]: + """ + Return a function that evaluates the Lagrangian of this `CausalProblem`. + + Following the + [KKT theorem](https://en.wikipedia.org/wiki/Karush%E2%80%93Kuhn%E2%80%93Tucker_conditions), + given the causal estimand and the constraints we can assemble a Lagrangian and + seek its stationary points, to in turn identify minimisers of the constrained + optimisation problem that we started with. + + The Lagrangian returned is a mathematical function of its first two arguments. + The first argument is the same dictionary of parameters that is passed to models + like `Graph.model`, and is the values the parameters (represented by the + `ParameterNode`s) are taking. The second argument is a 1D vector of Lagrange + multipliers, whose length is equal to the number of constraints. + + The remaining argument of the Lagrangian is the PRNG Key that should be used + when drawing samples. + + Note that our current implementation assumes there are no equality constraints + being imposed (in which case, we would need a 3-argument Lagrangian function). + + Args: + n_samples: The number of random samples to be drawn when estimating the + value of functions of the RVs. + maximum_problem: If passed as `True`, assemble the Lagrangian for the + maximisation problem. Otherwise assemble that for the minimisation + problem (default behaviour). + + Returns: + The Lagrangian, as a function of the model parameters, Lagrange multipliers, + and PRNG key. + + """ + maximisation_prefactor = -1.0 if maximum_problem else 1.0 + + def _inner( + parameter_values: dict[str, npt.ArrayLike], + l_mult: jax.Array, + rng_key: jax.Array, + ) -> npt.ArrayLike: + # In general, we will need to check which of our CE/CONs require masking, + # and do multiple predictive models to account for this... + # We can always pre-build the predictive models too, so we should replace + # the "model" input with something that can map the right predictive models + # to the CE/CONS that need them. + # TODO: https://github.com/UCL/causalprog/issues/90 + predictive_model = Predictive( + model=self._underlying_graph.model, num_samples=n_samples + ) + all_samples = sample_model(predictive_model, rng_key, parameter_values) + + # TODO: https://github.com/UCL/causalprog/issues/86 + value = maximisation_prefactor * self.causal_estimand.do_with_samples( + **all_samples + ) + # TODO: https://github.com/UCL/causalprog/issues/87 + value += sum( + l_mult[i] * c.do_with_samples(**all_samples) + for i, c in enumerate(self.constraints) + ) + return value + + return _inner diff --git a/tests/test_causal_problem/test_callables.py b/tests/test_causal_problem/test_callables.py deleted file mode 100644 index f70d2f2..0000000 --- a/tests/test_causal_problem/test_callables.py +++ /dev/null @@ -1,212 +0,0 @@ -from collections.abc import Callable -from typing import Literal - -import jax -import jax.numpy as jnp -import pytest - -from causalprog.algorithms import expectation, standard_deviation -from causalprog.causal_problem import CausalProblem -from causalprog.graph import Graph, Node - - -@pytest.fixture -def n_samples_for_estimands() -> int: - return 1000 - - -@pytest.fixture -def expectation_fixture( - rng_key: jax.Array, - n_samples_for_estimands: int, -) -> Callable[[Graph, Node, dict[str, float] | None], float]: - def _inner(g: Graph, x: Node, parameter_values=None) -> float: - return expectation( - g, - x.label, - samples=n_samples_for_estimands, - rng_key=rng_key, - parameter_values=parameter_values, - ) - - return _inner - - -@pytest.fixture -def std_fixture( - rng_key: jax.Array, - n_samples_for_estimands: int, -) -> Callable[[Graph, Node, dict[str, float] | None], float]: - def _inner(g: Graph, x: Node, parameter_values=None) -> float: - return ( - standard_deviation( - g, - x.label, - samples=n_samples_for_estimands, - rng_key=rng_key, - parameter_values=parameter_values, - ) - ** 2 - ) - - return _inner - - -@pytest.fixture -def vector_fixture( - rng_key: jax.Array, - n_samples_for_estimands: int, -) -> Callable[[Graph, Node, Node, dict[str, float] | None], jax.Array]: - """vector_fixture(g, x1, x2) = [mean of x1, std of x2].""" - - def _inner(g: Graph, x1: Node, x2: Node, parameter_values=None) -> jax.Array: - return jnp.array( - [ - expectation( - g, - x1.label, - samples=n_samples_for_estimands, - rng_key=rng_key, - parameter_values=parameter_values, - ), - standard_deviation( - g, - x2.label, - samples=n_samples_for_estimands, - rng_key=rng_key, - parameter_values=parameter_values, - ) - ** 2, - ] - ) - - return _inner - - -@pytest.fixture(params=["causal_estimand", "constraints"]) -def which(request: pytest.FixtureRequest) -> Literal["causal_estimand", "constraints"]: - """For tests applicable to both the causal_estimand and constraints methods.""" - return request.param - - -@pytest.mark.parametrize( - ("initial_param_values", "args_to_setter", "expected", "atol"), - [ - pytest.param( - {"mean": 1.0, "cov2": 1.0}, - { - "fn": "expectation_fixture", - "rvs_to_nodes": {"x": "mean"}, - "graph_argument": "g", - }, - 1.0, - 1.0e-12, - id="mean", - ), - pytest.param( - {"mean": 1.0, "cov2": 1.0}, - { - "fn": "expectation_fixture", - "rvs_to_nodes": {"x": "cov2"}, - "graph_argument": "g", - }, - 1.0, - 1.0e-12, - id="cov2", - ), - pytest.param( - {"mean": 0.0, "cov2": 1.0}, - { - "fn": "expectation_fixture", - "rvs_to_nodes": {"x": "X"}, - "graph_argument": "g", - }, - 0.0, - 3.0e-2, - id="E[x], infer association", - ), - pytest.param( - {"mean": 0.0, "cov2": 1.0}, - { - "fn": "std_fixture", - "rvs_to_nodes": {"x": "X"}, - "graph_argument": "g", - }, - # UX has fixed std 1, and cov2 will be set to 1. - 1.0**2 + 1.0**2, - 3.0e-1, - id="Var[y]", - ), - pytest.param( - {"mean": 0.0, "cov2": 1.0}, - { - "fn": "vector_fixture", - "rvs_to_nodes": {"x1": "UX", "x2": "X"}, - "graph_argument": "g", - }, - # As per the previous test cases - jnp.array([0.0, 1.0**2 + 1.0**2]), - jnp.array([3.0e-2, 2.0e-1]), - id="E[x], Var[y]", - ), - ], -) -def test_callables( - two_normal_graph: Callable[..., Graph], - which: Literal["causal_estimand", "constraints"], - initial_param_values: dict[str, float], - args_to_setter: dict[str, Callable[..., float] | dict[str, str] | str], - expected: float | jax.Array, - atol: float, - request: pytest.FixtureRequest, - raises_context, -) -> None: - """ - Test the set_{causal_estimand, constraints} and .{casual_estimand, constraints} - evaluation method. - - Test works by: - - Set the parameter values using the initial_param_values. - - Call the setter method using the given arguments. - - Evaluate the method that should have been set at the current parameter_vector, - which should evaluate the corresponding function at the current values of the - parameter vector, which will be the initial values just set. - - Check the result (lies within a given tolerance). - - In theory, there is no difference between the causal estimand and constraints when - it comes to this test - the constraints may be vector-valued but there is nothing - preventing the ``causal_estimand`` (programmatically) from being vector-valued - either. - """ - # Parametrised fixtures edit-in-place objects - args_to_setter = dict(args_to_setter) - if isinstance(args_to_setter["fn"], str): - args_to_setter["fn"] = request.getfixturevalue(args_to_setter["fn"]) - - if which == "constraints": - args_to_setter["constraints"] = args_to_setter.pop("fn") - else: - args_to_setter["sigma"] = args_to_setter.pop("fn") - - expected = jnp.array(expected, ndmin=1) - - # Test properly begins. - graph = two_normal_graph(cov=1.0) - cp = CausalProblem(graph) - - method = getattr(cp, which) - setter_method = getattr(cp, f"set_{which}") - - # Before setting the causal estimand, it should throw an error if called. - with raises_context( - NotImplementedError( - f"{which.replace('_', ' ').capitalize()} not set for CausalProblem." - ) - ): - method(cp.parameter_vector) - - cp.set_parameter_values(**initial_param_values) - setter_method(**args_to_setter) - result = jnp.array(method(cp.parameter_vector), ndmin=1) - - assert jnp.allclose(result, expected, atol=atol) diff --git a/tests/test_causal_problem/test_graph_and_param.py b/tests/test_causal_problem/test_graph_and_param.py deleted file mode 100644 index 96ea924..0000000 --- a/tests/test_causal_problem/test_graph_and_param.py +++ /dev/null @@ -1,42 +0,0 @@ -from collections.abc import Callable - -import jax.numpy as jnp - -from causalprog.causal_problem import CausalProblem -from causalprog.graph import Graph - - -def test_graph_and_parameter_interactions( - two_normal_graph: Callable[..., Graph], - raises_context, -) -> None: - cp = CausalProblem(label="TestCP") - - # Without a graph, we can't do anything - with raises_context(ValueError("No graph set for TestCP")): - cp.graph # noqa: B018 - with raises_context(ValueError("No graph set for TestCP")): - cp.parameter_values # noqa: B018 - - # Cannot set graph to non-graph value - with raises_context(TypeError("TestCP.graph must be a Graph instance")): - cp.graph = 1.0 - - # Provide an actual graph value - cp.graph = two_normal_graph(cov=1.0) - - # We should now be able to fetch parameter values, but they are all unset. - assert jnp.all(jnp.isnan(cp.parameter_vector)) - assert cp.parameter_vector.shape == (len(cp.graph.parameter_nodes),) - assert all(jnp.isnan(value) for value in cp.parameter_values.values()) - assert set(cp.parameter_values.keys()) == {"mean", "cov2"} - - # Users should only ever need to set parameter values via their names. - cp.set_parameter_values(mean=1.0, cov2=2.0) - assert cp.parameter_values == {"mean": 1.0, "cov2": 2.0} - # We don't know which way round the internal parameter vector is being stored, - # but that doesn't matter. We do know that it should contain the values 1 & 2 - # in some order though. - assert jnp.allclose(cp.parameter_vector, jnp.array([1.0, 2.0])) or jnp.allclose( - cp.parameter_vector, jnp.array([2.0, 1.0]) - ) diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index 1957a61..bbae1b8 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -1,12 +1,13 @@ +import sys from collections.abc import Callable import jax import jax.numpy as jnp -import numpy.typing as npt import optax import pytest -from numpyro.infer import Predictive +from causalprog.causal_problem.causal_estimand import CausalEstimand, Constraint +from causalprog.causal_problem.causal_problem import CausalProblem from causalprog.graph import Graph @@ -71,56 +72,27 @@ def test_two_normal_example( (again with $+\mu_{ux}$ in the maximisation case). In both cases, $\mathcal{L}$ is minimised at $(\mu_{ux}^{*}, \nu_x, 1)$. """ - # Setup the optimisation problem from the graph - g = two_normal_graph(cov=1.0) - predictive_model = Predictive(g.model, num_samples=n_samples) - - def lagrangian( - parameter_values: dict[str, npt.ArrayLike], - predictive_model: Predictive, - rng_key: jax.Array, - *, - ce_prefactor: float, - ): - subkeys = jax.random.split(rng_key, predictive_model.num_samples) - l_mult = parameter_values["_l_mult"] - - def _ux_sampler(pv: dict[str, npt.ArrayLike], key: jax.Array) -> float: - return predictive_model(key, **pv)["UX"] - - def _x_sampler(pv: dict[str, npt.ArrayLike], key: jax.Array) -> float: - return predictive_model(key, **pv)["X"] - - def _ce(pv, subkeys): - return ( - ce_prefactor - * jax.vmap(_x_sampler, in_axes=(None, 0))(pv, subkeys).mean() - ) - - def _constraint(pv, subkeys): - return ( - jnp.abs( - jax.vmap(_ux_sampler, in_axes=(None, 0))(pv, subkeys).mean() - - phi_observed - ) - - epsilon - ) - - return _ce(parameter_values, subkeys) + l_mult * _constraint( - parameter_values, subkeys - ) - - # In both cases, the Lagrange multiplier has the value 1.0 at the minimum. - lambda_sol = 1.0 ce_prefactor = 1.0 if not is_solving_max else -1.0 mu_x_sol = phi_observed - ce_prefactor * epsilon + # Setup the optimisation problem from the graph + ce = CausalEstimand(do_with_samples=lambda **pv: pv["X"].mean()) + con = Constraint( + do_with_samples=lambda **pv: jnp.abs(pv["UX"].mean() - phi_observed) - epsilon + ) + cp = CausalProblem( + two_normal_graph(cov=1.0), + con, + causal_estimand=ce, + ) + lagrangian = cp.lagrangian(n_samples=n_samples, maximum_problem=is_solving_max) + # We'll be seeking stationary points of the Lagrangian, using the # naive approach of minimising the norm of its gradient. We will need to # ensure we "converge" to a minimum value suitably close to 0. - def objective(params, predictive, key, ce_prefactor=ce_prefactor): - v = jax.grad(lagrangian)(params, predictive, key, ce_prefactor=ce_prefactor) - return sum(value**2 for value in v.values()) + def objective(params, l_mult, key): + v = jax.grad(lagrangian, argnums=(0, 1))(params, l_mult, key) + return sum(value**2 for value in v[0].values()) + (v[1] ** 2).sum() # Choose a starting guess that is at the optimal solution, in the hopes that # SGD converges quickly. We almost certainly will not have this luxury in general. @@ -130,26 +102,29 @@ def objective(params, predictive, key, ce_prefactor=ce_prefactor): params = { "mean": mu_x_sol, "cov2": nu_x_starting_value, - "_l_mult": lambda_sol, } + l_mult = jnp.atleast_1d(lagrange_mult_sol) + # Setup SGD optimiser optimiser = optax.adam(adams_learning_rate) - opt_state = optimiser.init(params) + opt_state = optimiser.init((params, l_mult)) + # Run optimisation loop on gradient of the Lagrangian converged = False for _ in range(maxiter): # Actual iteration loop - grads = jax.jacobian(objective)( - params, predictive_model, rng_key, ce_prefactor=ce_prefactor - ) + grads = jax.jacobian(objective, argnums=(0, 1))(params, l_mult, rng_key) updates, opt_state = optimiser.update(grads, opt_state) - params = optax.apply_updates(params, updates) + params, l_mult = optax.apply_updates((params, l_mult), updates) # Convergence "check" and progress update - objective_value = objective( - params, predictive_model, rng_key, ce_prefactor=ce_prefactor + objective_value = objective(params, l_mult, rng_key) + sys.stdout.write( + f"{_}, F_val={objective_value:.4e}, " + f"mu_ux={params['mean']:.4e}, " + f"nu_x={params['cov2']:.4e}, " + f"lambda={l_mult[0]:.4e}\n" ) - if jnp.abs(objective_value) <= minimisation_tolerance: converged = True break @@ -162,12 +137,12 @@ def objective(params, predictive, key, ce_prefactor=ce_prefactor): ) # Confirm that we found a minimiser that does satisfy the inequality constraints. - assert params["_l_mult"] > 0.0, ( - f"Converged, but not to a minimiser (lagrange multiplier = {params['_l_mult']})" + assert jnp.all(l_mult > 0.0), ( + f"Converged, but not to a minimiser (lagrange multiplier = {l_mult})" ) # Give a generous error margin in mu_ux and the Lagrange multiplier, # given SGD is being used on MC-integral functions. rtol = jnp.sqrt(1.0 / n_samples) assert jnp.isclose(params["mean"], mu_x_sol, rtol=rtol) - assert jnp.isclose(params["_l_mult"], lagrange_mult_sol, atol=rtol) + assert jnp.allclose(l_mult, lagrange_mult_sol, atol=rtol)