Skip to content

Commit

Permalink
Fix node_depth (#152)
Browse files Browse the repository at this point in the history
* Fix node_depth

* Fix typo
  • Loading branch information
YamLyubov committed Jul 24, 2023
1 parent 9f8dfcd commit 123a9d8
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 22 deletions.
52 changes: 35 additions & 17 deletions golem/core/dag/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,31 +94,49 @@ def subtree_impl(node):
return subtree_impl(node)


def node_depth(nodes: Union['GraphNode', Sequence['GraphNode']]) -> Union[int, List[int]]:
"""Gets the depth of the provided ``nodes`` in the graph
def node_depth(nodes: Union['GraphNode', Sequence['GraphNode']]) -> int:
"""Gets the maximal depth among the provided ``nodes`` in the graph
Args:
nodes: nodes to calculate the depth for
Returns:
int or List[int]: depth(s) of the nodes in the graph
int: maximal depth
"""
nodes = ensure_wrapped_in_sequence(nodes)
visited_nodes = [[node] for node in nodes]
depth = 1
parents = [node.nodes_from for node in nodes]
while any(parents):
depth += 1
for i, ith_parents in enumerate(parents):
grandparents = []
for parent in ith_parents:
if parent in visited_nodes[i]:
final_depth = {}
subnodes = set()
for node in nodes:
max_depth = 0
# if node is a subnode of another node it has smaller depth
if node.uid in subnodes:
continue
depth = 1
visited = []
if node in visited:
return -1
visited.append(node)
stack = [(node, depth, iter(node.nodes_from))]
while stack:
curr_node, depth_now, parents = stack[-1]
try:
parent = next(parents)
subnodes.add(parent.uid)
if parent not in visited:
visited.append(parent)
if parent.uid in final_depth:
# depth of the parent has been already calculated
stack.append((parent, depth_now + final_depth[parent.uid], iter([])))
else:
stack.append((parent, depth_now + 1, iter(parent.nodes_from)))
else:
return -1
grandparents.extend(parent.nodes_from)
visited_nodes[i].extend(ith_parents)
parents[i] = grandparents

return depth
except StopIteration:
_, depth_now, _ = stack.pop()
visited.pop()
max_depth = max(max_depth, depth_now)
final_depth[node.uid] = max_depth
return max(final_depth.values())


def map_dag_nodes(transform: Callable, nodes: Sequence) -> Sequence:
Expand Down
2 changes: 1 addition & 1 deletion golem/core/dag/linked_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def update_node(self, old_node: GraphNode, new_node: GraphNode):
self.actualise_old_node_children(old_node, new_node)
new_node.nodes_from.extend(old_node.nodes_from)
self._nodes.remove(old_node)
self._nodes.append(new_node)
self.add_node(new_node)
self.sort_nodes()
self._postprocess_nodes(self, self._nodes)

Expand Down
6 changes: 3 additions & 3 deletions test/unit/dag/test_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from golem.core.dag.linked_graph_node import LinkedGraphNode
from test.unit.dag.test_graph_operator import graph
from test.unit.utils import graph_first, simple_cycled_graph, branched_cycled_graph, graph_second, graph_third, \
graph_fifth, graph_with_multi_roots_first
graph_fifth, graph_with_multi_roots_first, joined_branches_graph

_ = graph

Expand Down Expand Up @@ -91,8 +91,8 @@ def test_graph_has_cycle():

@pytest.mark.parametrize('graph, nodes_names, correct_depths', [(simple_cycled_graph(), ['c', 'd', 'e'], -1),
(graph_fifth(), ['b', 'c', 'd'], 4),
(graph_with_multi_roots_first(), ['16', '13', '14'],
3)])
(graph_with_multi_roots_first(), ['16', '13', '14'], 3),
(joined_branches_graph(), ['d', 'f', 'c'], 5)])
def test_node_depth(graph, nodes_names, correct_depths):
nodes = [graph.get_nodes_by_name(name)[0] for name in nodes_names]
depths = node_depth(nodes)
Expand Down
18 changes: 17 additions & 1 deletion test/unit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ def simple_cycled_graph():


def branched_cycled_graph():
#
node_a_primary = LinkedGraphNode('a')
node_b = LinkedGraphNode('b', nodes_from=[node_a_primary])
node_c = LinkedGraphNode('c', nodes_from=[node_b])
Expand All @@ -270,6 +269,23 @@ def branched_cycled_graph():
return graph


def joined_branches_graph():
# a
# / \
# c - b
# | /
# d /
# | /
# f
node_a = LinkedGraphNode('a')
node_b = LinkedGraphNode('b', nodes_from=[node_a])
node_c = LinkedGraphNode('c', nodes_from=[node_b, node_a])
node_d = LinkedGraphNode('d', nodes_from=[node_c])
node_f = LinkedGraphNode('f', nodes_from=[node_d, node_b])
graph = GraphDelegate(node_f)
return graph


class RandomMetric:
@staticmethod
def get_value(graph, *args, delay=0, **kwargs) -> float:
Expand Down

0 comments on commit 123a9d8

Please sign in to comment.