In [1]:
# tensorIR 练习
# https://mlc.ai/zh/chapter_tensor_program/tensorir_exercises.html#id1

In [2]:
# 练习1：请编写一个 TensorIR 函数，将两个数组以广播的方式相加。

In [3]:
import IPython
import numpy as np
import tvm
from tvm.script import tir as T
from tvm.ir.module import IRModule

In [4]:
# init data

a = np.arange(16).reshape(4, 4)
b = np.arange(4, 0, -1).reshape(4)

# numpy version
c_np = a + b
c_np

array([[ 4,  4,  4,  4],
       [ 8,  8,  8,  8],
       [12, 12, 12, 12],
       [16, 16, 16, 16]])

In [5]:
# python version
def lnumpy_broadcast_add(a: np.ndarray, b: np.ndarray, c: np.ndarray):
    for i in range(4):
        for j in range(4):
            c[i, j] = 0
            c[i, j] = a[i, j] + b[j]
#test lnumpy_broadcast_add
c_lbro_add = np.zeros((4, 4))
lnumpy_broadcast_add(a, b, c_lbro_add)
c_lbro_add

array([[ 4.,  4.,  4.,  4.],
       [ 8.,  8.,  8.,  8.],
       [12., 12., 12., 12.],
       [16., 16., 16., 16.]])

In [6]:
#请完成以下 IRModule MyAdd 并运行代码以检查你的实现。
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer[(4, 4), "int64"],
          B: T.Buffer[(4), "int64"],
          C: T.Buffer[(4, 4), "int64"]):
    T.func_attr({"global_symbol": "add", "tir.noalias": True})
    # TODO
    for i, j in T.grid(4, 4):
        with T.block("C"):
            vi = T.axis.spatial(4, i)
            vj = T.axis.spatial(4, j)
            C[vi, vj] = A[vi, vj] + B[vj]

rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)

In [7]:
# 练习2：二维卷积
# 其中，
# A 是输入张量，W 是权重张量，b 是批次索引，k 是输出通道， i 和 j 是图像高度和宽度的索引，di 和 dj 是权重的索引 
# q 是输入通道，strides 是过滤器窗口的步幅
# 在练习中，我们选择了一个小而简单的情况，即 stride=1, padding=0。

In [8]:
N, CI, H, W, CO, K = 1, 1, 8, 8, 2, 3
OUT_H, OUT_W = H - K + 1, W - K + 1
# data-shape [1, 1, 8, 8]
data = np.arange(N*CI*H*W).reshape(N, CI, H, W)
# weight-shape [2, 1, 3, 3]
weight = np.arange(CO*CI*K*K).reshape(CO, CI, K, K)
print ("weight:", weight)
print ("data:",  data)

weight: [[[[ 0  1  2]
   [ 3  4  5]
   [ 6  7  8]]]


 [[[ 9 10 11]
   [12 13 14]
   [15 16 17]]]]
data: [[[[ 0  1  2  3  4  5  6  7]
   [ 8  9 10 11 12 13 14 15]
   [16 17 18 19 20 21 22 23]
   [24 25 26 27 28 29 30 31]
   [32 33 34 35 36 37 38 39]
   [40 41 42 43 44 45 46 47]
   [48 49 50 51 52 53 54 55]
   [56 57 58 59 60 61 62 63]]]]


In [9]:
# torch version
import torch

data_torch = torch.Tensor(data)
weight_torch = torch.Tensor(weight)
conv_torch = torch.nn.functional.conv2d(data_torch, weight_torch)
conv_torch = conv_torch.numpy().astype(np.int64)
conv_torch

array([[[[ 474,  510,  546,  582,  618,  654],
         [ 762,  798,  834,  870,  906,  942],
         [1050, 1086, 1122, 1158, 1194, 1230],
         [1338, 1374, 1410, 1446, 1482, 1518],
         [1626, 1662, 1698, 1734, 1770, 1806],
         [1914, 1950, 1986, 2022, 2058, 2094]],

        [[1203, 1320, 1437, 1554, 1671, 1788],
         [2139, 2256, 2373, 2490, 2607, 2724],
         [3075, 3192, 3309, 3426, 3543, 3660],
         [4011, 4128, 4245, 4362, 4479, 4596],
         [4947, 5064, 5181, 5298, 5415, 5532],
         [5883, 6000, 6117, 6234, 6351, 6468]]]])

