From https://tvm.apache.org/docs/arch/index.html

# Use TVMScript to create a TensorIR PrimFunc with fixed shape

In [None]:
import tvm

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

from tvm.tir.analysis import estimate_tir_flops

def showmod(mod: tvm.ir.module.IRModule):
    mod.show(
        black_format=True,
        show_meta=False,
        verbose_expr=True,
        show_object_address=False,
        show_all_struct_info=True,
    )


@I.ir_module
class Module:
    @T.prim_func
    def main(
        A: T.Buffer(shape=(128, 256), dtype="float32"),
        B: T.Buffer(shape=(256, 512), dtype="float32"),
        Y: T.Buffer(shape=(128, 512), dtype="float32"),
    ):
        for i, j, k in T.grid(128, 512, 256):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap(
                    kinds="SSR", bindings=[i, j, k], dtype="int32"
                )
                with T.init():
                    Y[vi, vj] = 0
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]


mod = Module
showmod(mod)
print(estimate_tir_flops(mod))


sch = tvm.tir.Schedule(mod["main"])
block_Y = sch.get_block("Y")
# Get the outer loops of the block
i, j, k = sch.get_loops(block_Y)

sch.split(i, [4, 32])
sch.split(j, [8, 64])
showmod(sch.mod)

i_0, i_1, j_0, j_1, k = sch.get_loops(block_Y)
A_local = sch.cache_read(block_Y, "A", "local")
B_local = sch.cache_read(block_Y, "B", "local")

sch.compute_at(A_local, i_0)
sch.compute_at(B_local, i_1)
showmod(sch.mod)

33554432.0


TypeError: missing a required argument: 'block'

# Use TVMScript to create a TensorIR PrimFunc with dynamic shape

In [None]:
import numpy as np

dtype = "float32"


