# TensorIR 实践

## 1.准备阶段

In [1]:
!python3 -m pip install mlc-ai-nightly -f https://mlc.ai/wheels

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://mlc.ai/wheels
Collecting mlc-ai-nightly
  Downloading https://github.com/mlc-ai/utils/releases/download/v0.9.dev0/mlc_ai_nightly-0.9.dev1664%2Bg1f3985de0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (43.3 MB)
[K     |████████████████████████████████| 43.3 MB 501 kB/s 
Collecting synr==0.6.0
  Downloading synr-0.6.0-py3-none-any.whl (18 kB)
Installing collected packages: synr, mlc-ai-nightly
Successfully installed mlc-ai-nightly-0.9.dev1664+g1f3985de0 synr-0.6.0


In [2]:
import IPython
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T

In [3]:
import IPython

def code2html(code):
    """Helper function to use pygments to turn the code string into highlighted html."""
    import pygments
    from pygments.lexers import Python3Lexer
    from pygments.formatters import HtmlFormatter
    formatter = HtmlFormatter()
    html = pygments.highlight(code, Python3Lexer(), formatter)
    return "<style>%s</style>%s\n" % (formatter.get_style_defs(".highlight"), html)

## 练习 1 广播加法

In [11]:
# init data
a = np.arange(16).reshape(4, 4)
print(a)
b = np.arange(4, 0, -1).reshape(4)
print(b[0])

[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]]
4


In [5]:
# numpy version
c_np = a + b
c_np

array([[ 4,  4,  4,  4],
       [ 8,  8,  8,  8],
       [12, 12, 12, 12],
       [16, 16, 16, 16]])

In [13]:
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A:T.Buffer[(4,4), "int64"],
          B:T.Buffer[(4,), "int64"],
          C:T.Buffer[(4,4), "int64"]):
    T.func_attr({"global_symbol": "add", "tir.noalias": True})
    for i, j in T.grid(4,4):
        with T.block("C"):
          vi = T.axis.spatial(4, i)
          vj = T.axis.spatial(4, j)
          C[vi, vj] = A[vi, vj] + B[vj]
    

rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)

## 练习2 二维卷积

In [54]:
N, CI, H, W, CO, K = 1, 1, 8, 8, 2, 3
OUT_H, OUT_W = H - K + 1, W - K + 1
print(OUT_H)
print(OUT_W)
data = np.arange(N*CI*H*W).reshape(N, CI, H, W)
print(data)
print(data.shape)
weight = np.arange(CO*CI*K*K).reshape(CO, CI, K, K)
print(weight)
print(weight.shape)

out = np.zeros(1*2*6*6).reshape(1, 2, 6, 6)
print(out)
print(out.shape)

6
6
[[[[ 0  1  2  3  4  5  6  7]
   [ 8  9 10 11 12 13 14 15]
   [16 17 18 19 20 21 22 23]
   [24 25 26 27 28 29 30 31]
   [32 33 34 35 36 37 38 39]
   [40 41 42 43 44 45 46 47]
   [48 49 50 51 52 53 54 55]
   [56 57 58 59 60 61 62 63]]]]
(1, 1, 8, 8)
[[[[ 0  1  2]
   [ 3  4  5]
   [ 6  7  8]]]


 [[[ 9 10 11]
   [12 13 14]
   [15 16 17]]]]
(2, 1, 3, 3)
[[[[0. 0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0. 0.]]

  [[0. 0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0. 0.]]]]
(1, 2, 6, 6)


In [None]:
def convol(A: np.ndarray, B: np.ndarray, C: np.ndarray):

    for i in range(1): 
        for j in range(2):
          for k in range(6):
            for l in range(6):
              cur  = 0
              for m in range(3):
                for n in range(3):
                  # print(A[i, i, k + m, l + n])
                  # print(B[j, i, m, n])
                  cur += A[i, i, k + m, l + n] * B[j, i, m, n]
                  # print(cur)
              C[i, j, k ,l] = cur

                
           

convol(data, weight, out)
print(out)
print(out.shape)

In [None]:
out1 = np.zeros(1*2*6*6).reshape(1, 2, 6, 6)