In [10]:
# python version 
def lnumpy_conv(a_in: np.ndarray, b_in: np.ndarray, c_out: np.ndarray):
    for b in range(1):
        for k in range(2):
            for i in range(6):
                for j in range(6):
                    for q in range(1):
                        if q == 0:
                            c_out[b, k, i, j] = 0
                        for di in range(3):
                            for dj in range(3):
                                c_out[b,k,i,j] += a_in[b, q, i + di, j + dj] * b_in[k, q, di, dj]
                                
c_lnumpy_conv = np.zeros((1, 2, 6, 6))
lnumpy_conv(data, weight, c_lnumpy_conv)
c_lnumpy_conv

array([[[[ 474.,  510.,  546.,  582.,  618.,  654.],
         [ 762.,  798.,  834.,  870.,  906.,  942.],
         [1050., 1086., 1122., 1158., 1194., 1230.],
         [1338., 1374., 1410., 1446., 1482., 1518.],
         [1626., 1662., 1698., 1734., 1770., 1806.],
         [1914., 1950., 1986., 2022., 2058., 2094.]],

        [[1203., 1320., 1437., 1554., 1671., 1788.],
         [2139., 2256., 2373., 2490., 2607., 2724.],
         [3075., 3192., 3309., 3426., 3543., 3660.],
         [4011., 4128., 4245., 4362., 4479., 4596.],
         [4947., 5064., 5181., 5298., 5415., 5532.],
         [5883., 6000., 6117., 6234., 6351., 6468.]]]])

In [11]:
#请完成以下 IRModule MyConv 并运行代码以检查您的实现。
@tvm.script.ir_module
class MyConv:
  @T.prim_func
  def conv(A: T.Buffer[(1, 1, 8, 8), "int64"],
           B: T.Buffer[(2, 1, 3, 3), "int64"],
           C: T.Buffer[(1, 2, 6, 6), "int64"]):
    T.func_attr({"global_symbol": "conv", "tir.noalias": True})
    # TODO
    # 不能是 b，只能是b_0, 避免和T.Buffer->B发生命名冲突
    for b_0, k, i, j, q, di, dj in T.grid(1, 2, 6, 6, 1, 3, 3):
        with T.block("C"):
            vb_0, vk, vi, vj, vq, vdi, vdj = T.axis.remap("SSSSRRR",[b_0, k, i, j, q, di, dj])
            with T.init():
                C[vb_0, vk, vi, vj] = T.int64(0)
            C[vb_0, vk, vi, vj] = C[vb_0, vk, vi, vj] + A[vb_0, vq, vi + vdi, vj + vdj] * B[vk, vb_0, vdi, vdj]    
rt_lib = tvm.build(MyConv, target="llvm")
data_tvm = tvm.nd.array(data)
weight_tvm = tvm.nd.array(weight)
conv_tvm = tvm.nd.array(np.empty((N, CO, OUT_H, OUT_W), dtype=np.int64))
rt_lib["conv"](data_tvm, weight_tvm, conv_tvm)
np.testing.assert_allclose(conv_tvm.numpy(), conv_torch, rtol=1e-5)

In [12]:
#练习 3：变换批量矩阵乘法程序

In [13]:
def lnumpy_mm_relu_v2(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((16, 128, 128), dtype="float32")
    for n in range(16):
        for i in range(128):
            for j in range(128):
                for k in range(128):
                    if k == 0:
                        Y[n, i, j] = 0
                    Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
    for n in range(16):
        for i in range(128):
            for j in range(128):
                C[n, i, j] = max(Y[n, i, j], 0)

In [24]:
@tvm.script.ir_module
class MyBmmRelu:
  @T.prim_func
  def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"],
               B: T.Buffer[(16, 128, 128), "float32"],
               C: T.Buffer[(16, 128, 128), "float32"]):
        T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        # TODO
        Y = T.alloc_buffer([16, 128, 128], dtype="float32")
        for n, i, j, k in T.grid(16, 128, 128, 128):
            with T.block("Y"):
                vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k])
                with T.init():
                    Y[vn, vi, vj] = T.float32(0)
                Y[vn, vi, vj] = Y[vn, vi, vj] + A[vn, vi, vk] * B[vn, vk, vj]
        for n, i, j in T.grid(16, 128, 128):
            with T.block("C"):
                vn, vi, vj = T.axis.remap("SSS", [n, i, j])
                C[vn, vi, vj] = T.max(Y[vn, vi, vj], T.float32(0))
                                                      
