In [None]:
import tvm

import tvm.te as te
import tvm.relax as rx
import tvm.tir as tir

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

from tvm.relax.binding_rewrite import DataflowBlockRewrite
from tvm.relax.analysis import name_to_binding


def showmod(mod: tvm.ir.module.IRModule, show_meta=False):
    mod.show(
        black_format=True,
        show_meta=show_meta,
        verbose_expr=True,
        show_object_address=False,
        show_all_struct_info=True,
    )


def createandshowmod(ops):
    te_func = te.create_prim_func(ops).with_attrs({"global_symbol": "test"})
    mod = tvm.IRModule({"test": te_func})
    showmod(mod)

In [None]:
testSuites = {}


def register(func):
    if func.__name__ in testSuites:
        raise Exception("Duplicated test suite name: " + func.__name__)
    else:
        testSuites[func.__name__] = func
    return func

In [None]:
@register
def test_var():
    v0 = rx.Var("v0", R.Tensor((1, 2, 3), "float32"))
    showmod(v0)
    assert v0.name_hint == "v0"
    assert v0.struct_info == R.Tensor((1, 2, 3), "float32")


testSuites["test_var"]()

In [None]:
@register
def test_dataflow_var():
    v0 = rx.DataflowVar("v0", R.Tensor((1, 2, 3), "float32"))
    showmod(v0)
    assert v0.name_hint == "v0"
    assert v0.struct_info == R.Tensor((1, 2, 3), "float32")


testSuites["test_dataflow_var"]()

In [None]:
@register
def test_match_cast():

    # rx.MatchCast(var,value,struct_info)
    # if `value` match `struct_info`, cast `value`'s struct_info into `struct_info` in runtime,
    # then assign `value` to `var`.
    # rx.MatchCast is mainly used to dynamic shape inference
    m = tir.Var("m", "int64")
    n = tir.Var("n", "int64")

    x = rx.Var("x", R.Tensor([m, n], "float32"))
    y = rx.MatchCast(rx.Var("y"), x, R.Tensor([n, m], "float32"))
    showmod(y)

    assert y.struct_info == R.Tensor([n, m], "float32")

    shape = rx.const([16, 8], "int32")
    b0 = rx.MatchCast(rx.Var("b0"), shape, R.Tensor([m, n], "int32"))
    showmod(b0)
    assert b0.struct_info == R.Tensor([m, n], "int32")

    value = rx.Var("value", R.Tensor(None, "float32", ndim=-1))
    var = rx.Var("var", R.Tensor([m, n], "float32"))
    b1 = rx.MatchCast(var, value, R.Tensor([10, 10], "float32"))
    showmod(b1)

    assert b1.value == value
    assert b1.struct_info == R.Tensor([10, 10], "float32")


testSuites["test_match_cast"]()

In [None]:
@register
def test_var_binding():
    m = tir.Var("m", "int64")
    n = tir.Var("n", "int64")

    import numpy as np

    # binding a value to a var

    # rx.const support numpy array as arguments
    value1 = rx.const(np.random.rand(24, 56))
    bind1 = rx.VarBinding(rx.Var("bind1"), value1)
    showmod(bind1)

    assert bind1.var.name_hint == "bind1"
    assert bind1.value == value1

    shape = rx.const(np.array([16, 8]), "int32")
    bind2 = rx.MatchCast(rx.Var("bind2"), shape, R.Tensor([m, n], "int32"))
    showmod(bind2)
    assert bind2.struct_info == R.Tensor([m, n], "int32")
    assert bind2.value == shape
    # rx.MatchCast is also a relax.Binding

    block0 = rx.BindingBlock([bind1, bind2])
    assert block0.bindings[0] == bind1
    assert block0.bindings[1] == bind2
    showmod(block0)


testSuites["test_var_binding"]()

In [None]:
@register
def test_dataflow_block():
    m = tir.Var("m", "int64")
    n = tir.Var("n", "int64")

    shape = rx.const([16, 8], "int32")
    b0 = rx.MatchCast(rx.Var("v0"), shape, R.Tensor([m, n], "int32"))

    import numpy as np

    v1 = rx.Var("v1")
    val1 = rx.const([1, 2], "int32")
    b1 = rx.VarBinding(v1, val1)

    block1 = rx.DataflowBlock([b0, b1])
    assert block1.bindings[0] == b0
    assert block1.bindings[1] == b1
    showmod(block1)


testSuites["test_dataflow_block"]()

In [None]:
@register
def test_seq_expr():
    x = rx.Var("x", R.Tensor([2, 4], "int32"))
    y = rx.Var("y", R.Tensor([4, 8], "int32"))
    res = rx.Var("ret", R.Tensor(ndim=-1))

    varBind1 = rx.VarBinding(x, rx.Call(tvm.ir.Op.get("relax.add"), [x, x]))
    varBind2 = rx.VarBinding(res, rx.Call(tvm.ir.Op.get("relax.multiply"), [x, y]))

    bindBlock = rx.BindingBlock([varBind1, varBind2])

    seq1 = rx.SeqExpr([bindBlock], res)
    assert seq1.body == res
    assert seq1.blocks[0].bindings[0] == varBind1
    assert seq1.blocks[0].bindings[1] == varBind2

    showmod(seq1)
    print("=" * 10)
    showmod(res.struct_info)


testSuites["test_seq_expr"]()



In [None]:
@register
def test_func():
    m, k, n = (
        tvm.tir.Var("m", "int64"),
        tvm.tir.Var("k", "int64"),
        tvm.tir.Var("n", "int64"),
    )

    a = rx.Var("a", R.Tensor([m, k], "int32"))
    b = rx.Var("b", R.Tensor([k, n], "int32"))

    c = rx.Call(tvm.ir.Op.get("relax.matmul"), [a, b])

    func = rx.Function([a, b], c, R.Tensor(ndim=-1))

    # update the attribute of func
    func = func.with_attr("global_symbol", "func")
    mod = tvm.IRModule.from_expr(func)
    showmod(mod)
    showmod(mod["func"])


testSuites["test_func"]()

In [None]:
@register
def test_shape_expr():
    shape = [96, 54]
    v1 = rx.Var("v1", R.Tensor(shape))
    s1 = rx.get_shape_of(v1)
    showmod(s1)

    shape_expr = rx.ShapeExpr([10, 20])
    showmod(shape_expr)


testSuites["test_shape_expr"]()

In [None]:
@register
def test_prim_value():
    pv0 = rx.PrimValue(1)
    pv1 = rx.PrimValue(tvm.tir.Mul(2, 3))  # R.prim_value(T.Mul(2, 3))
    pv2 = rx.PrimValue(tvm.tir.Var("n", "int32") + 1)  # R.prim_value(n + 1)
    pv3 = rx.PrimValue(tvm.tir.IntImm("int64", 1))
    showmod(pv0)
    showmod(pv1)
    showmod(pv2)
    showmod(pv3)


testSuites["test_prim_value"]()

In [None]:
@register
def test_call():
    x = rx.Var("x", R.Tensor(ndim=-1))
    y = rx.Var("y", R.Tensor(ndim=-1))
    z = rx.Call(tvm.ir.Op.get("relax.add"), [x, y])
    m = rx.op.add(x, z)

    func = rx.Function([x, y], m, R.Tensor(ndim=-1))
    showmod(func)


testSuites["test_call"]()