def convol_1(A: np.ndarray, B: np.ndarray, C: np.ndarray):

    for i in range(1): 
        for j in range(2):
          for k in range(6):
            for l in range(6):
              for m in range(3):
                for n in range(3):
                  C[i, j, k ,l] += A[i, i, k + m, l + n] * B[j, i, m, n]
                  
              

convol_1(data, weight, out1)
print(out1)
print(out1.shape)

In [31]:
# torch version
import torch

data_torch = torch.Tensor(data)
weight_torch = torch.Tensor(weight)
conv_torch = torch.nn.functional.conv2d(data_torch, weight_torch)
conv_torch = conv_torch.numpy().astype(np.int64)
conv_torch
print(conv_torch)
print(conv_torch.shape)

[[[[ 474  510  546  582  618  654]
   [ 762  798  834  870  906  942]
   [1050 1086 1122 1158 1194 1230]
   [1338 1374 1410 1446 1482 1518]
   [1626 1662 1698 1734 1770 1806]
   [1914 1950 1986 2022 2058 2094]]

  [[1203 1320 1437 1554 1671 1788]
   [2139 2256 2373 2490 2607 2724]
   [3075 3192 3309 3426 3543 3660]
   [4011 4128 4245 4362 4479 4596]
   [4947 5064 5181 5298 5415 5532]
   [5883 6000 6117 6234 6351 6468]]]]
(1, 2, 6, 6)


In [52]:
@tvm.script.ir_module
class MyConv:
  @T.prim_func
  def conv(A:T.Buffer[(1, 1, 8, 8),"int64"],
        B:T.Buffer[(2, 1, 3, 3),"int64"],
        C:T.Buffer[(1, 2, 6, 6),"int64"]):
    T.func_attr({"global_symbol": "conv", "tir.noalias": True})
    for i, j, k, l, m, n in T.grid(1, 2, 6, 6, 3, 3):
      with T.block("C"):
        vi, vj, vk, vl, vm, vn = T.axis.remap("SSSSRR",[i, j, k, l, m, n])
        with T.init():
          C[vi, vj, vk, vl]= T.int64(0)
        C[vi, vj, vk, vl] = C[vi, vj, vk, vl] + A[vi, vi, vk + vm, vl + vn] *  B[vj, vi, vm, vn]



rt_lib = tvm.build(MyConv, target="llvm")
data_tvm = tvm.nd.array(data)
weight_tvm = tvm.nd.array(weight)
conv_tvm = tvm.nd.array(np.empty((N, CO, OUT_H, OUT_W), dtype=np.int64))
rt_lib["conv"](data_tvm, weight_tvm, conv_tvm)
print(conv_tvm)
np.testing.assert_allclose(conv_tvm.numpy(), conv_torch, rtol=1e-5)

[[[[ 474  510  546  582  618  654]
   [ 762  798  834  870  906  942]
   [1050 1086 1122 1158 1194 1230]
   [1338 1374 1410 1446 1482 1518]
   [1626 1662 1698 1734 1770 1806]
   [1914 1950 1986 2022 2058 2094]]

  [[1203 1320 1437 1554 1671 1788]
   [2139 2256 2373 2490 2607 2724]
   [3075 3192 3309 3426 3543 3660]
   [4011 4128 4245 4362 4479 4596]
   [4947 5064 5181 5298 5415 5532]
   [5883 6000 6117 6234 6351 6468]]]]


**带参数的T.Buffer 似乎目前还不带支持**

In [None]:
@tvm.script.ir_module
class MyConv:
  @T.prim_func
  def conv(A:T.Buffer[(N, CI, H, W),"int64"],
        B:T.Buffer[(CO, CI, K, K),"int64"],
        C:T.Buffer[(N, CO, OUT_H, OUT_W),"int64"]):
    T.func_attr({"global_symbol": "conv", "tir.noalias": True})
    for i, j, k, l, m, n in T.grid(N, CO, OUT_H, OUT_W, K, K):
      with T.block("C"):
        vi, vj, vk, vl, vm, vn = T.axis.remap("SSSSRR",[i, j, k, l, m, n])
        with T.init():
          C[vi, vj, vk, vl]= T.int64(0)
        C[vi, vj, vk, vl] = C[vi, vj, vk, vl] + A[vi, vi, vk + vm, vl + vn] *  B[vj, vi, vm, vn]



