Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Relay visualization: exporter + visualizer #4370

Closed
wants to merge 4 commits into from
Closed
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 117 additions & 1 deletion python/tvm/relay/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
"""
from . import _analysis
from . import _make
from .expr import Expr
from .expr import Expr, Function, Var, Call, TupleGetItem
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sort imports

from .op.op import Op
from .ty import Type
from .module import Module
from .feature import Feature

import json


def post_order_visit(expr, fvisit):
"""Recursively visit the ir in post DFS order node,
Expand Down Expand Up @@ -408,3 +411,116 @@ def structural_hash(value):
msg = ("found value of type {0} expected" +
"relay.Expr or relay.Type").format(type(value))
raise TypeError(msg)

def _export_as_relayviz(expr):
"""Export a Relay function as a nested dictionary, following the RelayViz spec
(https://discuss.tvm.ai/t/rfc-visualizing-relay-program-as-graph/4825/10). The dictionary will
contain all information useful for visualizing the Relay program and is meant to be consumed
by other visualizers.

Parameters
----------
expr : tvm.relay.Expr
The input expression.

Returns
-------
viz : dict
Nested dictionary
"""

# node_dict maps a Relay node to an index (node ID)
def _traverse_expr(node, node_dict):
if node in node_dict:
return
node_dict[node] = len(node_dict)

node_dict = {}
post_order_visit(expr, lambda x: _traverse_expr(x, node_dict))

relayviz_nodes = []

# Sort by node ID
for node, node_idx in sorted(node_dict.items(), key=lambda x: x[1]):
if isinstance(node, Function):
relayviz_nodes.append({
'node_kind': 'Function',
'body': node_dict[node.body],
'params': [node_dict[x] for x in node.params],
'ret_type': {
'dtype': node.ret_type.dtype,
'shape': [int(x) for x in node.ret_type.shape]
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about there is a type attribute on every op and it contains a string that describes the type of what the node yields? It can be an empty string if there is no value. Any visualizer that wants to actually understand the structure of e.g. types can just deal with Relay directly.

Copy link
Contributor Author

@hcho3 hcho3 Nov 19, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@broune I don't see a way to obtain the return type of the op. In fact, it looks like the "ret_type" field in the Function node is optional. It appears that type inference occurs when the Relay program gets compiled. Let me go ahead and remove this field.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the main use cases for this is to print out the state of the IR as it gets mutated through compilation, where type inference is one of the first things (IIRC actually the first thing) that happens, so it is good to show the type if it is present - that accurately reflects the IR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll have to ask TVM experts for help. Specifically: 1) where type inference occurs, and 2) how to get the return type of each Op.

})
elif isinstance(node, Var):
relayviz_nodes.append({
'node_kind': 'Var',
'name': node.name_hint,
'dtype': node.type_annotation.dtype,
'shape': [int(x) for x in node.type_annotation.shape]
})
elif isinstance(node, Call):
relayviz_nodes.append({
'node_kind': 'Call',
'op': node_dict[node.op],
'args': [node_dict[arg] for arg in node.args]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we call the within-graph operands of all nodes the same regardless of what kind of node it is? E.g. operands.

})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't you want this to also have a type field that says what type is yielded by the call?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment above about inability to obtain the return type of an op.

elif isinstance(node, Op):
relayviz_nodes.append({
'node_kind': 'Op',
'name': node.name,
'attrs': {}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to myself: it is actually possible to obtain attributes of an op using Python.

})
elif isinstance(node, TupleGetItem):
relayviz_nodes.append({
'node_kind': 'TupleGetItem',
'tuple_value': node_dict[node.tuple_value],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also jus the operand of this operation, so could be called with a consistent name to other ops.

'index': node.index
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the same as an attribute on an op? Why does this node type need to be special? (actually, why does any node type need to be special?)

Copy link
Contributor Author

@hcho3 hcho3 Nov 19, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@broune Relay is not a dataflow graph but a programming language. So the current prototype makes differentiation between different kinds of language constructs.

@junrushao1994 What's your take on this? Should we collapse Relay nodes into generic universal node type?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is about providing sufficient information to the visualizer so that it can create a good visualization while being as simple as possible. In this case, I wonder whether the visualizer has a good use of a special case for the index field or TupleGetItem versus other ops. If not, seems simpler to not have the special case for this, and same question on a case-by-case basis for other cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@broune I see your point. In this case, you're okay with grouping Relay nodes into, say, 3-4 possible categories?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late reply. I'm OK with whatever you think makes sense on this; you're more involved in the details of this than I am, and what I most want (and am very OK with!) is to have the ability to see Relay graphs easily visualized. I just wanted to raise this as a thing to consider.

})
else:
raise RuntimeError(
'Unknown node type. node_idx: {}, node: {}'.format(node_idx, type(node)))

obj = {}
obj['format'] = 'relayviz'
obj['version'] = [1, 0]
obj['nodes'] = relayviz_nodes
return obj

def _export_as_graphviz(relayviz_obj):
from graphviz import Digraph
dot = Digraph(format='svg')
dot.attr(rankdir='BT')
dot.attr('node', shape='box')
for node_id, node in enumerate(relayviz_obj['nodes']):
if node['node_kind'] == 'Var':
dot.node(str(node_id),
'{}:\nTensor[{}, {}])'.format(
node['name'], tuple(node['shape']), node['dtype']
))
elif node['node_kind'] == 'Call':
dot.node(str(node_id), 'Call(op={})'.format(relayviz_obj['nodes'][ node['op'] ]['name']))
for arg in node['args']:
dot.edge(str(arg), str(node_id))
elif node['node_kind'] == 'Function':
dot.node(str(node_id), 'Function')
dot.edge(str(node['body']), str(node_id))
elif node['node_kind'] == 'TupleGetItem':
dot.node(str(node_id), 'TupleGetItem(idx={})'.format(node['index']))
dot.edge(str(node['tuple_value']), str(node_id))
elif node['node_kind'] == 'Op':
pass
else:
raise RuntimeError(
'Node type {} not supported by GraphViz visualizer.'.format(node['node_kind']))
return dot


def visualize(expr, output_format='graphviz'):
possible_format = ['graphviz']
if output_format not in possible_format:
raise RuntimeError('output_format should be one of {}'.format(possible_format))

relayviz_obj = _export_as_relayviz(expr)
if output_format == 'graphviz':
return _export_as_graphviz(relayviz_obj)