# 快速上手 IRModule

In [1]:
from IPython.display import display_svg
import tvm
from tvm.script.parser import ir_module
from tvm.ir.module import IRModule
from tvm.script import tir as T
import numpy as np

In [2]:
@ir_module
class MyModule:
    @T.prim_func
    def main(a: T.handle, b: T.handle):
        # We exchange data between function by handles, which are similar to pointer.
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # Create buffer from handles.
        A = T.match_buffer(a, (8,), dtype="float32")
        B = T.match_buffer(b, (8,), dtype="float32")
        for i in range(8):
            # A block is an abstraction for computation.
            with T.block("B"):
                # Define a spatial block iterator and bind it to value i.
                vi = T.axis.spatial(8, i)
                B[vi] = A[vi] + 1.0


ir_module = MyModule
print(type(ir_module))
print(ir_module.script())

<class 'tvm.ir.module.IRModule'>
# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer[8, "float32"], B: T.Buffer[8, "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # body
        # with T.block("root")
        for i in T.serial(8):
            with T.block("B"):
                vi = T.axis.spatial(8, i)
                T.reads(A[vi])
                T.writes(B[vi])
                B[vi] = A[vi] + T.float32(1)
    



In [3]:
from tvm import te

A = te.placeholder((8,), dtype="float32", name="A")
B = te.compute((8,), lambda *i: A(*i) + 1.0, name="B")
func = te.create_prim_func([A, B])
ir_module_from_te = IRModule({"main": func})
print(ir_module_from_te.script())

# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer[8, "float32"], B: T.Buffer[8, "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # body
        # with T.block("root")
        for i0 in T.serial(8):
            with T.block("B"):
                i0_1 = T.axis.spatial(8, i0)
                T.reads(A[i0_1])
                T.writes(B[i0_1])
                B[i0_1] = A[i0_1] + T.float32(1)
    



In [4]:
ir_module["main"]

PrimFunc([a, b]) attrs={"global_symbol": "main", "tir.noalias": (bool)1} {
  block root() {
    reads([])
    writes([])
    for (i, 0, 8) {
      block B(iter_var(vi, range(min=0, ext=8))) {
        bind(vi, i)
        reads([A[vi]])
        writes([B[vi]])
        B[vi] = (A[vi] + 1f)
      }
    }
  }
}

In [5]:
mod = tvm.build(ir_module, target="llvm")  # 用于 CPU 后端的模块
print(type(mod))

<class 'tvm.driver.build_module.OperatorModule'>


In [6]:
a = tvm.nd.array(np.arange(8).astype("float32"))
b = tvm.nd.array(np.zeros((8,)).astype("float32"))
mod(a, b)
print(a)
print(b)

[0. 1. 2. 3. 4. 5. 6. 7.]
[1. 2. 3. 4. 5. 6. 7. 8.]


In [7]:
sch = tvm.tir.Schedule(ir_module)

In [8]:
# 按 name 获取 block
block_b = sch.get_block("B")
# 获取 block 周围的 loops
(i,) = sch.get_loops(block_b)
# 平铺（tile）循环嵌套。
i_0, i_1, i_2 = sch.split(i, factors=[2, 2, 2])
print(sch.mod.script())

# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer[8, "float32"], B: T.Buffer[8, "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # body
        # with T.block("root")
        for i_0, i_1, i_2 in T.grid(2, 2, 2):
            with T.block("B"):
                vi = T.axis.spatial(8, i_0 * 4 + i_1 * 2 + i_2)
                T.reads(A[vi])
                T.writes(B[vi])
                B[vi] = A[vi] + T.float32(1)
    



In [11]:
ir_module["main"]

PrimFunc([a, b]) attrs={"global_symbol": "main", "tir.noalias": (bool)1} {
  block root() {
    reads([])
    writes([])
    for (i, 0, 8) {
      block B(iter_var(vi, range(min=0, ext=8))) {
        bind(vi, i)
        reads([A[vi]])
        writes([B[vi]])
        B[vi] = (A[vi] + 1f)
      }
    }
  }
}

In [None]:
from tvm import te

A = te.placeholder((8,), dtype="float32", name="A")
B = te.compute((8,), lambda *i: A(*i) + 1.0, name="B")
func = te.create_prim_func([A, B])
ir_module_from_te = IRModule({"main": func})
print(ir_module_from_te.script())