Skip to content

Commit

Permalink
[Relax] [PyTorch] Add support for torch.nn.Hardswish (#17084)
Browse files Browse the repository at this point in the history
* add hardswish support to fx_frontend

* run ./tests/lint/git-black.sh -i --rev upstream/main

* fix ci lint error
  • Loading branch information
mshr-h committed Jun 12, 2024
1 parent ab02979 commit cc7eb2f
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
11 changes: 11 additions & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,15 @@ def _gelu(self, node: fx.node.Node) -> relax.Expr:
else:
raise KeyError("Unregonized approximate algorithm for gelu: {}.".format(approximate))

def _hardswish(self, node: fx.node.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
dtype = x.struct_info.dtype
x0 = relax.op.add(x, relax.const(3, dtype))
x1 = relax.op.clip(x0, 0, 6)
x2 = relax.op.divide(x1, relax.const(6, dtype))
return self.block_builder.emit(relax.op.multiply(x, x2))

########## Compare ##########

def _lt(self, node: fx.node.Node) -> relax.Expr:
Expand Down Expand Up @@ -1358,6 +1367,7 @@ def create_convert_map(self):
nn.Sigmoid: self._sigmoid,
nn.Tanh: lambda node: self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])),
nn.SiLU: lambda node: self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])),
nn.Hardswish: self._hardswish,
nn.Flatten: self._flatten,
nn.BatchNorm2d: self._batch_norm_2d,
nn.LayerNorm: self._layer_norm,
Expand Down Expand Up @@ -1437,6 +1447,7 @@ def create_convert_map(self):
"leaky_relu": self._leakyrelu,
"gelu": self._gelu,
"silu": lambda node: self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])),
"hardswish": self._hardswish,
"interpolate": self._interpolate,
"size": self._size,
"getattr": self._getattr,
Expand Down
36 changes: 36 additions & 0 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,6 +1416,42 @@ def main(
verify_model(SiLU2(), input_info, {}, expected1)


def test_hardswish():
input_info = [([1, 3, 10, 10], "float32")]

class Hardswish(torch.nn.Module):
def __init__(self):
super().__init__()
self.hs = torch.nn.Hardswish()

def forward(self, input):
return self.hs(input)

class Hardswish2(torch.nn.Module):
def forward(self, input):
return torch.nn.functional.hardswish(input)

@tvm.script.ir_module
class expected1:
@R.function
def main(
inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32"))
lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0, 6)
lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
lv1, R.const(6, "float32")
)
lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(inp_0, lv2)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv3
R.output(gv)
return gv

verify_model(Hardswish(), input_info, {}, expected1)
verify_model(Hardswish2(), input_info, {}, expected1)


def test_groupnorm():
import torch
from torch.nn import Module
Expand Down

0 comments on commit cc7eb2f

Please sign in to comment.