diff --git a/src/causalprog/algorithms/__init__.py b/src/causalprog/algorithms/__init__.py index 1f87481..879d9aa 100644 --- a/src/causalprog/algorithms/__init__.py +++ b/src/causalprog/algorithms/__init__.py @@ -1,3 +1,4 @@ """Algorithms.""" +from .do import do from .expectation import expectation, standard_deviation diff --git a/src/causalprog/algorithms/do.py b/src/causalprog/algorithms/do.py new file mode 100644 index 0000000..b49d1b8 --- /dev/null +++ b/src/causalprog/algorithms/do.py @@ -0,0 +1,67 @@ +"""Algorithms for applying do to a graph.""" + +from copy import deepcopy + +from causalprog.graph import Graph + + +def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph: + """ + Apply do to a graph. + + Args: + graph: The graph to apply do to. This will be copied. + node: The label of the node to apply do to. + value: The value to set the node to. + label: The label of the new graph + + Returns: + A copy of the graph with do applied + + """ + if label is None: + label = f"{graph.label}|do({node}={value})" + + old_g = graph._graph # noqa: SLF001 + g = deepcopy(old_g) + + nodes_by_label = {n.label: n for n in g.nodes} + g.remove_node(nodes_by_label[node]) + + new_nodes = {} + # Search through the old graph, identifying nodes that had parameters which were + # defined by the node being fixed in the DO operation. + # We recreate these nodes, but replace each such parameter we encounter with + # a constant parameter equal that takes the fixed value given as an input. + for original_node in old_g.nodes: + new_n = None + for parameter_name, parameter_target_node in original_node.parameters.items(): + if parameter_target_node == node: + # If this parameter in the original_node was determined by the node we + # are fixing with DO. + if new_n is None: + new_n = deepcopy(original_node) + # Swap the parameter to a constant parameter, giving it the fixed value + new_n.constant_parameters[parameter_name] = value + # Remove the parameter from the node's record of non-constant parameters + new_n.parameters.pop(parameter_name) + # If we had to recreate a new node, add it to the new (Di)Graph. + # Also record the name of the node that it is set to replace + if new_n is not None: + g.add_node(new_n) + new_nodes[original_node.label] = new_n + + # Any new_nodes whose counterparts connect to other nodes in the network need + # to mimic these links. + for edge in old_g.edges: + if edge[0].label in new_nodes or edge[1].label in new_nodes: + g.add_edge( + new_nodes.get(edge[0].label, edge[0]), + new_nodes.get(edge[1].label, edge[1]), + ) + # Now that the new_nodes are present in the graph, and correctly connected, remove + # their counterparts from the graph. + for original_node in new_nodes: + g.remove_node(nodes_by_label[original_node]) + + return Graph(label=label, graph=g) diff --git a/src/causalprog/graph/graph.py b/src/causalprog/graph/graph.py index 17be27e..7dbd6dd 100644 --- a/src/causalprog/graph/graph.py +++ b/src/causalprog/graph/graph.py @@ -12,11 +12,16 @@ class Graph(Labelled): _nodes_by_label: dict[str, Node] - def __init__(self, *, label: str) -> None: + def __init__(self, *, label: str, graph: nx.DiGraph | None = None) -> None: """Create end empty graph.""" super().__init__(label=label) - self._graph = nx.DiGraph() self._nodes_by_label = {} + if graph is None: + graph = nx.DiGraph() + + self._graph = graph + for node in graph.nodes: + self._nodes_by_label[node.label] = node def get_node(self, label: str) -> Node: """Get a node from its label.""" diff --git a/src/causalprog/graph/node.py b/src/causalprog/graph/node.py index 0bf01d5..c9e0b2a 100644 --- a/src/causalprog/graph/node.py +++ b/src/causalprog/graph/node.py @@ -40,6 +40,10 @@ def sample( ) -> float: """Sample a value from the node.""" + @abstractmethod + def copy(self) -> Node: + """Make a copy of a node.""" + @property def is_outcome(self) -> bool: """Identify if the node is an outcome.""" @@ -50,6 +54,16 @@ def is_parameter(self) -> bool: """Identify if the node is a parameter.""" return self._is_parameter + @property + @abstractmethod + def constant_parameters(self) -> dict[str, float]: + """Named constants that this node depends on.""" + + @property + @abstractmethod + def parameters(self) -> dict[str, str]: + """Nodes that this node depends on.""" + class DistributionNode(Node): """A node containing a distribution.""" @@ -91,9 +105,29 @@ def sample( output[sample] = concrete_dist.sample(new_key[sample], 1)[0][0] return output + def copy(self) -> Node: + """Make a copy of a node.""" + return DistributionNode( + self._dist, + label=self.label, + parameters=dict(self._parameters), + constant_parameters=dict(self._constant_parameters.items()), + is_outcome=self.is_outcome, + ) + def __repr__(self) -> str: return f'DistributionNode("{self.label}")' + @property + def constant_parameters(self) -> dict[str, float]: + """Named constants that this node depends on.""" + return self._constant_parameters + + @property + def parameters(self) -> dict[str, str]: + """Nodes that this node depends on.""" + return self._parameters + class ParameterNode(Node): """ @@ -135,5 +169,23 @@ def sample( raise ValueError(msg) return np.full(samples, self.value) + def copy(self) -> Node: + """Make a copy of a node.""" + return ParameterNode( + label=self.label, + value=self.value, + is_outcome=self.is_outcome, + ) + def __repr__(self) -> str: return f'ParameterNode("{self.label}")' + + @property + def constant_parameters(self) -> dict[str, float]: + """Named constants that this node depends on.""" + return {} + + @property + def parameters(self) -> dict[str, str]: + """Nodes that this node depends on.""" + return {} diff --git a/tests/test_graph.py b/tests/test_graph.py index c110414..58b5675 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -349,3 +349,45 @@ def test_parameter_node(rng_key): node.value = 0.3 assert np.allclose(node.sample({}, 10, rng_key)[0], [0.3] * 10) + + +def test_do(rng_key): + graph = causalprog.graph.Graph(label="G0") + graph.add_node( + DistributionNode( + NormalFamily(), "UX", constant_parameters={"mean": 5.0, "cov": 1.0} + ) + ) + graph.add_node( + DistributionNode( + NormalFamily(), + label="X", + parameters={"mean": "UX"}, + constant_parameters={"cov": 1.0}, + is_outcome=True, + ) + ) + graph.add_edge("UX", "X") + + graph2 = causalprog.algorithms.do(graph, "UX", 4.0) + + assert "mean" in graph.get_node("X").parameters + assert "mean" not in graph.get_node("X").constant_parameters + assert "mean" not in graph2.get_node("X").parameters + assert "mean" in graph2.get_node("X").constant_parameters + + assert np.isclose( + causalprog.algorithms.expectation( + graph, outcome_node_label="X", samples=1000, rng_key=rng_key + ), + 5.0, + rtol=1e-1, + ) + + assert np.isclose( + causalprog.algorithms.expectation( + graph2, outcome_node_label="X", samples=1000, rng_key=rng_key + ), + 4.0, + rtol=1e-1, + )