/
graph.py
183 lines (149 loc) · 6.82 KB
/
graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
from queue import PriorityQueue, Queue
from typing import Callable, Dict, Iterator, List, Optional, Set, TypeVar
import networkx as nx
from lineapy.data.types import LineaID, Node, SessionContext
from lineapy.graph_reader.graph_printer import GraphPrinter
from lineapy.utils.utils import listify, prettify
class Graph(object):
def __init__(self, nodes: List[Node], session_context: SessionContext):
"""
Graph represents a
It is constructed based on the following variables:
:param nodes: a list of nodes and the session context in wh
:param session_context: the session context associated with the graph
NOTE:
# TODO: Possibly remove session context since we aren't using it anywhere
# we refer to the graph
- It makes sense to include session_context in the constructor of
the graph because the information in session_context is semantically
important to the notion of a Graph. Concretely, we are starting
to also use the code entry from the session_context.
"""
self.nodes: List[Node] = nodes
self.ids: Dict[LineaID, Node] = dict((n.id, n) for n in nodes)
self.nx_graph = nx.DiGraph()
self.nx_graph.add_nodes_from([node.id for node in nodes])
self.nx_graph.add_edges_from(
[
(parent_id, node.id)
for node in nodes
for parent_id in node.parents()
]
)
self.session_context = session_context
# validation
if not nx.is_directed_acyclic_graph(self.nx_graph):
raise AssertionError("Graph should not be cyclic")
def __eq__(self, other) -> bool:
return nx.is_isomorphic(self.nx_graph, other.nx_graph)
def print(self, **kwargs) -> str:
return GraphPrinter(self, **kwargs).print()
@listify
def visit_order(self) -> Iterator[Node]:
"""
Just using the line number as tie breaker for now since we don't have
a good way to track dependencies
Note that we cannot just use the line number to sort because
there are nodes created by us that do not have line numbers...
"""
# TODO: See if we could replace this with python's built in topological sort
# https://docs.python.org/3/library/graphlib.html
# It seems to suggest that resulting order is determined by insertion order
# so this could possibly also implicitly sort by line number, if we insert
# in that order.
# Generally, we want to traverse the graph in a way to maintain two
# constraints:
# 1. All parents must be traveresed before their children
# 2. If we have any freedom, those with earlier line number should come first
# To do this, we do a breadth first traversal, keeping our queue ordered
# by their line number. The sorting is done via the __lt__ method
# of the Node
queue: PriorityQueue[Node] = PriorityQueue()
# We also keep track of all nodes we have already added to the queue
# so that we don't add them again
seen: Set[LineaID] = set()
# We also keep a mapping of each node to the number of parents left
# which have not been visited yet.
# Note that we want to skip counting parents which are not part of our nodes
# This can happen we evaluate part of a graph, then another part.
# When evaluating the next part, we just have those nodes, so some
# of the parents will be missing, we assume they are already executed
remaining_parents: Dict[str, int] = {}
for node in self.nodes:
n_remaining_parents = len(
[
parent_id
for parent_id in self.nx_graph.pred[node.id]
if parent_id in self.ids
]
)
# First we add all of the nodes to the queue which have no parents
if n_remaining_parents == 0:
seen.add(node.id)
queue.put(node)
remaining_parents[node.id] = n_remaining_parents
while queue.qsize():
# Find the first node in the queue whcih has all its parents removed
node = queue_get_when(
queue, lambda n: remaining_parents[n.id] == 0
)
# Then, we add all of its children to the queue, making sure to mark
# for each that we have seen one of its parents
yield node
for child_id in self.get_children(node.id):
remaining_parents[child_id] -= 1
if child_id in seen:
continue
child_node = self.ids[child_id]
queue.put(child_node)
seen.add(child_id)
def get_parents(self, node_id: LineaID) -> List[LineaID]:
return list(self.nx_graph.predecessors(node_id))
def get_ancestors(self, node_id: LineaID) -> List[LineaID]:
return list(nx.ancestors(self.nx_graph, node_id))
def get_children(self, node_id: LineaID) -> List[LineaID]:
return list(self.nx_graph.successors(node_id))
def get_descendants(self, node_id: LineaID) -> List[LineaID]:
return list(nx.descendants(self.nx_graph, node_id))
def get_leaf_nodes(self) -> List[LineaID]:
return [
node
for node in self.nx_graph.nodes
if self.nx_graph.out_degree(node) == 0
]
def get_node(self, node_id: Optional[LineaID]) -> Optional[Node]:
if node_id is not None and node_id in self.ids:
return self.ids[node_id]
return None
def get_subgraph(self, nodes: List[Node]) -> "Graph":
"""
FIXME
"""
return Graph(nodes, self.session_context)
def __str__(self):
return prettify(
self.print(
include_source_location=False,
include_id_field=True,
include_session=False,
)
)
def __repr__(self):
return prettify(self.print())
T = TypeVar("T")
def queue_get_when(queue: "Queue[T]", filter_fn: Callable[[T], bool]) -> T:
"""
Gets the first element in the queue that satisfies the filter function.
"""
# We have to pop off a number of elements, stopping when we find one that
# satisfies our conditional, since we can't iterate through a queue.
# Use a timeout of 0 for the gets, otherewise if we have some bug
# where we are trying to get off the queue and its empty it will just
# block forever. with a timeout of 0, it will raise an exception instead.
popped_off = [queue.get(timeout=0)]
while not filter_fn(popped_off[-1]):
popped_off.append(queue.get(timeout=0))
*add_back_to_queue, found = popped_off
for tmp_node in add_back_to_queue:
queue.put(tmp_node)
return found