From 2e5fd3146c521c27c5772dcce002f3fe0f87d90c Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Tue, 29 Apr 2025 07:41:44 +0100 Subject: [PATCH 01/24] add parameter node --- src/causalprog/backend/_convert_signature.py | 4 +- src/causalprog/graph/__init__.py | 2 +- src/causalprog/graph/node.py | 39 +++++++++++++++++++- tests/test_graph.py | 11 +++++- 4 files changed, 49 insertions(+), 7 deletions(-) diff --git a/src/causalprog/backend/_convert_signature.py b/src/causalprog/backend/_convert_signature.py index 6b96419..f04791b 100644 --- a/src/causalprog/backend/_convert_signature.py +++ b/src/causalprog/backend/_convert_signature.py @@ -35,9 +35,7 @@ def _check_variable_length_params( parameter of that type exists in the signature. """ - named_args: dict[ParamKind, str | None] = { - kind: None for kind in _VARLENGTH_PARAM_TYPES - } + named_args: dict[ParamKind, str | None] = dict.fromkeys(_VARLENGTH_PARAM_TYPES) for kind in _VARLENGTH_PARAM_TYPES: possible_parameters = [ p_name for p_name, p in sig.parameters.items() if p.kind == kind diff --git a/src/causalprog/graph/__init__.py b/src/causalprog/graph/__init__.py index 89144ec..10b5794 100644 --- a/src/causalprog/graph/__init__.py +++ b/src/causalprog/graph/__init__.py @@ -1,4 +1,4 @@ """Creation and storage of graphs.""" from .graph import Graph -from .node import DistributionNode, Node +from .node import DistributionNode, Node, ParameterNode diff --git a/src/causalprog/graph/node.py b/src/causalprog/graph/node.py index 05fc183..bf9cbb4 100644 --- a/src/causalprog/graph/node.py +++ b/src/causalprog/graph/node.py @@ -19,10 +19,17 @@ class Node(Labelled): """An abstract node in a graph.""" - def __init__(self, label: str, *, is_outcome: bool = False) -> None: + def __init__( + self, + label: str, + *, + is_outcome: bool = False, + is_parameter: bool = False, + ) -> None: """Initialise.""" super().__init__(label=label) self._is_outcome = is_outcome + self._is_parameter = is_parameter @abstractmethod def sample( @@ -38,6 +45,11 @@ def is_outcome(self) -> bool: """Identify if the node is an outcome.""" return self._is_outcome + @property + def is_parameter(self) -> bool: + """Identify if the node is a parameter.""" + return self._is_parameter + class DistributionNode(Node): """A node containing a distribution.""" @@ -55,7 +67,7 @@ def __init__( self._dist = distribution self._constant_parameters = constant_parameters if constant_parameters else {} self._parameters = parameters if parameters else {} - super().__init__(label, is_outcome=is_outcome) + super().__init__(label, is_outcome=is_outcome, is_parameter=False) def sample( self, @@ -81,3 +93,26 @@ def sample( def __repr__(self) -> str: return f'DistributionNode("{self.label}")' + + +class ParameterNode(Node): + """A node containing a parameter.""" + + def __init__(self, label: str, *, is_outcome: bool = False) -> None: + """Initialise.""" + super().__init__(label, is_outcome=is_outcome, is_parameter=True) + + def sample( + self, + sampled_dependencies: dict[str, npt.NDArray[float]], + _samples: int, + _rng_key: jax.Array, + ) -> npt.NDArray[float]: + """Sample a value from the node.""" + if self.label not in sampled_dependencies: + msg = "Cannot sample an undetermined parameter node." + raise ValueError(msg) + return sampled_dependencies[self.label] + + def __repr__(self) -> str: + return f'ParameterNode("{self.label}")' diff --git a/tests/test_graph.py b/tests/test_graph.py index 7ad091f..ba481b0 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -8,7 +8,7 @@ import causalprog from causalprog.distribution.normal import NormalFamily -from causalprog.graph import DistributionNode, Graph +from causalprog.graph import DistributionNode, Graph, ParameterNode def test_label(): @@ -236,3 +236,12 @@ def test_two_node_graph(samples, rtol, mean, stdev, stdev2, rng_key): np.sqrt(stdev**2 + stdev2**2), rtol=rtol, ) + + +def test_paramater_node(rng_key): + node = ParameterNode("mu") + + with pytest.raises(ValueError, match="Cannot sample"): + node.sample({}, 1, rng_key) + + assert np.isclose(node.sample({"mu": np.array([0.3])}, 1, rng_key)[0], 0.3) From b12cd1d266195796314a25047b80a257489b1634 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Tue, 29 Apr 2025 11:52:05 +0100 Subject: [PATCH 02/24] Update src/causalprog/graph/node.py --- src/causalprog/graph/node.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/causalprog/graph/node.py b/src/causalprog/graph/node.py index bf9cbb4..85ac44b 100644 --- a/src/causalprog/graph/node.py +++ b/src/causalprog/graph/node.py @@ -101,6 +101,7 @@ class ParameterNode(Node): def __init__(self, label: str, *, is_outcome: bool = False) -> None: """Initialise.""" super().__init__(label, is_outcome=is_outcome, is_parameter=True) + self.value: int | None = None def sample( self, From 525ca6caf715c7ab3f1fadb6966fc4cf06ef1e70 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Tue, 29 Apr 2025 11:52:12 +0100 Subject: [PATCH 03/24] Update src/causalprog/graph/node.py --- src/causalprog/graph/node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/causalprog/graph/node.py b/src/causalprog/graph/node.py index 85ac44b..369876a 100644 --- a/src/causalprog/graph/node.py +++ b/src/causalprog/graph/node.py @@ -105,7 +105,7 @@ def __init__(self, label: str, *, is_outcome: bool = False) -> None: def sample( self, - sampled_dependencies: dict[str, npt.NDArray[float]], + _sampled_dependencies: dict[str, npt.NDArray[float]], _samples: int, _rng_key: jax.Array, ) -> npt.NDArray[float]: From df01d7fd9d18606a5db3dad70737a1279386476e Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Tue, 29 Apr 2025 11:52:18 +0100 Subject: [PATCH 04/24] Update src/causalprog/graph/node.py --- src/causalprog/graph/node.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/causalprog/graph/node.py b/src/causalprog/graph/node.py index 369876a..7fcc7c8 100644 --- a/src/causalprog/graph/node.py +++ b/src/causalprog/graph/node.py @@ -110,10 +110,10 @@ def sample( _rng_key: jax.Array, ) -> npt.NDArray[float]: """Sample a value from the node.""" - if self.label not in sampled_dependencies: + if self.value is None: msg = "Cannot sample an undetermined parameter node." raise ValueError(msg) - return sampled_dependencies[self.label] + return np.full(self.value, samples) def __repr__(self) -> str: return f'ParameterNode("{self.label}")' From cd5b06d85b4b6c29915f21793d0b27fdec5b7e6c Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Tue, 29 Apr 2025 12:00:23 +0100 Subject: [PATCH 05/24] Update src/causalprog/graph/node.py --- src/causalprog/graph/node.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/causalprog/graph/node.py b/src/causalprog/graph/node.py index 7fcc7c8..e55bf75 100644 --- a/src/causalprog/graph/node.py +++ b/src/causalprog/graph/node.py @@ -98,10 +98,10 @@ def __repr__(self) -> str: class ParameterNode(Node): """A node containing a parameter.""" - def __init__(self, label: str, *, is_outcome: bool = False) -> None: + def __init__(self, label: str, value: int | None, *, is_outcome: bool = False) -> None: """Initialise.""" super().__init__(label, is_outcome=is_outcome, is_parameter=True) - self.value: int | None = None + self.value = value def sample( self, From 5e4f786e2098aac004332de419304d7259656c9c Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Tue, 29 Apr 2025 12:02:21 +0100 Subject: [PATCH 06/24] ruff --- src/causalprog/graph/node.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/causalprog/graph/node.py b/src/causalprog/graph/node.py index e55bf75..c52bd32 100644 --- a/src/causalprog/graph/node.py +++ b/src/causalprog/graph/node.py @@ -98,7 +98,9 @@ def __repr__(self) -> str: class ParameterNode(Node): """A node containing a parameter.""" - def __init__(self, label: str, value: int | None, *, is_outcome: bool = False) -> None: + def __init__( + self, label: str, value: int | None, *, is_outcome: bool = False + ) -> None: """Initialise.""" super().__init__(label, is_outcome=is_outcome, is_parameter=True) self.value = value @@ -106,7 +108,7 @@ def __init__(self, label: str, value: int | None, *, is_outcome: bool = False) - def sample( self, _sampled_dependencies: dict[str, npt.NDArray[float]], - _samples: int, + samples: int, _rng_key: jax.Array, ) -> npt.NDArray[float]: """Sample a value from the node.""" From 23605e138f1c2cd8e4efb3776d386dd78ce51d33 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Tue, 29 Apr 2025 12:06:00 +0100 Subject: [PATCH 07/24] fix --- src/causalprog/graph/node.py | 2 +- tests/test_graph.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/causalprog/graph/node.py b/src/causalprog/graph/node.py index c52bd32..1f04518 100644 --- a/src/causalprog/graph/node.py +++ b/src/causalprog/graph/node.py @@ -99,7 +99,7 @@ class ParameterNode(Node): """A node containing a parameter.""" def __init__( - self, label: str, value: int | None, *, is_outcome: bool = False + self, label: str, *, value: int | None = None, is_outcome: bool = False ) -> None: """Initialise.""" super().__init__(label, is_outcome=is_outcome, is_parameter=True) diff --git a/tests/test_graph.py b/tests/test_graph.py index ba481b0..1ac7ce9 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -244,4 +244,6 @@ def test_paramater_node(rng_key): with pytest.raises(ValueError, match="Cannot sample"): node.sample({}, 1, rng_key) - assert np.isclose(node.sample({"mu": np.array([0.3])}, 1, rng_key)[0], 0.3) + node.value = 0.3 + + assert np.isclose(node.sample({}, 1, rng_key)[0], 0.3) From b7376cba58730b384ad65d0c377ef91e64e54fd6 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Tue, 29 Apr 2025 12:43:07 +0100 Subject: [PATCH 08/24] fix test --- src/causalprog/graph/node.py | 2 +- tests/test_graph.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/causalprog/graph/node.py b/src/causalprog/graph/node.py index 1f04518..7ab5e7f 100644 --- a/src/causalprog/graph/node.py +++ b/src/causalprog/graph/node.py @@ -115,7 +115,7 @@ def sample( if self.value is None: msg = "Cannot sample an undetermined parameter node." raise ValueError(msg) - return np.full(self.value, samples) + return np.full(samples, self.value) def __repr__(self) -> str: return f'ParameterNode("{self.label}")' diff --git a/tests/test_graph.py b/tests/test_graph.py index 1ac7ce9..b96ee51 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -246,4 +246,4 @@ def test_paramater_node(rng_key): node.value = 0.3 - assert np.isclose(node.sample({}, 1, rng_key)[0], 0.3) + assert np.allclose(node.sample({}, 10, rng_key)[0], [0.3] * 10) From f2a8aa38202a72350596cfa0f35cd894ccc753ca Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Wed, 30 Apr 2025 14:49:29 +0100 Subject: [PATCH 09/24] start adding do --- src/causalprog/algorithms/__init__.py | 1 + src/causalprog/algorithms/do.py | 17 ++++++++++++ src/causalprog/graph/graph.py | 9 +++++-- tests/test_graph.py | 37 +++++++++++++++++++++++++++ 4 files changed, 62 insertions(+), 2 deletions(-) create mode 100644 src/causalprog/algorithms/do.py diff --git a/src/causalprog/algorithms/__init__.py b/src/causalprog/algorithms/__init__.py index 1f87481..2134f9c 100644 --- a/src/causalprog/algorithms/__init__.py +++ b/src/causalprog/algorithms/__init__.py @@ -1,3 +1,4 @@ """Algorithms.""" from .expectation import expectation, standard_deviation +from .do import do diff --git a/src/causalprog/algorithms/do.py b/src/causalprog/algorithms/do.py new file mode 100644 index 0000000..17d8d37 --- /dev/null +++ b/src/causalprog/algorithms/do.py @@ -0,0 +1,17 @@ +"""Algorithms for applying do to a graph.""" + +import numpy.typing as npt + +from causalprog.graph import Graph + + +def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph: + if label is None: + label = f"{graph.label}|do({node}={value})" + + g = graph._graph.copy() + return Graph(label, g) + + from IPython import embed; embed() + + return graph.copy() diff --git a/src/causalprog/graph/graph.py b/src/causalprog/graph/graph.py index fdb604c..84b1297 100644 --- a/src/causalprog/graph/graph.py +++ b/src/causalprog/graph/graph.py @@ -12,11 +12,16 @@ class Graph(Labelled): _nodes_by_label: dict[str, Node] - def __init__(self, label: str) -> None: + def __init__(self, label: str, graph: nx.DiGraph | None = None) -> None: """Create end empty graph.""" super().__init__(label=label) - self._graph = nx.DiGraph() self._nodes_by_label = {} + if graph is None: + self._graph = nx.DiGraph() + else: + self._graph = graph + for node in graph.nodes: + self._nodes_by_label[node.label] = node def get_node(self, label: str) -> Node: """Get a node from its label.""" diff --git a/tests/test_graph.py b/tests/test_graph.py index 7ad091f..6b6ac93 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -236,3 +236,40 @@ def test_two_node_graph(samples, rtol, mean, stdev, stdev2, rng_key): np.sqrt(stdev**2 + stdev2**2), rtol=rtol, ) + + +def test_do(rng_key): + graph = causalprog.graph.Graph("G0") + graph.add_node( + DistributionNode( + NormalFamily(), "UX", constant_parameters={"mean": 5.0, "cov": 1.0} + ) + ) + graph.add_node( + DistributionNode( + NormalFamily(), + "X", + parameters={"mean": "UX"}, + constant_parameters={"cov": 1.0}, + is_outcome=True, + ) + ) + graph.add_edge("UX", "X") + + graph2 = causalprog.algorithms.do(graph, "UX", 4.0) + + assert np.isclose( + causalprog.algorithms.expectation( + graph, outcome_node_label="X", samples=1000, rng_key=rng_key + ), + 5.0, + rtol=1e-1, + ) + + assert np.isclose( + causalprog.algorithms.expectation( + graph2, outcome_node_label="X", samples=1000, rng_key=rng_key + ), + 4.0, + rtol=1e-1, + ) From 124e3531ade9bc6e20fb340acb9c135d2504a536 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Wed, 30 Apr 2025 15:07:48 +0100 Subject: [PATCH 10/24] do --- src/causalprog/algorithms/do.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/causalprog/algorithms/do.py b/src/causalprog/algorithms/do.py index 17d8d37..88814cb 100644 --- a/src/causalprog/algorithms/do.py +++ b/src/causalprog/algorithms/do.py @@ -2,7 +2,7 @@ import numpy.typing as npt -from causalprog.graph import Graph +from causalprog.graph import Graph, ParameterNode def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph: @@ -10,8 +10,22 @@ def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph label = f"{graph.label}|do({node}={value})" g = graph._graph.copy() - return Graph(label, g) + + new_node = ParameterNode(node, value=value) + g.add_node(new_node) + + edges_to_remove = [] + for e in graph._graph.edges: + if e[0].label == node: + g.add_edge(new_node, e[1]) + g.remove_edge(*e) + if e[1].label == node: + g.add_edge(e[0], new_node) + g.remove_edge(*e) + + g.remove_node(graph.get_node(node)) + from IPython import embed; embed() - return graph.copy() + return Graph(label, g) From 462a875d430b8f7cdc1da37cf5aeb6662be756ec Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Wed, 30 Apr 2025 15:09:05 +0100 Subject: [PATCH 11/24] need noqa here as they're somethimes used as kwargs --- src/causalprog/graph/node.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/causalprog/graph/node.py b/src/causalprog/graph/node.py index 7ab5e7f..1558070 100644 --- a/src/causalprog/graph/node.py +++ b/src/causalprog/graph/node.py @@ -107,9 +107,9 @@ def __init__( def sample( self, - _sampled_dependencies: dict[str, npt.NDArray[float]], + sampled_dependencies: dict[str, npt.NDArray[float]], # noqa: ARG002 samples: int, - _rng_key: jax.Array, + rng_key: jax.Array, # noqa: ARG002 ) -> npt.NDArray[float]: """Sample a value from the node.""" if self.value is None: From f4f625a07374bef78e6a69b0403d87dc7704c2dc Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Wed, 30 Apr 2025 15:13:45 +0100 Subject: [PATCH 12/24] make do work, ruff --- src/causalprog/algorithms/__init__.py | 2 +- src/causalprog/algorithms/do.py | 24 ++++++++++++++++-------- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/src/causalprog/algorithms/__init__.py b/src/causalprog/algorithms/__init__.py index 2134f9c..879d9aa 100644 --- a/src/causalprog/algorithms/__init__.py +++ b/src/causalprog/algorithms/__init__.py @@ -1,4 +1,4 @@ """Algorithms.""" -from .expectation import expectation, standard_deviation from .do import do +from .expectation import expectation, standard_deviation diff --git a/src/causalprog/algorithms/do.py b/src/causalprog/algorithms/do.py index 88814cb..6694b23 100644 --- a/src/causalprog/algorithms/do.py +++ b/src/causalprog/algorithms/do.py @@ -1,21 +1,32 @@ """Algorithms for applying do to a graph.""" -import numpy.typing as npt - from causalprog.graph import Graph, ParameterNode def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph: + """ + Apply do to a graph. + + Args: + graph: The graph to apply do to. This will be copied. + node: The label of the node to apply do to. + value: The value to set the node to. + label: The label of the new graph + + Returns: + A copy of the graph with do applied + + """ if label is None: label = f"{graph.label}|do({node}={value})" - g = graph._graph.copy() + old_g = graph._graph # noqa: SLF001 + g = old_g.copy() new_node = ParameterNode(node, value=value) g.add_node(new_node) - edges_to_remove = [] - for e in graph._graph.edges: + for e in old_g.edges: if e[0].label == node: g.add_edge(new_node, e[1]) g.remove_edge(*e) @@ -25,7 +36,4 @@ def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph g.remove_node(graph.get_node(node)) - - from IPython import embed; embed() - return Graph(label, g) From ac9d1eefc6f637d67f114595d07d745b28b04a51 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 1 May 2025 07:31:41 +0100 Subject: [PATCH 13/24] no need to add back in edges going to now-constant node --- src/causalprog/algorithms/do.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/causalprog/algorithms/do.py b/src/causalprog/algorithms/do.py index 6694b23..1eb9014 100644 --- a/src/causalprog/algorithms/do.py +++ b/src/causalprog/algorithms/do.py @@ -30,8 +30,7 @@ def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph if e[0].label == node: g.add_edge(new_node, e[1]) g.remove_edge(*e) - if e[1].label == node: - g.add_edge(e[0], new_node) + elif e[1].label == node: g.remove_edge(*e) g.remove_node(graph.get_node(node)) From 194a632d55829f4536717c03ce4cee2ee304929a Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 1 May 2025 10:23:54 +0100 Subject: [PATCH 14/24] remove constant parameter node --- src/causalprog/algorithms/do.py | 29 +++++++++++++------ src/causalprog/graph/node.py | 51 +++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 8 deletions(-) diff --git a/src/causalprog/algorithms/do.py b/src/causalprog/algorithms/do.py index 1eb9014..85df45e 100644 --- a/src/causalprog/algorithms/do.py +++ b/src/causalprog/algorithms/do.py @@ -1,6 +1,6 @@ """Algorithms for applying do to a graph.""" -from causalprog.graph import Graph, ParameterNode +from causalprog.graph import Graph def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph: @@ -23,16 +23,29 @@ def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph old_g = graph._graph # noqa: SLF001 g = old_g.copy() - new_node = ParameterNode(node, value=value) - g.add_node(new_node) - for e in old_g.edges: - if e[0].label == node: - g.add_edge(new_node, e[1]) - g.remove_edge(*e) - elif e[1].label == node: + if e[0].label == node or e[1].label == node: g.remove_edge(*e) g.remove_node(graph.get_node(node)) + new_nodes = {} + for n in old_g.nodes: + new_n = None + for i, j in n.parameters.items(): + if j == node: + if new_n is None: + new_n = n.copy() + new_n.constant_parameters[i] = value + del new_n.parameters[i] + if new_n is not None: + g.add_node(new_n) + + for e in old_g.edges: + if e[0].label in new_nodes or e[1].label in new_nodes: + g.add_edge(new_nodes.get(e[0].label, e[0]), new_nodes.get(e[1].label, e[1])) + g.remove_edge(*e) + for n in new_nodes: + g.remove_node(graph.get_node(n)) + return Graph(label, g) diff --git a/src/causalprog/graph/node.py b/src/causalprog/graph/node.py index 4455691..56f05db 100644 --- a/src/causalprog/graph/node.py +++ b/src/causalprog/graph/node.py @@ -40,6 +40,10 @@ def sample( ) -> float: """Sample a value from the node.""" + @abstractmethod + def copy(self) -> Node: + """Make a copy of a node.""" + @property def is_outcome(self) -> bool: """Identify if the node is an outcome.""" @@ -50,6 +54,15 @@ def is_parameter(self) -> bool: """Identify if the node is a parameter.""" return self._is_parameter + @property + @abstractmethod + def constant_parameters(self) -> dict[str, float]: + """Named constants that this node depends on.""" + + @property + def parameters(self) -> dict[str, str]: + """Nodes that this node depends on.""" + class DistributionNode(Node): """A node containing a distribution.""" @@ -91,9 +104,29 @@ def sample( output[sample] = concrete_dist.sample(new_key[sample], 1)[0][0] return output + def copy(self) -> Node: + """Make a copy of a node.""" + return DistributionNode( + self._dist, + label=self.label, + parameters=dict(self._parameters), + constant_parameters=dict(self._constant_parameters.items()), + is_outcome=self.is_outcome, + ) + def __repr__(self) -> str: return f'DistributionNode("{self.label}")' + @property + def constant_parameters(self) -> dict[str, float]: + """Named constants that this node depends on.""" + return self._constant_parameters + + @property + def parameters(self) -> dict[str, str]: + """Nodes that this node depends on.""" + return self._parameters + class ParameterNode(Node): """ @@ -135,5 +168,23 @@ def sample( raise ValueError(msg) return np.full(samples, self.value) + def copy(self) -> Node: + """Make a copy of a node.""" + return ParameterNode( + label=self.label, + value=self.value, + is_outcome=self.is_outcome, + ) + def __repr__(self) -> str: return f'ParameterNode("{self.label}")' + + @property + def constant_parameters(self) -> dict[str, float]: + """Named constants that this node depends on.""" + return {} + + @property + def parameters(self) -> dict[str, str]: + """Nodes that this node depends on.""" + return {} From d933857bc099fde83eed5c1f4d8e41935ad002d2 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 1 May 2025 10:26:48 +0100 Subject: [PATCH 15/24] removing a node also removed edges attached to it? --- src/causalprog/algorithms/do.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/causalprog/algorithms/do.py b/src/causalprog/algorithms/do.py index 85df45e..c2e5e79 100644 --- a/src/causalprog/algorithms/do.py +++ b/src/causalprog/algorithms/do.py @@ -23,10 +23,6 @@ def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph old_g = graph._graph # noqa: SLF001 g = old_g.copy() - for e in old_g.edges: - if e[0].label == node or e[1].label == node: - g.remove_edge(*e) - g.remove_node(graph.get_node(node)) new_nodes = {} @@ -44,7 +40,6 @@ def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph for e in old_g.edges: if e[0].label in new_nodes or e[1].label in new_nodes: g.add_edge(new_nodes.get(e[0].label, e[0]), new_nodes.get(e[1].label, e[1])) - g.remove_edge(*e) for n in new_nodes: g.remove_node(graph.get_node(n)) From 74a56edbf148c74b1bc0c8426fdb94f37e990dd6 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 1 May 2025 11:44:24 +0100 Subject: [PATCH 16/24] mypy --- src/causalprog/algorithms/do.py | 4 ++-- src/causalprog/graph/node.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/causalprog/algorithms/do.py b/src/causalprog/algorithms/do.py index c2e5e79..35d66a1 100644 --- a/src/causalprog/algorithms/do.py +++ b/src/causalprog/algorithms/do.py @@ -1,6 +1,6 @@ """Algorithms for applying do to a graph.""" -from causalprog.graph import Graph +from causalprog.graph import Graph, Node def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph: @@ -25,7 +25,7 @@ def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph g.remove_node(graph.get_node(node)) - new_nodes = {} + new_nodes: dict[str, Node] = {} for n in old_g.nodes: new_n = None for i, j in n.parameters.items(): diff --git a/src/causalprog/graph/node.py b/src/causalprog/graph/node.py index 56f05db..4ca3950 100644 --- a/src/causalprog/graph/node.py +++ b/src/causalprog/graph/node.py @@ -60,6 +60,7 @@ def constant_parameters(self) -> dict[str, float]: """Named constants that this node depends on.""" @property + @abstractmethod def parameters(self) -> dict[str, str]: """Nodes that this node depends on.""" From 4eeb6b6be85d0e97d608e1bff2f25f80997d9285 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 1 May 2025 11:46:20 +0100 Subject: [PATCH 17/24] Update tests/test_graph.py --- tests/test_graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_graph.py b/tests/test_graph.py index bf40218..db24a47 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -338,7 +338,7 @@ def test_set_parameters( assert normal_graph.get_node(node_name).value == expected_value -def test_paramater_node(rng_key): +def test_parameter_node(rng_key): node = ParameterNode("mu") with pytest.raises(ValueError, match="Cannot sample"): From d20904ff846ace897b288d36410b301f1f0091f8 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Wed, 7 May 2025 10:58:38 +0100 Subject: [PATCH 18/24] assert mean in correct parameter sets --- tests/test_graph.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_graph.py b/tests/test_graph.py index db24a47..be2dde2 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -369,6 +369,11 @@ def test_do(rng_key): graph2 = causalprog.algorithms.do(graph, "UX", 4.0) + assert "mean" in graph.get_node("X").parameters + assert "mean" not in graph.get_node("X").constant_parameters + assert "mean" not in graph2.get_node("X").parameters + assert "mean" in graph2.get_node("X").constant_parameters + assert np.isclose( causalprog.algorithms.expectation( graph, outcome_node_label="X", samples=1000, rng_key=rng_key From 5e286eaaee77b3aa5016046e946fd7bc50cabd90 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Thu, 15 May 2025 09:42:37 +0100 Subject: [PATCH 19/24] keyword args --- src/causalprog/algorithms/do.py | 2 +- tests/test_graph.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/causalprog/algorithms/do.py b/src/causalprog/algorithms/do.py index 35d66a1..be276d5 100644 --- a/src/causalprog/algorithms/do.py +++ b/src/causalprog/algorithms/do.py @@ -43,4 +43,4 @@ def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph for n in new_nodes: g.remove_node(graph.get_node(n)) - return Graph(label, g) + return Graph(label=label, graph=g) diff --git a/tests/test_graph.py b/tests/test_graph.py index 5a42aeb..58b5675 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -352,7 +352,7 @@ def test_parameter_node(rng_key): def test_do(rng_key): - graph = causalprog.graph.Graph("G0") + graph = causalprog.graph.Graph(label="G0") graph.add_node( DistributionNode( NormalFamily(), "UX", constant_parameters={"mean": 5.0, "cov": 1.0} @@ -361,7 +361,7 @@ def test_do(rng_key): graph.add_node( DistributionNode( NormalFamily(), - "X", + label="X", parameters={"mean": "UX"}, constant_parameters={"cov": 1.0}, is_outcome=True, From 432714ff54eacbdf0bf019839ff79db50d973ba1 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Mon, 4 Aug 2025 15:55:20 +0100 Subject: [PATCH 20/24] deepcopy --- src/causalprog/algorithms/do.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/causalprog/algorithms/do.py b/src/causalprog/algorithms/do.py index be276d5..f4f8872 100644 --- a/src/causalprog/algorithms/do.py +++ b/src/causalprog/algorithms/do.py @@ -1,6 +1,7 @@ """Algorithms for applying do to a graph.""" from causalprog.graph import Graph, Node +from copy import deepcopy def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph: @@ -21,7 +22,7 @@ def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph label = f"{graph.label}|do({node}={value})" old_g = graph._graph # noqa: SLF001 - g = old_g.copy() + g = deepcopy(old_g) g.remove_node(graph.get_node(node)) @@ -31,7 +32,7 @@ def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph for i, j in n.parameters.items(): if j == node: if new_n is None: - new_n = n.copy() + new_n = deepcopy(n) new_n.constant_parameters[i] = value del new_n.parameters[i] if new_n is not None: From 774b335b61a49ddea18e6aecee80845857026998 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Mon, 4 Aug 2025 15:57:17 +0100 Subject: [PATCH 21/24] Update src/causalprog/graph/graph.py Co-authored-by: Will Graham <32364977+willGraham01@users.noreply.github.com> --- src/causalprog/graph/graph.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/causalprog/graph/graph.py b/src/causalprog/graph/graph.py index 3b640d4..ebe0ef4 100644 --- a/src/causalprog/graph/graph.py +++ b/src/causalprog/graph/graph.py @@ -17,11 +17,11 @@ def __init__(self, *, label: str, graph: nx.DiGraph | None = None) -> None: super().__init__(label=label) self._nodes_by_label = {} if graph is None: - self._graph = nx.DiGraph() - else: - self._graph = graph - for node in graph.nodes: - self._nodes_by_label[node.label] = node + graph = nx.DiGraph() + + self._graph = graph + for node in graph.nodes: + self._nodes_by_label[node.label] = node def get_node(self, label: str) -> Node: """Get a node from its label.""" From 213a2526510de351eb99ce690123eb97b2af5e86 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Mon, 4 Aug 2025 15:58:33 +0100 Subject: [PATCH 22/24] Update src/causalprog/algorithms/do.py Co-authored-by: Will Graham <32364977+willGraham01@users.noreply.github.com> --- src/causalprog/algorithms/do.py | 40 ++++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/src/causalprog/algorithms/do.py b/src/causalprog/algorithms/do.py index be276d5..c3c046d 100644 --- a/src/causalprog/algorithms/do.py +++ b/src/causalprog/algorithms/do.py @@ -26,21 +26,39 @@ def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph g.remove_node(graph.get_node(node)) new_nodes: dict[str, Node] = {} - for n in old_g.nodes: + # Search through the old graph, identifying nodes that had parameters which were + # defined by the node being fixed in the DO operation. + # We recreate these nodes, but replace each such parameter we encounter with + # a constant parameter equal that takes the fixed value given as an input. + for original_node in old_g.nodes: new_n = None - for i, j in n.parameters.items(): - if j == node: + for parameter_name, parameter_target_node in original_node.parameters.items(): + if parameter_target_node == node: + # If this parameter in the original_node was determined by the node we + # are fixing with DO. if new_n is None: - new_n = n.copy() - new_n.constant_parameters[i] = value - del new_n.parameters[i] + new_n = original_node.copy() + # Swap the parameter to a constant parameter, giving it the fixed value + new_n.constant_parameters[parameter_name] = value + # Remove the parameter from the node's record of non-constant parameters + new_n.parameters.pop(parameter_name) + # If we had to recreate a new node, add it to the new (Di)Graph. + # Also record the name of the node that it is set to replace if new_n is not None: g.add_node(new_n) + # new_nodes[original_node.label] = new_node ? - for e in old_g.edges: - if e[0].label in new_nodes or e[1].label in new_nodes: - g.add_edge(new_nodes.get(e[0].label, e[0]), new_nodes.get(e[1].label, e[1])) - for n in new_nodes: - g.remove_node(graph.get_node(n)) + # Any new_nodes whose counterparts connect to other nodes in the network need + # to mimic these links. + for edge in old_g.edges: + if edge[0].label in new_nodes or edge[1].label in new_nodes: + g.add_edge( + new_nodes.get(edge[0].label, edge[0]), + new_nodes.get(edge[1].label, edge[1]), + ) + # Now that the new_nodes are present in the graph, and correctly connected, remove + # their counterparts from the graph. + for original_node in new_nodes: + g.remove_node(graph.get_node(original_node)) return Graph(label=label, graph=g) From 010db2eeaaa808d045a5e4aca3fd76de77bdf416 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Mon, 4 Aug 2025 15:59:21 +0100 Subject: [PATCH 23/24] Update src/causalprog/algorithms/do.py --- src/causalprog/algorithms/do.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/causalprog/algorithms/do.py b/src/causalprog/algorithms/do.py index c3c046d..8abd430 100644 --- a/src/causalprog/algorithms/do.py +++ b/src/causalprog/algorithms/do.py @@ -46,7 +46,7 @@ def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph # Also record the name of the node that it is set to replace if new_n is not None: g.add_node(new_n) - # new_nodes[original_node.label] = new_node ? + new_nodes[original_node.label] = new_node # Any new_nodes whose counterparts connect to other nodes in the network need # to mimic these links. From 55c39a409e5c1d661c8a616ac95af9a69c232ef3 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Mon, 4 Aug 2025 16:18:14 +0100 Subject: [PATCH 24/24] fix deepcopy --- src/causalprog/algorithms/do.py | 12 +++++++----- src/causalprog/graph/graph.py | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/causalprog/algorithms/do.py b/src/causalprog/algorithms/do.py index 2d39cfb..b49d1b8 100644 --- a/src/causalprog/algorithms/do.py +++ b/src/causalprog/algorithms/do.py @@ -1,8 +1,9 @@ """Algorithms for applying do to a graph.""" -from causalprog.graph import Graph, Node from copy import deepcopy +from causalprog.graph import Graph + def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph: """ @@ -24,9 +25,10 @@ def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph old_g = graph._graph # noqa: SLF001 g = deepcopy(old_g) - g.remove_node(graph.get_node(node)) + nodes_by_label = {n.label: n for n in g.nodes} + g.remove_node(nodes_by_label[node]) - new_nodes: dict[str, Node] = {} + new_nodes = {} # Search through the old graph, identifying nodes that had parameters which were # defined by the node being fixed in the DO operation. # We recreate these nodes, but replace each such parameter we encounter with @@ -47,7 +49,7 @@ def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph # Also record the name of the node that it is set to replace if new_n is not None: g.add_node(new_n) - new_nodes[original_node.label] = new_node + new_nodes[original_node.label] = new_n # Any new_nodes whose counterparts connect to other nodes in the network need # to mimic these links. @@ -60,6 +62,6 @@ def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph # Now that the new_nodes are present in the graph, and correctly connected, remove # their counterparts from the graph. for original_node in new_nodes: - g.remove_node(graph.get_node(original_node)) + g.remove_node(nodes_by_label[original_node]) return Graph(label=label, graph=g) diff --git a/src/causalprog/graph/graph.py b/src/causalprog/graph/graph.py index ebe0ef4..7dbd6dd 100644 --- a/src/causalprog/graph/graph.py +++ b/src/causalprog/graph/graph.py @@ -18,7 +18,7 @@ def __init__(self, *, label: str, graph: nx.DiGraph | None = None) -> None: self._nodes_by_label = {} if graph is None: graph = nx.DiGraph() - + self._graph = graph for node in graph.nodes: self._nodes_by_label[node.label] = node