sch = tvm.tir.Schedule(MyBmmRelu)
IPython.display.Code(sch.mod.script(), language="python")
# Also please validate your result

In [25]:
@tvm.script.ir_module
class TargetModule:
    @T.prim_func
    def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"], B: T.Buffer[(16, 128, 128), "float32"], C: T.Buffer[(16, 128, 128), "float32"]) -> None:
        T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        Y = T.alloc_buffer([16, 128, 128], dtype="float32")
        for i0 in T.parallel(16): #i0 => 16 => n
            for i1, i2_0 in T.grid(128, 16): #i1 => 128 => i ; i2_0 => j_0 => (16 * 8) => 16
                for ax0_init in T.vectorized(8): # ax0_init => (16 * 8) => 8 => j_1
                    with T.block("Y_init"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + ax0_init)
                        Y[n, i, j] = T.float32(0)
                for ax1_0 in T.serial(32):     #ax1_0 => 32 => k0
                    for ax1_1 in T.unroll(4):   #ax1_1 => 4 => k1
                        for ax0 in T.serial(8): #ax0 =》 (16 * 8) => 8 => j_1
                            with T.block("Y_update"):
                                n, i = T.axis.remap("SS", [i0, i1])
                                j = T.axis.spatial(128, i2_0 * 8 + ax0)
                                k = T.axis.reduce(128, ax1_0 * 4 + ax1_1)
                                Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
                for i2_1 in T.vectorized(8):
                    with T.block("C"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + i2_1)
                        C[n, i, j] = T.max(Y[n, i, j], T.float32(0))

In [28]:
sch = tvm.tir.Schedule(MyBmmRelu)
# TODO: transformations
# Hints: you can use
# `IPython.display.Code(sch.mod.script(), language="python")`
# or `print(sch.mod.script())`
# to show the current program at any time during the transformation.

# Step 1. Get blocks
Y = sch.get_block("Y", func_name="bmm_relu")

# Step 2. Get loops
n, i, j, k = sch.get_loops(Y)

# Step 3. Organize the loops
j0, j1 = sch.split(j, factors = [None, 8])
sch.reorder(n, i, j0, k, j1)
n, i, j0, k, j1 = sch.get_loops(Y)
sch.parallel(n)
C = sch.get_block("C", "bmm_relu")
sch.reverse_compute_at(C, j0)

# Step 4. decompose reduction
sch.decompose_reduction(Y, k)

# Step 5. vectorize / parallel / unroll
Y_init = sch.get_block("Y_init", "bmm_relu")
ax0_init = sch.get_loops(Y_init)
#ax0_init : [tir.LoopRV(0x2385450), tir.LoopRV(0x7a7fe90), tir.LoopRV(0x7a3ae30), tir.LoopRV(0x7941b40)]
sch.vectorize(ax0_init[3])

C = sch.get_block("C", "bmm_relu")
ax0 = sch.get_loops(C)
sch.vectorize(ax0[3])

k0, k1 = sch.split(k, factors = [None, 4])
sch.unroll(k1)

IPython.display.Code(sch.mod.script(), language="python")
tvm.ir.assert_structural_equal(sch.mod, TargetModule)
print("Pass")


Pass


In [29]:
# 评估变换后的程序的性能
before_rt_lib = tvm.build(MyBmmRelu, target="llvm")
after_rt_lib = tvm.build(sch.mod, target="llvm")
a_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
b_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
c_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
after_rt_lib["bmm_relu"](a_tvm, b_tvm, c_tvm)
before_timer = before_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("Before transformation:")
print(before_timer(a_tvm, b_tvm, c_tvm))

f_timer = after_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("After transformation:")
print(f_timer(a_tvm, b_tvm, c_tvm))

Before transformation:
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  37.2572      37.2572      37.2572      37.2572       0.0000   
               
After transformation:
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
   4.3325       4.3325       4.3325       4.3325       0.0000   
               
