In [1]:
import torch

## Fuse SiLU

In [2]:
from torch.nn import functional as F

class M(torch.nn.Module):
    def forward(self, x):
        x = x * torch.sigmoid(x)
        x = 1 + x
        return F.silu(x) + x

x = torch.randn(3, 3)
model = M()

In [3]:
from tvm import relay

with torch.no_grad():
    scripted_model = torch.jit.trace(model, x).eval()

mod, params = relay.frontend.from_pytorch(scripted_model, [('x', x.shape)])
mod['main']

fn (%x: Tensor[(3, 3), float32]) {
  %0 = sigmoid(%x);
  %1 = multiply(%x, %0);
  %2 = add(%1, 1f);
  %3 = sigmoid(%2);
  %4 = multiply(%2, %3);
  add(%4, %2)
}

上述tvm relay图包含了两次"sigmoid + multiply"操作（以add为分界线），其中一次是由于`F.silu()`被拆分了。

由于tvm当前版本不支持`silu`算子，于是通过定义pattern将"sigmoid + multiply"合并为`nn.relu`：

In [4]:
from tvm.relay.dataflow_pattern import DFPatternCallback, is_op, wildcard

class Silu(DFPatternCallback):
    # A callback class to rewrite the matched pattern to a batch_norm op.
    def __init__(self, require_type=False):
        super().__init__(require_type)
        self.x = wildcard()
        self.pattern = is_op('multiply')(self.x, is_op("sigmoid")(self.x))

    def callback(self, pre, post, node_map):
        x = node_map[self.x][0]
        return relay.op.nn.relu(x)
        # return relay.op.silu(x)

from tvm.relay.dataflow_pattern import rewrite
out = rewrite(Silu(), mod['main'])
out

fn (%x: Tensor[(3, 3), float32]) {
  %0 = nn.relu(%x);
  %1 = add(%0, 1f);
  %2 = nn.relu(%1);
  add(%2, %1)
}

---

In [19]:
import tvm
import numpy as np

from tvm import relay
from tvm.relay.testing import run_infer_type, create_workload
from tvm.relay.build_module import bind_params_by_name

remove_bn_pass = tvm.transform.Sequential(
    [
        relay.transform.InferType(),
        relay.transform.SimplifyInference(),
        relay.transform.FoldConstant(),
        relay.transform.FoldScaleAxis(),
    ]
)

data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
weight = relay.var("weight")
bn_gamma = relay.var("bn_gamma")
bn_beta = relay.var("bn_beta")
bn_mmean = relay.var("bn_mean")
bn_mvar = relay.var("bn_var")

conv = relay.nn.conv2d(
    data=data, weight=weight, kernel_size=(3, 3), channels=16, padding=(1, 1)
)
bn_output = relay.nn.batch_norm(conv, bn_gamma, bn_beta, bn_mmean, bn_mvar)

def initializer(_, param):
    param = np.zeros(param.shape)

mod, params = create_workload(bn_output[0], initializer)
mod["main"] = bind_params_by_name(mod["main"], params)

with tvm.transform.PassContext(opt_level=3):
    mod = remove_bn_pass(mod)
mod["main"]

fn (%data: Tensor[(1, 3, 224, 224), float32]) -> Tensor[(1, 16, 224, 224), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 224, 224), float32] */;
  add(%0, meta[relay.Constant][1] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 224, 224), float32] */
}

In [6]:
mod["main"]

fn (%data: Tensor[(1, 3, 224, 224), float32], %weight: Tensor[(16, 3, 3, 3), float32], %bn_gamma: Tensor[(16), float32], %bn_beta: Tensor[(16), float32], %bn_mean: Tensor[(16), float32], %bn_var: Tensor[(16), float32]) -> Tensor[(1, 16, 224, 224), float32] {
  %0 = nn.conv2d(%data, %weight, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 224, 224), float32] */;
  %1 = nn.batch_norm(%0, %bn_gamma, %bn_beta, %bn_mean, %bn_var) /* ty=(Tensor[(1, 16, 224, 224), float32], Tensor[(16), float32], Tensor[(16), float32]) */;
  %1.0
}

In [20]:
import torch

from torch import nn
from torch.nn import functional as F
from tvm import relay

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        x1 = self.conv(x)
        x2 = self.conv(x)
        return x1 + x2

x = torch.randn(1, 1, 3, 3)
model = M()

with torch.no_grad():
    scripted_model = torch.jit.trace(model, x).eval()

mod, params = relay.frontend.from_pytorch(scripted_model, [('x', x.shape)])
mod['main']

fn (%x: Tensor[(1, 1, 3, 3), float32], %conv.weight: Tensor[(1, 1, 3, 3), float32], %conv.bias: Tensor[(1), float32]) {
  %0 = nn.conv2d(%x, %conv.weight, padding=[0, 0, 0, 0], channels=1, kernel_size=[3, 3]);
  %1 = nn.conv2d(%x, %conv.weight, padding=[0, 0, 0, 0], channels=1, kernel_size=[3, 3]);
  %2 = nn.bias_add(%0, %conv.bias);
  %3 = nn.bias_add(%1, %conv.bias);
  add(%2, %3)
}

In [17]:
from tvm.relay.dataflow_pattern import DFPatternCallback, is_op, wildcard

class Conv2d(DFPatternCallback):
    # A callback class to rewrite the matched pattern to a batch_norm op.
    def __init__(self, require_type=False):
        super().__init__(require_type)
        self.x = wildcard()
        self.w = wildcard()
        b = wildcard()
        self.pattern = is_op('nn.bias_add')(is_op("nn.conv2d")(self.x, self.w), b)

    def callback(self, pre, post, node_map):
        x = node_map[self.x][0]
        return relay.op.nn.conv2d(x)
        # return relay.op.silu(x)

from tvm.relay.dataflow_pattern import rewrite
out = rewrite(Conv2d(), mod['main'])
out

fn (%x: Tensor[(1, 1, 3, 3), float32], %conv.weight: Tensor[(1, 1, 3, 3), float32], %conv.bias: Tensor[(1), float32]) {
  %0 = nn.relu(%x);
  add(%0, 1f)
}