In [41]:
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T      # TensorIR 是标准机器学习编译框架 Apache TVM 中使用的张量程序抽象

### 使用张量程序抽象的主要目的是表示循环和相关的硬件加速选择，如多线程、特殊硬件指令的使用和内存访问。

In [42]:
### numpy's mm_relu 
dtype = "float32"
a_np = np.random.rand(128, 128).astype(dtype)
b_np = np.random.rand(128, 128).astype(dtype)
# a @ b is equivalent to np.matmul(a, b)
c_mm_relu = np.maximum(a_np @ b_np, 0)

In [43]:
### python 

def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((128, 128), dtype="float32")
    for i in range(128):
        for j in range(128):
            for k in range(128):
                if k == 0:
                    Y[i, j] = 0
                Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
    for i in range(128):
        for j in range(128):
            C[i, j] = max(Y[i, j], 0)

c_np = np.empty((128, 128), dtype=dtype)
lnumpy_mm_relu(a_np, b_np, c_np)
np.testing.assert_allclose(c_mm_relu, c_np, rtol=1e-5)

In [44]:
# TVMScript 

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def mm_relu(A: T.Buffer((128, 128), "float32"),
                B: T.Buffer((128, 128), "float32"),
                C: T.Buffer((128, 128), "float32")):
        T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})   ## what does this mean ?  tir.noalias 是一个属性，表示所有的缓冲存储器不重叠
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                # vi = T.axis.spatial(128, i)   # parallel appear in the output
                # vj = T.axis.spatial(128, j)
                # vk = T.axis.reduce(128, k)    # redunction, means summation, does not appear in the output!
                # SSR means the properties of each axes are "spatial", "spatial", "reduce"
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])  # more clean declaration!
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))

In [45]:
## another example with two functions

@tvm.script.ir_module
class MyModuleWithTwoFunctions:
    @T.prim_func
    def mm(A: T.Buffer((128, 128), "float32"),
           B: T.Buffer((128, 128), "float32"),
           Y: T.Buffer((128, 128), "float32")):
        T.func_attr({"global_symbol": "mm", "tir.noalias": True})
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]

    @T.prim_func
    def relu(A: T.Buffer((128, 128), "float32"),
             B: T.Buffer((128, 128), "float32")):
        T.func_attr({"global_symbol": "relu", "tir.noalias": True})
        for i, j in T.grid(128, 128):
            with T.block("B"):
                vi, vj = T.axis.remap("SS", [i, j])
                B[vi, vj] = T.max(A[vi, vj], T.float32(0))

In [46]:
import IPython

IPython.display.Code(MyModule.script(), language="python")   ### 查看代码

In [47]:
sch = tvm.tir.Schedule(MyModule)   # 创建一个以给定的 MyModule 作为输入的 Schedule 辅助类。
block_Y = sch.get_block("Y", func_name="mm_relu")  # 获得对块 Y 和相应循环的引用。
i, j, k = sch.get_loops(block_Y) 

# 我们将执行的第一个变换是将循环 j 分成两个循环，其中内部循环的长度为 4
j0, j1 = sch.split(j, factors=[None, 4])

IPython.display.Code(sch.mod.script(), language="python")

In [48]:
sch.reorder(j0, k, j1)  # 循环重排
#IPython.display.Code(sch.mod.script(), language="python")

In [49]:
# 使用名为 reverse_compute_at 的原语将块 C 移动到 Y 的内循环里。 这一步是融合算子 ！ 
block_C = sch.get_block("C", "mm_relu")
sch.reverse_compute_at(block_C, j0)
IPython.display.Code(sch.mod.script(), language="python")

In [50]:
sch.decompose_reduction(block_Y, k)  # 将 Y 元素的初始化与归约更新分开。
IPython.display.Code(sch.mod.script(), language="python")

## 构建与运行

In [51]:
rt_lib = tvm.build(MyModule, target="llvm")

In [52]:
a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.array(b_np)
c_nd = tvm.nd.empty((128, 128), dtype="float32")
type(c_nd)

tvm.runtime.ndarray.NDArray

In [53]:
func_mm_relu = rt_lib["mm_relu"]
func_mm_relu(a_nd, b_nd, c_nd)

np.testing.assert_allclose(c_mm_relu, c_nd.numpy(), rtol=1e-5)

In [55]:
rt_lib_after = tvm.build(sch.mod, target="llvm")
rt_lib_after["mm_relu"](a_nd, b_nd, c_nd)
np.testing.assert_allclose(c_mm_relu, c_nd.numpy(), rtol=1e-5)

In [62]:
f_timer_before = rt_lib.time_evaluator("mm_relu", tvm.cpu())
print("Time cost of MyModule %g sec" % f_timer_before(a_nd, b_nd, c_nd).mean)
f_timer_after = rt_lib_after.time_evaluator("mm_relu", tvm.cpu())
print("Time cost of transformed sch.mod %g sec" % f_timer_after(a_nd, b_nd, c_nd).mean)

Time cost of MyModule 0.00167354 sec
Time cost of transformed sch.mod 0.000434446 sec


# exercise 尝试使用不同的 factor 观察性能 

In [64]:
def transform(mod, jfactor):
    sch = tvm.tir.Schedule(mod)
    block_Y = sch.get_block("Y", func_name="mm_relu")
    i, j, k = sch.get_loops(block_Y)
    j0, j1 = sch.split(j, factors=[None, jfactor])
    sch.reorder(j0, k, j1)
    block_C = sch.get_block("C", "mm_relu")
    sch.reverse_compute_at(block_C, j0)
    return sch.mod

for jfactor in [1,2,4,8,16,32,64]:
    print("jfactor = ", jfactor)
    mod_transformed = transform(MyModule, jfactor=jfactor)

    rt_lib_transformed = tvm.build(mod_transformed, "llvm")
    f_timer_transformed = rt_lib_transformed.time_evaluator("mm_relu", tvm.cpu())
    print("Time cost of transformed mod_transformed %g sec" % f_timer_transformed(a_nd, b_nd, c_nd).mean)
    # display the code below
    # IPython.display.Code(mod_transformed.script(), language="python")

jfactor =  1
Time cost of transformed mod_transformed 0.00155437 sec
jfactor =  2
Time cost of transformed mod_transformed 0.000888377 sec
jfactor =  4
Time cost of transformed mod_transformed 0.000441852 sec
jfactor =  8
Time cost of transformed mod_transformed 0.000223969 sec
jfactor =  16
Time cost of transformed mod_transformed 0.000178636 sec
jfactor =  32
Time cost of transformed mod_transformed 0.000169721 sec
jfactor =  64
Time cost of transformed mod_transformed 0.000226198 sec


## 更高层级的抽象 

### 使用 TE 生成 ! 

In [66]:
from tvm import te

A = te.placeholder((128, 128), "float32", name="A")
B = te.placeholder((128, 128), "float32", name="B")
k = te.reduce_axis((0, 128), "k")
Y = te.compute((128, 128), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="Y")
C = te.compute((128, 128), lambda i, j: te.max(Y[i, j], 0), name="C")

In [67]:
te_func = te.create_prim_func([A, B, C]).with_attr({"global_symbol": "mm_relu"})
MyModuleFromTE = tvm.IRModule({"mm_relu": te_func})
IPython.display.Code(MyModuleFromTE.script(), language="python")