@I.ir_module
class mm_relu:
    @T.prim_func
    def main(A: T.handle, B: T.handle, C: T.handle):
        M, K, N = T.int64(), T.int64(), T.int64()

        A_Buf = T.match_buffer(A, [M, K], dtype)
        B_Buf = T.match_buffer(B, [K, N], dtype)
        C_Buf = T.match_buffer(C, [M, N], dtype)

        Y_Buf = T.alloc_buffer(shape=[M, N], dtype=dtype)

        for i, j, k in T.grid(M, N, K):
            with T.block("Y_Buf"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y_Buf[vi, vj] = T.cast(0.0, dtype)
                Y_Buf[vi, vj] = Y_Buf[vi, vj] + A_Buf[vi, vk] * B_Buf[vk, vj]

        for i, j in T.grid(M, N):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C_Buf[vi, vj] = T.max(Y_Buf[vi, vj], T.cast(0.0, dtype))


mod = mm_relu
showmod(mod)


def evaluate_dynamic_shape(lib: tvm.runtime.Module, M: int, N: int, K: int):
    A = tvm.nd.array(np.random.rand(M, K).astype(dtype))
    B = tvm.nd.array(np.random.rand(K, N).astype(dtype))
    C = tvm.nd.array(np.zeros((M, N), dtype=dtype))
    lib(A, B, C)
    return C.numpy()


dyn_shape_lib = tvm.compile(mod, target="llvm")
print(evaluate_dynamic_shape(dyn_shape_lib, M=4, N=4, K=4))
print(evaluate_dynamic_shape(dyn_shape_lib, M=64, N=64, K=128))

[[0.5969938  0.31348184 0.5708697  0.54372936]
 [0.27157515 0.15715148 0.9323122  0.83957475]
 [0.8692241  0.4826384  1.2326066  1.2651988 ]
 [0.6769788  0.35828868 1.1284438  0.994617  ]]
[[33.65542  34.19142  37.55951  ... 33.432407 34.989227 32.59262 ]
 [32.08593  31.031483 34.483974 ... 31.428686 32.827892 31.715357]
 [31.162546 28.180746 33.096756 ... 30.35415  32.720776 30.617868]
 ...
 [32.41506  33.08332  34.810047 ... 30.944891 33.084816 31.294012]
 [31.619503 31.252602 33.693348 ... 30.552355 33.015503 32.91438 ]
 [32.70215  32.70163  35.50897  ... 32.496212 33.199398 32.513847]]


# Use TE to create a TensorIR PrimFunc with fixed shape

In [None]:
from tvm import te

M, N, K = 128, 128, 128

A = te.placeholder((M, K), name="A")
B = te.placeholder((K, N), name="B")
k = te.reduce_axis((0, K), "k")
Y = te.compute((M, N), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="Y")
C = te.compute((M, N), lambda i, j: te.max(Y[i, j], 0.0), name="C")

te_func = te.create_prim_func([A, B, C]).with_attr({"global_symbol": "mm_relu"})
TEModule = tvm.IRModule({"mm_relu": te_func})
TEModule.show()

# Use TE to create a TensorIR PrimFunc with dynamic shape

In [None]:
M, N, K = te.var("M"), te.var("N"), te.var("K")
A = te.placeholder((M, K), name="A")
B = te.placeholder((K, N), name="B")
k = te.reduce_axis((0, K), name="k")
Y = te.compute((M, N), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="Y")
C = te.compute((M, N), lambda i, j: te.max(Y[i, j], 0.0), name="C")

te_func = te.create_prim_func([A, B, C]).with_attr({"global_symbol": "mm_relu"})
TEModule = tvm.IRModule({"mm_relu": te_func})
TEModule.show()


dyn_shape_lib = tvm.compile(TEModule, target="llvm")
print(evaluate_dynamic_shape(dyn_shape_lib, M=4, N=4, K=4))
print(evaluate_dynamic_shape(dyn_shape_lib, M=64, N=64, K=128))

[[0.4420731  0.46422315 0.34757143 0.6245781 ]
 [1.1127197  0.7421403  0.86480963 1.0171852 ]
 [0.95629585 0.8491702  0.7548789  1.1682342 ]
 [1.5452828  1.3739009  1.3338703  1.681134  ]]
[[29.226852 32.625904 33.503204 ... 33.044163 28.95899  31.496515]
 [36.18077  41.110897 40.24666  ... 36.26573  36.107643 34.727505]
 [34.42203  38.699276 39.330524 ... 35.501095 36.18746  33.735184]
 ...
 [32.558502 32.58266  34.919712 ... 31.748478 31.427551 29.18895 ]
 [28.33954  32.832607 35.059547 ... 30.971407 29.516375 29.559652]
 [34.509884 36.96429  34.686565 ... 34.59144  32.244663 29.51299 ]]


# Transformation on TIR

In [None]:
import numpy as np

dtype = "float32"
M, K, N = 256, 128, 256


@I.ir_module
class mm_relu:
    @T.prim_func
    def main(A: T.handle, B: T.handle, C: T.handle):
        M, K, N = T.int64(), T.int64(), T.int64()

        A_Buf = T.match_buffer(A, [M, K], dtype)
        B_Buf = T.match_buffer(B, [K, N], dtype)
        C_Buf = T.match_buffer(C, [M, N], dtype)

        Y_Buf = T.alloc_buffer(shape=[M, N], dtype=dtype)

        for i, j, k in T.grid(M, N, K):
            with T.block("Y_Buf"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y_Buf[vi, vj] = T.cast(0.0, dtype)
                Y_Buf[vi, vj] = Y_Buf[vi, vj] + A_Buf[vi, vk] * B_Buf[vk, vj]

        for i, j in T.grid(M, N):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C_Buf[vi, vj] = T.max(Y_Buf[vi, vj], T.cast(0.0, dtype))


mod = mm_relu
showmod(mod)


a_np = np.random.uniform(size=(M, K)).astype("float32")
b_np = np.random.uniform(size=(K, N)).astype("float32")
c_np = a_np @ b_np

a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.array(b_np)
c_nd = tvm.nd.array(np.zeros((M, N), dtype="float32"))


def evaluate(mod: tvm.IRModule):
    lib = tvm.tir.build(mod, target="llvm")
    # check correctness
    lib(a_nd, b_nd, c_nd)
    np.testing.assert_allclose(c_nd.numpy(), c_np, rtol=1e-5)
    # evaluate performance
    f_timer = lib.time_evaluator("main", tvm.cpu())
    print(f_timer(a_nd, b_nd, c_nd))


evaluate(mod)

Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
   8.1039       8.1039       8.1039       8.1039       0.0000                  


In [None]:
# Loop Tiling
sch = tvm.tir.Schedule(mod)
block_Y = sch.get_block("Y_Buf")
i, j, k = sch.get_loops(block_Y)
j0, j1 = sch.split(j, factors=[None, 16])
sch.reorder(j0, k, j1)

showmod(sch.mod)
evaluate(sch.mod)

Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
   4.8435       4.8435       4.8435       4.8435       0.0000                  


In [None]:
# Leverage Localities
block_C = sch.get_block("C")
sch.reverse_compute_at(block_C, j0)

showmod(sch.mod)
evaluate(sch.mod)

Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
   4.9527       4.9527       4.9527       4.9527       0.0000                  


In [None]:
# Rewrite Reduction
sch.decompose_reduction(block_Y, k)

showmod(sch.mod)
evaluate(sch.mod)

Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
   6.0450       6.0450       6.0450       6.0450       0.0000                  


In [None]:
sch.trace.show()

# Create Relax programs using TVMScript

In [None]:
from tvm.script import relax as R


@I.ir_module
class Module:
    @R.function
    def main(
        x: R.Tensor(shape=(1, 784), dtype="float32"),
        weight: R.Tensor(shape=(784, 256), dtype="float32"),
        bias: R.Tensor(shape=(256,), dtype="float32"),
    ) -> R.Tensor(shape=(1, 256), dtype="float32"):
        with R.dataflow():
            lv0 = R.matmul(x, weight)
            lv1 = R.add(lv0, bias)
            out = R.sigmoid(lv1)
            R.output(out)
        return out


mod = Module
showmod(mod)

from tvm.relax.transform import LegalizeOps, ToNonDataflow

# 算子合法化（将高层算子如 matmul 转换为 TIR）
mod = LegalizeOps()(mod)

showmod(mod)

# Lower Relax Abstraction to TIR

In [None]:
from tvm.relax.transform import LegalizeOps, ToNonDataflow


@I.ir_module
class Module:
    @R.function
    def main(
        data: R.Tensor(("n", 784), "float32"),
        w0: R.Tensor((784, 128), "float32"),
        b0: R.Tensor((128,), "float32"),
        w1: R.Tensor((128, 10), "float32"),
        b1: R.Tensor((10,), "float32"),
    ) -> R.Tensor(("n", 10), "float32"):
        with R.dataflow():
            lv0 = R.matmul(data, w0) + b0
            lv1 = R.nn.relu(lv0)
            lv2 = R.matmul(lv1, w1) + b1
            R.output(lv2)
        return lv2


mod = Module
showmod(mod)

# 算子合法化（将高层算子如 matmul 转换为 TIR）
mod = LegalizeOps()(mod)

showmod(mod)

# The concept of pure and side-effect

- A function is pure or side-effect free if:
  - it only reads from its inputs and returns the result via its output
  - it will not change other parts of the program (such as incrementing a global counter).

For example, all R.call_tir functions are pure functions, as they only read from their inputs and write the output to another new allocated tensor. However, the inplace operations are not pure functions, in other words, they are side-effect functions, because they will change the existing intermediate or input tensors.

A dataflow block is a way for us to mark the computational graph regions of the program. Specifically, ***within a dataflow block, all the operations need to be side-effect free***. Outside a dataflow block, the operations can contain side-effect.

# Call TensorIR functions in Relax function

In [None]:
dtype = "float32"


@I.ir_module
class Module:
    @T.prim_func
    def relu(A: T.handle, Y: T.handle):
        m, n = T.int64(), T.int64()
        A_Buf = T.match_buffer(A, shape=(m, n), dtype=dtype)
        Y_Buf = T.match_buffer(Y, shape=(m, n), dtype=dtype)
        for i, j in T.grid(m, n):
            with T.block("Y"):
                vi, vj = T.axis.remap("SS", [i, j])
                Y_Buf[vi, vj] = T.max(A_Buf[vi, vj], T.cast(T.float32(0.0), dtype))

    @T.prim_func
    def matmul(A: T.handle, B: T.handle, Y: T.handle):
        m, n, k = T.int64(), T.int64(), T.int64()
        A_Buf = T.match_buffer(A, shape=(m, k), dtype=dtype)
        B_Buf = T.match_buffer(B, shape=(k, n), dtype=dtype)
        Y_Buf = T.match_buffer(Y, shape=(m, n), dtype=dtype)
        for i, j, k in T.grid(m, n, k):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y_Buf[vi, vj] = T.cast(T.float32(0), dtype)
                Y_Buf[vi, vj] = Y_Buf[vi, vj] + A_Buf[vi, vk] * B_Buf[vk, vj]

    @R.function
    def main(
        A: R.Tensor(("m", "k"), dtype=dtype), B: R.Tensor(("k", "n"), dtype=dtype)
    ) -> R.Tensor(("m", "n"), dtype=dtype):
        with R.dataflow():
            m, k, n = T.int64(), T.int64(), T.int64()
            cls = Module
            lv0 = R.call_tir(cls.matmul, (A, B), R.Tensor((m, n), dtype=dtype))
            lv1 = R.call_tir(cls.relu, (lv0,), R.Tensor((m, n), dtype=dtype))
            R.output(lv1)
        return lv1


mod = Module
showmod(mod)

# Create Relax programs using NNModule API

In [None]:
from tvm.relax.frontend import nn


class NNModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        return x


mod, params = NNModule().export_tvm(
    spec={"forward": {"x": nn.spec.Tensor(("n", 784), "float32")}}
)
showmod(mod)

# 算子合法化（将高层算子如 matmul 转换为 TIR）
mod = LegalizeOps()(mod)

showmod(mod)

# Create Relax programs using Block Builder API

In [None]:
bb = tvm.relax.BlockBuilder()
n = T.int64()
x = tvm.relax.Var("x", R.Tensor((n, 784), "float32"))
fc1_weight = tvm.relax.Var("fc1_weight", R.Tensor((128, 784), "float32"))
fc1_bias = tvm.relax.Var("fc1_bias", R.Tensor((128,), "float32"))
fc2_weight = tvm.relax.Var("fc2_weight", R.Tensor((10, 128), "float32"))
fc2_bias = tvm.relax.Var("fc2_bias", R.Tensor((10,), "float32"))

with bb.function(name="main", params=[x, fc1_weight, fc1_bias, fc2_weight, fc2_bias]):
    with bb.dataflow():
        lv0 = bb.emit(
            tvm.relax.op.matmul(x, tvm.relax.op.permute_dims(fc1_weight, axes=(1, 0)))
            + fc1_bias
        )
        lv1 = bb.emit(tvm.relax.op.nn.relu(lv0))
        lv2 = bb.emit(
            tvm.relax.op.matmul(lv1, tvm.relax.op.permute_dims(fc2_weight, axes=(1, 0)))
            + fc2_bias
        )
        lv3 = bb.emit(tvm.relax.op.nn.silu(x))
        bb.emit_output(lv3)
    bb.emit_func_output(lv3)

mod = bb.get()
showmod(mod)

# 算子合法化（将高层算子如 matmul 转换为 TIR）
mod = LegalizeOps()(mod)

showmod(mod)

In [None]:
@T.prim_func
def tir_linear(x: T.handle, w: T.handle, b: T.handle, z: T.handle):
    M, K = T.int64(), T.int64()
    X = T.match_buffer(x, (M, K))
    N = T.int64()
    W = T.match_buffer(w, (N, K))
    B = T.match_buffer(b, (N,))
    Z = T.match_buffer(z, (M, N))
    # with T.block("root"):
    for i, j, k in T.grid(M, N, K):
        with T.block("linear"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            T.reads(X[vi, vk], W[vj, vk])
            T.writes(Z[vi, vj])
            with T.init():
                Z[vi, vj] = T.float32(0.0)
            Z[vi, vj] = Z[vi, vj] + X[vi, vk] * W[vj, vk]
    for i, j in T.grid(M, N):
        with T.block("add"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads(Z[vi, vj], B[vj])
            T.writes(Z[vi, vj])
            Z[vi, vj] = Z[vi, vj] + B[vj]


bb = tvm.relax.BlockBuilder()
with bb.function("forward", [x, fc1_weight, fc1_bias, fc2_weight, fc2_bias]):
    with bb.dataflow():
        lv0 = bb.emit(
            tvm.relax.call_dps_packed(
                "env.linear",
                [x, fc1_weight, fc1_bias],
                out_sinfo=tvm.relax.TensorStructInfo((n, 128), "float32"),
            )
        )
        lv1 = bb.emit_te(tvm.topi.nn.relu, lv0)
        tir_gv = bb.add_func(tir_linear, "tir_linear")
        gv = bb.emit(
            tvm.relax.call_tir(
                tir_gv,
                [lv1, fc2_weight, fc2_bias],
                out_sinfo=tvm.relax.TensorStructInfo((n, 10), "float32"),
            )
        )
        bb.emit_output(gv)
    bb.emit_func_output(gv)
mod = bb.get()
showmod(mod)


# 算子合法化（将高层算子如 matmul 转换为 TIR）
mod = LegalizeOps()(mod)

showmod(mod)

# Transformation on Relax

In [None]:
import tvm
from tvm import IRModule, relax
from tvm.relax.frontend import nn


class NNModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        return x


origin_mod, params = NNModule().export_tvm(
    {"forward": {"x": nn.spec.Tensor(("n", 784), "float32")}}
)
showmod(origin_mod)

mod = tvm.relax.transform.LegalizeOps()(origin_mod)
showmod(mod)

In [None]:
mod = tvm.ir.transform.Sequential(
    [
        tvm.relax.transform.AnnotateTIROpPattern(),
        tvm.relax.transform.FuseOps(),
        tvm.relax.transform.FuseTIR(),
    ]
)(mod)
showmod(mod)

In [None]:
from tvm.relax.expr_functor import PyExprMutator, mutator


@mutator
class ReluRewriter(PyExprMutator):
    def __init__(self, mod):
        super().__init__(mod)

    def visit_call_(self, call: relax.Call) -> relax.Expr:
        # visit the relax.Call expr, and only handle the case when op is relax.nn.relu
        if call.op.name == "relax.nn.relu":
            return relax.op.nn.gelu(call.args[0])

        return super().visit_call_(call)


@tvm.transform.module_pass(opt_level=0, name="ReluToGelu")
class ReluToGelu:  # pylint: disable=too-few-public-methods
    def transform_module(
        self, mod: IRModule, _ctx: tvm.transform.PassContext
    ) -> IRModule:
        """IRModule-level transformation"""
        rewriter = ReluRewriter(mod)
        for g_var, func in mod.functions_items():
            if isinstance(func, relax.Function):
                func = rewriter.visit_expr(func)
                rewriter.builder_.update_func(g_var, func)
        return rewriter.builder_.get()


mod = ReluToGelu()(origin_mod)
showmod(mod)

In [None]:
mod = tvm.relax.transform.LegalizeOps()(mod)
mod = tvm.ir.transform.Sequential(
    [
        tvm.relax.transform.AnnotateTIROpPattern(),
        tvm.relax.transform.FuseOps(),
        tvm.relax.transform.FuseTIR(),
    ]
)(mod)
showmod(mod)