Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity][TVMScript] Parse R.Object return type from call_pure_packed #16593

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 13 additions & 0 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tvm
import tvm.runtime
from tvm.runtime.object import Object
from tvm.runtime import ObjectGeneric

from . import _ffi_api
from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc, GlobalVar, Var
Expand Down Expand Up @@ -709,12 +710,24 @@ def call_pure_packed(
func = func.global_symbol

op = ExternFunc(func)

if sinfo_args is None:
raise ValueError("R.call_pure_packed is required to have type_args")

if isinstance(sinfo_args, tuple): # type: ignore
sinfo_args = list(sinfo_args)
elif not isinstance(sinfo_args, list):
sinfo_args = [sinfo_args]

sinfo_args = [
sinfo()
if callable(sinfo)
else sinfo.asobject()
if isinstance(sinfo, ObjectGeneric)
else sinfo
for sinfo in sinfo_args
]

# note: if we need attributes, we can also take them here

return _ffi_api.call_pure_packed(op, args, None, sinfo_args) # type: ignore # pylint: disable=no-member
Expand Down
16 changes: 9 additions & 7 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,13 +357,15 @@ def call_packed(
sinfo_args = list(sinfo_args)
elif not isinstance(sinfo_args, list):
sinfo_args = [sinfo_args]
for i, sinfo_arg in enumerate(sinfo_args):
if callable(sinfo_arg):
sinfo_arg = sinfo_arg()
# Convert possible StructInfoProxy to StructInfo
if isinstance(sinfo_arg, ObjectGeneric):
sinfo_arg = sinfo_arg.asobject()
sinfo_args[i] = sinfo_arg

sinfo_args = [
sinfo()
if callable(sinfo)
else sinfo.asobject()
if isinstance(sinfo, ObjectGeneric)
else sinfo
for sinfo in sinfo_args
]

is_default = False
if "attrs_type_key" in kwargs:
Expand Down
14 changes: 14 additions & 0 deletions tests/python/relax/test_tvmscript_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1800,6 +1800,20 @@ def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
_check(foo, bb.get()["foo"])


def test_call_pure_packed_returning_object():
@R.function
def foo() -> R.Object:
z = R.call_pure_packed("dummy_func", sinfo_args=R.Object)
return z

bb = relax.BlockBuilder()
with bb.function("foo", params=[]):
z = bb.emit(R.call_pure_packed("dummy_func", sinfo_args=[relax.ObjectStructInfo()]))
bb.emit_func_output(z)

_check(foo, bb.get()["foo"])


def test_private_function():
@I.ir_module
class Addition:
Expand Down