# TensorIR 练习

In [1]:
import IPython
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T

## 第一节：如何编写 TensorIR
在本节中，让我们尝试根据高级指令（例如 Numpy 或 Torch）手动编写 TensorIR。首先，我们给出一个逐位相加函数的例子，来展示我们应该如何编写一个 TensorIR 函数。

### 示例：逐位相加
首先，让我们尝试使用 Numpy 编写一个逐位相加函数。

In [2]:
# init data
a = np.arange(16).reshape(4, 4)
b = np.arange(16, 0, -1).reshape(4, 4)

In [3]:
c_np = a + b
c_np

array([[16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16]])

在我们直接编写 TensorIR 之前，我们应该首先将高级计算抽象（例如，ndarray + ndarray）转换为低级 Python 实现（具有元素访问和操作的循环的标准）。

值得注意的是，输出数组（或缓冲区）的初始值并不总是 0。我们需要在我们的实现中编写或初始化它，这对于归约运算符（例如 matmul 和 conv）很重要。

In [4]:
# low level numpy
def low_level_np_add(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    for i in range(4):
        for j in range(4):
            C[i, j] = A[i, j] + B[i, j]
c_lnumpy = np.empty((4, 4), dtype="int64")
low_level_np_add(a, b, c_lnumpy)
c_lnumpy

array([[16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16]])

现在，让我们更进一步：将低级 NumPy 实现转换为 TensorIR，并将结果与来自 NumPy 的结果进行比较。

In [5]:
@tvm.script.ir_module
class MyAdd:
    @T.prim_func
    def add(A: T.Buffer[(4, 4), "int64"],
            B: T.Buffer[(4, 4), "int64"],
            C: T.Buffer[(4, 4), "int64"]):
        T.func_attr({"global_symbol":"add"})
        for i, j in T.grid(4, 4):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = A[vi, vj] + B[vi, vj]

rt_build = tvm.build(MyAdd, target="llvm")
a_nd = tvm.nd.array(a)
b_nd = tvm.nd.array(b)
c_nd = tvm.nd.empty((4, 4), dtype="int64")
rt_build["add"](a_nd, b_nd, c_nd)
c_nd

<tvm.nd.NDArray shape=(4, 4), cpu(0)>
array([[16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16]])

### 练习 1：广播加法
请编写一个 TensorIR 函数，将两个数组以广播的方式相加。

In [6]:
# init data
a = np.arange(16).reshape(4, 4)
b = np.arange(4, 0, -1).reshape(4)

In [7]:
# numpy version
c_np = a + b
c_np

array([[ 4,  4,  4,  4],
       [ 8,  8,  8,  8],
       [12, 12, 12, 12],
       [16, 16, 16, 16]])

In [8]:
# low level numpy
def low_level_np_brocast_add(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    for i in range(4):
        for j in range(4):
            C[i, j] = A[i, j] + B[j]

c_lnumpy = np.empty((4, 4), dtype=np.int64)
low_level_np_brocast_add(a, b, c_lnumpy)
c_lnumpy

array([[ 4,  4,  4,  4],
       [ 8,  8,  8,  8],
       [12, 12, 12, 12],
       [16, 16, 16, 16]])

In [9]:
@tvm.script.ir_module
class MyAdd:
    @T.prim_func
    # -------------------------- YOUR CODE -----------------------------
    def add(A: T.Buffer[(4, 4), "int64"],
          B: T.Buffer[(4,  ), "int64"],
          C: T.Buffer[(4, 4), "int64"]):
        T.func_attr({"global_symbol": "add", "tir.noalias": True})
        for i, j in T.grid(4, 4):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", (i, j))
                C[vi, vj] = A[vi, vj] + B[vj]
    # -------------------------- END OF YOUR CODE -----------------------------
    
rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)
c_tvm

<tvm.nd.NDArray shape=(4, 4), cpu(0)>
array([[ 4,  4,  4,  4],
       [ 8,  8,  8,  8],
       [12, 12, 12, 12],
       [16, 16, 16, 16]])

### 练习 2：二维卷积
stride = 1, padding = 0

In [10]:
"""
    N: batch size
    CI: Image input channel
    H: 图片 hight
    W: 图片 width
    CO: Image output channel
"""

N, CI, H, W, CO, K = 1, 1, 8, 8, 2, 3
OUT_H, OUT_W = H - K + 1, W - K + 1
data = np.arange(N*CI*H*W).reshape(N, CI, H, W)
weight = np.arange(CO*CI*K*K).reshape(CO, CI, K, K)
print(data)
print(weight)

[[[[ 0  1  2  3  4  5  6  7]
   [ 8  9 10 11 12 13 14 15]
   [16 17 18 19 20 21 22 23]
   [24 25 26 27 28 29 30 31]
   [32 33 34 35 36 37 38 39]
   [40 41 42 43 44 45 46 47]
   [48 49 50 51 52 53 54 55]
   [56 57 58 59 60 61 62 63]]]]
[[[[ 0  1  2]
   [ 3  4  5]
   [ 6  7  8]]]


 [[[ 9 10 11]
   [12 13 14]
   [15 16 17]]]]


In [11]:
# torch version
import torch

data_torch = torch.Tensor(data)
weight_torch = torch.Tensor(weight)
res_torch = torch.nn.functional.conv2d(data_torch, weight_torch)
res_torch = res_torch.numpy().astype(np.int64)
res_torch

  from .autonotebook import tqdm as notebook_tqdm


array([[[[ 474,  510,  546,  582,  618,  654],
         [ 762,  798,  834,  870,  906,  942],
         [1050, 1086, 1122, 1158, 1194, 1230],
         [1338, 1374, 1410, 1446, 1482, 1518],
         [1626, 1662, 1698, 1734, 1770, 1806],
         [1914, 1950, 1986, 2022, 2058, 2094]],

        [[1203, 1320, 1437, 1554, 1671, 1788],
         [2139, 2256, 2373, 2490, 2607, 2724],
         [3075, 3192, 3309, 3426, 3543, 3660],
         [4011, 4128, 4245, 4362, 4479, 4596],
         [4947, 5064, 5181, 5298, 5415, 5532],
         [5883, 6000, 6117, 6234, 6351, 6468]]]])

In [12]:
# low level numpy
def low_level_np_conv2d(data: np.ndarray, weight: np.ndarray, res: np.ndarray):
    """
        data shape: (N, CI, H, W)
        weight shape: (CO, CI, K, K)
        res shape: (N, CO, H - K + 1, W - K + 1)
    """
    for N_ind in range(N):
        for oc_ind in range(CO):
            for h_ind in range(OUT_H):
                for w_ind in range(OUT_W):
                    
                    for ci_ind in range(CI):
                        for h_K_ind in range(K):
                            for w_K_ind in range(K):
                                if ci_ind == 0 and h_K_ind == 0 and w_K_ind == 0:
                                    res[N_ind, oc_ind, h_ind, w_ind] = 0
                                res[N_ind, oc_ind, h_ind, w_ind] = res[N_ind, oc_ind, h_ind, w_ind] + data[N_ind, ci_ind, h_K_ind + h_ind, w_K_ind + w_ind] * weight[oc_ind, ci_ind, h_K_ind, w_K_ind]
data_np = data_torch.numpy()
weight_np = weight_torch.numpy()
res_np = np.empty((N, CO, H - K + 1, W - K + 1), dtype="int64")
low_level_np_conv2d(data_np, weight_np, res_np)
res_np

array([[[[ 474,  510,  546,  582,  618,  654],
         [ 762,  798,  834,  870,  906,  942],
         [1050, 1086, 1122, 1158, 1194, 1230],
         [1338, 1374, 1410, 1446, 1482, 1518],
         [1626, 1662, 1698, 1734, 1770, 1806],
         [1914, 1950, 1986, 2022, 2058, 2094]],

        [[1203, 1320, 1437, 1554, 1671, 1788],
         [2139, 2256, 2373, 2490, 2607, 2724],
         [3075, 3192, 3309, 3426, 3543, 3660],
         [4011, 4128, 4245, 4362, 4479, 4596],
         [4947, 5064, 5181, 5298, 5415, 5532],
         [5883, 6000, 6117, 6234, 6351, 6468]]]])

In [27]:
# TODO (prim func 的参数问题，能不能使用变量)
# TVMScript
@tvm.script.ir_module
class MyConv:
    @T.prim_func
    def conv(data: T.Buffer[(N, CI, 8, 8), "int64"],
             weight: T.Buffer[(CO, CI, K, K), "int64"],
             res: T.Buffer[(N, CO, OUT_H, OUT_W), "int64"]):
        T.func_attr({"global_symbol": "conv", "tir.noalias": True})
        for N_ind, oc_ind, h_ind, w_ind, ci_ind, h_K_ind, w_K_ind in T.grid(N, CO, OUT_H, OUT_W, CI, K, K):
            with T.block("res"):
                v_N_ind, v_oc_ind, v_h_ind, v_w_ind, v_ci_ind, v_h_K_ind, v_w_K_ind = T.axis.remap("SSSSRRR", [N_ind, oc_ind, h_ind, w_ind, ci_ind, h_K_ind, w_K_ind])
                with T.init():
                    res[v_N_ind, v_oc_ind, v_h_ind, v_w_ind] = T.int64(0)
                res[v_N_ind, v_oc_ind, v_h_ind, v_w_ind] = res[v_N_ind, v_oc_ind, v_h_ind, v_w_ind] + data[v_N_ind, v_ci_ind, v_h_K_ind + v_h_ind, v_w_K_ind + v_w_ind] * weight[v_oc_ind, v_ci_ind, v_h_K_ind, v_w_K_ind]

rt_lib = tvm.build(MyConv, target="llvm")
data_tvm = tvm.nd.array(data)
weight_tvm = tvm.nd.array(weight)
conv_tvm = tvm.nd.array(np.empty((N, CO, OUT_H, OUT_W), dtype=np.int64))
rt_lib["conv"](data_tvm, weight_tvm, conv_tvm)
np.testing.assert_allclose(conv_tvm.numpy(), res_torch, rtol=1e-5)
conv_tvm

<tvm.nd.NDArray shape=(1, 2, 6, 6), cpu(0)>
array([[[[ 474,  510,  546,  582,  618,  654],
         [ 762,  798,  834,  870,  906,  942],
         [1050, 1086, 1122, 1158, 1194, 1230],
         [1338, 1374, 1410, 1446, 1482, 1518],
         [1626, 1662, 1698, 1734, 1770, 1806],
         [1914, 1950, 1986, 2022, 2058, 2094]],

        [[1203, 1320, 1437, 1554, 1671, 1788],
         [2139, 2256, 2373, 2490, 2607, 2724],
         [3075, 3192, 3309, 3426, 3543, 3660],
         [4011, 4128, 4245, 4362, 4479, 4596],
         [4947, 5064, 5181, 5298, 5415, 5532],
         [5883, 6000, 6117, 6234, 6351, 6468]]]])

# 第二节：变换 IR
在本节中，让我们尝试变换程序。我们在采用了 bmm_relu (batched_matmul_relu)，这是一种常见于 Transformer 等模型中的操作变体。

首先，我们介绍一些新的原语：parallel、vectorize 和 unroll。这三个原语被应用于循环上，指示循环应当如何执行。这是示例：

In [29]:
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer[(4, 4), "int64"],
          B: T.Buffer[(4, 4), "int64"],
          C: T.Buffer[(4, 4), "int64"]):
    T.func_attr({"global_symbol": "add"})
    for i, j in T.grid(4, 4):
      with T.block("C"):
        vi = T.axis.spatial(4, i)
        vj = T.axis.spatial(4, j)
        C[vi, vj] = A[vi, vj] + B[vi, vj]

