In [8]:
from typing import Dict, Union, Tuple, List
import tvm
from tvm.contrib import relay_viz
from tvm.contrib.relay_viz.interface import VizEdge, VizNode, VizParser
from tvm.contrib.relay_viz.terminal import TermGraph, TermPlotter, TermVizParser

DEBUG:graphviz._tools:deprecate positional args: graphviz.backend.piping.pipe(['renderer', 'formatter', 'neato_no_op', 'quiet'])
DEBUG:graphviz._tools:deprecate positional args: graphviz.backend.rendering.render(['renderer', 'formatter', 'neato_no_op', 'quiet'])
DEBUG:graphviz._tools:deprecate positional args: graphviz.backend.unflattening.unflatten(['stagger', 'fanout', 'chain', 'encoding'])
DEBUG:graphviz._tools:deprecate positional args: graphviz.backend.viewing.view(['quiet'])
DEBUG:graphviz._tools:deprecate positional args: graphviz.quoting.quote(['is_html_string', 'is_valid_id', 'dot_keywords', 'endswith_odd_number_of_backslashes', 'escape_unescaped_quotes'])
DEBUG:graphviz._tools:deprecate positional args: graphviz.quoting.a_list(['kwargs', 'attributes'])
DEBUG:graphviz._tools:deprecate positional args: graphviz.quoting.attr_list(['kwargs', 'attributes'])
DEBUG:graphviz._tools:deprecate positional args: graphviz.dot.Dot.clear(['keep_attrs'])
DEBUG:graphviz._tools:deprecate posit

In [9]:
data = relay.var("data")
bias = relay.var("bias")
add_op = relay.add(data, bias)
add_func = relay.Function([data, bias], add_op)
add_gvar = relay.GlobalVar("AddFunc")

input0 = relay.var("input0")
input1 = relay.var("input1")
input2 = relay.var("input2")
add_01 = relay.Call(add_gvar, [input0, input1])
add_012 = relay.Call(add_gvar, [input2, add_01])
main_func = relay.Function([input0, input1, input2], add_012)
main_gvar = relay.GlobalVar("main")

mod = tvm.IRModule({main_gvar: main_func, add_gvar: add_func})

In [10]:
viz = relay_viz.RelayVisualizer(mod)
viz.render()

@main([Var(input0), Var(input1), Var(input2)])
`--Call 
   |--GlobalVar AddFunc
   |--Var(Input) name_hint: input2
   `--Call 
      |--GlobalVar AddFunc
      |--Var(Input) name_hint: input0
      `--Var(Input) name_hint: input1
@AddFunc([Var(data), Var(bias)])
`--Call 
   |--add 
   |--Var(Input) name_hint: data
   `--Var(Input) name_hint: bias


In [11]:
class YourAwesomeParser(VizParser):
    def __init__(self):
        self._delegate = TermVizParser()

    def get_node_edges(
        self,
        node: relay.Expr,
        relay_param: Dict[str, tvm.runtime.NDArray],
        node_to_id: Dict[relay.Expr, str],
    ) -> Tuple[Union[VizNode, None], List[VizEdge]]:

        if isinstance(node, relay.Var):
            node = VizNode(node_to_id[node], "AwesomeVar", f"name_hint {node.name_hint}")
            # no edge is introduced. So return an empty list.
            return node, []

        # delegate other types to the other parser.
        return self._delegate.get_node_edges(node, relay_param, node_to_id)

In [12]:
viz = relay_viz.RelayVisualizer(mod, {}, TermPlotter(), YourAwesomeParser())
viz.render()

@main([Var(input0), Var(input1), Var(input2)])
`--Call 
   |--GlobalVar AddFunc
   |--AwesomeVar name_hint input2
   `--Call 
      |--GlobalVar AddFunc
      |--AwesomeVar name_hint input0
      `--AwesomeVar name_hint input1
@AddFunc([Var(data), Var(bias)])
`--Call 
   |--add 
   |--AwesomeVar name_hint data
   `--AwesomeVar name_hint bias


In [13]:
class AwesomeGraph(TermGraph):
    def node(self, viz_node):
        # add the node first
        super().node(viz_node)
        # if it's AwesomeVar, duplicate it.
        if viz_node.type_name == "AwesomeVar":
            duplicated_id = f"duplicated_{viz_node.identity}"
            duplicated_type = "double AwesomeVar"
            super().node(VizNode(duplicated_id, duplicated_type, ""))
            # connect the duplicated var to the original one
            super().edge(VizEdge(duplicated_id, viz_node.identity))


# override TermPlotter to use `AwesomeGraph` instead
class AwesomePlotter(TermPlotter):
    def create_graph(self, name):
        self._name_to_graph[name] = AwesomeGraph(name)
        return self._name_to_graph[name]


viz = relay_viz.RelayVisualizer(mod, {}, AwesomePlotter(), YourAwesomeParser())
viz.render()

@main([Var(input0), Var(input1), Var(input2)])
`--Call 
   |--GlobalVar AddFunc
   |--AwesomeVar name_hint input2
   |  `--double AwesomeVar 
   `--Call 
      |--GlobalVar AddFunc
      |--AwesomeVar name_hint input0
      |  `--double AwesomeVar 
      `--AwesomeVar name_hint input1
         `--double AwesomeVar 
@AddFunc([Var(data), Var(bias)])
`--Call 
   |--add 
   |--AwesomeVar name_hint data
   |  `--double AwesomeVar 
   `--AwesomeVar name_hint bias
      `--double AwesomeVar 
