Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
2e5fd31
add parameter node
mscroggs Apr 29, 2025
b12cd1d
Update src/causalprog/graph/node.py
mscroggs Apr 29, 2025
525ca6c
Update src/causalprog/graph/node.py
mscroggs Apr 29, 2025
df01d7f
Update src/causalprog/graph/node.py
mscroggs Apr 29, 2025
cd5b06d
Update src/causalprog/graph/node.py
mscroggs Apr 29, 2025
5e4f786
ruff
mscroggs Apr 29, 2025
23605e1
fix
mscroggs Apr 29, 2025
b7376cb
fix test
mscroggs Apr 29, 2025
f2a8aa3
start adding do
mscroggs Apr 30, 2025
5428618
Merge branch 'mscroggs/parameter-node' into mscroggs/do
mscroggs Apr 30, 2025
124e353
do
mscroggs Apr 30, 2025
462a875
need noqa here as they're somethimes used as kwargs
mscroggs Apr 30, 2025
a455662
Merge branch 'mscroggs/parameter-node' into mscroggs/do
mscroggs Apr 30, 2025
f4f625a
make do work, ruff
mscroggs Apr 30, 2025
ac9d1ee
no need to add back in edges going to now-constant node
mscroggs May 1, 2025
ec4b935
Merge branch 'main' into mscroggs/do
mscroggs May 1, 2025
194a632
remove constant parameter node
mscroggs May 1, 2025
d933857
removing a node also removed edges attached to it?
mscroggs May 1, 2025
74a56ed
mypy
mscroggs May 1, 2025
4eeb6b6
Update tests/test_graph.py
mscroggs May 1, 2025
d20904f
assert mean in correct parameter sets
mscroggs May 7, 2025
850d293
Merge branch 'main' into mscroggs/do
mscroggs May 15, 2025
5e286ea
keyword args
mscroggs May 15, 2025
432714f
deepcopy
mscroggs Aug 4, 2025
774b335
Update src/causalprog/graph/graph.py
mscroggs Aug 4, 2025
213a252
Update src/causalprog/algorithms/do.py
mscroggs Aug 4, 2025
010db2e
Update src/causalprog/algorithms/do.py
mscroggs Aug 4, 2025
e9d79f6
Merge branch 'mscroggs/do' of github.com:UCL/causalprog into mscroggs/do
mscroggs Aug 4, 2025
55c39a4
fix deepcopy
mscroggs Aug 4, 2025
5a0b76c
Merge branch 'main' into mscroggs/do
mscroggs Aug 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/causalprog/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Algorithms."""

from .do import do
from .expectation import expectation, standard_deviation
67 changes: 67 additions & 0 deletions src/causalprog/algorithms/do.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Algorithms for applying do to a graph."""

from copy import deepcopy

from causalprog.graph import Graph


def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph:
"""
Apply do to a graph.

Args:
graph: The graph to apply do to. This will be copied.
node: The label of the node to apply do to.
value: The value to set the node to.
label: The label of the new graph

Returns:
A copy of the graph with do applied

"""
if label is None:
label = f"{graph.label}|do({node}={value})"

old_g = graph._graph # noqa: SLF001
g = deepcopy(old_g)

nodes_by_label = {n.label: n for n in g.nodes}
g.remove_node(nodes_by_label[node])

new_nodes = {}
# Search through the old graph, identifying nodes that had parameters which were
# defined by the node being fixed in the DO operation.
# We recreate these nodes, but replace each such parameter we encounter with
# a constant parameter equal that takes the fixed value given as an input.
for original_node in old_g.nodes:
new_n = None
for parameter_name, parameter_target_node in original_node.parameters.items():
if parameter_target_node == node:
# If this parameter in the original_node was determined by the node we
# are fixing with DO.
if new_n is None:
new_n = deepcopy(original_node)
# Swap the parameter to a constant parameter, giving it the fixed value
new_n.constant_parameters[parameter_name] = value
# Remove the parameter from the node's record of non-constant parameters
new_n.parameters.pop(parameter_name)
# If we had to recreate a new node, add it to the new (Di)Graph.
# Also record the name of the node that it is set to replace
if new_n is not None:
g.add_node(new_n)
new_nodes[original_node.label] = new_n