sch = tvm.tir.Schedule(MyAdd)
block = sch.get_block("C", func_name="add")
i, j = sch.get_loops(block)
i0, i1 = sch.split(i, factors=[2, 2])
sch.parallel(i0)
sch.unroll(i1)
sch.vectorize(j)
print(sch.mod.script())

@tvm.script.ir_module
class Module:
    @tir.prim_func
    def func(A: tir.Buffer[(4, 4), "int64"], B: tir.Buffer[(4, 4), "int64"], C: tir.Buffer[(4, 4), "int64"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "add"})
        # body
        # with tir.block("root")
        for i_0 in tir.parallel(2):
            for i_1 in tir.unroll(2):
                for j in tir.vectorized(4):
                    with tir.block("C"):
                        vi = tir.axis.spatial(4, i_0 * 2 + i_1)
                        vj = tir.axis.spatial(4, j)
                        tir.reads(A[vi, vj], B[vi, vj])
                        tir.writes(C[vi, vj])
                        C[vi, vj] = A[vi, vj] + B[vi, vj]
    


变换批量矩阵乘法程序

现在，让我们回到 bmm_relu 练习。首先，让我们看看 bmm 的定义：

- $Y_{n,i,j} = \sum_k A_{n,i,k} * B{n,k,j}$
- $C_{n,i,j} =relu(Y_{n,i,j})$

