In [None]:
import tvm

import tvm.te as te

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R


def showmod(mod: tvm.ir.module.IRModule):
    mod.show(
        black_format=True,
        show_meta=False,
        verbose_expr=True,
        show_object_address=False,
        show_all_struct_info=True,
    )


def createandshowmod(ops):
    te_func = te.create_prim_func(ops).with_attrs({"global_symbol": "test"})
    mod = tvm.IRModule({"test": te_func})
    showmod(mod)
    return mod


from tvm.ir.base import *

#### class EnvFunc

In [None]:
M, K, N = 128, 256, 512


@I.ir_module
class Module:
    @T.prim_func
    def main(
        A: T.Buffer(shape=(M, K), dtype="float32"),
        B: T.buffer(shape=(K, N), dtype="float32"),
    ):
        Y = T.alloc_buffer(shape=(M, N), dtype="float32")
        for i, j, k in T.grid(M, K, N):
            with T.block("Y"):
                i, j, k = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y[i, k] = 0.0
                Y[i, k] += A[i, j] * B[j, k]


@tvm.register_func("test.myfunc")
def printfuncofmod(mod: tvm.IRModule):
    assert len(mod.functions_items()) > 0
    print("func: ", mod.functions_items()[0][0])
    showmod(mod.functions_items()[0][1])


func = tvm.get_global_func("test.myfunc")
func(Module)

func = tvm.ir.EnvFunc.get("test.myfunc")
func(Module)

func:  I.GlobalVar("main")


func:  I.GlobalVar("main")


#### load_json and save_json

In [None]:
# Save and load an Expr
x: tvm.ir.PrimExpr = 1
print(x)

json_str = tvm.ir.save_json(x)
print(json_str)

y = tvm.ir.load_json(json_str)
print(y)

# Save and load a Function
M, N = 4, 256
shape = (M, N)
dtype = "float32"
name = "buffer"
data = None
strides = [N, 1]
elem_offset = tvm.tir.Var("elem_offset", dtype="int32")
scope = "global"
data_alignment = 4
offset_factor = 4
buffer_type = "auto_broadcast"
axis_separators = []

buffer = tvm.tir.decl_buffer(
    shape=shape,
    dtype=dtype,
    name=name,
    data=data,
    strides=strides,
    elem_offset=elem_offset,
    scope=scope,
    data_alignment=data_alignment,
    offset_factor=offset_factor,
    buffer_type=buffer_type,
    axis_separators=axis_separators,
)

i = tvm.tir.Var("i", "int32")
j = tvm.tir.Var("j", "int32")

initial_loop_M = tvm.tir.For(
    loop_var=i,
    min=0,
    extent=M,
    kind=tvm.tir.ForKind.VECTORIZED,
    body=tvm.tir.BufferStore(buffer, 1.0, [i, 0]),
)
initial_loop_N = tvm.tir.For(
    loop_var=j,
    # TODO: 缺少一个 Pass，能否自动识别可向量化但是 min 不为零的循环。
    min=0,  # change to 0 for vectorized
    extent=N - 1,
    kind=tvm.tir.ForKind.VECTORIZED,
    body=tvm.tir.BufferStore(buffer, 1.0, [0, j + 1]),
)

m = tvm.tir.Var("m", "int32")
n = tvm.tir.Var("n", "int32")

inner_loop = tvm.tir.For(
    loop_var=n,
    min=0,  # change to 0 for vectorized
    extent=N - 1,
    kind=tvm.tir.ForKind.VECTORIZED,  # serial/parallel/vectorized/unrolled
    body=tvm.tir.BufferStore(
        buffer,
        tvm.tir.BufferLoad(buffer, [m - 1, n + 1]) + tvm.tir.BufferLoad(buffer, [m, n]),
        [m, n + 1],
    ),
)
outer_loop = tvm.tir.For(
    loop_var=m,
    min=1,  # Only the last index of a buffer can be used in vectorized loops.
    extent=M - 1,
    # UnrollLoop 通常需要以下条件之一才能生效：
    # - 循环边界是编译时常量且较小（默认阈值是 16）
    # - 循环被显式标记为 kind=UNROLLED
    kind=tvm.tir.ForKind.UNROLLED,
    body=inner_loop,
)

func = tvm.tir.PrimFunc(
    params=[elem_offset, buffer],
    body=tvm.tir.SeqStmt([initial_loop_M, initial_loop_N, outer_loop]),
    ret_type=None,
)

# Save and load
func_str = tvm.ir.save_json(func)
func = tvm.ir.load_json(func_str)
showmod(func)

1
{
  "root": 1, 
  "nodes": [
    {
      "type_key": ""
    }, 
    {
      "type_key": "runtime.BoxInt", 
      "repr_str": "1"
    }
  ], 
  "b64ndarrays": [], 
  "attrs": {"tvm_version": "0.21.dev0"}
}
1


#### structural_equal

In [None]:
"""Check structural equality of lhs and rhs.

The structural equality is recursively defined in the DAG of IRNodes.
There are two kinds of nodes:

- Graph node: a graph node in lhs can only be mapped as equal to
  one and only one graph node in rhs.
- Normal node: equality is recursively defined without the restriction
  of graph nodes.

Vars(tir::Var, relax::Var) are graph nodes.

A var-type node(e.g. tir::Var) can be mapped as equal to another var
with the same type if one of the following condition holds:

- They appear in a same definition point(e.g. function argument).
- They points to the same VarNode via the same_as relation.
- They appear in a same usage point, and map_free_vars is set to be True.

The rules for var are used to remap variables occurs in function
arguments and let-bindings.

Parameters
----------
lhs : Object
    The left operand.

rhs : Object
    The left operand.

map_free_vars : bool
    Whether free variables (i.e. variables without a definition site) should be mapped
    as equal to each other.

Return
------
result : bool
    The comparison result.

See Also
--------
structural_hash
assert_strucural_equal
"""


