# Simple mutator and visitor demo
## Todo: check correctness

In [1]:
import tvm
from tvm import relay

In [2]:
# create mutator and visitor by inherit, simple example here, isinstance is recommended
from tvm.relay.expr_functor import ExprMutator, ExprVisitor

class ChangeOperator(ExprMutator):
    def __init__(self):
        super().__init__()
    
    def visit_call(self, c):
        if c.op.name == 'nn.conv2d':
#             print(dir(c.attrs)) # check attrs' properties to create another call
            return relay.nn.contrib_conv2d_nchwc(
                    c.args[0], 
                    c.args[1], 
                    padding=c.attrs.padding,
                    channels=c.attrs.channels,
                    kernel_size=c.attrs.kernel_size
            )
        return super().visit_call(c)
    
class Vis(ExprVisitor):
    def __init__(self):
        super().__init__()
    
    def visit_call(self, c):
        # postorder print to get input to output call
        super().visit_call(c)
        print(c.op)

In [3]:
# simple CNN net to demo mutator
out_channels = 16
batch_size = 1

data = relay.var("data", relay.TensorType((batch_size, 3, 224, 224), "float32"))
weight = relay.var("weight")
bn_gamma = relay.var("bn_gamma")
bn_beta = relay.var("bn_beta")
bn_mmean = relay.var("bn_mean")
bn_mvar = relay.var("bn_var")

simple_net = relay.nn.conv2d(
    data=data, weight=weight, kernel_size=(3, 3), channels=out_channels, padding=(1, 1)
)
simple_net = relay.nn.batch_norm(simple_net, bn_gamma, bn_beta, bn_mmean, bn_mvar)[0]
simple_net = relay.nn.relu(simple_net)
simple_net = relay.Function(relay.analysis.free_vars(simple_net), simple_net)

simple_net

fn (%data: Tensor[(1, 3, 224, 224), float32], %weight, %bn_gamma, %bn_beta, %bn_mean, %bn_var) {
  %0 = nn.conv2d(%data, %weight, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]);
  %1 = nn.batch_norm(%0, %bn_gamma, %bn_beta, %bn_mean, %bn_var);
  %2 = %1.0;
  nn.relu(%2)
}

In [4]:
# nn.conv2d is changed to another conv2d w/ same attrs
new_net = ChangeOperator().visit(simple_net)
new_net

fn (%data: Tensor[(1, 3, 224, 224), float32], %weight, %bn_gamma, %bn_beta, %bn_mean, %bn_var) {
  %0 = nn.contrib_conv2d_NCHWc(%data, %weight, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3], data_layout="NCHW8c");
  %1 = nn.batch_norm(%0, %bn_gamma, %bn_beta, %bn_mean, %bn_var);
  %2 = %1.0;
  nn.relu(%2)
}

In [5]:
# demo visitor
Vis().visit(new_net)

nn.contrib_conv2d_NCHWc
nn.batch_norm
nn.relu