现在是你为 bmm_relu 编写 TensorIR 的时候了。我们提供 lnumpy 函数作为提示：


In [39]:
def lnumpy_mm_relu_v2(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((16, 128, 128), dtype="float32")
    for n in range(16):
        for i in range(128):
            for j in range(128):
                for k in range(128):
                    if k == 0:
                        Y[n, i, j] = 0
                    Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
    for n in range(16):
        for i in range(128):
            for j in range(128):
                C[n, i, j] = max(Y[n, i, j], 0)

In [44]:
from numpy import dtype


@tvm.script.ir_module
class MyBmmRule:
  @T.prim_func
  def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"],
               B: T.Buffer[(16, 128, 128), "float32"],
               C: T.Buffer[(16, 128, 128), "float32"]):
    T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
    Y = T.alloc_buffer((16, 128, 128), dtype="float32")
    for n, i, j, k in T.grid(16, 128, 128, 128):
        with T.block("Y"):
            vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k])
            with T.init():
                Y[vn, vi, vj] = T.float32(0)
            Y[vn, vi, vj] = Y[vn, vi, vj] + A[vn, vi, vk] * B[vn, vk, vj]

    for n, i, j in T.grid(16, 128, 128):
                with T.block("C"):
                    vn, vi, vj = T.axis.remap("SSS", [n, i, j])
                    C[vn, vi, vj] = T.max(Y[vn, vi, vj], T.float32(0))

