In [98]:
import tvm

import tvm.te as te
import tvm.relax as rx
import tvm.tir as tir

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

from tvm.relax.binding_rewrite import DataflowBlockRewrite
from tvm.relax.analysis import name_to_binding


def showmod(mod: tvm.ir.module.IRModule):
    mod.show(
        black_format=True,
        show_meta=False,
        verbose_expr=True,
        show_object_address=False,
        show_all_struct_info=True,
    )


def createandshowmod(ops):
    te_func = te.create_prim_func(ops).with_attrs({"global_symbol": "test"})
    mod = tvm.IRModule({"test": te_func})
    showmod(mod)


from tvm.relax import Expr, PrimValue, StringImm, Tuple
from typing import List
from tvm.relax.type_converter import args_converter

In [99]:
# NOTE `_ArgsConverter` is a class that automates the conversion of Python
# type arguments into TVM Relax Expr or List[Expr].

# tvm.PrimExpr -> relax.PrimValue
# tvm.String or str -> relax.StringImm
# tuple/list of PrimExpr -> relax.Tuple


def test_args_to_expr(prim_value: PrimValue, string_imm: StringImm, tuple: Tuple):
    assert isinstance(prim_value, PrimValue)
    assert isinstance(string_imm, StringImm)
    assert isinstance(tuple, Tuple)
    print(f"prim_value: {prim_value}, string_imm: {string_imm}, tuple: {tuple}")


# Some variables with python types
prim_value = 1
string_imm = "hello"
tuple = (1, 2, 3)

# This will raise an AssertionError because the arguments can't be converted
# to relax.PrimValue, relax.StringImm, and relax.Tuple.
# test_args_to_expr(prim_value, string_imm, tuple)

In [100]:
print(f"prim_value: {prim_value}, string_imm: {string_imm}, tuple: {tuple}")

test_args_to_expr2 = args_converter.to_expr("prim_value", "string_imm", "tuple")(
    test_args_to_expr
)

# Now `test_args_to_expr2` can implicitly convert its arguments with the python
# types to relax.PrimValue, relax.StringImm, and relax.Tuple.
test_args_to_expr2(prim_value, string_imm, tuple)

prim_value: 1, string_imm: hello, tuple: (1, 2, 3)
prim_value: R.prim_value(1), string_imm: R.str("hello"), tuple: (R.prim_value(1), R.prim_value(2), R.prim_value(3))


In [101]:
print(f"prim_value: {prim_value}, string_imm: {string_imm}, tuple: {tuple}")


@args_converter.to_expr("prim_value", "string_imm", "tuple")
def test_args_to_expr_decorator(
    prim_value: PrimValue, string_imm: StringImm, tuple: Tuple
):
    assert isinstance(prim_value, PrimValue)
    assert isinstance(string_imm, StringImm)
    assert isinstance(tuple, Tuple)
    print(f"prim_value: {prim_value}, string_imm: {string_imm}, tuple: {tuple}")


# We also can use the decorator to achieve the same effect.
test_args_to_expr_decorator(prim_value, string_imm, tuple)

prim_value: 1, string_imm: hello, tuple: (1, 2, 3)
prim_value: R.prim_value(1), string_imm: R.str("hello"), tuple: (R.prim_value(1), R.prim_value(2), R.prim_value(3))


In [102]:
@args_converter.to_list_expr("prim_value", "string_imm", "tuple")
def test_args_to_list_expr(
    prim_value: List[PrimValue], string_imm: List[StringImm], tuple: List[Tuple]
):
    assert isinstance(prim_value, List) and all(
        [isinstance(arg, PrimValue) for arg in prim_value]
    )
    assert isinstance(string_imm, List) and all(
        [isinstance(arg, StringImm) for arg in string_imm]
    )
    assert isinstance(tuple, List) and all([isinstance(arg, Tuple) for arg in tuple])
    print(f"prim_value: {prim_value}, \nstring_imm: {string_imm}, \ntuple: {tuple}")

# Test list of arguments
prim_value = [1, 2, 3]
string_imm = ["hello", "world"]
tuple = [(1, 2, 3), (4, 5, 6)]
test_args_to_list_expr(prim_value, string_imm, tuple)

prim_value: [R.prim_value(1), R.prim_value(2), R.prim_value(3)], 
string_imm: [R.str("hello"), R.str("world")], 
tuple: [(R.prim_value(1), R.prim_value(2), R.prim_value(3)), (R.prim_value(4), R.prim_value(5), R.prim_value(6))]


In [103]:
@args_converter.auto
def test_auto_to_list_expr(
    prim_value: List[Expr], string_imm: List[Expr], tuple: List[Expr]
):
    assert isinstance(prim_value, List) and all(
        [isinstance(arg, PrimValue) for arg in prim_value]
    )
    assert isinstance(string_imm, List) and all(
        [isinstance(arg, StringImm) for arg in string_imm]
    )
    assert isinstance(tuple, List) and all([isinstance(arg, Tuple) for arg in tuple])

    print(prim_value)
    print(string_imm)
    print(tuple)


test_auto_to_list_expr(prim_value, string_imm, tuple)

[R.prim_value(1), R.prim_value(2), R.prim_value(3)]
[R.str("hello"), R.str("world")]
[(R.prim_value(1), R.prim_value(2), R.prim_value(3)), (R.prim_value(4), R.prim_value(5), R.prim_value(6))]


In [104]:
# NOTE We can use `args_converter.auto` to automatically convert the arguments
# without specifying the argument names. But we must specify the types of the
# formal arguments to Expr or List[Expr].
@args_converter.auto
def test_auto_to_expr(prim_value: Expr, string_imm: Expr, tuple: Expr):
    assert isinstance(prim_value, PrimValue)
    assert isinstance(string_imm, StringImm)
    assert isinstance(tuple, Tuple)
    print(f"prim_value: {prim_value}, string_imm: {string_imm}, tuple: {tuple}")


test_auto_to_expr(1, "abc", (1, 2, 3))

prim_value: R.prim_value(1), string_imm: R.str("abc"), tuple: (R.prim_value(1), R.prim_value(2), R.prim_value(3))