# Standard Format
# ***************
# Let's take an example of ``mm_relu`` from :ref:`tir-learning`. Here is the
# complete format of the ir_module and in TVMScript:
@I.ir_module
class MyModule:
    @T.prim_func
    def mm_relu(
        A: T.Buffer((128, 128), "float32"),
        B: T.Buffer((128, 128), "float32"),
        C: T.Buffer((128, 128), "float32"),
    ):
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i in range(128):
            for j in range(128):
                for k in range(128):
                    with T.block("Y"):
                        vi = T.axis.spatial(128, i)
                        vj = T.axis.spatial(128, j)
                        vk = T.axis.reduce(128, k)
                        T.reads(A[vi, vk], B[vk, vj])
                        T.writes(Y[vi, vj])
                        with T.init():
                            Y[vi, vj] = T.float32(0)
                        Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i in range(128):
            for j in range(128):
                with T.block("C"):
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j)
                    T.reads(Y[vi, vj])
                    T.writes(C[vi, vj])
                    C[vi, vj] = T.max(Y[vi, vj], T.float32(0))


# Concise with Syntactic Sugar
# ****************************
# For ease of writing, we can employ the following syntactic sugar to
# streamline the code:
#
# - Utilize ``T.grid`` to condense nested loops;
# - Employ ``T.axis.remap`` to abbreviate block iterator annotations;
# - Exclude ``T.reads`` and ``T.writes`` for blocks whose content can
#   be inferred from the block body;
@I.ir_module
class ConciseModule:
    @T.prim_func
    def mm_relu(
        A: T.Buffer((128, 128), "float32"),
        B: T.Buffer((128, 128), "float32"),
        C: T.Buffer((128, 128), "float32"),
    ):
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))


# Interactive with Python Variables
# *********************************
# Despite TVMScript not being executed by a Python interpreter, limited
# interaction with Python is feasible. For instance, Python variables can
# be used to ascertain the shape and data type of a TensorIR.

# Python variables
M = N = K = 128
dtype = "float32"


# IRModule in TVMScript
@I.ir_module
class ConciseModuleFromPython:
    @T.prim_func
    def mm_relu(
        A: T.Buffer((M, K), dtype),
        B: T.Buffer((K, N), dtype),
        C: T.Buffer((M, N), dtype),
    ):
        Y = T.alloc_buffer((M, N), dtype)
        for i, j, k in T.grid(M, N, K):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y[vi, vj] = T.cast(T.float32(0), dtype)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(M, N):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = T.max(Y[vi, vj], T.cast(T.float32(0), dtype))


print(tvm.ir.structural_equal(MyModule, ConciseModule))
print(tvm.ir.structural_equal(MyModule, ConciseModuleFromPython))

assert_structural_equal(MyModule, ConciseModule, map_free_vars=False)

for gvar, func in MyModule.functions.items():
    showmod(func)

True
True


#### get_first_structural_mismatch

In [None]:
"""Like structural_equal(), but returns the ObjectPaths of the first detected mismatch.

Parameters
----------
lhs : Object
    The left operand.

rhs : Object
    The left operand.

map_free_vars : bool
    Whether free variables (i.e. variables without a definition site) should be mapped
    as equal to each other.

Returns
-------
mismatch: Optional[Tuple[ObjectPath, ObjectPath]]
    `None` if `lhs` and `rhs` are structurally equal.
    Otherwise, a tuple of two ObjectPath objects that point to the first detected mismtach.

"""


# Standard Format
# ***************
# Let's take an example of ``mm_relu`` from :ref:`tir-learning`. Here is the
# complete format of the ir_module and in TVMScript:
@I.ir_module
class MyModule:
    @T.prim_func
    def mm_relu(
        A: T.Buffer((128, 128), "float32"),
        B: T.Buffer((128, 128), "float32"),
        C: T.Buffer((128, 128), "float32"),
    ):
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i in range(128):
            for j in range(128):
                for k in range(128):
                    with T.block("Y"):
                        vi = T.axis.spatial(128, i)
                        vj = T.axis.spatial(128, j)
                        vk = T.axis.reduce(128, k)
                        T.reads(A[vi, vk], B[vk, vj])
                        T.writes(Y[vi, vj])
                        with T.init():
                            Y[vi, vj] = T.float32(0)
                        Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i in range(128):
            for j in range(128):
                with T.block("C"):
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j)
                    T.reads(Y[vi, vj])
                    T.writes(C[vi, vj])
                    C[vi, vj] = T.max(Y[vi, vj], T.float32(0))


# Concise with Syntactic Sugar
# ****************************
# For ease of writing, we can employ the following syntactic sugar to
# streamline the code:
#
# - Utilize ``T.grid`` to condense nested loops;
# - Employ ``T.axis.remap`` to abbreviate block iterator annotations;
# - Exclude ``T.reads`` and ``T.writes`` for blocks whose content can
#   be inferred from the block body;
@I.ir_module
class ConciseModule:
    @T.prim_func
    def mm_relu(
        A: T.Buffer(
            (128, 256), "float32"
        ),  # 128 vs 256 is the first structural mismatch
        B: T.Buffer((128, 128), "float32"),
        C: T.Buffer((128, 128), "float32"),
    ):
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))


print(get_first_structural_mismatch(MyModule, ConciseModule, map_free_vars=False))

(<root>.functions[I.GlobalVar("mm_relu")].buffer_map[A_handle].shape[1].value, <root>.functions[I.GlobalVar("mm_relu")].buffer_map[A_handle].shape[1].value)


-1092125359670513877