# Any new_nodes whose counterparts connect to other nodes in the network need
# to mimic these links.
for edge in old_g.edges:
if edge[0].label in new_nodes or edge[1].label in new_nodes:
g.add_edge(
new_nodes.get(edge[0].label, edge[0]),
new_nodes.get(edge[1].label, edge[1]),
)
# Now that the new_nodes are present in the graph, and correctly connected, remove
# their counterparts from the graph.
for original_node in new_nodes:
g.remove_node(nodes_by_label[original_node])

return Graph(label=label, graph=g)
9 changes: 7 additions & 2 deletions src/causalprog/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,16 @@ class Graph(Labelled):

_nodes_by_label: dict[str, Node]

def __init__(self, *, label: str) -> None:
def __init__(self, *, label: str, graph: nx.DiGraph | None = None) -> None:
"""Create end empty graph."""
super().__init__(label=label)
self._graph = nx.DiGraph()
self._nodes_by_label = {}
if graph is None:
graph = nx.DiGraph()

self._graph = graph
for node in graph.nodes:
self._nodes_by_label[node.label] = node

def get_node(self, label: str) -> Node:
"""Get a node from its label."""
Expand Down
52 changes: 52 additions & 0 deletions src/causalprog/graph/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def sample(
) -> float:
"""Sample a value from the node."""

@abstractmethod
def copy(self) -> Node:
"""Make a copy of a node."""

@property
def is_outcome(self) -> bool:
"""Identify if the node is an outcome."""
Expand All @@ -50,6 +54,16 @@ def is_parameter(self) -> bool:
"""Identify if the node is a parameter."""
return self._is_parameter

@property
@abstractmethod
def constant_parameters(self) -> dict[str, float]:
"""Named constants that this node depends on."""

@property
@abstractmethod
def parameters(self) -> dict[str, str]:
"""Nodes that this node depends on."""


class DistributionNode(Node):
"""A node containing a distribution."""
Expand Down Expand Up @@ -91,9 +105,29 @@ def sample(
output[sample] = concrete_dist.sample(new_key[sample], 1)[0][0]
return output

def copy(self) -> Node:
"""Make a copy of a node."""
return DistributionNode(
self._dist,
label=self.label,
parameters=dict(self._parameters),
constant_parameters=dict(self._constant_parameters.items()),
is_outcome=self.is_outcome,
)

def __repr__(self) -> str:
return f'DistributionNode("{self.label}")'

@property
def constant_parameters(self) -> dict[str, float]:
"""Named constants that this node depends on."""
return self._constant_parameters

@property
def parameters(self) -> dict[str, str]:
"""Nodes that this node depends on."""
return self._parameters


class ParameterNode(Node):
"""
Expand Down Expand Up @@ -135,5 +169,23 @@ def sample(
raise ValueError(msg)
return np.full(samples, self.value)

def copy(self) -> Node:
"""Make a copy of a node."""
return ParameterNode(
label=self.label,
value=self.value,
is_outcome=self.is_outcome,
)

def __repr__(self) -> str:
return f'ParameterNode("{self.label}")'

@property
def constant_parameters(self) -> dict[str, float]:
"""Named constants that this node depends on."""
return {}

@property
def parameters(self) -> dict[str, str]:
"""Nodes that this node depends on."""
return {}
42 changes: 42 additions & 0 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,45 @@ def test_parameter_node(rng_key):
node.value = 0.3

assert np.allclose(node.sample({}, 10, rng_key)[0], [0.3] * 10)


def test_do(rng_key):
graph = causalprog.graph.Graph(label="G0")
graph.add_node(
DistributionNode(
NormalFamily(), "UX", constant_parameters={"mean": 5.0, "cov": 1.0}
)
)
graph.add_node(
DistributionNode(
NormalFamily(),
label="X",
parameters={"mean": "UX"},
constant_parameters={"cov": 1.0},
is_outcome=True,
)
)
graph.add_edge("UX", "X")

graph2 = causalprog.algorithms.do(graph, "UX", 4.0)

assert "mean" in graph.get_node("X").parameters
assert "mean" not in graph.get_node("X").constant_parameters
assert "mean" not in graph2.get_node("X").parameters
assert "mean" in graph2.get_node("X").constant_parameters

assert np.isclose(
causalprog.algorithms.expectation(
graph, outcome_node_label="X", samples=1000, rng_key=rng_key
),
5.0,
rtol=1e-1,
)

assert np.isclose(
causalprog.algorithms.expectation(
graph2, outcome_node_label="X", samples=1000, rng_key=rng_key
),
4.0,
rtol=1e-1,
)