Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# pylint: disable=no-else-return, invalid-name, unused-argument
"""Developer API of constructing Relax AST."""

from typing import Dict, List, Optional, Union, Any, Callable
from typing import Dict, List, Optional, Union, Any, Callable, Sequence
from tvm.ir.module import IRModule
from tvm.runtime import Object
from tvm import relax as rx, tir
Expand Down Expand Up @@ -286,6 +286,21 @@ def dataflow(self) -> DataflowScope:
"""
return DataflowScope(self)

def _normalize_python_tuple(self, expr: Union[Expr, Sequence[Expr]]):
"""Internal utility function to convert to relax.Tuple

The `emit`, `emit_output`, and `emit_func_output` can be
called with python `list` or `tuple` objects. These objects
should be converted to `relax.Tuple` prior to calling an FFI
function, as they would otherwise be converted to
`tvm.runtime.Array`. In addition, any nested tuple objects
should be converted.
"""
if isinstance(expr, (list, tuple)):
return Tuple([self._normalize_python_tuple(element) for element in expr])
else:
return expr

def emit(self, expr: Expr, name_hint: str = "") -> Var:
"""Emit an expr.
This infers the shape and type of the expr, create a variable,
Expand All @@ -304,6 +319,7 @@ def emit(self, expr: Expr, name_hint: str = "") -> Var:
ret : tvm.relax.Var
A newly created variable that gets bound to the input expr.
"""
expr = self._normalize_python_tuple(expr)
return _ffi_api.BlockBuilderEmit(self, expr, name_hint) # type: ignore

def call_te(self, func: Callable, *args: Any, **kwargs: Any) -> Expr:
Expand Down Expand Up @@ -557,8 +573,7 @@ def emit_output(self, output: Union[Expr, Tuple, List[Expr]], name_hint: str = "
ret : tvm.relax.Var
The return variable which gets bound to the output.
"""
if isinstance(output, (list, tuple)):
output = Tuple(output)
output = self._normalize_python_tuple(output)
return _ffi_api.BlockBuilderEmitOutput(self, output, name_hint) # type: ignore

def emit_func_output(
Expand Down Expand Up @@ -601,8 +616,7 @@ def emit_func_output(
if BlockBuilder.current() is not self:
raise RuntimeError("BlockBuilder.current() must be self.")

if isinstance(output, (list, tuple)):
output = Tuple(output)
output = self._normalize_python_tuple(output)

block = self._end_block()
if len(block.bindings) > 0:
Expand Down
57 changes: 56 additions & 1 deletion tests/python/relax/test_blockbuilder_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from tvm import relax as rx, relay
from tvm.ir.base import assert_structural_equal
from tvm.relax import ExternFunc
from tvm.script import relax as R
from tvm.script import relax as R, tir as T
from tvm.tir.function import PrimFunc


Expand Down Expand Up @@ -628,5 +628,60 @@ def test_block_builder_scope_recovery():
bb.emit_func_output(gv0)


@pytest.mark.parametrize("emit_nested_tuple", [True, False])
def test_emit_nested_tuple(emit_nested_tuple):
"""Convert nested tuples when emitting relax"""

def make_function(emit_nested_tuple: bool):
bb = rx.BlockBuilder()

n_sym = tir.Var("n", "int64")
m_sym = tir.Var("m", "int64")
n = rx.Var("n", rx.PrimStructInfo(value=n_sym))
m = rx.Var("m", rx.PrimStructInfo(value=m_sym))
x = rx.Var("x", rx.TensorStructInfo([n_sym, m_sym], "float32"))
y = rx.Var("y", rx.TensorStructInfo([m_sym, n_sym], "float32"))

with bb.function("func", [n, m, x, y]):
scalars = (n, m)
if not emit_nested_tuple:
scalars = bb.emit(scalars)
output = (scalars, x, y)
bb.emit_func_output(output)

return bb.get()["func"]

def make_expected(emit_nested_tuple: bool):
if emit_nested_tuple:

@R.function
def func(
n_1: R.Prim(value="n"),
m_1: R.Prim(value="m"),
x: R.Tensor(("n", "m"), dtype="float32"),
y: R.Tensor(("m", "n"), dtype="float32"),
):
return ((n_1, m_1), x, y)

else:

@R.function
def func(
n_1: R.Prim(value="n"),
m_1: R.Prim(value="m"),
x: R.Tensor(("n", "m"), dtype="float32"),
y: R.Tensor(("m", "n"), dtype="float32"),
):
gv = n_1, m_1
return (gv, x, y)

return func

expected = make_expected(emit_nested_tuple)
actual = make_function(emit_nested_tuple)

tvm.ir.assert_structural_equal(expected, actual)


if __name__ == "__main__":
tvm.testing.main()