In [None]:
import tvm

import tvm.te as te

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R


def showmod(mod: tvm.ir.module.IRModule):
    mod.show(
        black_format=True,
        show_meta=False,
        verbose_expr=True,
        show_object_address=False,
        show_all_struct_info=True,
    )


def createandshowmod(ops):
    te_func = te.create_prim_func(ops).with_attrs({"global_symbol": "test"})
    mod = tvm.IRModule({"test": te_func})
    showmod(mod)
    return mod


from tvm.ir.analysis import *

#### collect_call_map

In [None]:
dtype = "float32"


# Constructing an End to End IRModule in TVMScript
@tvm.script.ir_module
class Network:
    @T.prim_func
    # T.handle creates a TIR var that represents a pointer.
    def relu0(x: T.handle, y: T.handle):
        n = T.int64()
        X = T.match_buffer(param=x, shape=(1, n), dtype=dtype)
        Y = T.match_buffer(param=y, shape=(1, n), dtype=dtype)
        for i, j in T.grid(1, n):
            with T.block("Y"):
                vi, vj = T.axis.remap(kinds="SS", bindings=[i, j])
                Y[vi, vj] = T.max(X[vi, vj], T.float32(0))

    @T.prim_func
    def linear0(x: T.handle, w: T.handle, b: T.handle, z: T.handle):
        m, n, k = T.int64(), T.int64(), T.int64()
        """
        |--|         |---------|
        |  |         |         |
        |  | m  X  n |         |
        |  |         |         |
        |--|         |---------|
          1               m
        """
        X = T.match_buffer(param=x, shape=(1, m), dtype=dtype)
        W = T.match_buffer(param=w, shape=(n, m), dtype=dtype)
        B = T.match_buffer(param=b, shape=(n), dtype=dtype)
        Z = T.match_buffer(param=z, shape=(1, n), dtype=dtype)
        Y = T.alloc_buffer(shape=(1, n), dtype=dtype)
        for i, j, k in T.grid(1, n, m):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap(kinds="SSR", bindings=[i, j, k])
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + X[vi, vk] * W[vj, vk]
        for i, j in T.grid(1, n):
            with T.block("Z"):
                vi, vj = T.axis.remap(kinds="SS", bindings=[i, j])
                Z[vi, vj] = B[vj] + Y[vi, vj]

    @R.function
    def main(
        x: R.Tensor((1, "m"), "float32"),
        w0: R.Tensor(("n", "m"), "float32"),
        b0: R.Tensor(("n",), "float32"),
        w1: R.Tensor(("k", "n"), "float32"),
        b1: R.Tensor(("k",), "float32"),
    ):
        m, k, n = T.int64(), T.int64(), T.int64()
        with R.dataflow():
            lv0 = R.call_dps_packed(
                func="linear0", args=(x, w0, b0), out_sinfo=R.Tensor((1, n), "float32")
            )
            lv1 = R.call_dps_packed(
                func="relu0", args=(lv0), out_sinfo=R.Tensor((1, n), "float32")
            )
            lv2 = R.call_dps_packed(
                func="linear0",
                args=(lv1, w1, b1),
                out_sinfo=R.Tensor((1, k), "float32"),
            )
            R.output(lv2)

        return lv2


mod = Network
showmod(mod)

In [None]:
"""Collect the call map of a module

Parameters
----------
module: tvm.ir.IRModule
    The module to inspect

Returns
-------
call_map: Dict[tvm.ir.GlobalVar, List[tvm.ir.GlobalVar]]
    A map from functions to the subroutines they call.
"""

call_map = collect_call_map(mod)
print(call_map)

{I.GlobalVar("relu0"): [], I.GlobalVar("linear0"): [], I.GlobalVar("main"): []}


#### test_collect_relax_to_relax

In [None]:
@I.ir_module
class Module:
    @R.function
    def main():
        return Module.subroutine()

    @R.function
    def subroutine():
        return R.tuple()

call_map = collect_call_map(Module)
print(call_map)

{I.GlobalVar("main"): [I.GlobalVar("subroutine")], I.GlobalVar("subroutine"): []}


#### test_collect_relax_to_tir

In [None]:
@I.ir_module
class Module:
    @R.function
    def main() -> R.Prim("int32"):
        return Module.subroutine(R.prim_value(T.int32(42)))

    @T.prim_func
    def subroutine(i: T.int32) -> T.int32:
        return i + 1

call_map = collect_call_map(Module)
print(call_map)

{I.GlobalVar("main"): [I.GlobalVar("subroutine")], I.GlobalVar("subroutine"): []}


#### test_collect_tir_to_tir

In [None]:
@I.ir_module
class Module:
    @T.prim_func
    def main() -> T.int32:
        return Module.subroutine(42)

    @T.prim_func
    def subroutine(i: T.int32) -> T.int32:
        return i + 1

call_map = collect_call_map(Module)
print(call_map)

{I.GlobalVar("main"): [I.GlobalVar("subroutine")], I.GlobalVar("subroutine"): []}
