diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index e26e9bc7dc4c..a5efcce27859 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -243,6 +243,15 @@ def _gelu(self, node: fx.node.Node) -> relax.Expr: else: raise KeyError("Unregonized approximate algorithm for gelu: {}.".format(approximate)) + def _hardswish(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dtype = x.struct_info.dtype + x0 = relax.op.add(x, relax.const(3, dtype)) + x1 = relax.op.clip(x0, 0, 6) + x2 = relax.op.divide(x1, relax.const(6, dtype)) + return self.block_builder.emit(relax.op.multiply(x, x2)) + ########## Compare ########## def _lt(self, node: fx.node.Node) -> relax.Expr: @@ -1358,6 +1367,7 @@ def create_convert_map(self): nn.Sigmoid: self._sigmoid, nn.Tanh: lambda node: self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])), nn.SiLU: lambda node: self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])), + nn.Hardswish: self._hardswish, nn.Flatten: self._flatten, nn.BatchNorm2d: self._batch_norm_2d, nn.LayerNorm: self._layer_norm, @@ -1437,6 +1447,7 @@ def create_convert_map(self): "leaky_relu": self._leakyrelu, "gelu": self._gelu, "silu": lambda node: self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])), + "hardswish": self._hardswish, "interpolate": self._interpolate, "size": self._size, "getattr": self._getattr, diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index dfa5cad4a5a7..49131b5ff891 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -1416,6 +1416,42 @@ def main( verify_model(SiLU2(), input_info, {}, expected1) +def test_hardswish(): + input_info = [([1, 3, 10, 10], "float32")] + + class Hardswish(torch.nn.Module): + def __init__(self): + super().__init__() + self.hs = torch.nn.Hardswish() + + def forward(self, input): + return self.hs(input) + + class Hardswish2(torch.nn.Module): + def forward(self, input): + return torch.nn.functional.hardswish(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32")) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0, 6) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( + lv1, R.const(6, "float32") + ) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(inp_0, lv2) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv3 + R.output(gv) + return gv + + verify_model(Hardswish(), input_info, {}, expected1) + verify_model(Hardswish2(), input_info, {}, expected1) + + def test_groupnorm(): import torch from torch.nn import Module