From 0afe05fc8aa043a14b9692625b5990df11f0ea2d Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Fri, 8 Aug 2025 09:43:22 +0100 Subject: [PATCH 01/18] move fixtures out of test file --- tests/fixtures/graph.py | 45 +++++++++++++++++++++++++++++++++++++++++ tests/test_graph.py | 32 ----------------------------- 2 files changed, 45 insertions(+), 32 deletions(-) create mode 100644 tests/fixtures/graph.py diff --git a/tests/fixtures/graph.py b/tests/fixtures/graph.py new file mode 100644 index 0000000..fabe030 --- /dev/null +++ b/tests/fixtures/graph.py @@ -0,0 +1,45 @@ +"""Tests for graph module.""" + +from typing import Literal, TypeAlias + +import pytest + +from causalprog.distribution.normal import NormalFamily +from causalprog.graph import DistributionNode, Graph, ParameterNode + +NormalGraphNodeNames: TypeAlias = Literal["mean", "cov", "outcome"] +NormalGraphNodes: TypeAlias = dict[ + NormalGraphNodeNames, DistributionNode | ParameterNode +] + + +@pytest.fixture +def normal_graph_nodes() -> NormalGraphNodes: + """Collection of Nodes used to construct `normal_graph`. + + See `normal_graph` docstring for more details. + """ + return { + "mean": ParameterNode(label="mean"), + "cov": ParameterNode(label="cov"), + "outcome": DistributionNode( + NormalFamily(), label="outcome", parameters={"mean": "mean", "cov": "std"} + ), + } + + +@pytest.fixture +def normal_graph(normal_graph_nodes: NormalGraphNodes) -> Graph: + """Creates a 3-node graph: + + mean (P) cov (P) + |---> outcome <----| + + where outcome is a normal distribution. + + Parameter nodes are initialised with no `value` set. + """ + graph = Graph(label="normal dist") + graph.add_edge(normal_graph_nodes["mean"], normal_graph_nodes["outcome"]) + graph.add_edge(normal_graph_nodes["cov"], normal_graph_nodes["outcome"]) + return graph diff --git a/tests/test_graph.py b/tests/test_graph.py index b630405..5291339 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -17,38 +17,6 @@ ] -@pytest.fixture -def normal_graph_nodes() -> NormalGraphNodes: - """Collection of Nodes used to construct `normal_graph`. - - See `normal_graph` docstring for more details. - """ - return { - "mean": ParameterNode(label="mean"), - "cov": ParameterNode(label="cov"), - "outcome": DistributionNode( - NormalFamily(), label="outcome", parameters={"mean": "mean", "cov": "std"} - ), - } - - -@pytest.fixture -def normal_graph(normal_graph_nodes: NormalGraphNodes) -> Graph: - """Creates a 3-node graph: - - mean (P) cov (P) - |---> outcome <----| - - where outcome is a normal distribution. - - Parameter nodes are initialised with no `value` set. - """ - graph = Graph(label="normal dist") - graph.add_edge(normal_graph_nodes["mean"], normal_graph_nodes["outcome"]) - graph.add_edge(normal_graph_nodes["cov"], normal_graph_nodes["outcome"]) - return graph - - def test_label(): d = NormalFamily() node = DistributionNode(d, label="X") From 221f79cad3eb2f310c998c68d8c60d2079ffe788 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Fri, 8 Aug 2025 10:15:05 +0100 Subject: [PATCH 02/18] split test_graph.py into three files --- .../test_algorithms.py} | 271 +++++------------- tests/test_graph/test_construction.py | 78 +++++ tests/test_graph/test_parameters.py | 87 ++++++ 3 files changed, 230 insertions(+), 206 deletions(-) rename tests/{test_graph.py => test_graph/test_algorithms.py} (51%) create mode 100644 tests/test_graph/test_construction.py create mode 100644 tests/test_graph/test_parameters.py diff --git a/tests/test_graph.py b/tests/test_graph/test_algorithms.py similarity index 51% rename from tests/test_graph.py rename to tests/test_graph/test_algorithms.py index 5291339..5c98d44 100644 --- a/tests/test_graph.py +++ b/tests/test_graph/test_algorithms.py @@ -1,6 +1,5 @@ -"""Tests for graph module.""" +"""Tests for graph algorithms.""" -import re from typing import Literal, TypeAlias import jax @@ -17,106 +16,80 @@ ] -def test_label(): - d = NormalFamily() - node = DistributionNode(d, label="X") - node2 = DistributionNode(d, label="Y") - node_copy = node - - assert node.label == node_copy.label == "X" - assert node.label != node2.label - assert node2.label == "Y" - - assert isinstance(node, causalprog.graph.node.Node) - assert isinstance(node2, causalprog.graph.node.Node) - - -def test_duplicate_label(): - d = NormalFamily() - - graph = Graph(label="G0") - graph.add_node(DistributionNode(d, label="X")) - with pytest.raises(ValueError, match=re.escape("Duplicate node label: X")): - graph.add_node(DistributionNode(d, label="X")) - - -@pytest.mark.parametrize( - "use_labels", - [pytest.param(True, id="Via labels"), pytest.param(False, id="Via variables")], -) -def test_build_graph(*, use_labels: bool) -> None: - root_label = "root" - outcome_label = "outcome_label" - d = NormalFamily() - - root_node = DistributionNode(d, label=root_label) - outcome_node = DistributionNode(d, label=outcome_label) - - graph = Graph(label="G0") - graph.add_node(root_node) - graph.add_node(outcome_node) - - if use_labels: - graph.add_edge(root_label, outcome_label) - else: - graph.add_edge(root_node, outcome_node) - - assert graph.roots_down_to_outcome(outcome_label) == [root_node, outcome_node] - - def test_roots_down_to_outcome() -> None: d = NormalFamily() graph = Graph(label="G0") - u = DistributionNode(d, label="U") - v = DistributionNode(d, label="V") - w = DistributionNode(d, label="W") - x = DistributionNode(d, label="X") - y = DistributionNode(d, label="Y") - z = DistributionNode(d, label="Z") - - graph.add_node(u) - graph.add_node(v) - graph.add_node(w) - graph.add_node(x) - graph.add_node(y) - graph.add_node(z) - - graph.add_edge("V", "W") - graph.add_edge("V", "X") - graph.add_edge("V", "Y") - graph.add_edge("X", "Z") - graph.add_edge("Y", "Z") - graph.add_edge("U", "Z") - - assert graph.roots_down_to_outcome("V") == [v] - assert graph.roots_down_to_outcome("W") == [v, w] + graph.add_node(DistributionNode(d, label="U")) + graph.add_node(DistributionNode(d, label="V")) + graph.add_node(DistributionNode(d, label="W")) + graph.add_node(DistributionNode(d, label="X")) + graph.add_node(DistributionNode(d, label="Y")) + graph.add_node(DistributionNode(d, label="Z")) + + edges = [ + ["V", "W"], + ["V", "X"], + ["V", "Y"], + ["X", "Z"], + ["Y", "Z"], + ["U", "Z"], + ] + for e in edges: + graph.add_edge(*e) + + assert graph.roots_down_to_outcome("V") == [graph.get_node("V")] + assert graph.roots_down_to_outcome("W") == [ + graph.get_node("V"), + graph.get_node("W"), + ] nodes = graph.roots_down_to_outcome("Z") assert len(nodes) == 5 # noqa: PLR2004 - assert ( - nodes.index(v) - < min(nodes.index(x), nodes.index(y)) - < max(nodes.index(x), nodes.index(y)) - < nodes.index(z) - ) - assert nodes.index(u) < nodes.index(z) + for e in edges: + if "W" not in e: + assert nodes.index(graph.get_node(e[0])) < nodes.index(graph.get_node(e[1])) -def test_cycle() -> None: - d = NormalFamily() +def test_do(rng_key): + graph = causalprog.graph.Graph(label="G0") + graph.add_node( + DistributionNode( + NormalFamily(), label="UX", constant_parameters={"mean": 5.0, "cov": 1.0} + ) + ) + graph.add_node( + DistributionNode( + NormalFamily(), + label="X", + parameters={"mean": "UX"}, + constant_parameters={"cov": 1.0}, + ) + ) + graph.add_edge("UX", "X") - node0 = DistributionNode(d, label="X") - node1 = DistributionNode(d, label="Y") - node2 = DistributionNode(d, label="Z") + graph2 = causalprog.algorithms.do(graph, "UX", 4.0) - graph = Graph(label="G0") - graph.add_edge(node0, node1) - graph.add_edge(node1, node2) - graph.add_edge(node2, node0) + assert "mean" in graph.get_node("X").parameters + assert "mean" not in graph.get_node("X").constant_parameters + assert "mean" not in graph2.get_node("X").parameters + assert "mean" in graph2.get_node("X").constant_parameters + + assert np.isclose( + causalprog.algorithms.expectation( + graph, outcome_node_label="X", samples=1000, rng_key=rng_key + ), + 5.0, + rtol=1e-1, + ) - with pytest.raises(RuntimeError, match="Graph is not acyclic."): - graph.roots_down_to_outcome("X") + assert np.isclose( + causalprog.algorithms.expectation( + graph2, outcome_node_label="X", samples=1000, rng_key=rng_key + ), + 4.0, + rtol=1e-1, + ) @pytest.mark.parametrize( @@ -128,7 +101,7 @@ def test_cycle() -> None: pytest.param(1.0, 1.2, 10000000, 1e-3, id="N(mean=1, stdev=1.2), 10^7 samples"), ], ) -def test_single_normal_node(samples, rtol, mean, stdev, rng_key): +def test_mean_stdev_single_normal_node(samples, rtol, mean, stdev, rng_key): node = DistributionNode( NormalFamily(), label="X", @@ -207,7 +180,7 @@ def test_single_normal_node(samples, rtol, mean, stdev, rng_key): ), ], ) -def test_two_node_graph(samples, rtol, mean, stdev, stdev2, rng_key): +def test_mean_stdev_two_node_graph(samples, rtol, mean, stdev, stdev2, rng_key): if samples > 100: # noqa: PLR2004 pytest.xfail("Test currently too slow") graph = causalprog.graph.Graph(label="G0") @@ -242,117 +215,3 @@ def test_two_node_graph(samples, rtol, mean, stdev, stdev2, rng_key): np.sqrt(stdev**2 + stdev2**2), rtol=rtol, ) - - -@pytest.mark.parametrize( - ("param_values_before", "params_to_set", "expected"), - [ - pytest.param( - {}, - {"outcome": 4.0}, - TypeError("Node outcome is not a parameter node."), - id="Give non-parameter node", - ), - pytest.param( - {}, - {"mean": 4.0}, - {"mean": 4.0, "cov": None}, - id="Set only one parameter", - ), - pytest.param( - {}, - {}, - {"mean": None, "cov": None}, - id="Doing nothing is fine", - ), - pytest.param( - {"mean": 0.0, "cov": 0.0}, - {"cov": 1.0}, - {"mean": 0.0, "cov": 1.0}, - id="Omission preserves current value", - ), - ], -) -def test_set_parameters( - normal_graph_nodes: NormalGraphNodes, - normal_graph: Graph, - param_values_before: dict[NormalGraphNodeNames, float], - params_to_set: dict[str, float], - expected: Exception | dict[NormalGraphNodeNames, float], -) -> None: - """Test that we can identify parameter nodes, and set their values.""" - parameter_nodes = normal_graph.parameter_nodes - assert normal_graph_nodes["mean"] in parameter_nodes - assert normal_graph_nodes["cov"] in parameter_nodes - assert normal_graph_nodes["outcome"] not in parameter_nodes - - # Set any pre-existing values we might want the parameter nodes to have in - # this test. - for node_label, value in param_values_before.items(): - n = normal_graph.get_node(node_label) - assert isinstance(n, ParameterNode), ( - "Cannot set .value on non-parameter node (test input error)." - ) - n.value = value - - # Check behaviour of set_parameters method. - if isinstance(expected, Exception): - with pytest.raises(type(expected), match=re.escape(str(expected))): - normal_graph.set_parameters(**params_to_set) - else: - normal_graph.set_parameters(**params_to_set) - - for node_name, expected_value in expected.items(): - assert normal_graph.get_node(node_name).value == expected_value - - -def test_parameter_node(rng_key): - node = ParameterNode(label="mu") - - with pytest.raises(ValueError, match="Cannot sample"): - node.sample({}, 1, rng_key) - - node.value = 0.3 - - assert np.allclose(node.sample({}, 10, rng_key)[0], [0.3] * 10) - - -def test_do(rng_key): - graph = causalprog.graph.Graph(label="G0") - graph.add_node( - DistributionNode( - NormalFamily(), label="UX", constant_parameters={"mean": 5.0, "cov": 1.0} - ) - ) - graph.add_node( - DistributionNode( - NormalFamily(), - label="X", - parameters={"mean": "UX"}, - constant_parameters={"cov": 1.0}, - ) - ) - graph.add_edge("UX", "X") - - graph2 = causalprog.algorithms.do(graph, "UX", 4.0) - - assert "mean" in graph.get_node("X").parameters - assert "mean" not in graph.get_node("X").constant_parameters - assert "mean" not in graph2.get_node("X").parameters - assert "mean" in graph2.get_node("X").constant_parameters - - assert np.isclose( - causalprog.algorithms.expectation( - graph, outcome_node_label="X", samples=1000, rng_key=rng_key - ), - 5.0, - rtol=1e-1, - ) - - assert np.isclose( - causalprog.algorithms.expectation( - graph2, outcome_node_label="X", samples=1000, rng_key=rng_key - ), - 4.0, - rtol=1e-1, - ) diff --git a/tests/test_graph/test_construction.py b/tests/test_graph/test_construction.py new file mode 100644 index 0000000..d27e9b8 --- /dev/null +++ b/tests/test_graph/test_construction.py @@ -0,0 +1,78 @@ +"""Tests for graph module.""" + +import re +from typing import Literal, TypeAlias + +import pytest + +import causalprog +from causalprog.distribution.normal import NormalFamily +from causalprog.graph import DistributionNode, Graph, ParameterNode + +NormalGraphNodeNames: TypeAlias = Literal["mean", "cov", "outcome"] +NormalGraphNodes: TypeAlias = dict[ + NormalGraphNodeNames, DistributionNode | ParameterNode +] + + +def test_label(): + d = NormalFamily() + node = DistributionNode(d, label="X") + node2 = DistributionNode(d, label="Y") + node_copy = node + + assert node.label == node_copy.label == "X" + assert node.label != node2.label + assert node2.label == "Y" + + assert isinstance(node, causalprog.graph.node.Node) + assert isinstance(node2, causalprog.graph.node.Node) + + +def test_duplicate_label(): + d = NormalFamily() + + graph = Graph(label="G0") + graph.add_node(DistributionNode(d, label="X")) + with pytest.raises(ValueError, match=re.escape("Duplicate node label: X")): + graph.add_node(DistributionNode(d, label="X")) + + +@pytest.mark.parametrize( + "use_labels", + [pytest.param(True, id="Via labels"), pytest.param(False, id="Via variables")], +) +def test_build_graph(*, use_labels: bool) -> None: + root_label = "root" + outcome_label = "outcome_label" + d = NormalFamily() + + root_node = DistributionNode(d, label=root_label) + outcome_node = DistributionNode(d, label=outcome_label) + + graph = Graph(label="G0") + graph.add_node(root_node) + graph.add_node(outcome_node) + + if use_labels: + graph.add_edge(root_label, outcome_label) + else: + graph.add_edge(root_node, outcome_node) + + assert graph.roots_down_to_outcome(outcome_label) == [root_node, outcome_node] + + +def test_cycle() -> None: + d = NormalFamily() + + node0 = DistributionNode(d, label="X") + node1 = DistributionNode(d, label="Y") + node2 = DistributionNode(d, label="Z") + + graph = Graph(label="G0") + graph.add_edge(node0, node1) + graph.add_edge(node1, node2) + graph.add_edge(node2, node0) + + with pytest.raises(RuntimeError, match="Graph is not acyclic."): + graph.roots_down_to_outcome("X") diff --git a/tests/test_graph/test_parameters.py b/tests/test_graph/test_parameters.py new file mode 100644 index 0000000..20f7760 --- /dev/null +++ b/tests/test_graph/test_parameters.py @@ -0,0 +1,87 @@ +"""Tests for graph module.""" + +import re +from typing import Literal, TypeAlias + +import numpy as np +import pytest + +from causalprog.graph import DistributionNode, Graph, ParameterNode + +NormalGraphNodeNames: TypeAlias = Literal["mean", "cov", "outcome"] +NormalGraphNodes: TypeAlias = dict[ + NormalGraphNodeNames, DistributionNode | ParameterNode +] + + +@pytest.mark.parametrize( + ("param_values_before", "params_to_set", "expected"), + [ + pytest.param( + {}, + {"outcome": 4.0}, + TypeError("Node outcome is not a parameter node."), + id="Give non-parameter node", + ), + pytest.param( + {}, + {"mean": 4.0}, + {"mean": 4.0, "cov": None}, + id="Set only one parameter", + ), + pytest.param( + {}, + {}, + {"mean": None, "cov": None}, + id="Doing nothing is fine", + ), + pytest.param( + {"mean": 0.0, "cov": 0.0}, + {"cov": 1.0}, + {"mean": 0.0, "cov": 1.0}, + id="Omission preserves current value", + ), + ], +) +def test_set_parameters( + normal_graph_nodes: NormalGraphNodes, + normal_graph: Graph, + param_values_before: dict[NormalGraphNodeNames, float], + params_to_set: dict[str, float], + expected: Exception | dict[NormalGraphNodeNames, float], +) -> None: + """Test that we can identify parameter nodes, and set their values.""" + parameter_nodes = normal_graph.parameter_nodes + assert normal_graph_nodes["mean"] in parameter_nodes + assert normal_graph_nodes["cov"] in parameter_nodes + assert normal_graph_nodes["outcome"] not in parameter_nodes + + # Set any pre-existing values we might want the parameter nodes to have in + # this test. + for node_label, value in param_values_before.items(): + n = normal_graph.get_node(node_label) + assert isinstance(n, ParameterNode), ( + "Cannot set .value on non-parameter node (test input error)." + ) + n.value = value + + # Check behaviour of set_parameters method. + if isinstance(expected, Exception): + with pytest.raises(type(expected), match=re.escape(str(expected))): + normal_graph.set_parameters(**params_to_set) + else: + normal_graph.set_parameters(**params_to_set) + + for node_name, expected_value in expected.items(): + assert normal_graph.get_node(node_name).value == expected_value + + +def test_parameter_node(rng_key): + node = ParameterNode(label="mu") + + with pytest.raises(ValueError, match="Cannot sample"): + node.sample({}, 1, rng_key) + + node.value = 0.3 + + assert np.allclose(node.sample({}, 10, rng_key)[0], [0.3] * 10) From a6c04b7a2f0867d2fa955c32e68b9fa8b02babc6 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Fri, 8 Aug 2025 11:14:22 +0100 Subject: [PATCH 03/18] factory as fixture (see https://docs.pytest.org/en/7.1.x/how-to/fixtures.html#factories-as-fixtures) --- tests/fixtures/graph.py | 67 ++++++++++++++++++++--------- tests/test_graph/test_algorithms.py | 32 ++++---------- tests/test_graph/test_parameters.py | 18 ++++---- 3 files changed, 63 insertions(+), 54 deletions(-) diff --git a/tests/fixtures/graph.py b/tests/fixtures/graph.py index fabe030..7068064 100644 --- a/tests/fixtures/graph.py +++ b/tests/fixtures/graph.py @@ -1,6 +1,6 @@ """Tests for graph module.""" -from typing import Literal, TypeAlias +from typing import Literal, TypeAlias, Callable import pytest @@ -14,32 +14,57 @@ @pytest.fixture -def normal_graph_nodes() -> NormalGraphNodes: - """Collection of Nodes used to construct `normal_graph`. +def normal_graph() -> Callable[NormalGraphNodes | None, Graph]: + def _normal_graph(normal_graph_nodes: NormalGraphNodes | None = None) -> Graph: + """Creates a 3-node graph: - See `normal_graph` docstring for more details. - """ - return { - "mean": ParameterNode(label="mean"), - "cov": ParameterNode(label="cov"), - "outcome": DistributionNode( - NormalFamily(), label="outcome", parameters={"mean": "mean", "cov": "std"} - ), - } + mean (P) cov (P) + |---> outcome <----| + + where outcome is a normal distribution. + + Parameter nodes are initialised with no `value` set. + """ + if normal_graph_nodes is None: + normal_graph_nodes = { + "mean": ParameterNode(label="mean"), + "cov": ParameterNode(label="cov"), + "outcome": DistributionNode( + NormalFamily(), label="outcome", parameters={"mean": "mean", "cov": "std"} + ), + } + graph = Graph(label="normal dist") + graph.add_edge(normal_graph_nodes["mean"], normal_graph_nodes["outcome"]) + graph.add_edge(normal_graph_nodes["cov"], normal_graph_nodes["outcome"]) + return graph + + return _normal_graph @pytest.fixture -def normal_graph(normal_graph_nodes: NormalGraphNodes) -> Graph: - """Creates a 3-node graph: +def ux_x_graph() -> Graph: + """Creates a 2 node graph: - mean (P) cov (P) - |---> outcome <----| + UX --> X - where outcome is a normal distribution. + where EX is a normal distribution with mean 5 and covariance 1, and X is + a normal distrubution with mean UX and covariance 1. - Parameter nodes are initialised with no `value` set. """ - graph = Graph(label="normal dist") - graph.add_edge(normal_graph_nodes["mean"], normal_graph_nodes["outcome"]) - graph.add_edge(normal_graph_nodes["cov"], normal_graph_nodes["outcome"]) + graph = Graph(label="G0") + graph.add_node( + DistributionNode( + NormalFamily(), label="UX", constant_parameters={"mean": 5.0, "cov": 1.0} + ) + ) + graph.add_node( + DistributionNode( + NormalFamily(), + label="X", + parameters={"mean": "UX"}, + constant_parameters={"cov": 1.0}, + ) + ) + graph.add_edge("UX", "X") + return graph diff --git a/tests/test_graph/test_algorithms.py b/tests/test_graph/test_algorithms.py index 5c98d44..e68df23 100644 --- a/tests/test_graph/test_algorithms.py +++ b/tests/test_graph/test_algorithms.py @@ -51,33 +51,17 @@ def test_roots_down_to_outcome() -> None: assert nodes.index(graph.get_node(e[0])) < nodes.index(graph.get_node(e[1])) -def test_do(rng_key): - graph = causalprog.graph.Graph(label="G0") - graph.add_node( - DistributionNode( - NormalFamily(), label="UX", constant_parameters={"mean": 5.0, "cov": 1.0} - ) - ) - graph.add_node( - DistributionNode( - NormalFamily(), - label="X", - parameters={"mean": "UX"}, - constant_parameters={"cov": 1.0}, - ) - ) - graph.add_edge("UX", "X") - - graph2 = causalprog.algorithms.do(graph, "UX", 4.0) +def test_do(rng_key, ux_x_graph): + do_graph = causalprog.algorithms.do(ux_x_graph, "UX", 4.0) - assert "mean" in graph.get_node("X").parameters - assert "mean" not in graph.get_node("X").constant_parameters - assert "mean" not in graph2.get_node("X").parameters - assert "mean" in graph2.get_node("X").constant_parameters + assert "mean" in ux_x_graph.get_node("X").parameters + assert "mean" not in ux_x_graph.get_node("X").constant_parameters + assert "mean" not in do_graph.get_node("X").parameters + assert "mean" in do_graph.get_node("X").constant_parameters assert np.isclose( causalprog.algorithms.expectation( - graph, outcome_node_label="X", samples=1000, rng_key=rng_key + ux_x_graph, outcome_node_label="X", samples=1000, rng_key=rng_key ), 5.0, rtol=1e-1, @@ -85,7 +69,7 @@ def test_do(rng_key): assert np.isclose( causalprog.algorithms.expectation( - graph2, outcome_node_label="X", samples=1000, rng_key=rng_key + do_graph, outcome_node_label="X", samples=1000, rng_key=rng_key ), 4.0, rtol=1e-1, diff --git a/tests/test_graph/test_parameters.py b/tests/test_graph/test_parameters.py index 20f7760..f5b2a3c 100644 --- a/tests/test_graph/test_parameters.py +++ b/tests/test_graph/test_parameters.py @@ -44,22 +44,22 @@ ], ) def test_set_parameters( - normal_graph_nodes: NormalGraphNodes, normal_graph: Graph, param_values_before: dict[NormalGraphNodeNames, float], params_to_set: dict[str, float], expected: Exception | dict[NormalGraphNodeNames, float], ) -> None: """Test that we can identify parameter nodes, and set their values.""" - parameter_nodes = normal_graph.parameter_nodes - assert normal_graph_nodes["mean"] in parameter_nodes - assert normal_graph_nodes["cov"] in parameter_nodes - assert normal_graph_nodes["outcome"] not in parameter_nodes + graph = normal_graph() + parameter_nodes = graph.parameter_nodes + assert graph.get_node("mean") in parameter_nodes + assert graph.get_node("cov") in parameter_nodes + assert graph.get_node("outcome") not in parameter_nodes # Set any pre-existing values we might want the parameter nodes to have in # this test. for node_label, value in param_values_before.items(): - n = normal_graph.get_node(node_label) + n = graph.get_node(node_label) assert isinstance(n, ParameterNode), ( "Cannot set .value on non-parameter node (test input error)." ) @@ -68,12 +68,12 @@ def test_set_parameters( # Check behaviour of set_parameters method. if isinstance(expected, Exception): with pytest.raises(type(expected), match=re.escape(str(expected))): - normal_graph.set_parameters(**params_to_set) + graph.set_parameters(**params_to_set) else: - normal_graph.set_parameters(**params_to_set) + graph.set_parameters(**params_to_set) for node_name, expected_value in expected.items(): - assert normal_graph.get_node(node_name).value == expected_value + assert graph.get_node(node_name).value == expected_value def test_parameter_node(rng_key): From 278a631068d5e0c39508b54631ce5c8e3ef0724b Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Fri, 8 Aug 2025 11:25:21 +0100 Subject: [PATCH 04/18] more fixtures --- tests/fixtures/graph.py | 64 +++++++++++++++++------------ tests/test_graph/test_algorithms.py | 37 ++++++----------- 2 files changed, 49 insertions(+), 52 deletions(-) diff --git a/tests/fixtures/graph.py b/tests/fixtures/graph.py index 7068064..7ac61c0 100644 --- a/tests/fixtures/graph.py +++ b/tests/fixtures/graph.py @@ -1,6 +1,7 @@ """Tests for graph module.""" -from typing import Literal, TypeAlias, Callable +from collections.abc import Callable +from typing import Literal, TypeAlias import pytest @@ -14,23 +15,26 @@ @pytest.fixture -def normal_graph() -> Callable[NormalGraphNodes | None, Graph]: - def _normal_graph(normal_graph_nodes: NormalGraphNodes | None = None) -> Graph: - """Creates a 3-node graph: +def normal_graph() -> Callable[[NormalGraphNodes | None], Graph]: + """Creates a 3-node graph: + + mean (P) cov (P) + |---> outcome <----| - mean (P) cov (P) - |---> outcome <----| + where outcome is a normal distribution. - where outcome is a normal distribution. + Parameter nodes are initialised with no `value` set. + """ - Parameter nodes are initialised with no `value` set. - """ + def _normal_graph(normal_graph_nodes: NormalGraphNodes | None = None) -> Graph: if normal_graph_nodes is None: normal_graph_nodes = { "mean": ParameterNode(label="mean"), "cov": ParameterNode(label="cov"), "outcome": DistributionNode( - NormalFamily(), label="outcome", parameters={"mean": "mean", "cov": "std"} + NormalFamily(), + label="outcome", + parameters={"mean": "mean", "cov": "std"}, ), } graph = Graph(label="normal dist") @@ -42,29 +46,35 @@ def _normal_graph(normal_graph_nodes: NormalGraphNodes | None = None) -> Graph: @pytest.fixture -def ux_x_graph() -> Graph: +def ux_x_graph() -> Callable[[NormalGraphNodes | None], Graph]: """Creates a 2 node graph: UX --> X - where EX is a normal distribution with mean 5 and covariance 1, and X is - a normal distrubution with mean UX and covariance 1. + where EX is a normal distribution with mean `mean` and covariance `cov`, and X is + a normal distrubution with mean UX and covariance `cov2`. """ - graph = Graph(label="G0") - graph.add_node( - DistributionNode( - NormalFamily(), label="UX", constant_parameters={"mean": 5.0, "cov": 1.0} + + def _ux_x_graph(mean: float = 5.0, cov: float = 1.0, cov2: float = 1.0) -> Graph: + graph = Graph(label="G0") + graph.add_node( + DistributionNode( + NormalFamily(), + label="UX", + constant_parameters={"mean": mean, "cov": cov}, + ) ) - ) - graph.add_node( - DistributionNode( - NormalFamily(), - label="X", - parameters={"mean": "UX"}, - constant_parameters={"cov": 1.0}, + graph.add_node( + DistributionNode( + NormalFamily(), + label="X", + parameters={"mean": "UX"}, + constant_parameters={"cov": cov2}, + ) ) - ) - graph.add_edge("UX", "X") + graph.add_edge("UX", "X") + + return graph - return graph + return _ux_x_graph diff --git a/tests/test_graph/test_algorithms.py b/tests/test_graph/test_algorithms.py index e68df23..955aa84 100644 --- a/tests/test_graph/test_algorithms.py +++ b/tests/test_graph/test_algorithms.py @@ -52,16 +52,17 @@ def test_roots_down_to_outcome() -> None: def test_do(rng_key, ux_x_graph): - do_graph = causalprog.algorithms.do(ux_x_graph, "UX", 4.0) + graph = ux_x_graph() + graph2 = causalprog.algorithms.do(graph, "UX", 4.0) - assert "mean" in ux_x_graph.get_node("X").parameters - assert "mean" not in ux_x_graph.get_node("X").constant_parameters - assert "mean" not in do_graph.get_node("X").parameters - assert "mean" in do_graph.get_node("X").constant_parameters + assert "mean" in graph.get_node("X").parameters + assert "mean" not in graph.get_node("X").constant_parameters + assert "mean" not in graph2.get_node("X").parameters + assert "mean" in graph2.get_node("X").constant_parameters assert np.isclose( causalprog.algorithms.expectation( - ux_x_graph, outcome_node_label="X", samples=1000, rng_key=rng_key + graph, outcome_node_label="X", samples=1000, rng_key=rng_key ), 5.0, rtol=1e-1, @@ -69,7 +70,7 @@ def test_do(rng_key, ux_x_graph): assert np.isclose( causalprog.algorithms.expectation( - do_graph, outcome_node_label="X", samples=1000, rng_key=rng_key + graph2, outcome_node_label="X", samples=1000, rng_key=rng_key ), 4.0, rtol=1e-1, @@ -164,26 +165,12 @@ def test_mean_stdev_single_normal_node(samples, rtol, mean, stdev, rng_key): ), ], ) -def test_mean_stdev_two_node_graph(samples, rtol, mean, stdev, stdev2, rng_key): +def test_mean_stdev_two_node_graph( + ux_x_graph, samples, rtol, mean, stdev, stdev2, rng_key +): if samples > 100: # noqa: PLR2004 pytest.xfail("Test currently too slow") - graph = causalprog.graph.Graph(label="G0") - graph.add_node( - DistributionNode( - NormalFamily(), - label="UX", - constant_parameters={"mean": mean, "cov": stdev**2}, - ) - ) - graph.add_node( - DistributionNode( - NormalFamily(), - label="X", - parameters={"mean": "UX"}, - constant_parameters={"cov": stdev2**2}, - ) - ) - graph.add_edge("UX", "X") + graph = ux_x_graph(mean=mean, cov=stdev**2, cov2=stdev2**2) assert np.isclose( causalprog.algorithms.expectation( From 1e77f8f3d3e49ea1db370aa194e5a156805095b3 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Fri, 8 Aug 2025 11:31:31 +0100 Subject: [PATCH 05/18] typing --- tests/fixtures/graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/fixtures/graph.py b/tests/fixtures/graph.py index 7ac61c0..e3dd8d0 100644 --- a/tests/fixtures/graph.py +++ b/tests/fixtures/graph.py @@ -46,7 +46,7 @@ def _normal_graph(normal_graph_nodes: NormalGraphNodes | None = None) -> Graph: @pytest.fixture -def ux_x_graph() -> Callable[[NormalGraphNodes | None], Graph]: +def ux_x_graph() -> Callable[[float, float, float], Graph]: """Creates a 2 node graph: UX --> X From 83aca68a92d8c0b39490034f97f9c1a1a1c9f33e Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Fri, 8 Aug 2025 13:03:09 +0100 Subject: [PATCH 06/18] add moment function --- src/causalprog/algorithms/__init__.py | 2 +- .../algorithms/{expectation.py => moments.py} | 15 +++ tests/test_graph/test_algorithms.py | 97 ++++++++++++++++--- 3 files changed, 102 insertions(+), 12 deletions(-) rename src/causalprog/algorithms/{expectation.py => moments.py} (79%) diff --git a/src/causalprog/algorithms/__init__.py b/src/causalprog/algorithms/__init__.py index 879d9aa..da37e9a 100644 --- a/src/causalprog/algorithms/__init__.py +++ b/src/causalprog/algorithms/__init__.py @@ -1,4 +1,4 @@ """Algorithms.""" from .do import do -from .expectation import expectation, standard_deviation +from .moments import expectation, moment, standard_deviation diff --git a/src/causalprog/algorithms/expectation.py b/src/causalprog/algorithms/moments.py similarity index 79% rename from src/causalprog/algorithms/expectation.py rename to src/causalprog/algorithms/moments.py index 6ad8345..e9664d3 100644 --- a/src/causalprog/algorithms/expectation.py +++ b/src/causalprog/algorithms/moments.py @@ -47,3 +47,18 @@ def standard_deviation( ) -> float: """Estimate the standard deviation of a graph.""" return sample(graph, outcome_node_label, samples, rng_key=rng_key).std() + + +def moment( + order: int, + graph: Graph, + outcome_node_label: str, + samples: int, + *, + rng_key: jax.Array, +) -> float: + """Estimate moment of the given order of the data.""" + return ( + sum(sample(graph, outcome_node_label, samples, rng_key=rng_key) ** order) + / samples + ) diff --git a/tests/test_graph/test_algorithms.py b/tests/test_graph/test_algorithms.py index 955aa84..1dd57bc 100644 --- a/tests/test_graph/test_algorithms.py +++ b/tests/test_graph/test_algorithms.py @@ -6,7 +6,7 @@ import numpy as np import pytest -import causalprog +from causalprog import algorithms from causalprog.distribution.normal import NormalFamily from causalprog.graph import DistributionNode, Graph, ParameterNode @@ -53,7 +53,7 @@ def test_roots_down_to_outcome() -> None: def test_do(rng_key, ux_x_graph): graph = ux_x_graph() - graph2 = causalprog.algorithms.do(graph, "UX", 4.0) + graph2 = algorithms.do(graph, "UX", 4.0) assert "mean" in graph.get_node("X").parameters assert "mean" not in graph.get_node("X").constant_parameters @@ -61,7 +61,7 @@ def test_do(rng_key, ux_x_graph): assert "mean" in graph2.get_node("X").constant_parameters assert np.isclose( - causalprog.algorithms.expectation( + algorithms.expectation( graph, outcome_node_label="X", samples=1000, rng_key=rng_key ), 5.0, @@ -69,7 +69,7 @@ def test_do(rng_key, ux_x_graph): ) assert np.isclose( - causalprog.algorithms.expectation( + algorithms.expectation( graph2, outcome_node_label="X", samples=1000, rng_key=rng_key ), 4.0, @@ -86,7 +86,7 @@ def test_do(rng_key, ux_x_graph): pytest.param(1.0, 1.2, 10000000, 1e-3, id="N(mean=1, stdev=1.2), 10^7 samples"), ], ) -def test_mean_stdev_single_normal_node(samples, rtol, mean, stdev, rng_key): +def test_expectation_stdev_single_normal_node(samples, rtol, mean, stdev, rng_key): node = DistributionNode( NormalFamily(), label="X", @@ -108,14 +108,14 @@ def test_mean_stdev_single_normal_node(samples, rtol, mean, stdev, rng_key): # Check within hand-computation assert np.isclose( - causalprog.algorithms.expectation( + algorithms.expectation( graph, outcome_node_label="X", samples=samples, rng_key=rng_key ), mean, rtol=rtol, ) assert np.isclose( - causalprog.algorithms.standard_deviation( + algorithms.standard_deviation( graph, outcome_node_label="X", samples=samples, rng_key=rng_key ), stdev, @@ -123,13 +123,13 @@ def test_mean_stdev_single_normal_node(samples, rtol, mean, stdev, rng_key): ) # Check within computational distance assert np.isclose( - causalprog.algorithms.expectation( + algorithms.expectation( graph, outcome_node_label="X", samples=samples, rng_key=rng_key ), expected_mean, ) assert np.isclose( - causalprog.algorithms.standard_deviation( + algorithms.standard_deviation( graph, outcome_node_label="X", samples=samples, rng_key=rng_key ), expected_std_dev, @@ -173,16 +173,91 @@ def test_mean_stdev_two_node_graph( graph = ux_x_graph(mean=mean, cov=stdev**2, cov2=stdev2**2) assert np.isclose( - causalprog.algorithms.expectation( + algorithms.expectation( graph, outcome_node_label="X", samples=samples, rng_key=rng_key ), mean, rtol=rtol, ) assert np.isclose( - causalprog.algorithms.standard_deviation( + algorithms.standard_deviation( graph, outcome_node_label="X", samples=samples, rng_key=rng_key ), np.sqrt(stdev**2 + stdev2**2), rtol=rtol, ) + + +@pytest.mark.parametrize( + ("samples", "rtol"), + [ + pytest.param(100, 1, id="100 samples"), + pytest.param(10000, 1e-1, id="10^4 samples"), + pytest.param(1000000, 1e-2, id="10^6 samples"), + ], +) +def test_expectation(ux_x_graph, rng_key, samples, rtol): + if samples > 100: # noqa: PLR2004 + pytest.xfail("Test currently too slow") + graph = ux_x_graph() + + assert np.isclose( + algorithms.expectation( + graph, outcome_node_label="X", samples=samples, rng_key=rng_key + ), + sum( + algorithms.moments.sample( + graph, outcome_node_label="X", samples=samples, rng_key=rng_key + ) + ) + / samples, + rtol=rtol, + ) + assert np.isclose( + algorithms.expectation( + graph, outcome_node_label="X", samples=samples, rng_key=rng_key + ), + algorithms.moment( + 1, graph, outcome_node_label="X", samples=samples, rng_key=rng_key + ), + rtol=rtol, + ) + + +@pytest.mark.parametrize( + ("samples", "rtol"), + [ + pytest.param(100, 1, id="100 samples"), + pytest.param(10000, 1e-1, id="10^4 samples"), + pytest.param(1000000, 1e-2, id="10^6 samples"), + ], +) +def test_stdev(ux_x_graph, rng_key, samples, rtol): + if samples > 100: # noqa: PLR2004 + pytest.xfail("Test currently too slow") + graph = ux_x_graph() + + s = algorithms.moments.sample( + graph, outcome_node_label="X", samples=samples, rng_key=rng_key + ) + variance = (sum(s**2) - sum(s) ** 2 / samples) / samples + assert np.isclose( + algorithms.standard_deviation( + graph, outcome_node_label="X", samples=samples, rng_key=rng_key + ), + variance**0.5, + rtol=rtol, + ) + assert np.isclose( + algorithms.standard_deviation( + graph, outcome_node_label="X", samples=samples, rng_key=rng_key + ), + algorithms.moment( + 2, graph, outcome_node_label="X", samples=samples, rng_key=rng_key + ) + - algorithms.moment( + 1, graph, outcome_node_label="X", samples=samples, rng_key=rng_key + ) + ** 2, + rtol=rtol, + ) From 3d50a90288ea7dc688adbbe1eeac4897525c6b49 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Fri, 8 Aug 2025 13:07:40 +0100 Subject: [PATCH 07/18] simplify tests --- tests/test_graph/test_algorithms.py | 35 ++++------------------------- 1 file changed, 4 insertions(+), 31 deletions(-) diff --git a/tests/test_graph/test_algorithms.py b/tests/test_graph/test_algorithms.py index 1dd57bc..8315ff4 100644 --- a/tests/test_graph/test_algorithms.py +++ b/tests/test_graph/test_algorithms.py @@ -205,22 +205,9 @@ def test_expectation(ux_x_graph, rng_key, samples, rtol): algorithms.expectation( graph, outcome_node_label="X", samples=samples, rng_key=rng_key ), - sum( - algorithms.moments.sample( - graph, outcome_node_label="X", samples=samples, rng_key=rng_key - ) - ) - / samples, - rtol=rtol, - ) - assert np.isclose( - algorithms.expectation( + algorithms.moments.sample( graph, outcome_node_label="X", samples=samples, rng_key=rng_key - ), - algorithms.moment( - 1, graph, outcome_node_label="X", samples=samples, rng_key=rng_key - ), - rtol=rtol, + ).mean(), ) @@ -237,27 +224,13 @@ def test_stdev(ux_x_graph, rng_key, samples, rtol): pytest.xfail("Test currently too slow") graph = ux_x_graph() - s = algorithms.moments.sample( - graph, outcome_node_label="X", samples=samples, rng_key=rng_key - ) variance = (sum(s**2) - sum(s) ** 2 / samples) / samples assert np.isclose( algorithms.standard_deviation( graph, outcome_node_label="X", samples=samples, rng_key=rng_key ), - variance**0.5, - rtol=rtol, - ) - assert np.isclose( - algorithms.standard_deviation( + algorithms.moments.sample( graph, outcome_node_label="X", samples=samples, rng_key=rng_key - ), - algorithms.moment( - 2, graph, outcome_node_label="X", samples=samples, rng_key=rng_key - ) - - algorithms.moment( - 1, graph, outcome_node_label="X", samples=samples, rng_key=rng_key - ) - ** 2, + ).std(), rtol=rtol, ) From 6cf500d26949a8197fa9b2e1eb8097ded0e0e151 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Fri, 8 Aug 2025 13:14:39 +0100 Subject: [PATCH 08/18] make expectation and standard_deviation functions use moment function --- src/causalprog/algorithms/moments.py | 7 +++++-- tests/fixtures/graph.py | 2 +- tests/test_graph/test_algorithms.py | 6 ++++-- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/causalprog/algorithms/moments.py b/src/causalprog/algorithms/moments.py index e9664d3..b124d74 100644 --- a/src/causalprog/algorithms/moments.py +++ b/src/causalprog/algorithms/moments.py @@ -35,7 +35,7 @@ def expectation( rng_key: jax.Array, ) -> float: """Estimate the expectation of a graph.""" - return sample(graph, outcome_node_label, samples, rng_key=rng_key).mean() + return moment(1, graph, outcome_node_label, samples, rng_key=rng_key) def standard_deviation( @@ -46,7 +46,10 @@ def standard_deviation( rng_key: jax.Array, ) -> float: """Estimate the standard deviation of a graph.""" - return sample(graph, outcome_node_label, samples, rng_key=rng_key).std() + return ( + moment(2, graph, outcome_node_label, samples, rng_key=rng_key) + - moment(1, graph, outcome_node_label, samples, rng_key=rng_key) ** 2 + ) ** 0.5 def moment( diff --git a/tests/fixtures/graph.py b/tests/fixtures/graph.py index e3dd8d0..bfbbf68 100644 --- a/tests/fixtures/graph.py +++ b/tests/fixtures/graph.py @@ -56,7 +56,7 @@ def ux_x_graph() -> Callable[[float, float, float], Graph]: """ - def _ux_x_graph(mean: float = 5.0, cov: float = 1.0, cov2: float = 1.0) -> Graph: + def _ux_x_graph(mean: float = 5.0, cov: float = 0.5, cov2: float = 2.0) -> Graph: graph = Graph(label="G0") graph.add_node( DistributionNode( diff --git a/tests/test_graph/test_algorithms.py b/tests/test_graph/test_algorithms.py index 8315ff4..085219b 100644 --- a/tests/test_graph/test_algorithms.py +++ b/tests/test_graph/test_algorithms.py @@ -87,6 +87,8 @@ def test_do(rng_key, ux_x_graph): ], ) def test_expectation_stdev_single_normal_node(samples, rtol, mean, stdev, rng_key): + if samples > 100: # noqa: PLR2004 + pytest.xfail("Test currently too slow") node = DistributionNode( NormalFamily(), label="X", @@ -165,7 +167,7 @@ def test_expectation_stdev_single_normal_node(samples, rtol, mean, stdev, rng_ke ), ], ) -def test_mean_stdev_two_node_graph( +def test_expectation_stdev_two_node_graph( ux_x_graph, samples, rtol, mean, stdev, stdev2, rng_key ): if samples > 100: # noqa: PLR2004 @@ -208,6 +210,7 @@ def test_expectation(ux_x_graph, rng_key, samples, rtol): algorithms.moments.sample( graph, outcome_node_label="X", samples=samples, rng_key=rng_key ).mean(), + rtol=rtol, ) @@ -224,7 +227,6 @@ def test_stdev(ux_x_graph, rng_key, samples, rtol): pytest.xfail("Test currently too slow") graph = ux_x_graph() - variance = (sum(s**2) - sum(s) ** 2 / samples) / samples assert np.isclose( algorithms.standard_deviation( graph, outcome_node_label="X", samples=samples, rng_key=rng_key From 15ac29573bd424f4dbf5ffe7edffd1b17de8d23e Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Tue, 12 Aug 2025 14:01:59 +0100 Subject: [PATCH 09/18] ruff --- tests/fixtures/graph.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/fixtures/graph.py b/tests/fixtures/graph.py index eb574cc..9a98f4e 100644 --- a/tests/fixtures/graph.py +++ b/tests/fixtures/graph.py @@ -3,9 +3,6 @@ from collections.abc import Callable from typing import Literal, TypeAlias -import pytest - -from causalprog.distribution.normal import NormalFamily import numpy.typing as npt import numpyro import pytest From ef08b45916af0a4e8183ffaca4c96b08b4f3dbb4 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Tue, 12 Aug 2025 14:32:57 +0100 Subject: [PATCH 10/18] fix typo --- tests/fixtures/graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/fixtures/graph.py b/tests/fixtures/graph.py index 9a98f4e..7975b3f 100644 --- a/tests/fixtures/graph.py +++ b/tests/fixtures/graph.py @@ -53,7 +53,7 @@ def two_normal_graph() -> Callable[[float, float, float], Graph]: UX --> X - where EX is a normal distribution with mean `mean` and covariance `cov`, and X is + where UX is a normal distribution with mean `mean` and covariance `cov`, and X is a normal distrubution with mean UX and covariance `cov2`. """ From 2291b6beacb3724b4f523186637beb6e7dbe0c44 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Wed, 13 Aug 2025 09:26:13 +0100 Subject: [PATCH 11/18] work in progress from yesterday's train journey --- src/causalprog/graph/node.py | 21 +++++++------ tests/fixtures/graph.py | 47 ++++++++++++++++------------- tests/test_graph/test_algorithms.py | 33 +++++++++----------- 3 files changed, 52 insertions(+), 49 deletions(-) diff --git a/src/causalprog/graph/node.py b/src/causalprog/graph/node.py index e2db0a3..e65ca14 100644 --- a/src/causalprog/graph/node.py +++ b/src/causalprog/graph/node.py @@ -178,19 +178,20 @@ def sample( samples: int, rng_key: jax.Array, ) -> npt.NDArray[float]: + d = self._dist( + # Pass in node values derived from construction so far + **{ + native_name: sampled_dependencies[node_name] + for native_name, node_name in self.parameters.items() + }, + # Pass in any constant parameters this node sets + **self.constant_parameters, + ) return numpyro.sample( self.label, - self._dist( - # Pass in node values derived from construction so far - **{ - native_name: sampled_dependencies[node_name] - for native_name, node_name in self.parameters.items() - }, - # Pass in any constant parameters this node sets - **self.constant_parameters, - ), + d, rng_key=rng_key, - sample_shape=(samples,), + sample_shape=(samples,) if d.batch_shape == () and samples > 1 else (), ) @override diff --git a/tests/fixtures/graph.py b/tests/fixtures/graph.py index 7975b3f..9fb6a22 100644 --- a/tests/fixtures/graph.py +++ b/tests/fixtures/graph.py @@ -17,31 +17,36 @@ @pytest.fixture -def normal_graph() -> Callable[[NormalGraphNodes | None], Graph]: - """Creates a 3-node graph: +def normal_graph() -> Callable[[float, float], Graph]: + """Creates a graph with one normal distribution X. - mean (P) cov (P) - |---> outcome <----| - - where outcome is a normal distribution. - - Parameter nodes are initialised with no `value` set. + Parameter nodes are included if no values are given for the mean and covariance. """ - def _inner(normal_graph_nodes: NormalGraphNodes | None = None) -> Graph: - if normal_graph_nodes is None: - normal_graph_nodes = { - "mean": ParameterNode(label="mean"), - "cov": ParameterNode(label="cov"), - "outcome": DistributionNode( - Normal, - label="outcome", - parameters={"loc": "mean", "scale": "std"}, - ), - } + def _inner(mean: float | None = None, cov: float | None = None): graph = Graph(label="normal dist") - graph.add_edge(normal_graph_nodes["mean"], normal_graph_nodes["outcome"]) - graph.add_edge(normal_graph_nodes["cov"], normal_graph_nodes["outcome"]) + parameters = {} + constant_parameters = {} + if mean is None: + graph.add_node(ParameterNode(label="mean")) + parameters["loc"] = "mean" + else: + constant_parameters["loc"] = mean + if cov is None: + graph.add_node(ParameterNode(label="cov")) + parameters["scale"] = "cov" + else: + constant_parameters["scale"] = cov + graph.add_node( + DistributionNode( + Normal, + label="X", + parameters=parameters, + constant_parameters=constant_parameters, + ) + ) + for node in parameters.values(): + graph.add_edge(node, "X") return graph return _inner diff --git a/tests/test_graph/test_algorithms.py b/tests/test_graph/test_algorithms.py index 76b842e..e8a271f 100644 --- a/tests/test_graph/test_algorithms.py +++ b/tests/test_graph/test_algorithms.py @@ -9,6 +9,7 @@ from causalprog import algorithms from causalprog.graph import DistributionNode, Graph +max_samples = 10**5 def test_roots_down_to_outcome() -> None: graph = Graph(label="G0") @@ -78,24 +79,18 @@ def test_do(rng_key, two_normal_graph): pytest.param(1.0, 1.2, 10000000, 1e-3, id="N(mean=1, stdev=1.2), 10^7 samples"), ], ) -def test_expectation_stdev_single_normal_node(samples, rtol, mean, stdev, rng_key): - if samples > 100: # noqa: PLR2004 +def test_expectation_stdev_single_normal_node(normal_graph, samples, rtol, mean, stdev, rng_key): + if samples > max_samples: pytest.xfail("Test currently too slow") - node = DistributionNode( - Normal, - label="X", - constant_parameters={"loc": mean, "scale": stdev}, - ) - graph = Graph(label="G0") - graph.add_node(node) + graph = normal_graph(mean, stdev) # To compensate for rng-key splitting in sample methods, note the "split" key # that is actually used to draw the samples from the distribution, so we can # attempt to replicate its behaviour explicitly. key = jax.random.split(rng_key, 1)[0] what_we_should_get = jax.random.multivariate_normal( - key, jax.numpy.atleast_1d(mean), jax.numpy.atleast_2d(stdev**2), shape=samples + rng_key, jax.numpy.atleast_1d(mean), jax.numpy.atleast_2d(stdev**2), shape=samples ) expected_mean = what_we_should_get.mean() expected_std_dev = what_we_should_get.std() @@ -121,12 +116,14 @@ def test_expectation_stdev_single_normal_node(samples, rtol, mean, stdev, rng_ke graph, outcome_node_label="X", samples=samples, rng_key=rng_key ), expected_mean, + # rtol=rtol, ) assert np.isclose( algorithms.standard_deviation( graph, outcome_node_label="X", samples=samples, rng_key=rng_key ), expected_std_dev, + # rtol=rtol, ) @@ -162,8 +159,8 @@ def test_expectation_stdev_single_normal_node(samples, rtol, mean, stdev, rng_ke def test_mean_stdev_two_node_graph( two_normal_graph, samples, rtol, mean, stdev, stdev2, rng_key ): - if samples > 100000: # noqa: PLR2004 - pytest.xfail("Test currently runs out of memory") + if samples > max_samples: + pytest.xfail("Test currently too slow") graph = two_normal_graph(mean=mean, cov=stdev, cov2=stdev2) @@ -191,10 +188,10 @@ def test_mean_stdev_two_node_graph( pytest.param(1000000, 1e-2, id="10^6 samples"), ], ) -def test_expectation(ux_x_graph, rng_key, samples, rtol): - if samples > 100: # noqa: PLR2004 +def test_expectation(two_normal_graph, rng_key, samples, rtol): + if samples > max_samples: pytest.xfail("Test currently too slow") - graph = ux_x_graph() + graph = two_normal_graph() assert np.isclose( algorithms.expectation( @@ -215,10 +212,10 @@ def test_expectation(ux_x_graph, rng_key, samples, rtol): pytest.param(1000000, 1e-2, id="10^6 samples"), ], ) -def test_stdev(ux_x_graph, rng_key, samples, rtol): - if samples > 100: # noqa: PLR2004 +def test_stdev(two_normal_graph, rng_key, samples, rtol): + if samples > max_samples: pytest.xfail("Test currently too slow") - graph = ux_x_graph() + graph = two_normal_graph() assert np.isclose( algorithms.standard_deviation( From 477918848f51681615a7daaaf52de09fe721bbae Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Wed, 13 Aug 2025 11:30:55 +0100 Subject: [PATCH 12/18] adjust test tolerances --- tests/fixtures/graph.py | 4 +-- tests/test_causal_problem/test_callables.py | 12 ++------ tests/test_graph/test_algorithms.py | 31 +++------------------ tests/test_graph/test_parameters.py | 8 +++--- 4 files changed, 13 insertions(+), 42 deletions(-) diff --git a/tests/fixtures/graph.py b/tests/fixtures/graph.py index 9fb6a22..5cbc639 100644 --- a/tests/fixtures/graph.py +++ b/tests/fixtures/graph.py @@ -69,7 +69,7 @@ def _inner(mean: float = 5.0, cov: float = 1.0, cov2: float = 1.0) -> Graph: DistributionNode( Normal, label="UX", - constant_parameters={"loc": mean, "scale": cov**2}, + constant_parameters={"loc": mean, "scale": cov}, ) ) graph.add_node( @@ -77,7 +77,7 @@ def _inner(mean: float = 5.0, cov: float = 1.0, cov2: float = 1.0) -> Graph: Normal, label="X", parameters={"loc": "UX"}, - constant_parameters={"scale": cov2**2}, + constant_parameters={"scale": cov2}, ) ) graph.add_edge("UX", "X") diff --git a/tests/test_causal_problem/test_callables.py b/tests/test_causal_problem/test_callables.py index 7aecf29..9e686c8 100644 --- a/tests/test_causal_problem/test_callables.py +++ b/tests/test_causal_problem/test_callables.py @@ -92,9 +92,7 @@ def which(request: pytest.FixtureRequest) -> Literal["causal_estimand", "constra "graph_argument": "g", }, 0.0, - # Empirical calculation with 1000 samples with fixture RNG key - # should give 1.8808e-2 as the empirical expectation. - 2.0e-2, + 3.0e-2, id="E[x], infer association", ), pytest.param( @@ -106,9 +104,7 @@ def which(request: pytest.FixtureRequest) -> Literal["causal_estimand", "constra }, # x has fixed std 1, and nu_y will be set to 1. 1.0**2 + 1.0**2, - # Empirical calculation with 1000 samples with fixture RNG key - # should give 1.8506 as the empirical std of y. - 2.0e-1, + 3.0e-1, id="Var[y]", ), pytest.param( @@ -120,9 +116,7 @@ def which(request: pytest.FixtureRequest) -> Literal["causal_estimand", "constra }, # As per the previous test cases jnp.array([0.0, 1.0**2 + 1.0**2]), - # As per the above cases, both components should be within - # 2.0e-1 of the analytical value. - jnp.array([2.0e-2, 2.0e-1]), + jnp.array([3.0e-2, 2.0e-1]), id="E[x], Var[y]", ), ], diff --git a/tests/test_graph/test_algorithms.py b/tests/test_graph/test_algorithms.py index e8a271f..95de311 100644 --- a/tests/test_graph/test_algorithms.py +++ b/tests/test_graph/test_algorithms.py @@ -1,6 +1,5 @@ """Tests for graph algorithms.""" -import jax import numpy as np import pytest from numpyro.distributions import Normal @@ -11,6 +10,7 @@ max_samples = 10**5 + def test_roots_down_to_outcome() -> None: graph = Graph(label="G0") @@ -79,22 +79,14 @@ def test_do(rng_key, two_normal_graph): pytest.param(1.0, 1.2, 10000000, 1e-3, id="N(mean=1, stdev=1.2), 10^7 samples"), ], ) -def test_expectation_stdev_single_normal_node(normal_graph, samples, rtol, mean, stdev, rng_key): +def test_expectation_stdev_single_normal_node( + normal_graph, samples, rtol, mean, stdev, rng_key +): if samples > max_samples: pytest.xfail("Test currently too slow") graph = normal_graph(mean, stdev) - # To compensate for rng-key splitting in sample methods, note the "split" key - # that is actually used to draw the samples from the distribution, so we can - # attempt to replicate its behaviour explicitly. - key = jax.random.split(rng_key, 1)[0] - what_we_should_get = jax.random.multivariate_normal( - rng_key, jax.numpy.atleast_1d(mean), jax.numpy.atleast_2d(stdev**2), shape=samples - ) - expected_mean = what_we_should_get.mean() - expected_std_dev = what_we_should_get.std() - # Check within hand-computation assert np.isclose( algorithms.expectation( @@ -110,21 +102,6 @@ def test_expectation_stdev_single_normal_node(normal_graph, samples, rtol, mean, stdev, rtol=rtol, ) - # Check within computational distance - assert np.isclose( - algorithms.expectation( - graph, outcome_node_label="X", samples=samples, rng_key=rng_key - ), - expected_mean, - # rtol=rtol, - ) - assert np.isclose( - algorithms.standard_deviation( - graph, outcome_node_label="X", samples=samples, rng_key=rng_key - ), - expected_std_dev, - # rtol=rtol, - ) @pytest.mark.parametrize( diff --git a/tests/test_graph/test_parameters.py b/tests/test_graph/test_parameters.py index b7336a3..0e7c790 100644 --- a/tests/test_graph/test_parameters.py +++ b/tests/test_graph/test_parameters.py @@ -7,7 +7,7 @@ from causalprog.graph import DistributionNode, Graph, ParameterNode -NormalGraphNodeNames: TypeAlias = Literal["mean", "cov", "outcome"] +NormalGraphNodeNames: TypeAlias = Literal["mean", "cov", "X"] NormalGraphNodes: TypeAlias = dict[ NormalGraphNodeNames, DistributionNode | ParameterNode ] @@ -18,8 +18,8 @@ [ pytest.param( {}, - {"outcome": 4.0}, - TypeError("Node outcome is not a parameter node."), + {"X": 4.0}, + TypeError("Node X is not a parameter node."), id="Give non-parameter node", ), pytest.param( @@ -54,7 +54,7 @@ def test_set_parameters( parameter_nodes = graph.parameter_nodes assert graph.get_node("mean") in parameter_nodes assert graph.get_node("cov") in parameter_nodes - assert graph.get_node("outcome") not in parameter_nodes + assert graph.get_node("X") not in parameter_nodes # Set any pre-existing values we might want the parameter nodes to have in # this test. From 2276f0d489cb744293825f63f58f302b077a9c4f Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Wed, 13 Aug 2025 11:41:24 +0100 Subject: [PATCH 13/18] Add test that sample returns something with the correct shape --- tests/test_graph/test_algorithms.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_graph/test_algorithms.py b/tests/test_graph/test_algorithms.py index 95de311..200f58e 100644 --- a/tests/test_graph/test_algorithms.py +++ b/tests/test_graph/test_algorithms.py @@ -203,3 +203,14 @@ def test_stdev(two_normal_graph, rng_key, samples, rtol): ).std(), rtol=rtol, ) + + +@pytest.mark.parametrize("samples", [1, 2, 10, 100]) +def test_sample_shape(two_normal_graph, rng_key, samples): + graph = two_normal_graph() + + s1 = algorithms.moments.sample(graph, "X", samples, rng_key=rng_key) + assert s1.shape == () if samples == 1 else (samples,) + + s2 = algorithms.moments.sample(graph, "UX", samples, rng_key=rng_key) + assert s2.shape == () if samples == 1 else (samples,) From e8e191141698206921f7365b27e139eb981a81a1 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Wed, 13 Aug 2025 12:18:40 +0100 Subject: [PATCH 14/18] use parameter nodes in fixture graph if no values requested --- tests/fixtures/graph.py | 79 +++++++++++++---------------- tests/test_graph/test_algorithms.py | 8 +-- tests/test_graph/test_model.py | 34 ++++++------- 3 files changed, 56 insertions(+), 65 deletions(-) diff --git a/tests/fixtures/graph.py b/tests/fixtures/graph.py index 5cbc639..11d52fa 100644 --- a/tests/fixtures/graph.py +++ b/tests/fixtures/graph.py @@ -61,65 +61,56 @@ def two_normal_graph() -> Callable[[float, float, float], Graph]: where UX is a normal distribution with mean `mean` and covariance `cov`, and X is a normal distrubution with mean UX and covariance `cov2`. + Parameter nodes are included if no values are given for the mean and covariances. + """ - def _inner(mean: float = 5.0, cov: float = 1.0, cov2: float = 1.0) -> Graph: + def _inner( + mean: float | None = None, cov: float | None = None, cov2: float | None = None + ) -> Graph: graph = Graph(label="G0") - graph.add_node( - DistributionNode( - Normal, - label="UX", - constant_parameters={"loc": mean, "scale": cov}, - ) - ) - graph.add_node( - DistributionNode( - Normal, - label="X", - parameters={"loc": "UX"}, - constant_parameters={"scale": cov2}, - ) - ) - graph.add_edge("UX", "X") - - return graph - return _inner - - -@pytest.fixture -def two_normal_graph_parametrized_mean() -> Callable[[float], Graph]: - """Creates a graph: - SDUX SDX - | | - V v - mu_x --> UX --> X - - where UX is a normal distribution with mean mu_x and covariance `co`, and X is - a normal distrubution with mean UX and covariance nu_y. - - """ + x_parameters = {"loc": "UX"} + x_constant_parameters = {} + ux_parameters = {} + ux_constant_parameters = {} + if mean is None: + graph.add_node(ParameterNode(label="mean")) + ux_parameters["loc"] = "mean" + else: + ux_constant_parameters["loc"] = mean + if cov is None: + graph.add_node(ParameterNode(label="cov")) + ux_parameters["scale"] = "cov" + else: + ux_constant_parameters["scale"] = cov + if cov2 is None: + graph.add_node(ParameterNode(label="cov2")) + x_parameters["scale"] = "cov2" + else: + x_constant_parameters["scale"] = cov2 - def _inner(cov: float = 1.0) -> Graph: - graph = Graph(label="G0") - graph.add_node(ParameterNode(label="nu_y")) - graph.add_node(ParameterNode(label="mu_x")) graph.add_node( DistributionNode( Normal, label="UX", - parameters={"loc": "mu_x"}, - constant_parameters={"scale": cov}, + parameters=ux_parameters, + constant_parameters=ux_constant_parameters, ) ) graph.add_node( DistributionNode( Normal, label="X", - parameters={"loc": "UX", "scale": "nu_y"}, + parameters=x_parameters, + constant_parameters=x_constant_parameters, ) ) graph.add_edge("UX", "X") + for node in ux_parameters.values(): + graph.add_edge(node, "UX") + for node in x_parameters.values(): + graph.add_edge(node, "X") return graph @@ -130,9 +121,9 @@ def _inner(cov: float = 1.0) -> Graph: def two_normal_graph_expected_model() -> Callable[..., dict[str, npt.ArrayLike]]: """Creates the model that the two_normal_graph should produce.""" - def _inner(mu_x: float, nu_y: float) -> dict[str, npt.ArrayLike]: - ux = numpyro.sample("UX", Normal(loc=mu_x, scale=1.0)) - x = numpyro.sample("X", Normal(loc=ux, scale=nu_y)) + def _inner(mean: float, cov2: float) -> dict[str, npt.ArrayLike]: + ux = numpyro.sample("UX", Normal(loc=mean, scale=1.0)) + x = numpyro.sample("X", Normal(loc=ux, scale=cov2)) return {"X": x, "UX": ux} diff --git a/tests/test_graph/test_algorithms.py b/tests/test_graph/test_algorithms.py index 200f58e..fe5f46b 100644 --- a/tests/test_graph/test_algorithms.py +++ b/tests/test_graph/test_algorithms.py @@ -45,7 +45,7 @@ def test_roots_down_to_outcome() -> None: def test_do(rng_key, two_normal_graph): - graph = two_normal_graph() + graph = two_normal_graph(5.0, 1.2, 0.8) graph2 = causalprog.algorithms.do(graph, "UX", 4.0) assert "loc" in graph.get_node("X").parameters @@ -168,7 +168,7 @@ def test_mean_stdev_two_node_graph( def test_expectation(two_normal_graph, rng_key, samples, rtol): if samples > max_samples: pytest.xfail("Test currently too slow") - graph = two_normal_graph() + graph = two_normal_graph(1.0, 1.2, 0.8) assert np.isclose( algorithms.expectation( @@ -192,7 +192,7 @@ def test_expectation(two_normal_graph, rng_key, samples, rtol): def test_stdev(two_normal_graph, rng_key, samples, rtol): if samples > max_samples: pytest.xfail("Test currently too slow") - graph = two_normal_graph() + graph = two_normal_graph(1.0, 1.2, 0.8) assert np.isclose( algorithms.standard_deviation( @@ -207,7 +207,7 @@ def test_stdev(two_normal_graph, rng_key, samples, rtol): @pytest.mark.parametrize("samples", [1, 2, 10, 100]) def test_sample_shape(two_normal_graph, rng_key, samples): - graph = two_normal_graph() + graph = two_normal_graph(1.0, 1.2, 0.8) s1 = algorithms.moments.sample(graph, "X", samples, rng_key=rng_key) assert s1.shape == () if samples == 1 else (samples,) diff --git a/tests/test_graph/test_model.py b/tests/test_graph/test_model.py index cddeb72..26e0b04 100644 --- a/tests/test_graph/test_model.py +++ b/tests/test_graph/test_model.py @@ -8,13 +8,13 @@ @pytest.mark.parametrize( "param_values", [ - pytest.param({"mu_x": 0.0, "nu_y": 1.0}, id="mean(X) = 0, cov(UX) = 1"), - pytest.param({"mu_x": 1.0, "nu_y": 2.0}, id="mean(X) = 1, cov(UX) = 2"), + pytest.param({"mean": 0.0, "cov2": 1.0}, id="mean(X) = 0, cov(UX) = 1"), + pytest.param({"mean": 1.0, "cov2": 2.0}, id="mean(X) = 1, cov(UX) = 2"), ], ) def test_model( param_values: dict[str, npt.ArrayLike], - two_normal_graph_parametrized_mean, + two_normal_graph, two_normal_graph_expected_model, assert_samples_are_identical, run_default_nuts_mcmc, @@ -31,7 +31,7 @@ def test_model( MCMC sampling output of `Graph.model` with the explicit model that it should correspond to. """ - graph = two_normal_graph_parametrized_mean() + graph = two_normal_graph(cov=1.0) assert callable(graph.model) via_model = run_default_nuts_mcmc( @@ -47,19 +47,19 @@ def test_model( def test_model_missing_parameter( - two_normal_graph_parametrized_mean, + two_normal_graph, raises_context, seed: int, ) -> None: """`Graph.model` will raise a `KeyError` when a value is not passed for a `ParameterNode`. """ - graph = two_normal_graph_parametrized_mean() + graph = two_normal_graph(cov=1.0) - # Deliberately leave out the "nu_y" variable. - parameter_values = {"mu_x": 0.0} + # Deliberately leave out the "cov2" variable. + parameter_values = {"mean": 0.0} # Which should result in the error below. - expected_exception = KeyError("ParameterNode 'nu_y' not assigned") + expected_exception = KeyError("ParameterNode 'cov2' not assigned") # Not passing enough parameters should be picked up by the model. with raises_context(expected_exception), numpyro.handlers.seed(rng_seed=seed): @@ -67,33 +67,33 @@ def test_model_missing_parameter( def test_model_extension( - two_normal_graph_parametrized_mean, + two_normal_graph, assert_samples_are_identical, run_default_nuts_mcmc, ) -> None: """Test that `Graph.model` can be extended.""" - graph = two_normal_graph_parametrized_mean() + graph = two_normal_graph(cov=1.0) - parameter_values = {"mu_x": 0.0, "nu_y": 1.0} + parameter_values = {"mean": 0.0, "cov2": 1.0} # Build the graph, but without the X-node. - mu_x = ParameterNode(label="mu_x") + mean = ParameterNode(label="mean") x = DistributionNode( numpyro.distributions.Normal, label="UX", - parameters={"loc": "mu_x"}, + parameters={"loc": "mean"}, constant_parameters={"scale": 1.0}, ) one_normal_graph = Graph(label="One normal") - one_normal_graph.add_edge(mu_x, x) + one_normal_graph.add_edge(mean, x) - def extended_model(*, nu_y, **parameter_values): + def extended_model(*, cov2, **parameter_values): sites = one_normal_graph.model(**parameter_values) numpyro.sample( "X", numpyro.distributions.Normal( loc=sites["UX"], - scale=nu_y, + scale=cov2, ), ) From 1762b4f4b24a6fa00a61ab2f7621c40df082a8dd Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Wed, 13 Aug 2025 12:19:40 +0100 Subject: [PATCH 15/18] better labels --- tests/fixtures/graph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/fixtures/graph.py b/tests/fixtures/graph.py index 11d52fa..62143d6 100644 --- a/tests/fixtures/graph.py +++ b/tests/fixtures/graph.py @@ -24,7 +24,7 @@ def normal_graph() -> Callable[[float, float], Graph]: """ def _inner(mean: float | None = None, cov: float | None = None): - graph = Graph(label="normal dist") + graph = Graph(label="normal_graph") parameters = {} constant_parameters = {} if mean is None: @@ -68,7 +68,7 @@ def two_normal_graph() -> Callable[[float, float, float], Graph]: def _inner( mean: float | None = None, cov: float | None = None, cov2: float | None = None ) -> Graph: - graph = Graph(label="G0") + graph = Graph(label="two_normal_graph") x_parameters = {"loc": "UX"} x_constant_parameters = {} From b30169b45c7b57bede0ee4c835f338cb73f0151b Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 14 Aug 2025 11:17:21 +0100 Subject: [PATCH 16/18] accept second rng key --- src/causalprog/algorithms/moments.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/causalprog/algorithms/moments.py b/src/causalprog/algorithms/moments.py index f3bec09..3010e78 100644 --- a/src/causalprog/algorithms/moments.py +++ b/src/causalprog/algorithms/moments.py @@ -41,11 +41,19 @@ def standard_deviation( samples: int, *, rng_key: jax.Array, + rng_key_first_moment: jax.Array | None = None, ) -> float: """Estimate the standard deviation of a graph.""" return ( moment(2, graph, outcome_node_label, samples, rng_key=rng_key) - - moment(1, graph, outcome_node_label, samples, rng_key=rng_key) ** 2 + - moment( + 1, + graph, + outcome_node_label, + samples, + rng_key=rng_key if rng_key_first_moment is None else rng_key_first_moment, + ) + ** 2 ) ** 0.5 From cc70daa701209869ee56950700151ec63b7db633 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 14 Aug 2025 11:18:52 +0100 Subject: [PATCH 17/18] update docs --- src/causalprog/algorithms/moments.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/causalprog/algorithms/moments.py b/src/causalprog/algorithms/moments.py index 3010e78..a91a65b 100644 --- a/src/causalprog/algorithms/moments.py +++ b/src/causalprog/algorithms/moments.py @@ -13,7 +13,7 @@ def sample( *, rng_key: jax.Array, ) -> npt.NDArray[float]: - """Sample data from a graph.""" + """Sample data from a random variable attached to a node in a graph.""" nodes = graph.roots_down_to_outcome(outcome_node_label) values: dict[str, npt.NDArray[float]] = {} @@ -31,7 +31,7 @@ def expectation( *, rng_key: jax.Array, ) -> float: - """Estimate the expectation of a graph.""" + """Estimate the expectation of a random variable attached to a node in a graph.""" return moment(1, graph, outcome_node_label, samples, rng_key=rng_key) @@ -43,7 +43,7 @@ def standard_deviation( rng_key: jax.Array, rng_key_first_moment: jax.Array | None = None, ) -> float: - """Estimate the standard deviation of a graph.""" + """Estimate the standard deviation of a random variable attached to a node in a graph.""" return ( moment(2, graph, outcome_node_label, samples, rng_key=rng_key) - moment( @@ -65,7 +65,7 @@ def moment( *, rng_key: jax.Array, ) -> float: - """Estimate moment of the given order of the data.""" + """Estimate a moment of a random variable attached to a node in a graph.""" return ( sum(sample(graph, outcome_node_label, samples, rng_key=rng_key) ** order) / samples From bf1cbcb830e474896c9372827410951e3a2b0de9 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 14 Aug 2025 11:20:22 +0100 Subject: [PATCH 18/18] ruff --- src/causalprog/algorithms/moments.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/causalprog/algorithms/moments.py b/src/causalprog/algorithms/moments.py index a91a65b..c196d01 100644 --- a/src/causalprog/algorithms/moments.py +++ b/src/causalprog/algorithms/moments.py @@ -13,7 +13,7 @@ def sample( *, rng_key: jax.Array, ) -> npt.NDArray[float]: - """Sample data from a random variable attached to a node in a graph.""" + """Sample data from (a random variable attached to) a node in a graph.""" nodes = graph.roots_down_to_outcome(outcome_node_label) values: dict[str, npt.NDArray[float]] = {} @@ -31,7 +31,7 @@ def expectation( *, rng_key: jax.Array, ) -> float: - """Estimate the expectation of a random variable attached to a node in a graph.""" + """Estimate the expectation of (a random variable attached to) a node in a graph.""" return moment(1, graph, outcome_node_label, samples, rng_key=rng_key) @@ -43,7 +43,7 @@ def standard_deviation( rng_key: jax.Array, rng_key_first_moment: jax.Array | None = None, ) -> float: - """Estimate the standard deviation of a random variable attached to a node in a graph.""" + """Estimate the standard deviation of (a RV attached to) a node in a graph.""" return ( moment(2, graph, outcome_node_label, samples, rng_key=rng_key) - moment( @@ -65,7 +65,7 @@ def moment( *, rng_key: jax.Array, ) -> float: - """Estimate a moment of a random variable attached to a node in a graph.""" + """Estimate a moment of (a random variable attached to) a node in a graph.""" return ( sum(sample(graph, outcome_node_label, samples, rng_key=rng_key) ** order) / samples