Permalink
Cannot retrieve contributors at this time
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
333 lines (290 sloc)
11.5 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import inspect | |
import ast | |
import re | |
def deindent_docstring(doc): | |
if doc: | |
# Find the indent to remove from the doctring. We consider the following possibilities: | |
# Option 1: | |
# """This is the first line | |
# This is the second line | |
# """ | |
# Option 2: | |
# """ | |
# This is the first line | |
# This is the second line | |
# """ | |
# Option 3: | |
# """ | |
# This is the first line | |
# This is the second line | |
# """ | |
# | |
# In all cases, we can find the indent to remove by doing the following: | |
# - Check the first non-empty line, if it has an indent, use that as the base indent | |
# - If it does not have an indent and there is a second line, check the indent of the | |
# second line and use that | |
saw_first_line = False | |
matched_indent = None | |
for line in doc.splitlines(): | |
if line: | |
matched_indent = re.match("[\t ]+", line) | |
if matched_indent is not None or saw_first_line: | |
break | |
saw_first_line = True | |
if matched_indent: | |
return re.sub(r"\n" + matched_indent.group(), "\n", doc).strip() | |
else: | |
return doc | |
else: | |
return "" | |
class DAGNode(object): | |
def __init__(self, func_ast, decos, doc): | |
self.name = func_ast.name | |
self.func_lineno = func_ast.lineno | |
self.decorators = decos | |
self.doc = deindent_docstring(doc) | |
self.parallel_step = any(getattr(deco, "IS_PARALLEL", False) for deco in decos) | |
# these attributes are populated by _parse | |
self.tail_next_lineno = 0 | |
self.type = None | |
self.out_funcs = [] | |
self.has_tail_next = False | |
self.invalid_tail_next = False | |
self.num_args = 0 | |
self.foreach_param = None | |
self.num_parallel = 0 | |
self.parallel_foreach = False | |
self._parse(func_ast) | |
# these attributes are populated by _traverse_graph | |
self.in_funcs = set() | |
self.split_parents = [] | |
self.matching_join = None | |
# these attributes are populated by _postprocess | |
self.is_inside_foreach = False | |
def _expr_str(self, expr): | |
return "%s.%s" % (expr.value.id, expr.attr) | |
def _parse(self, func_ast): | |
self.num_args = len(func_ast.args.args) | |
tail = func_ast.body[-1] | |
# end doesn't need a transition | |
if self.name == "end": | |
# TYPE: end | |
self.type = "end" | |
# ensure that the tail an expression | |
if not isinstance(tail, ast.Expr): | |
return | |
# determine the type of self.next transition | |
try: | |
if not self._expr_str(tail.value.func) == "self.next": | |
return | |
self.has_tail_next = True | |
self.invalid_tail_next = True | |
self.tail_next_lineno = tail.lineno | |
self.out_funcs = [e.attr for e in tail.value.args] | |
keywords = dict( | |
(k.arg, getattr(k.value, "s", None)) for k in tail.value.keywords | |
) | |
if len(keywords) == 1: | |
if "foreach" in keywords: | |
# TYPE: foreach | |
self.type = "foreach" | |
if len(self.out_funcs) == 1: | |
self.foreach_param = keywords["foreach"] | |
self.invalid_tail_next = False | |
elif "num_parallel" in keywords: | |
self.type = "foreach" | |
self.parallel_foreach = True | |
if len(self.out_funcs) == 1: | |
self.num_parallel = keywords["num_parallel"] | |
self.invalid_tail_next = False | |
elif len(keywords) == 0: | |
if len(self.out_funcs) > 1: | |
# TYPE: split | |
self.type = "split" | |
self.invalid_tail_next = False | |
elif len(self.out_funcs) == 1: | |
# TYPE: linear | |
if self.name == "start": | |
self.type = "start" | |
elif self.num_args > 1: | |
self.type = "join" | |
else: | |
self.type = "linear" | |
self.invalid_tail_next = False | |
except AttributeError: | |
return | |
def __str__(self): | |
return """*[{0.name} {0.type} (line {0.func_lineno})]* | |
in_funcs={in_funcs} | |
out_funcs={out_funcs} | |
split_parents={parents} | |
matching_join={matching_join} | |
is_inside_foreach={is_inside_foreach} | |
decorators={decos} | |
num_args={0.num_args} | |
has_tail_next={0.has_tail_next} (line {0.tail_next_lineno}) | |
invalid_tail_next={0.invalid_tail_next} | |
foreach_param={0.foreach_param} | |
parallel_step={0.parallel_step} | |
parallel_foreach={0.parallel_foreach} | |
-> {out}""".format( | |
self, | |
matching_join=self.matching_join and "[%s]" % self.matching_join, | |
is_inside_foreach=self.is_inside_foreach, | |
out_funcs=", ".join("[%s]" % x for x in self.out_funcs), | |
in_funcs=", ".join("[%s]" % x for x in self.in_funcs), | |
parents=", ".join("[%s]" % x for x in self.split_parents), | |
decos=" | ".join(map(str, self.decorators)), | |
out=", ".join("[%s]" % x for x in self.out_funcs), | |
) | |
class StepVisitor(ast.NodeVisitor): | |
def __init__(self, nodes, flow): | |
self.nodes = nodes | |
self.flow = flow | |
super(StepVisitor, self).__init__() | |
def visit_FunctionDef(self, node): | |
func = getattr(self.flow, node.name) | |
if hasattr(func, "is_step"): | |
self.nodes[node.name] = DAGNode(node, func.decorators, func.__doc__) | |
class FlowGraph(object): | |
def __init__(self, flow): | |
self.name = flow.__name__ | |
self.nodes = self._create_nodes(flow) | |
self.doc = deindent_docstring(flow.__doc__) | |
self._traverse_graph() | |
self._postprocess() | |
def _create_nodes(self, flow): | |
module = __import__(flow.__module__) | |
tree = ast.parse(inspect.getsource(module)).body | |
root = [n for n in tree if isinstance(n, ast.ClassDef) and n.name == self.name][ | |
0 | |
] | |
nodes = {} | |
StepVisitor(nodes, flow).visit(root) | |
return nodes | |
def _postprocess(self): | |
# any node who has a foreach as any of its split parents | |
# has is_inside_foreach=True *unless* all of those foreaches | |
# are joined by the node | |
for node in self.nodes.values(): | |
foreaches = [ | |
p for p in node.split_parents if self.nodes[p].type == "foreach" | |
] | |
if [f for f in foreaches if self.nodes[f].matching_join != node.name]: | |
node.is_inside_foreach = True | |
def _traverse_graph(self): | |
def traverse(node, seen, split_parents): | |
if node.type in ("split", "foreach"): | |
node.split_parents = split_parents | |
split_parents = split_parents + [node.name] | |
elif node.type == "join": | |
# ignore joins without splits | |
if split_parents: | |
self[split_parents[-1]].matching_join = node.name | |
node.split_parents = split_parents | |
split_parents = split_parents[:-1] | |
else: | |
node.split_parents = split_parents | |
for n in node.out_funcs: | |
# graph may contain loops - ignore them | |
if n not in seen: | |
# graph may contain unknown transitions - ignore them | |
if n in self: | |
child = self[n] | |
child.in_funcs.add(node.name) | |
traverse(child, seen + [n], split_parents) | |
if "start" in self: | |
traverse(self["start"], [], []) | |
# fix the order of in_funcs | |
for node in self.nodes.values(): | |
node.in_funcs = sorted(node.in_funcs) | |
def __getitem__(self, x): | |
return self.nodes[x] | |
def __contains__(self, x): | |
return x in self.nodes | |
def __iter__(self): | |
return iter(self.nodes.values()) | |
def __str__(self): | |
return "\n".join( | |
str(n) for _, n in sorted((n.func_lineno, n) for n in self.nodes.values()) | |
) | |
def output_dot(self): | |
def edge_specs(): | |
for node in self.nodes.values(): | |
for edge in node.out_funcs: | |
yield "%s -> %s;" % (node.name, edge) | |
def node_specs(): | |
for node in self.nodes.values(): | |
nodetype = "join" if node.num_args > 1 else node.type | |
yield '"{0.name}"' '[ label = <<b>{0.name}</b> | <font point-size="10">{type}</font>> ' ' fontname = "Helvetica" ' ' shape = "record" ];'.format( | |
node, type=nodetype | |
) | |
return ( | |
"digraph {0.name} {{\n" | |
"{nodes}\n" | |
"{edges}\n" | |
"}}".format( | |
self, nodes="\n".join(node_specs()), edges="\n".join(edge_specs()) | |
) | |
) | |
def output_steps(self): | |
steps_info = {} | |
graph_structure = [] | |
def node_to_type(node): | |
if node.type in ["linear", "start", "end", "join"]: | |
return node.type | |
elif node.type == "split": | |
return "split-static" | |
elif node.type == "foreach": | |
if node.parallel_foreach: | |
return "split-parallel" | |
return "split-foreach" | |
return "unknown" # Should never happen | |
def node_to_dict(name, node): | |
d = { | |
"name": name, | |
"type": node_to_type(node), | |
"line": node.func_lineno, | |
"doc": node.doc, | |
"decorators": [ | |
{ | |
"name": deco.name, | |
"attributes": deco.attributes, | |
"statically_defined": deco.statically_defined, | |
} | |
for deco in node.decorators | |
if not deco.name.startswith("_") | |
], | |
"next": node.out_funcs, | |
} | |
if d["type"] == "split-foreach": | |
d["foreach_artifact"] = node.foreach_param | |
elif d["type"] == "split-parallel": | |
d["num_parallel"] = node.num_parallel | |
if node.matching_join: | |
d["matching_join"] = node.matching_join | |
return d | |
def populate_block(start_name, end_name): | |
cur_name = start_name | |
resulting_list = [] | |
while cur_name != end_name: | |
cur_node = self.nodes[cur_name] | |
node_dict = node_to_dict(cur_name, cur_node) | |
steps_info[cur_name] = node_dict | |
resulting_list.append(cur_name) | |
if cur_node.type not in ("start", "linear", "join"): | |
# We need to look at the different branches for this | |
resulting_list.append( | |
[ | |
populate_block(s, cur_node.matching_join) | |
for s in cur_node.out_funcs | |
] | |
) | |
cur_name = cur_node.matching_join | |
else: | |
cur_name = cur_node.out_funcs[0] | |
return resulting_list | |
graph_structure = populate_block("start", "end") | |
steps_info["end"] = node_to_dict("end", self.nodes["end"]) | |
graph_structure.append("end") | |
return steps_info, graph_structure |