sch = tvm.tir.Schedule(MyBmmRule)
print(sch.mod.script())
# Also please validate your result

mat_a = np.random.rand(16, 128, 128).astype("float32")
mat_b = np.random.rand(16, 128, 128).astype("float32")
mat_c_np = np.empty((16, 128, 128), dtype="float32")
lnumpy_mm_relu_v2(mat_a, mat_b, mat_c_np)


mat_a_nd = tvm.nd.array(mat_a)
mat_b_nd = tvm.nd.array(mat_b)
mat_c_nd = tvm.nd.array(np.empty((16, 128, 128), dtype="float32"))
rt_lib = tvm.build(MyBmmRule, target="llvm")
func = rt_lib["bmm_relu"]
func(mat_a_nd, mat_b_nd, mat_c_nd)
np.testing.assert_allclose(mat_c_nd.numpy(), mat_c_np, rtol=1e-5)

@tvm.script.ir_module
class Module:
    @tir.prim_func
    def func(A: tir.Buffer[(16, 128, 128), "float32"], B: tir.Buffer[(16, 128, 128), "float32"], C: tir.Buffer[(16, 128, 128), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        # body
        # with tir.block("root")
        Y = tir.alloc_buffer([16, 128, 128], dtype="float32")
        for n, i, j, k in tir.grid(16, 128, 128, 128):
            with tir.block("Y"):
                vn, vi, vj, vk = tir.axis.remap("SSSR", [n, i, j, k])
                tir.reads(A[vn, vi, vk], B[vn, vk, vj])
                tir.writes(Y[vn, vi, vj])
                with tir.init():
                    Y[vn, vi, vj] = tir.float32(0)
                Y[vn, vi, vj] = Y[vn, vi, vj] + A[vn, vi, vk] * B[vn, vk, vj]
        for n, i, j in tir.grid(16, 128, 128):
            with tir.block("C"):
                vn, vi, vj = tir.axis.remap("SSS", [n, i, j])
                tir.read

在本练习中，让我们专注于将原始程序变换为特定目标。请注意，由于硬件不同，目标程序可能不是最好的程序。但是这个练习旨在让你了解如何将程序变换为想要的程序。 这是目标程序：
```python
@tvm.script.ir_module
class TargetModule:
    @T.prim_func
    def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"], B: T.Buffer[(16, 128, 128), "float32"], C: T.Buffer[(16, 128, 128), "float32"]) -> None:
        T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        Y = T.alloc_buffer([16, 128, 128], dtype="float32")
        for i0 in T.parallel(16):
            for i1, i2_0 in T.grid(128, 16):
                for ax0_init in T.vectorized(8):
                    with T.block("Y_init"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + ax0_init)
                        Y[n, i, j] = T.float32(0)
                for ax1_0 in T.serial(32):
                    for ax1_1 in T.unroll(4):
                        for ax0 in T.serial(8):
                            with T.block("Y_update"):
                                n, i = T.axis.remap("SS", [i0, i1])
                                j = T.axis.spatial(128, i2_0 * 8 + ax0)
                                k = T.axis.reduce(128, ax1_0 * 4 + ax1_1)
                                Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
                for i2_1 in T.vectorized(8):
                    with T.block("C"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + i2_1)
                        C[n, i, j] = T.max(Y[n, i, j], T.float32(0))

In [90]:
sch = tvm.tir.Schedule(MyBmmRule)
# TODO: transformations
# Hints: you can use
# `IPython.display.Code(sch.mod.script(), language="python")`
# or `print(sch.mod.script())`
# to show the current program at any time during the transformation.

# Step 1. Get blocks
Y = sch.get_block("Y", func_name="bmm_relu")
C = sch.get_block("C", func_name="bmm_relu")

# Step 2. Get loops
b, i, j, k = sch.get_loops(Y)

# Step 3. Organize the loops
k0, k1 = sch.split(k, [None, 4])
j0, j1 = sch.split(j, [None, 8])

sch.reorder(j0, k0, k1, j1)

sch.reverse_compute_at(C, j0)

# Step 4. decompose reduction
Y_init = sch.decompose_reduction(Y, k0)
_, _, _, y_init = sch.get_loops(Y_init)

# Step 5. vectorize / parallel / unroll
sch.vectorize(y_init)
# sch.parallel(b)
# sch.unroll(...)
print(sch.mod.script())

@tvm.script.ir_module
class Module:
    @tir.prim_func
    def func(A: tir.Buffer[(16, 128, 128), "float32"], B: tir.Buffer[(16, 128, 128), "float32"], C: tir.Buffer[(16, 128, 128), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        # body
        # with tir.block("root")
        Y = tir.alloc_buffer([16, 128, 128], dtype="float32")
        for n, i, j_0 in tir.grid(16, 128, 16):
            for j_1_init in tir.vectorized(8):
                with tir.block("Y_init"):
                    vn, vi = tir.axis.remap("SS", [n, i])
                    vj = tir.axis.spatial(128, j_0 * 8 + j_1_init)
                    tir.reads()
                    tir.writes(Y[vn, vi, vj])
                    Y[vn, vi, vj] = tir.float32(0)
            for k_0, k_1, j_1 in tir.grid(32, 4, 8):
                with tir.block("Y_update"):
                    vn, vi = tir.axis.remap("SS", [n, i])
                    vj = tir.axis.spat

<built-in method __dir__ of method-wrapper object at 0x7f7eca16ad10>
