Skip to content

Commit

Permalink
[Relax] [PyTorch] Add support for torch.nn.Hardsigmoid (#17085)
Browse files Browse the repository at this point in the history
add hardsigmoid support to fx_frontend
  • Loading branch information
mshr-h committed Jun 13, 2024
1 parent 5618628 commit d7ae4c7
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
10 changes: 10 additions & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,14 @@ def _gelu(self, node: fx.node.Node) -> relax.Expr:
else:
raise KeyError("Unregonized approximate algorithm for gelu: {}.".format(approximate))

def _hardsigmoid(self, node: fx.node.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
dtype = x.struct_info.dtype
x0 = relax.op.add(x, relax.const(3, dtype))
x1 = relax.op.clip(x0, 0, 6)
return self.block_builder.emit(relax.op.divide(x1, relax.const(6, dtype)))

def _hardswish(self, node: fx.node.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
Expand Down Expand Up @@ -1367,6 +1375,7 @@ def create_convert_map(self):
nn.Sigmoid: self._sigmoid,
nn.Tanh: lambda node: self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])),
nn.SiLU: lambda node: self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])),
nn.Hardsigmoid: self._hardsigmoid,
nn.Hardswish: self._hardswish,
nn.Flatten: self._flatten,
nn.BatchNorm2d: self._batch_norm_2d,
Expand Down Expand Up @@ -1447,6 +1456,7 @@ def create_convert_map(self):
"leaky_relu": self._leakyrelu,
"gelu": self._gelu,
"silu": lambda node: self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])),
"hardsigmoid": self._hardsigmoid,
"hardswish": self._hardswish,
"interpolate": self._interpolate,
"size": self._size,
Expand Down
35 changes: 35 additions & 0 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,6 +1416,41 @@ def main(
verify_model(SiLU2(), input_info, {}, expected1)


def test_hardsigmoid():
input_info = [([1, 3, 10, 10], "float32")]

class Hardsigmoid(torch.nn.Module):
def __init__(self):
super().__init__()
self.hs = torch.nn.Hardsigmoid()

def forward(self, input):
return self.hs(input)

class Hardsigmoid2(torch.nn.Module):
def forward(self, input):
return torch.nn.functional.hardsigmoid(input)

@tvm.script.ir_module
class expected1:
@R.function
def main(
inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32"))
lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0, 6)
lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
lv1, R.const(6, "float32")
)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv2
R.output(gv)
return gv

verify_model(Hardsigmoid(), input_info, {}, expected1)
verify_model(Hardsigmoid2(), input_info, {}, expected1)


def test_hardswish():
input_info = [([1, 3, 10, 10], "float32")]

Expand Down

0 comments on commit d7ae4c7

Please sign in to comment.