diff --git a/src/causalprog/algorithms/__init__.py b/src/causalprog/algorithms/__init__.py index 879d9aa..da37e9a 100644 --- a/src/causalprog/algorithms/__init__.py +++ b/src/causalprog/algorithms/__init__.py @@ -1,4 +1,4 @@ """Algorithms.""" from .do import do -from .expectation import expectation, standard_deviation +from .moments import expectation, moment, standard_deviation diff --git a/src/causalprog/algorithms/expectation.py b/src/causalprog/algorithms/expectation.py deleted file mode 100644 index f1e4ce4..0000000 --- a/src/causalprog/algorithms/expectation.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Algorithms for estimating the expectation and standard deviation.""" - -import jax -import numpy.typing as npt - -from causalprog.graph import Graph - - -def sample( - graph: Graph, - outcome_node_label: str, - samples: int, - *, - rng_key: jax.Array, -) -> npt.NDArray[float]: - """Sample data from a graph.""" - nodes = graph.roots_down_to_outcome(outcome_node_label) - - values: dict[str, npt.NDArray[float]] = {} - keys = jax.random.split(rng_key, len(nodes)) - - for node, key in zip(nodes, keys, strict=False): - values[node.label] = node.sample(values, samples, rng_key=key) - return values[outcome_node_label] - - -def expectation( - graph: Graph, - outcome_node_label: str, - samples: int, - *, - rng_key: jax.Array, -) -> float: - """Estimate the expectation of a graph.""" - return sample(graph, outcome_node_label, samples, rng_key=rng_key).mean() - - -def standard_deviation( - graph: Graph, - outcome_node_label: str, - samples: int, - *, - rng_key: jax.Array, -) -> float: - """Estimate the standard deviation of a graph.""" - return sample(graph, outcome_node_label, samples, rng_key=rng_key).std() diff --git a/src/causalprog/algorithms/moments.py b/src/causalprog/algorithms/moments.py new file mode 100644 index 0000000..c196d01 --- /dev/null +++ b/src/causalprog/algorithms/moments.py @@ -0,0 +1,72 @@ +"""Algorithms for estimating the expectation and standard deviation.""" + +import jax +import numpy.typing as npt + +from causalprog.graph import Graph + + +def sample( + graph: Graph, + outcome_node_label: str, + samples: int, + *, + rng_key: jax.Array, +) -> npt.NDArray[float]: + """Sample data from (a random variable attached to) a node in a graph.""" + nodes = graph.roots_down_to_outcome(outcome_node_label) + + values: dict[str, npt.NDArray[float]] = {} + keys = jax.random.split(rng_key, len(nodes)) + + for node, key in zip(nodes, keys, strict=False): + values[node.label] = node.sample(values, samples, rng_key=key) + return values[outcome_node_label] + + +def expectation( + graph: Graph, + outcome_node_label: str, + samples: int, + *, + rng_key: jax.Array, +) -> float: + """Estimate the expectation of (a random variable attached to) a node in a graph.""" + return moment(1, graph, outcome_node_label, samples, rng_key=rng_key) + + +def standard_deviation( + graph: Graph, + outcome_node_label: str, + samples: int, + *, + rng_key: jax.Array, + rng_key_first_moment: jax.Array | None = None, +) -> float: + """Estimate the standard deviation of (a RV attached to) a node in a graph.""" + return ( + moment(2, graph, outcome_node_label, samples, rng_key=rng_key) + - moment( + 1, + graph, + outcome_node_label, + samples, + rng_key=rng_key if rng_key_first_moment is None else rng_key_first_moment, + ) + ** 2 + ) ** 0.5 + + +def moment( + order: int, + graph: Graph, + outcome_node_label: str, + samples: int, + *, + rng_key: jax.Array, +) -> float: + """Estimate a moment of (a random variable attached to) a node in a graph.""" + return ( + sum(sample(graph, outcome_node_label, samples, rng_key=rng_key) ** order) + / samples + ) diff --git a/src/causalprog/graph/node.py b/src/causalprog/graph/node.py index ac9f571..b161c70 100644 --- a/src/causalprog/graph/node.py +++ b/src/causalprog/graph/node.py @@ -169,19 +169,20 @@ def sample( samples: int, rng_key: jax.Array, ) -> npt.NDArray[float]: + d = self._dist( + # Pass in node values derived from construction so far + **{ + native_name: sampled_dependencies[node_name] + for native_name, node_name in self.parameters.items() + }, + # Pass in any constant parameters this node sets + **self.constant_parameters, + ) return numpyro.sample( self.label, - self._dist( - # Pass in node values derived from construction so far - **{ - native_name: sampled_dependencies[node_name] - for native_name, node_name in self.parameters.items() - }, - # Pass in any constant parameters this node sets - **self.constant_parameters, - ), + d, rng_key=rng_key, - sample_shape=(samples,), + sample_shape=(samples,) if d.batch_shape == () and samples > 1 else (), ) @override diff --git a/tests/fixtures/graph.py b/tests/fixtures/graph.py index 7975b3f..62143d6 100644 --- a/tests/fixtures/graph.py +++ b/tests/fixtures/graph.py @@ -17,31 +17,36 @@ @pytest.fixture -def normal_graph() -> Callable[[NormalGraphNodes | None], Graph]: - """Creates a 3-node graph: +def normal_graph() -> Callable[[float, float], Graph]: + """Creates a graph with one normal distribution X. - mean (P) cov (P) - |---> outcome <----| - - where outcome is a normal distribution. - - Parameter nodes are initialised with no `value` set. + Parameter nodes are included if no values are given for the mean and covariance. """ - def _inner(normal_graph_nodes: NormalGraphNodes | None = None) -> Graph: - if normal_graph_nodes is None: - normal_graph_nodes = { - "mean": ParameterNode(label="mean"), - "cov": ParameterNode(label="cov"), - "outcome": DistributionNode( - Normal, - label="outcome", - parameters={"loc": "mean", "scale": "std"}, - ), - } - graph = Graph(label="normal dist") - graph.add_edge(normal_graph_nodes["mean"], normal_graph_nodes["outcome"]) - graph.add_edge(normal_graph_nodes["cov"], normal_graph_nodes["outcome"]) + def _inner(mean: float | None = None, cov: float | None = None): + graph = Graph(label="normal_graph") + parameters = {} + constant_parameters = {} + if mean is None: + graph.add_node(ParameterNode(label="mean")) + parameters["loc"] = "mean" + else: + constant_parameters["loc"] = mean + if cov is None: + graph.add_node(ParameterNode(label="cov")) + parameters["scale"] = "cov" + else: + constant_parameters["scale"] = cov + graph.add_node( + DistributionNode( + Normal, + label="X", + parameters=parameters, + constant_parameters=constant_parameters, + ) + ) + for node in parameters.values(): + graph.add_edge(node, "X") return graph return _inner @@ -56,65 +61,56 @@ def two_normal_graph() -> Callable[[float, float, float], Graph]: where UX is a normal distribution with mean `mean` and covariance `cov`, and X is a normal distrubution with mean UX and covariance `cov2`. - """ - - def _inner(mean: float = 5.0, cov: float = 1.0, cov2: float = 1.0) -> Graph: - graph = Graph(label="G0") - graph.add_node( - DistributionNode( - Normal, - label="UX", - constant_parameters={"loc": mean, "scale": cov**2}, - ) - ) - graph.add_node( - DistributionNode( - Normal, - label="X", - parameters={"loc": "UX"}, - constant_parameters={"scale": cov2**2}, - ) - ) - graph.add_edge("UX", "X") - - return graph - - return _inner - - -@pytest.fixture -def two_normal_graph_parametrized_mean() -> Callable[[float], Graph]: - """Creates a graph: - SDUX SDX - | | - V v - mu_x --> UX --> X - - where UX is a normal distribution with mean mu_x and covariance `co`, and X is - a normal distrubution with mean UX and covariance nu_y. + Parameter nodes are included if no values are given for the mean and covariances. """ - def _inner(cov: float = 1.0) -> Graph: - graph = Graph(label="G0") - graph.add_node(ParameterNode(label="nu_y")) - graph.add_node(ParameterNode(label="mu_x")) + def _inner( + mean: float | None = None, cov: float | None = None, cov2: float | None = None + ) -> Graph: + graph = Graph(label="two_normal_graph") + + x_parameters = {"loc": "UX"} + x_constant_parameters = {} + ux_parameters = {} + ux_constant_parameters = {} + if mean is None: + graph.add_node(ParameterNode(label="mean")) + ux_parameters["loc"] = "mean" + else: + ux_constant_parameters["loc"] = mean + if cov is None: + graph.add_node(ParameterNode(label="cov")) + ux_parameters["scale"] = "cov" + else: + ux_constant_parameters["scale"] = cov + if cov2 is None: + graph.add_node(ParameterNode(label="cov2")) + x_parameters["scale"] = "cov2" + else: + x_constant_parameters["scale"] = cov2 + graph.add_node( DistributionNode( Normal, label="UX", - parameters={"loc": "mu_x"}, - constant_parameters={"scale": cov}, + parameters=ux_parameters, + constant_parameters=ux_constant_parameters, ) ) graph.add_node( DistributionNode( Normal, label="X", - parameters={"loc": "UX", "scale": "nu_y"}, + parameters=x_parameters, + constant_parameters=x_constant_parameters, ) ) graph.add_edge("UX", "X") + for node in ux_parameters.values(): + graph.add_edge(node, "UX") + for node in x_parameters.values(): + graph.add_edge(node, "X") return graph @@ -125,9 +121,9 @@ def _inner(cov: float = 1.0) -> Graph: def two_normal_graph_expected_model() -> Callable[..., dict[str, npt.ArrayLike]]: """Creates the model that the two_normal_graph should produce.""" - def _inner(mu_x: float, nu_y: float) -> dict[str, npt.ArrayLike]: - ux = numpyro.sample("UX", Normal(loc=mu_x, scale=1.0)) - x = numpyro.sample("X", Normal(loc=ux, scale=nu_y)) + def _inner(mean: float, cov2: float) -> dict[str, npt.ArrayLike]: + ux = numpyro.sample("UX", Normal(loc=mean, scale=1.0)) + x = numpyro.sample("X", Normal(loc=ux, scale=cov2)) return {"X": x, "UX": ux} diff --git a/tests/test_causal_problem/test_callables.py b/tests/test_causal_problem/test_callables.py index 7aecf29..9e686c8 100644 --- a/tests/test_causal_problem/test_callables.py +++ b/tests/test_causal_problem/test_callables.py @@ -92,9 +92,7 @@ def which(request: pytest.FixtureRequest) -> Literal["causal_estimand", "constra "graph_argument": "g", }, 0.0, - # Empirical calculation with 1000 samples with fixture RNG key - # should give 1.8808e-2 as the empirical expectation. - 2.0e-2, + 3.0e-2, id="E[x], infer association", ), pytest.param( @@ -106,9 +104,7 @@ def which(request: pytest.FixtureRequest) -> Literal["causal_estimand", "constra }, # x has fixed std 1, and nu_y will be set to 1. 1.0**2 + 1.0**2, - # Empirical calculation with 1000 samples with fixture RNG key - # should give 1.8506 as the empirical std of y. - 2.0e-1, + 3.0e-1, id="Var[y]", ), pytest.param( @@ -120,9 +116,7 @@ def which(request: pytest.FixtureRequest) -> Literal["causal_estimand", "constra }, # As per the previous test cases jnp.array([0.0, 1.0**2 + 1.0**2]), - # As per the above cases, both components should be within - # 2.0e-1 of the analytical value. - jnp.array([2.0e-2, 2.0e-1]), + jnp.array([3.0e-2, 2.0e-1]), id="E[x], Var[y]", ), ], diff --git a/tests/test_graph/test_algorithms.py b/tests/test_graph/test_algorithms.py index 4daed23..fe5f46b 100644 --- a/tests/test_graph/test_algorithms.py +++ b/tests/test_graph/test_algorithms.py @@ -1,13 +1,15 @@ """Tests for graph algorithms.""" -import jax import numpy as np import pytest from numpyro.distributions import Normal import causalprog +from causalprog import algorithms from causalprog.graph import DistributionNode, Graph +max_samples = 10**5 + def test_roots_down_to_outcome() -> None: graph = Graph(label="G0") @@ -43,7 +45,7 @@ def test_roots_down_to_outcome() -> None: def test_do(rng_key, two_normal_graph): - graph = two_normal_graph() + graph = two_normal_graph(5.0, 1.2, 0.8) graph2 = causalprog.algorithms.do(graph, "UX", 4.0) assert "loc" in graph.get_node("X").parameters @@ -60,7 +62,7 @@ def test_do(rng_key, two_normal_graph): ) assert np.isclose( - causalprog.algorithms.expectation( + algorithms.expectation( graph2, outcome_node_label="X", samples=1000, rng_key=rng_key ), 4.0, @@ -77,54 +79,29 @@ def test_do(rng_key, two_normal_graph): pytest.param(1.0, 1.2, 10000000, 1e-3, id="N(mean=1, stdev=1.2), 10^7 samples"), ], ) -def test_mean_stdev_single_normal_node(samples, rtol, mean, stdev, rng_key): - node = DistributionNode( - Normal, - label="X", - constant_parameters={"loc": mean, "scale": stdev}, - ) +def test_expectation_stdev_single_normal_node( + normal_graph, samples, rtol, mean, stdev, rng_key +): + if samples > max_samples: + pytest.xfail("Test currently too slow") - graph = Graph(label="G0") - graph.add_node(node) - - # To compensate for rng-key splitting in sample methods, note the "split" key - # that is actually used to draw the samples from the distribution, so we can - # attempt to replicate its behaviour explicitly. - key = jax.random.split(rng_key, 1)[0] - what_we_should_get = jax.random.multivariate_normal( - key, jax.numpy.atleast_1d(mean), jax.numpy.atleast_2d(stdev**2), shape=samples - ) - expected_mean = what_we_should_get.mean() - expected_std_dev = what_we_should_get.std() + graph = normal_graph(mean, stdev) # Check within hand-computation assert np.isclose( - causalprog.algorithms.expectation( + algorithms.expectation( graph, outcome_node_label="X", samples=samples, rng_key=rng_key ), mean, rtol=rtol, ) assert np.isclose( - causalprog.algorithms.standard_deviation( + algorithms.standard_deviation( graph, outcome_node_label="X", samples=samples, rng_key=rng_key ), stdev, rtol=rtol, ) - # Check within computational distance - assert np.isclose( - causalprog.algorithms.expectation( - graph, outcome_node_label="X", samples=samples, rng_key=rng_key - ), - expected_mean, - ) - assert np.isclose( - causalprog.algorithms.standard_deviation( - graph, outcome_node_label="X", samples=samples, rng_key=rng_key - ), - expected_std_dev, - ) @pytest.mark.parametrize( @@ -159,8 +136,8 @@ def test_mean_stdev_single_normal_node(samples, rtol, mean, stdev, rng_key): def test_mean_stdev_two_node_graph( two_normal_graph, samples, rtol, mean, stdev, stdev2, rng_key ): - if samples > 100000: # noqa: PLR2004 - pytest.xfail("Test currently runs out of memory") + if samples > max_samples: + pytest.xfail("Test currently too slow") graph = two_normal_graph(mean=mean, cov=stdev, cov2=stdev2) @@ -172,9 +149,68 @@ def test_mean_stdev_two_node_graph( rtol=rtol, ) assert np.isclose( - causalprog.algorithms.standard_deviation( + algorithms.standard_deviation( graph, outcome_node_label="X", samples=samples, rng_key=rng_key ), np.sqrt(stdev**2 + stdev2**2), rtol=rtol, ) + + +@pytest.mark.parametrize( + ("samples", "rtol"), + [ + pytest.param(100, 1, id="100 samples"), + pytest.param(10000, 1e-1, id="10^4 samples"), + pytest.param(1000000, 1e-2, id="10^6 samples"), + ], +) +def test_expectation(two_normal_graph, rng_key, samples, rtol): + if samples > max_samples: + pytest.xfail("Test currently too slow") + graph = two_normal_graph(1.0, 1.2, 0.8) + + assert np.isclose( + algorithms.expectation( + graph, outcome_node_label="X", samples=samples, rng_key=rng_key + ), + algorithms.moments.sample( + graph, outcome_node_label="X", samples=samples, rng_key=rng_key + ).mean(), + rtol=rtol, + ) + + +@pytest.mark.parametrize( + ("samples", "rtol"), + [ + pytest.param(100, 1, id="100 samples"), + pytest.param(10000, 1e-1, id="10^4 samples"), + pytest.param(1000000, 1e-2, id="10^6 samples"), + ], +) +def test_stdev(two_normal_graph, rng_key, samples, rtol): + if samples > max_samples: + pytest.xfail("Test currently too slow") + graph = two_normal_graph(1.0, 1.2, 0.8) + + assert np.isclose( + algorithms.standard_deviation( + graph, outcome_node_label="X", samples=samples, rng_key=rng_key + ), + algorithms.moments.sample( + graph, outcome_node_label="X", samples=samples, rng_key=rng_key + ).std(), + rtol=rtol, + ) + + +@pytest.mark.parametrize("samples", [1, 2, 10, 100]) +def test_sample_shape(two_normal_graph, rng_key, samples): + graph = two_normal_graph(1.0, 1.2, 0.8) + + s1 = algorithms.moments.sample(graph, "X", samples, rng_key=rng_key) + assert s1.shape == () if samples == 1 else (samples,) + + s2 = algorithms.moments.sample(graph, "UX", samples, rng_key=rng_key) + assert s2.shape == () if samples == 1 else (samples,) diff --git a/tests/test_graph/test_model.py b/tests/test_graph/test_model.py index cddeb72..26e0b04 100644 --- a/tests/test_graph/test_model.py +++ b/tests/test_graph/test_model.py @@ -8,13 +8,13 @@ @pytest.mark.parametrize( "param_values", [ - pytest.param({"mu_x": 0.0, "nu_y": 1.0}, id="mean(X) = 0, cov(UX) = 1"), - pytest.param({"mu_x": 1.0, "nu_y": 2.0}, id="mean(X) = 1, cov(UX) = 2"), + pytest.param({"mean": 0.0, "cov2": 1.0}, id="mean(X) = 0, cov(UX) = 1"), + pytest.param({"mean": 1.0, "cov2": 2.0}, id="mean(X) = 1, cov(UX) = 2"), ], ) def test_model( param_values: dict[str, npt.ArrayLike], - two_normal_graph_parametrized_mean, + two_normal_graph, two_normal_graph_expected_model, assert_samples_are_identical, run_default_nuts_mcmc, @@ -31,7 +31,7 @@ def test_model( MCMC sampling output of `Graph.model` with the explicit model that it should correspond to. """ - graph = two_normal_graph_parametrized_mean() + graph = two_normal_graph(cov=1.0) assert callable(graph.model) via_model = run_default_nuts_mcmc( @@ -47,19 +47,19 @@ def test_model( def test_model_missing_parameter( - two_normal_graph_parametrized_mean, + two_normal_graph, raises_context, seed: int, ) -> None: """`Graph.model` will raise a `KeyError` when a value is not passed for a `ParameterNode`. """ - graph = two_normal_graph_parametrized_mean() + graph = two_normal_graph(cov=1.0) - # Deliberately leave out the "nu_y" variable. - parameter_values = {"mu_x": 0.0} + # Deliberately leave out the "cov2" variable. + parameter_values = {"mean": 0.0} # Which should result in the error below. - expected_exception = KeyError("ParameterNode 'nu_y' not assigned") + expected_exception = KeyError("ParameterNode 'cov2' not assigned") # Not passing enough parameters should be picked up by the model. with raises_context(expected_exception), numpyro.handlers.seed(rng_seed=seed): @@ -67,33 +67,33 @@ def test_model_missing_parameter( def test_model_extension( - two_normal_graph_parametrized_mean, + two_normal_graph, assert_samples_are_identical, run_default_nuts_mcmc, ) -> None: """Test that `Graph.model` can be extended.""" - graph = two_normal_graph_parametrized_mean() + graph = two_normal_graph(cov=1.0) - parameter_values = {"mu_x": 0.0, "nu_y": 1.0} + parameter_values = {"mean": 0.0, "cov2": 1.0} # Build the graph, but without the X-node. - mu_x = ParameterNode(label="mu_x") + mean = ParameterNode(label="mean") x = DistributionNode( numpyro.distributions.Normal, label="UX", - parameters={"loc": "mu_x"}, + parameters={"loc": "mean"}, constant_parameters={"scale": 1.0}, ) one_normal_graph = Graph(label="One normal") - one_normal_graph.add_edge(mu_x, x) + one_normal_graph.add_edge(mean, x) - def extended_model(*, nu_y, **parameter_values): + def extended_model(*, cov2, **parameter_values): sites = one_normal_graph.model(**parameter_values) numpyro.sample( "X", numpyro.distributions.Normal( loc=sites["UX"], - scale=nu_y, + scale=cov2, ), ) diff --git a/tests/test_graph/test_parameters.py b/tests/test_graph/test_parameters.py index b7336a3..0e7c790 100644 --- a/tests/test_graph/test_parameters.py +++ b/tests/test_graph/test_parameters.py @@ -7,7 +7,7 @@ from causalprog.graph import DistributionNode, Graph, ParameterNode -NormalGraphNodeNames: TypeAlias = Literal["mean", "cov", "outcome"] +NormalGraphNodeNames: TypeAlias = Literal["mean", "cov", "X"] NormalGraphNodes: TypeAlias = dict[ NormalGraphNodeNames, DistributionNode | ParameterNode ] @@ -18,8 +18,8 @@ [ pytest.param( {}, - {"outcome": 4.0}, - TypeError("Node outcome is not a parameter node."), + {"X": 4.0}, + TypeError("Node X is not a parameter node."), id="Give non-parameter node", ), pytest.param( @@ -54,7 +54,7 @@ def test_set_parameters( parameter_nodes = graph.parameter_nodes assert graph.get_node("mean") in parameter_nodes assert graph.get_node("cov") in parameter_nodes - assert graph.get_node("outcome") not in parameter_nodes + assert graph.get_node("X") not in parameter_nodes # Set any pre-existing values we might want the parameter nodes to have in # this test.