Skip to content

Commit

Permalink
[Unity] Check for symbolic vars in PrimValue in when lowering to TIR (#…
Browse files Browse the repository at this point in the history
…16564)

Prior to this commit, a fused relax function could accept a `R.Prim`
value, but wouldn't use it to provide symbolic variables to the fused
function.
  • Loading branch information
Lunderberg committed Feb 14, 2024
1 parent 685355e commit c5aaa99
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 2 deletions.
7 changes: 5 additions & 2 deletions python/tvm/relax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,11 @@ def _convert_te_arg_helper(arg):
arg, ShapeExpr
), "For Expr having ShapeStructInfo, emit_te now only supports ShapeExpr"
return [_convert_te_arg_helper(val) for val in arg.values]
if isinstance(arg.struct_info, PrimStructInfo):
return arg.value
if (
isinstance(arg.struct_info, PrimStructInfo)
and arg.struct_info.value is not None
):
return _convert_te_arg_helper(arg.struct_info.value)
elif isinstance(arg, (list, Array)):
return [_convert_te_arg_helper(x) for x in arg]
elif isinstance(arg, tuple):
Expand Down
54 changes: 54 additions & 0 deletions tests/python/relax/test_blockbuilder_emit_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
""" This file tests advanced emit_te features with help of TVMScript assertion"""
# The tests here depend on tvmscript
import tvm
from tvm import te, tir
from tvm import relax as rx
from tvm.ir.base import assert_structural_equal
Expand Down Expand Up @@ -69,3 +70,56 @@ def main(
return gv

assert_structural_equal(after, Expected)


def test_symbolic_shape_in_prim_value():
"""Symbolic vars may be provided to TE in R.Prim"""

def te_slice(tensor, i):
return tvm.te.compute([tensor.shape[1]], lambda j: tensor[i, j], name="slice")

def from_builder():
bb = rx.BlockBuilder()
A = rx.Var("A", R.Tensor([16, 16], "float32"))
tir_i = tvm.tir.Var("tir_i", "int64")
relax_i = rx.Var("relax_i", R.Prim(value=tir_i))

with bb.function("main", params=[A, relax_i]):
A_sliced = bb.emit_te(te_slice, A, relax_i)
bb.emit_func_output(A_sliced)

return bb.get()

@I.ir_module
class Expected:
@T.prim_func(private=True)
def te_slice(
A: T.Buffer([T.int64(16), T.int64(16)], "float32"),
Output: T.Buffer(T.int64(16), "float32"),
row_index: T.int64,
):
T.func_attr({"tir.noalias": T.bool(True)})

for i in range(A.shape[1]):
with T.block("slice"):
vi = T.axis.remap("S", [i])
Output[vi] = A[row_index, vi]

@R.function
def main(
A: R.Tensor([16, 16], "float32"),
arg_row_index: R.Prim(value="row_index"),
):
cls = Expected

row_index = T.int64()

gv = R.call_tir(
cls.te_slice,
A,
tir_vars=[row_index],
out_sinfo=R.Tensor([16], "float32"),
)
return gv

tvm.ir.assert_structural_equal(from_builder(), Expected)

0 comments on commit c5aaa99

Please sign in to comment.