diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 8c1cf8009435..de0d67d228d4 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -64,6 +64,26 @@ def _reciprocal(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] return self.block_builder.emit(relax.op.divide(relax.const(1.0, x.struct_info.dtype), x)) + def _sqrt(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dtype = x.struct_info.dtype + + # Check if input is integer type and convert to float32 if needed + if dtype in ("int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"): + x = self.block_builder.emit(relax.op.astype(x, "float32")) + + return self.block_builder.emit(relax.op.sqrt(x)) + + def _rsqrt(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dtype = x.struct_info.dtype + + # Check if input is integer type and convert to float32 if needed + if dtype in ("int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"): + x = self.block_builder.emit(relax.op.astype(x, "float32")) + + return self.block_builder.emit(relax.op.rsqrt(x)) + ########## Neural Network ########## def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var: @@ -872,7 +892,7 @@ def create_convert_map( "relu6.default": self._unary_op(relax.op.nn.relu6), "relu6_.default": self._unary_op(relax.op.nn.relu6), "round.default": self._round, - "rsqrt.default": self._unary_op(relax.op.rsqrt), + "rsqrt.default": self._rsqrt, "scalar_tensor.default": self._scalar_tensor, "rsub.Tensor": self._rsub, "rsub.Scalar": self._rsub, @@ -888,7 +908,7 @@ def create_convert_map( "softplus.default": self._softplus, "softshrink.default": self._softshrink, "softsign.default": self._softsign, - "sqrt.default": self._unary_op(relax.op.sqrt), + "sqrt.default": self._sqrt, "square.default": self._unary_op(relax.op.square), "tan.default": self._unary_op(relax.op.tan), "tanh.default": self._unary_op(relax.op.tanh), diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 0d2e240be641..a93f78866910 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -96,6 +96,26 @@ def _log1p(self, node: fx.Node) -> relax.Var: one = relax.const(1, x.struct_info.dtype) return self.block_builder.emit(relax.op.log(relax.op.add(x, one))) + def _sqrt(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dtype = x.struct_info.dtype + + # Check if input is integer type and convert to float32 if needed + if dtype in ["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"]: + x = self.block_builder.emit(relax.op.astype(x, "float32")) + + return self.block_builder.emit(relax.op.sqrt(x)) + + def _rsqrt(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dtype = x.struct_info.dtype + + # Check if input is integer type and convert to float32 if needed + if dtype in ["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"]: + x = self.block_builder.emit(relax.op.astype(x, "float32")) + + return self.block_builder.emit(relax.op.rsqrt(x)) + def _log_softmax_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -825,7 +845,7 @@ def create_convert_map( "relu": self._unary_op(relax.op.nn.relu), "relu6": self._unary_op(relax.op.nn.relu6), "round": self._round, - "rsqrt": self._unary_op(relax.op.rsqrt), + "rsqrt": self._rsqrt, "selu": self._unary_op(relax.op.nn.selu), "sigmoid": self._unary_op(relax.op.sigmoid), "sign": self._unary_op(relax.op.sign), @@ -834,7 +854,7 @@ def create_convert_map( "sinh": self._unary_op(relax.op.sinh), "softmax": self._softmax, "softplus": self._softplus, - "sqrt": self._unary_op(relax.op.sqrt), + "sqrt": self._sqrt, "square": self._unary_op(relax.op.square), "tan": self._unary_op(relax.op.tan), "tanh": self._unary_op(relax.op.tanh), diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 4bf041710801..986100391fa8 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -126,6 +126,47 @@ def main( verify_model(UnaryOp(), example_args, {}, expected, run_ep_decomposition=True) +def test_sqrt_integer_input(): + """Test that sqrt operation works with integer tensors by auto-converting to float.""" + example_args = (torch.tensor([[4, 9, 16, 25]], dtype=torch.int64),) + + class SqrtIntModel(Module): + def forward(self, input): + return torch.sqrt(input) + + @tvm.script.ir_module + class expected_int64: + @R.function + def main( + input_1: R.Tensor((1, 4), dtype="int64") + ) -> R.Tuple(R.Tensor((1, 4), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 4), dtype="float32") = R.astype(input_1, dtype="float32") + lv1: R.Tensor((1, 4), dtype="float32") = R.sqrt(lv) + gv: R.Tuple(R.Tensor((1, 4), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + verify_model(SqrtIntModel(), example_args, {}, expected_int64, run_ep_decomposition=True) + + example_args_int32 = (torch.tensor([[1, 4, 9]], dtype=torch.int32),) + + @tvm.script.ir_module + class expected_int32: + @R.function + def main( + input_1: R.Tensor((1, 3), dtype="int32") + ) -> R.Tuple(R.Tensor((1, 3), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3), dtype="float32") = R.astype(input_1, dtype="float32") + lv1: R.Tensor((1, 3), dtype="float32") = R.sqrt(lv) + gv: R.Tuple(R.Tensor((1, 3), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + verify_model(SqrtIntModel(), example_args_int32, {}, expected_int32, run_ep_decomposition=True) + + def test_extended_unary_ops(): example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 69ebdcbf76bc..d377bb7574df 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -2749,6 +2749,27 @@ def main( verify_model(Unary(), input_info, {}, expected_unary) +def test_sqrt_integer_input_fx(): + input_info = [([1, 4], "int64")] + + class SqrtIntModel(Module): + def forward(self, input): + return torch.sqrt(input) + + @tvm.script.ir_module + class expected: + @R.function + def main(input_1: R.Tensor((1, 4), dtype="int64")) -> R.Tensor((1, 4), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 4), dtype="float32") = R.astype(input_1, dtype="float32") + lv1: R.Tensor((1, 4), dtype="float32") = R.sqrt(lv) + gv: R.Tensor((1, 4), dtype="float32") = lv1 + R.output(gv) + return gv + + verify_model(SqrtIntModel(), input_info, {}, expected) + + operator_bool_unary = [ (torch.isnan, R.isnan), (torch.isinf, R.isinf),