In [1]:
import torch

## Fuse SiLU

In [2]:
from torch.nn import functional as F

class M(torch.nn.Module):
    def forward(self, x):
        x = x * torch.sigmoid(x)
        x = 1 + x
        return F.silu(x) + x

x = torch.randn(3, 3)
model = M()

In [3]:
from tvm import relay

with torch.no_grad():
    scripted_model = torch.jit.trace(model, x).eval()

mod, params = relay.frontend.from_pytorch(scripted_model, [('x', x.shape)])
mod['main']

fn (%x: Tensor[(3, 3), float32]) {
  %0 = sigmoid(%x);
  %1 = multiply(%x, %0);
  %2 = add(%1, 1f);
  %3 = sigmoid(%2);
  %4 = multiply(%2, %3);
  add(%4, %2)
}

上述tvm relay图包含了两次"sigmoid + multiply"操作（以add为分界线），其中一次是由于`F.silu()`被拆分了。

由于tvm当前版本不支持`silu`算子，于是通过定义pattern将"sigmoid + multiply"合并为`nn.relu`：

In [4]:
from tvm.relay.dataflow_pattern import DFPatternCallback, is_op, wildcard

class Silu(DFPatternCallback):
    # A callback class to rewrite the matched pattern to a batch_norm op.
    def __init__(self, require_type=False):
        super().__init__(require_type)
        self.x = wildcard()
        self.pattern = is_op('multiply')(self.x, is_op("sigmoid")(self.x))

    def callback(self, pre, post, node_map):
        x = node_map[self.x][0]
        return relay.op.nn.relu(x)
        # return relay.op.silu(x)

from tvm.relay.dataflow_pattern import rewrite
out = rewrite(Silu(), mod['main'])
out

fn (%x: Tensor[(3, 3), float32]) {
  %0 = nn.relu(%x);
  %1 = add(%0, 1f);
  %2 = nn.relu(%1);
  add(%2, %1)
}