rt_lib = tvm.build(MyConv, target="llvm")
data_tvm = tvm.nd.array(data)
weight_tvm = tvm.nd.array(weight)
conv_tvm = tvm.nd.array(np.empty((N, CO, OUT_H, OUT_W), dtype=np.int64))
rt_lib["conv"](data_tvm, weight_tvm, conv_tvm)
print(conv_tvm)
np.testing.assert_allclose(conv_tvm.numpy(), conv_torch, rtol=1e-5)

## 练习3 变换批量矩阵乘法程序

In [4]:
def lnumpy_mm_relu_v2(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((16, 128, 128), dtype="float32")
    for n in range(16):
        for i in range(128):
            for j in range(128):
                for k in range(128):
                    if k == 0:
                        Y[n, i, j] = 0
                    Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
    for n in range(16):
        for i in range(128):
            for j in range(128):
                C[n, i, j] = max(Y[n, i, j], 0)

**mm_relu 的TVM实现**

In [49]:
@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def mm_relu(A: T.Buffer[(128, 128), "float32"],
                B: T.Buffer[(128, 128), "float32"],
                C: T.Buffer[(128, 128), "float32"]):
        T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                vk = T.axis.reduce(128, k)
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))

**Bmm_relu 的TVM实现**

In [48]:
@tvm.script.ir_module
class MyBmmRelu:
  @T.prim_func
  def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"],
               B: T.Buffer[(16, 128, 128), "float32"],
               C: T.Buffer[(16, 128, 128), "float32"]):
    T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
    Y = T.alloc_buffer((16, 128, 128))
    for n ,i, j, k in T.grid(16, 128, 128, 128):
      with T.block("Y"):
        vn, vi, vj ,vk = T.axis.remap("SSSR", [n, i, j, k])
        with T.init():
            Y[vn, vi, vj] = T.float32(0)
        Y[vn, vi, vj] = Y[vn, vi, vj] + A[vn, vi ,vk] * B[vn, vk, vj]
    
    for n, i, j in T.grid(16, 128, 128):
      with T.block("C"):
        vn, vi ,vj = T.axis.remap("SSS", [n, i, j])
        C[vn, vi, vj]  = T.max(Y[vn, vi, vj], T.float32(0))
    

sch = tvm.tir.Schedule(MyBmmRelu)
IPython.display.HTML(code2html(sch.mod.script()))
# Also please validate your result

**测试写的TVM版本是否正确**

In [50]:
a = np.random.rand(16, 128, 128)
b = np.random.rand(16, 128, 128)
c_1 = np.random.rand(16, 128, 128)
c_2 = c_1
lnumpy_mm_relu_v2(a, b, c_1)
# print(c_1)

before_rt_lib = tvm.build(MyBmmRelu, target="llvm")
a_tvm = tvm.nd.array(a.astype("float32"))
b_tvm = tvm.nd.array(b.astype("float32"))
c_tvm = tvm.nd.array(c_2.astype("float32"))
before_rt_lib["bmm_relu"](a_tvm, b_tvm, c_tvm)
# print(c_tvm)

np.testing.assert_allclose(c_tvm.numpy(), c_1, rtol=1e-5)

**并行化 向量化 与循环展开**

In [57]:
@tvm.script.ir_module
class MyBmmRelu:
  @T.prim_func
  def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"],
               B: T.Buffer[(16, 128, 128), "float32"],
               C: T.Buffer[(16, 128, 128), "float32"]):
    T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
    Y = T.alloc_buffer((16, 128, 128))
    for n ,i, j, k in T.grid(16, 128, 128, 128):
      with T.block("Y"):
        vn, vi, vj ,vk = T.axis.remap("SSSR", [n, i, j, k])
        with T.init():
            Y[vn, vi, vj] = T.float32(0)
        Y[vn, vi, vj] = Y[vn, vi, vj] + A[vn, vi ,vk] * B[vn, vk, vj]
    
    for n, i, j in T.grid(16, 128, 128):
      with T.block("C"):
        vn, vi ,vj = T.axis.remap("SSS", [n, i, j])
        C[vn, vi, vj]  = T.max(Y[vn, vi, vj], T.float32(0))
    

