-
Notifications
You must be signed in to change notification settings - Fork 0
Implement do #42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Implement do #42
Changes from all commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
2e5fd31
add parameter node
mscroggs b12cd1d
Update src/causalprog/graph/node.py
mscroggs 525ca6c
Update src/causalprog/graph/node.py
mscroggs df01d7f
Update src/causalprog/graph/node.py
mscroggs cd5b06d
Update src/causalprog/graph/node.py
mscroggs 5e4f786
ruff
mscroggs 23605e1
fix
mscroggs b7376cb
fix test
mscroggs f2a8aa3
start adding do
mscroggs 5428618
Merge branch 'mscroggs/parameter-node' into mscroggs/do
mscroggs 124e353
do
mscroggs 462a875
need noqa here as they're somethimes used as kwargs
mscroggs a455662
Merge branch 'mscroggs/parameter-node' into mscroggs/do
mscroggs f4f625a
make do work, ruff
mscroggs ac9d1ee
no need to add back in edges going to now-constant node
mscroggs ec4b935
Merge branch 'main' into mscroggs/do
mscroggs 194a632
remove constant parameter node
mscroggs d933857
removing a node also removed edges attached to it?
mscroggs 74a56ed
mypy
mscroggs 4eeb6b6
Update tests/test_graph.py
mscroggs d20904f
assert mean in correct parameter sets
mscroggs 850d293
Merge branch 'main' into mscroggs/do
mscroggs 5e286ea
keyword args
mscroggs 432714f
deepcopy
mscroggs 774b335
Update src/causalprog/graph/graph.py
mscroggs 213a252
Update src/causalprog/algorithms/do.py
mscroggs 010db2e
Update src/causalprog/algorithms/do.py
mscroggs e9d79f6
Merge branch 'mscroggs/do' of github.com:UCL/causalprog into mscroggs/do
mscroggs 55c39a4
fix deepcopy
mscroggs 5a0b76c
Merge branch 'main' into mscroggs/do
mscroggs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| """Algorithms.""" | ||
|
|
||
| from .do import do | ||
| from .expectation import expectation, standard_deviation |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| """Algorithms for applying do to a graph.""" | ||
|
|
||
| from copy import deepcopy | ||
|
|
||
| from causalprog.graph import Graph | ||
|
|
||
|
|
||
| def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph: | ||
| """ | ||
| Apply do to a graph. | ||
|
|
||
| Args: | ||
| graph: The graph to apply do to. This will be copied. | ||
| node: The label of the node to apply do to. | ||
| value: The value to set the node to. | ||
| label: The label of the new graph | ||
|
|
||
| Returns: | ||
| A copy of the graph with do applied | ||
|
|
||
| """ | ||
| if label is None: | ||
| label = f"{graph.label}|do({node}={value})" | ||
|
|
||
| old_g = graph._graph # noqa: SLF001 | ||
| g = deepcopy(old_g) | ||
|
|
||
| nodes_by_label = {n.label: n for n in g.nodes} | ||
| g.remove_node(nodes_by_label[node]) | ||
|
|
||
| new_nodes = {} | ||
| # Search through the old graph, identifying nodes that had parameters which were | ||
| # defined by the node being fixed in the DO operation. | ||
| # We recreate these nodes, but replace each such parameter we encounter with | ||
| # a constant parameter equal that takes the fixed value given as an input. | ||
| for original_node in old_g.nodes: | ||
| new_n = None | ||
| for parameter_name, parameter_target_node in original_node.parameters.items(): | ||
| if parameter_target_node == node: | ||
| # If this parameter in the original_node was determined by the node we | ||
| # are fixing with DO. | ||
| if new_n is None: | ||
| new_n = deepcopy(original_node) | ||
| # Swap the parameter to a constant parameter, giving it the fixed value | ||
| new_n.constant_parameters[parameter_name] = value | ||
| # Remove the parameter from the node's record of non-constant parameters | ||
| new_n.parameters.pop(parameter_name) | ||
| # If we had to recreate a new node, add it to the new (Di)Graph. | ||
| # Also record the name of the node that it is set to replace | ||
| if new_n is not None: | ||
| g.add_node(new_n) | ||
| new_nodes[original_node.label] = new_n | ||
|
|
||
| # Any new_nodes whose counterparts connect to other nodes in the network need | ||
| # to mimic these links. | ||
| for edge in old_g.edges: | ||
| if edge[0].label in new_nodes or edge[1].label in new_nodes: | ||
| g.add_edge( | ||
| new_nodes.get(edge[0].label, edge[0]), | ||
| new_nodes.get(edge[1].label, edge[1]), | ||
| ) | ||
| # Now that the new_nodes are present in the graph, and correctly connected, remove | ||
| # their counterparts from the graph. | ||
| for original_node in new_nodes: | ||
| g.remove_node(nodes_by_label[original_node]) | ||
|
|
||
| return Graph(label=label, graph=g) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.