In [7]:
import tvm

import tvm.te as te
import tvm.relax as rx
import tvm.tir as tir

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

from tvm.relax.binding_rewrite import DataflowBlockRewrite
from tvm.relax.analysis import name_to_binding


def showmod(mod: tvm.ir.module.IRModule):
    mod.show(
        black_format=True,
        show_meta=False,
        verbose_expr=True,
        show_object_address=False,
        show_all_struct_info=True,
    )


def createandshowmod(ops):
    te_func = te.create_prim_func(ops).with_attrs({"global_symbol": "test"})
    mod = tvm.IRModule({"test": te_func})
    showmod(mod)

In [8]:
def struct_info_test():
    s0 = rx.ObjectStructInfo()
    showmod(s0)
    print(tvm.ir.save_json(s0))

    s0 = rx.ShapeStructInfo([1, 2, 3])
    showmod(s0)
    assert s0.ndim == 3

    t0 = rx.TensorStructInfo([1, 2, 3], "float32")
    showmod(t0)
    assert t0.ndim == 3
    assert t0.dtype == "float32"
    print(t0.shape)  # R.shape([1,2,3])

    # NOTE can't compare `ShapeExpr` as follows.
    # there is no `__eq__` method in `ShapeExpr`
    # assert t0.shape == R.shape([1, 2, 3])
    assert list(t0.shape.values) == [1, 2, 3]

    shapeVar = rx.Var("shape", rx.ShapeStructInfo(ndim=3))
    t1 = rx.TensorStructInfo(shapeVar, "float32")
    showmod(t1)
    assert t1.ndim == 3
    assert t1.dtype == "float32"
    assert t1.shape == shapeVar

    t2 = rx.TupleStructInfo([t0, t1])
    showmod(t2)
    assert t2.fields[0] == t0
    assert t2.fields[1] == t1

    m = tvm.tir.Var("m", "int64")
    n = tvm.tir.Var("n", "int64")
    k = tvm.tir.Var("k", "int64")

    a = rx.TensorStructInfo([m, k], "float32")
    b = rx.TensorStructInfo([k, n], "float32")
    c = rx.TensorStructInfo([m, n], "float32")

    f = rx.FuncStructInfo([a, b], c)
    showmod(f)

    f1 = rx.FuncStructInfo.opaque_func()
    showmod(f1)


struct_info_test()

{
  "root": 1, 
  "nodes": [
    {
      "type_key": ""
    }, 
    {
      "type_key": "relax.ObjectStructInfo", 
      "attrs": {"span": "0"}
    }
  ], 
  "b64ndarrays": [], 
  "attrs": {"tvm_version": "0.21.dev0"}
}


R.shape([1, 2, 3])