sch = tvm.tir.Schedule(MyBmmRelu)
IPython.display.HTML(code2html(sch.mod.script()))

In [126]:
sch = tvm.tir.Schedule(MyBmmRelu)

# Step 1. Get blocks 首先对Y进行拆解
Y = sch.get_block("Y", func_name="bmm_relu")

# Step 2. Get loops
n, i, j, k = sch.get_loops(Y)


# Step 3. Organize the loops 
  
  # 由结果可知n, i 不变 |  将 j 分为 16 * 8 
j0, j1 = sch.split(j, factors = [None, 8])
  #根据结果重新排序
sch.reorder(n, i, j0, k, j1)
n, i, j0, k, j1 = sch.get_loops(Y)
# print(n)
  #并行化    
sch.parallel(n)

  # 将 C 移动到 Y内部
C = sch.get_block("C", "bmm_relu")


sch.reverse_compute_at(C, j0)




# Step 4. decompose reduction 将初始化与规约分开
sch.decompose_reduction(Y, k)

# Step 5. vectorize / parallel / unroll

Y_init = sch.get_block("Y_init", "bmm_relu")
ax0_init = sch.get_loops(Y_init)
# print(ax0_init)
sch.vectorize(ax0_init[3])

  # C
C = sch.get_block("C", "bmm_relu")
ax0 = sch.get_loops(C)
# print(ax0[3])
sch.vectorize(ax0[3])


  # Y_update
k0, k1 = sch.split(k, factors = [None, 4])
sch.unroll(k1)

  # sch.vectorize(j1)

  




IPython.display.HTML(code2html(sch.mod.script()))

In [127]:
@tvm.script.ir_module
class TargetModule:
    @T.prim_func
    def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"], B: T.Buffer[(16, 128, 128), "float32"], C: T.Buffer[(16, 128, 128), "float32"]) -> None:
        T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        Y = T.alloc_buffer([16, 128, 128], dtype="float32")
        for i0 in T.parallel(16): # n = i0
            for i1, i2_0 in T.grid(128, 16): # i1 = i , j0 = i2_0
                #初始化Y 
                for ax0_init in T.vectorized(8):
                    with T.block("Y_init"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + ax0_init)
                        Y[n, i, j] = T.float32(0)
                #计算Y
                for ax1_0 in T.serial(32):  # [ax1_0,ax1_1] =[ 32 ,4] k 
                    for ax1_1 in T.unroll(4):
                        for ax0 in T.serial(8): # j1 = ax0
                            with T.block("Y_update"):
                                n, i = T.axis.remap("SS", [i0, i1])
                                j = T.axis.spatial(128, i2_0 * 8 + ax0)
                                k = T.axis.reduce(128, ax1_0 * 4 + ax1_1)
                                Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
                #Relu计算
                for i2_1 in T.vectorized(8):
                    with T.block("C"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + i2_1)
                        C[n, i, j] = T.max(Y[n, i, j], T.float32(0))

tvm.ir.assert_structural_equal(sch.mod, TargetModule)
print("Pass")

Pass


In [128]:
before_rt_lib = tvm.build(MyBmmRelu, target="llvm")
after_rt_lib = tvm.build(sch.mod, target="llvm")
a_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
b_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
c_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
after_rt_lib["bmm_relu"](a_tvm, b_tvm, c_tvm)
before_timer = before_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("Before transformation:")
print(before_timer(a_tvm, b_tvm, c_tvm))

f_timer = after_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("After transformation:")
print(f_timer(a_tvm, b_tvm, c_tvm))

Before transformation:
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  53.4332      53.4332      53.4332      53.4332       0.0000   
               
After transformation:
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  15.1879      15.1879      15.1879      15.1879       0.0000   
               
