-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Relay visualization: exporter + visualizer #4370
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,11 +22,14 @@ | |
""" | ||
from . import _analysis | ||
from . import _make | ||
from .expr import Expr | ||
from .expr import Expr, Function, Var, Call, TupleGetItem | ||
from .op.op import Op | ||
from .ty import Type | ||
from .module import Module | ||
from .feature import Feature | ||
|
||
import json | ||
|
||
|
||
def post_order_visit(expr, fvisit): | ||
"""Recursively visit the ir in post DFS order node, | ||
|
@@ -408,3 +411,116 @@ def structural_hash(value): | |
msg = ("found value of type {0} expected" + | ||
"relay.Expr or relay.Type").format(type(value)) | ||
raise TypeError(msg) | ||
|
||
def _export_as_relayviz(expr): | ||
"""Export a Relay function as a nested dictionary, following the RelayViz spec | ||
(https://discuss.tvm.ai/t/rfc-visualizing-relay-program-as-graph/4825/10). The dictionary will | ||
contain all information useful for visualizing the Relay program and is meant to be consumed | ||
by other visualizers. | ||
|
||
Parameters | ||
---------- | ||
expr : tvm.relay.Expr | ||
The input expression. | ||
|
||
Returns | ||
------- | ||
viz : dict | ||
Nested dictionary | ||
""" | ||
|
||
# node_dict maps a Relay node to an index (node ID) | ||
def _traverse_expr(node, node_dict): | ||
if node in node_dict: | ||
return | ||
node_dict[node] = len(node_dict) | ||
|
||
node_dict = {} | ||
post_order_visit(expr, lambda x: _traverse_expr(x, node_dict)) | ||
|
||
relayviz_nodes = [] | ||
|
||
# Sort by node ID | ||
for node, node_idx in sorted(node_dict.items(), key=lambda x: x[1]): | ||
if isinstance(node, Function): | ||
relayviz_nodes.append({ | ||
'node_kind': 'Function', | ||
'body': node_dict[node.body], | ||
'params': [node_dict[x] for x in node.params], | ||
'ret_type': { | ||
'dtype': node.ret_type.dtype, | ||
'shape': [int(x) for x in node.ret_type.shape] | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about there is a type attribute on every op and it contains a string that describes the type of what the node yields? It can be an empty string if there is no value. Any visualizer that wants to actually understand the structure of e.g. types can just deal with Relay directly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @broune I don't see a way to obtain the return type of the op. In fact, it looks like the "ret_type" field in the Function node is optional. It appears that type inference occurs when the Relay program gets compiled. Let me go ahead and remove this field. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One of the main use cases for this is to print out the state of the IR as it gets mutated through compilation, where type inference is one of the first things (IIRC actually the first thing) that happens, so it is good to show the type if it is present - that accurately reflects the IR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll have to ask TVM experts for help. Specifically: 1) where type inference occurs, and 2) how to get the return type of each Op. |
||
}) | ||
elif isinstance(node, Var): | ||
relayviz_nodes.append({ | ||
'node_kind': 'Var', | ||
'name': node.name_hint, | ||
'dtype': node.type_annotation.dtype, | ||
'shape': [int(x) for x in node.type_annotation.shape] | ||
}) | ||
elif isinstance(node, Call): | ||
relayviz_nodes.append({ | ||
'node_kind': 'Call', | ||
'op': node_dict[node.op], | ||
'args': [node_dict[arg] for arg in node.args] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we call the within-graph operands of all nodes the same regardless of what kind of node it is? E.g. operands. |
||
}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't you want this to also have a type field that says what type is yielded by the call? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See my comment above about inability to obtain the return type of an op. |
||
elif isinstance(node, Op): | ||
relayviz_nodes.append({ | ||
'node_kind': 'Op', | ||
'name': node.name, | ||
'attrs': {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note to myself: it is actually possible to obtain attributes of an op using Python. |
||
}) | ||
elif isinstance(node, TupleGetItem): | ||
relayviz_nodes.append({ | ||
'node_kind': 'TupleGetItem', | ||
'tuple_value': node_dict[node.tuple_value], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is also jus the operand of this operation, so could be called with a consistent name to other ops. |
||
'index': node.index | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this the same as an attribute on an op? Why does this node type need to be special? (actually, why does any node type need to be special?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @broune Relay is not a dataflow graph but a programming language. So the current prototype makes differentiation between different kinds of language constructs. @junrushao1994 What's your take on this? Should we collapse Relay nodes into generic universal node type? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is about providing sufficient information to the visualizer so that it can create a good visualization while being as simple as possible. In this case, I wonder whether the visualizer has a good use of a special case for the index field or TupleGetItem versus other ops. If not, seems simpler to not have the special case for this, and same question on a case-by-case basis for other cases. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @broune I see your point. In this case, you're okay with grouping Relay nodes into, say, 3-4 possible categories? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for the late reply. I'm OK with whatever you think makes sense on this; you're more involved in the details of this than I am, and what I most want (and am very OK with!) is to have the ability to see Relay graphs easily visualized. I just wanted to raise this as a thing to consider. |
||
}) | ||
else: | ||
raise RuntimeError( | ||
'Unknown node type. node_idx: {}, node: {}'.format(node_idx, type(node))) | ||
|
||
obj = {} | ||
obj['format'] = 'relayviz' | ||
obj['version'] = [1, 0] | ||
obj['nodes'] = relayviz_nodes | ||
return obj | ||
|
||
def _export_as_graphviz(relayviz_obj): | ||
from graphviz import Digraph | ||
dot = Digraph(format='svg') | ||
dot.attr(rankdir='BT') | ||
dot.attr('node', shape='box') | ||
for node_id, node in enumerate(relayviz_obj['nodes']): | ||
if node['node_kind'] == 'Var': | ||
dot.node(str(node_id), | ||
'{}:\nTensor[{}, {}])'.format( | ||
node['name'], tuple(node['shape']), node['dtype'] | ||
)) | ||
elif node['node_kind'] == 'Call': | ||
dot.node(str(node_id), 'Call(op={})'.format(relayviz_obj['nodes'][ node['op'] ]['name'])) | ||
for arg in node['args']: | ||
dot.edge(str(arg), str(node_id)) | ||
elif node['node_kind'] == 'Function': | ||
dot.node(str(node_id), 'Function') | ||
dot.edge(str(node['body']), str(node_id)) | ||
elif node['node_kind'] == 'TupleGetItem': | ||
dot.node(str(node_id), 'TupleGetItem(idx={})'.format(node['index'])) | ||
dot.edge(str(node['tuple_value']), str(node_id)) | ||
elif node['node_kind'] == 'Op': | ||
pass | ||
else: | ||
raise RuntimeError( | ||
'Node type {} not supported by GraphViz visualizer.'.format(node['node_kind'])) | ||
return dot | ||
|
||
|
||
def visualize(expr, output_format='graphviz'): | ||
possible_format = ['graphviz'] | ||
if output_format not in possible_format: | ||
raise RuntimeError('output_format should be one of {}'.format(possible_format)) | ||
|
||
relayviz_obj = _export_as_relayviz(expr) | ||
if output_format == 'graphviz': | ||
return _export_as_graphviz(relayviz_obj) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sort imports