diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 753b0d791495..ed7811dd7102 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -2357,6 +2357,12 @@ def _item(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] return self.block_builder.emit(relax.op.take(x, relax.const(0, "int64"), axis=0)) + def _sym_size_int(self, node: fx.Node) -> relax.Expr: + x = self.env[node.args[0]] + shape = self.shape_of(x) + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) + return self.block_builder.emit(relax.const(int(shape[dim]), "int32")) + def _zeros_inplace(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] output = self.block_builder.emit(relax.op.zeros_like(x)) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index a2b9b2afa4cf..782c14e91cbd 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1189,6 +1189,7 @@ def create_convert_map( # other "getitem": self._getitem, "item.default": self._item, + "sym_size.int": self._sym_size_int, "_local_scalar_dense.default": self._item, } diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index a93f78866910..6bf164430ad0 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -730,12 +730,6 @@ def _getattr(self, node: fx.Node) -> relax.Var: return self.shape_of(self.env[node.args[0]]) return getattr(self.env[node.args[0]], node.args[1]) - def _sym_size_int(self, node: fx.Node) -> relax.Expr: - x = self.env[node.args[0]] - shape = self.shape_of(x) - idx = node.args[1] - return self.block_builder.emit(relax.const(shape[idx].value, "int32")) - def create_input_vars(self, input_info: List[Tuple[Tuple[int], str]]) -> List[relax.Var]: inputs = list() for idx, (shape, dtype) in enumerate(input_info): diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 1429dec5e731..60a91204453a 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -7508,5 +7508,36 @@ def main( tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True) +def test_sym_size_int(): + class SymSizeInt(Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + # TODO(@mshr-h): `torch.ops.aten.sym_size.int(x, self.dim)` would be ideal, but currently + # the ep frontend is not able to handle it. + return torch.add(x[0], torch.ops.aten.sym_size.int(x, self.dim)) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((1, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((3, 4), dtype="float32") = R.take( + x, R.const(0, "int64"), axis=0, mode="fast" + ) + lv1: R.Tensor((3, 4), dtype="float32") = R.add(lv, R.const(3.0, "float32")) + gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 4),) + verify_model(SymSizeInt(dim=1), example_args, {}, Expected) + verify_model(SymSizeInt(dim=-2), example_args, {}, Expected) + + if __name__ == "__main__": tvm.testing.main()