In [91]:
import tvm

import tvm.te as te
import tvm.relax as rx
import tvm.tir as tir

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

from tvm.relax.binding_rewrite import DataflowBlockRewrite
from tvm.relax.analysis import name_to_binding


def showmod(mod: tvm.ir.module.IRModule):
    mod.show(
        black_format=True,
        show_meta=False,
        verbose_expr=True,
        show_object_address=False,
        show_all_struct_info=True,
    )


def createandshowmod(ops):
    te_func = te.create_prim_func(ops).with_attrs({"global_symbol": "test"})
    mod = tvm.IRModule({"test": te_func})
    showmod(mod)

In [92]:
# NOTE `pipeline` is a sequence of `transform.Pass`.
# The file `pipeline.py` pre-defines some pipelines.
# `zero_pipeline` is the basic pipeline that applies some fundamental passes.
# `default_build_pipeline` is the default pipeline of `tvm.compile`
# `static_shape_tuning_pipeline` is for tuning models with static shapes.

def test_zero_pipeline():

    @tvm.script.ir_module
    class MyModule:
        @R.function
        def matmul(x: R.Tensor((128,128), "float32"), y: R.Tensor((128,128), "float32")):
            z = R.matmul(x, y)
            return z

        @R.function
        def relu(z: R.Tensor((128,128), "float32")):
            z_relu = R.nn.relu(z)
            return z_relu

    showmod(MyModule)
    from tvm.relax.pipeline import zero_pipeline
    mod = zero_pipeline()(MyModule)
    # `zero_pipeline` will lower(legalize) the relax.function to tir.prim_func
    showmod(mod)

test_zero_pipeline()

In [93]:
# TODO @Benkangpeng: Finish the rest test after reading the source code of transform.py

dtype = "float32"

@tvm.script.ir_module
class Network:
    @T.prim_func
    # T.handle creates a TIR var that represents a pointer.
    def relu0(x: T.handle, y: T.handle):
        n = T.int64()
        X = T.match_buffer(param=x, shape=(1, n), dtype=dtype)
        Y = T.match_buffer(param=y, shape=(1, n), dtype=dtype)
        for i, j in T.grid(1, n):
            with T.block("Y"):
                vi, vj = T.axis.remap(kinds="SS", bindings=[i, j])
                Y[vi, vj] = T.max(X[vi, vj], T.float32(0))

    @T.prim_func
    def linear0(x: T.handle, w: T.handle, b: T.handle, z: T.handle):
        m, n, k = T.int64(), T.int64(), T.int64()
        """
        |--|         |---------|
        |  |         |         |
        |  | m  X  n |         |
        |  |         |         |
        |--|         |---------|
          1               m
        """
        X = T.match_buffer(param=x, shape=(1, m), dtype=dtype)
        W = T.match_buffer(param=w, shape=(n, m), dtype=dtype)
        B = T.match_buffer(param=b, shape=(n), dtype=dtype)
        Z = T.match_buffer(param=z, shape=(1, n), dtype=dtype)
        Y = T.alloc_buffer(shape=(1, n), dtype=dtype)
        for i, j, k in T.grid(1, n, m):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap(kinds="SSR", bindings=[i, j, k])
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + X[vi, vk] * W[vj, vk]
        # for i, j in T.grid(1, n):
        #     with T.block("Z"):
        #         vi, vj = T.axis.remap(kinds="SS", bindings=[i, j])
        #         Z[vi, vj] = B[vj] + Y[vi, vj]
        for m, n in T.grid(1, n):
            with T.block("Z"):
                vm, vn = T.axis.remap(kinds="SS", bindings=[m, n])
                Z[vm, vn] = B[vn] + Y[vm, vn]

    @R.function
    def main(
        x: R.Tensor((1, "m"), "float32"),
        w0: R.Tensor(("n", "m"), "float32"),
        b0: R.Tensor(("n",), "float32"),
        w1: R.Tensor(("k", "n"), "float32"),
        b1: R.Tensor(("k",), "float32"),
    ):
        m, k, n = T.int64(), T.int64(), T.int64()
        with R.dataflow():
            lv0 = R.call_dps_packed(
                func="linear0",
                args=(x, w0, b0),
                out_sinfo=R.Tensor((1, n), "float32"),
            )
            lv1 = R.call_dps_packed(
                func="relu0", args=(lv0), out_sinfo=R.Tensor((1, n), "float32")
            )
            lv2 = R.call_dps_packed(
                func="linear0",
                args=(lv1, w1, b1),
                out_sinfo=R.Tensor((1, k), "float32"),
            )
            R.output(lv2)
        return lv2


def test_default_build_pipeline():
    mod = Network
    showmod(mod)
    
    from tvm.relax.pipeline import default_build_pipeline
    mod = default_build_pipeline()(mod)
    showmod(mod)
    
test_default_build_pipeline()

In [94]:
def test_static_shape_tuning_pipeline():
    mod = Network
    showmod(mod)

    # The following pass is used in static_shape_tuning_pipeline().
    @tvm.transform.module_pass(opt_level=0)
    def _pipeline(
        mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext
    ) -> tvm.ir.IRModule:
        with tvm.target.Target(target="llvm"):
            mod = tvm.transform.Sequential(
                [
                    tvm.relax.transform.DecomposeOpsForInference(),
                    tvm.relax.transform.CanonicalizeBindings(),
                    tvm.relax.pipeline.zero_pipeline(),
                ]
            )(mod)

        return mod

    mod = _pipeline(mod)
    showmod(mod)


test_static_shape_tuning_pipeline()