From 0af0d342a0c6f0e4a334a63b7caaebb6b6239102 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 26 Oct 2023 11:10:51 -0500 Subject: [PATCH] [Unity][BlockBuilder] Allow emitting nested tuple Prior to this commit, all expressions needed to be explicitly converted from native python types to relax expressions before being emitted by the block builder. This commit relaxes that requirement, and allows python tuples to be converted to relax tuples. --- python/tvm/relax/block_builder.py | 24 +++++++-- tests/python/relax/test_blockbuilder_core.py | 57 +++++++++++++++++++- 2 files changed, 75 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index 331a3905f321..5bf36ce33034 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -17,7 +17,7 @@ # pylint: disable=no-else-return, invalid-name, unused-argument """Developer API of constructing Relax AST.""" -from typing import Dict, List, Optional, Union, Any, Callable +from typing import Dict, List, Optional, Union, Any, Callable, Sequence from tvm.ir.module import IRModule from tvm.runtime import Object from tvm import relax as rx, tir @@ -286,6 +286,21 @@ def dataflow(self) -> DataflowScope: """ return DataflowScope(self) + def _normalize_python_tuple(self, expr: Union[Expr, Sequence[Expr]]): + """Internal utility function to convert to relax.Tuple + + The `emit`, `emit_output`, and `emit_func_output` can be + called with python `list` or `tuple` objects. These objects + should be converted to `relax.Tuple` prior to calling an FFI + function, as they would otherwise be converted to + `tvm.runtime.Array`. In addition, any nested tuple objects + should be converted. + """ + if isinstance(expr, (list, tuple)): + return Tuple([self._normalize_python_tuple(element) for element in expr]) + else: + return expr + def emit(self, expr: Expr, name_hint: str = "") -> Var: """Emit an expr. This infers the shape and type of the expr, create a variable, @@ -304,6 +319,7 @@ def emit(self, expr: Expr, name_hint: str = "") -> Var: ret : tvm.relax.Var A newly created variable that gets bound to the input expr. """ + expr = self._normalize_python_tuple(expr) return _ffi_api.BlockBuilderEmit(self, expr, name_hint) # type: ignore def call_te(self, func: Callable, *args: Any, **kwargs: Any) -> Expr: @@ -557,8 +573,7 @@ def emit_output(self, output: Union[Expr, Tuple, List[Expr]], name_hint: str = " ret : tvm.relax.Var The return variable which gets bound to the output. """ - if isinstance(output, (list, tuple)): - output = Tuple(output) + output = self._normalize_python_tuple(output) return _ffi_api.BlockBuilderEmitOutput(self, output, name_hint) # type: ignore def emit_func_output( @@ -601,8 +616,7 @@ def emit_func_output( if BlockBuilder.current() is not self: raise RuntimeError("BlockBuilder.current() must be self.") - if isinstance(output, (list, tuple)): - output = Tuple(output) + output = self._normalize_python_tuple(output) block = self._end_block() if len(block.bindings) > 0: diff --git a/tests/python/relax/test_blockbuilder_core.py b/tests/python/relax/test_blockbuilder_core.py index 4ba25bdffc34..f09b7bc6f538 100644 --- a/tests/python/relax/test_blockbuilder_core.py +++ b/tests/python/relax/test_blockbuilder_core.py @@ -24,7 +24,7 @@ from tvm import relax as rx, relay from tvm.ir.base import assert_structural_equal from tvm.relax import ExternFunc -from tvm.script import relax as R +from tvm.script import relax as R, tir as T from tvm.tir.function import PrimFunc @@ -628,5 +628,60 @@ def test_block_builder_scope_recovery(): bb.emit_func_output(gv0) +@pytest.mark.parametrize("emit_nested_tuple", [True, False]) +def test_emit_nested_tuple(emit_nested_tuple): + """Convert nested tuples when emitting relax""" + + def make_function(emit_nested_tuple: bool): + bb = rx.BlockBuilder() + + n_sym = tir.Var("n", "int64") + m_sym = tir.Var("m", "int64") + n = rx.Var("n", rx.PrimStructInfo(value=n_sym)) + m = rx.Var("m", rx.PrimStructInfo(value=m_sym)) + x = rx.Var("x", rx.TensorStructInfo([n_sym, m_sym], "float32")) + y = rx.Var("y", rx.TensorStructInfo([m_sym, n_sym], "float32")) + + with bb.function("func", [n, m, x, y]): + scalars = (n, m) + if not emit_nested_tuple: + scalars = bb.emit(scalars) + output = (scalars, x, y) + bb.emit_func_output(output) + + return bb.get()["func"] + + def make_expected(emit_nested_tuple: bool): + if emit_nested_tuple: + + @R.function + def func( + n_1: R.Prim(value="n"), + m_1: R.Prim(value="m"), + x: R.Tensor(("n", "m"), dtype="float32"), + y: R.Tensor(("m", "n"), dtype="float32"), + ): + return ((n_1, m_1), x, y) + + else: + + @R.function + def func( + n_1: R.Prim(value="n"), + m_1: R.Prim(value="m"), + x: R.Tensor(("n", "m"), dtype="float32"), + y: R.Tensor(("m", "n"), dtype="float32"), + ): + gv = n_1, m_1 + return (gv, x, y) + + return func + + expected = make_expected(emit_nested_tuple) + actual = make_function(emit_nested_tuple) + + tvm.ir.assert_structural_equal(expected, actual) + + if __name__ == "__main__": tvm.testing.main()