In [29]:
import IPython
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T

#### 练习 1：广播加法

请编写一个 TensorIR 函数，将两个数组以广播的方式相加。

In [30]:
# init data
a = np.arange(16).reshape(4, 4)
b = np.arange(4, 0, -1).reshape(4)

In [31]:
# numpy version
c_np = a + b
c_np

array([[ 4,  4,  4,  4],
       [ 8,  8,  8,  8],
       [12, 12, 12, 12],
       [16, 16, 16, 16]])

In [32]:
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer[(4,4),"int64"],
          B: T.Buffer[(4),"int64"],
          C: T.Buffer[(4,4),"int64"]):
    T.func_attr({"global_symbol": "add", "tir.noalias": True})
    for i, j in T.grid(4, 4):
        with T.block("C"):
            vi=T.axis.spatial(4, i)
            vj=T.axis.spatial(4, j)
            C[vi, vj] = A[vi, vj] + B[vj]

rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype="int64"))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
f_timer = rt_lib.time_evaluator("add", tvm.cpu())
print("Time cost of transformed mod_transformed %g sec" % f_timer(a_tvm, b_tvm, c_tvm).mean)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)

Time cost of transformed mod_transformed 2.92e-08 sec




#### 练习 2：二维卷积

然后，让我们尝试做一些具有挑战性的事情：二维卷积。这是图像处理中的常见操作。

这是使用 NCHW 布局的卷积的数学定义：
$$Conv[b, k, i, j] =
    \sum_{di, dj, q} A[b, q, strides * i + di, strides * j + dj] * W[k, q, di, dj],$$
其中，`A` 是输入张量，`W` 是权重张量，`b` 是批次索引，`k` 是输出通道，`i` 和 `j` 是图像高度和宽度的索引，`di` 和 `dj` 是权重的索引，`q` 是输入通道，`strides` 是过滤器窗口的步幅。

在练习中，我们选择了一个小而简单的情况，即 `stride=1, padding=0`。

In [33]:
N, CI, H, W, CO, K = 1, 1, 8, 8, 2, 3
OUT_H, OUT_W = H - K + 1, W - K + 1
data = np.arange(N*CI*H*W).reshape(N, CI, H, W)
weight = np.arange(CO*CI*K*K).reshape(CO, CI, K, K)

In [34]:
# torch version
import torch
data_torch = torch.Tensor(data)
weight_torch = torch.Tensor(weight)
conv_torch = torch.nn.functional.conv2d(data_torch, weight_torch)
conv_torch = conv_torch.numpy().astype(np.int64)
conv_torch

array([[[[ 474,  510,  546,  582,  618,  654],
         [ 762,  798,  834,  870,  906,  942],
         [1050, 1086, 1122, 1158, 1194, 1230],
         [1338, 1374, 1410, 1446, 1482, 1518],
         [1626, 1662, 1698, 1734, 1770, 1806],
         [1914, 1950, 1986, 2022, 2058, 2094]],

        [[1203, 1320, 1437, 1554, 1671, 1788],
         [2139, 2256, 2373, 2490, 2607, 2724],
         [3075, 3192, 3309, 3426, 3543, 3660],
         [4011, 4128, 4245, 4362, 4479, 4596],
         [4947, 5064, 5181, 5298, 5415, 5532],
         [5883, 6000, 6117, 6234, 6351, 6468]]]])

In [35]:
@tvm.script.ir_module
class MyConv:
  @T.prim_func
  def conv(A: T.Buffer[(N, CI, H, W), "int64"], B: T.Buffer[(CO, CI, K, K), "int64"], C: T.Buffer[(N, CO, OUT_H, OUT_W), "int64"]):
    T.func_attr({"global_symbol": "conv", "tir.noalias": True})
    for b, k, i, j, q, di, dj in T.grid(N, CO, OUT_H, OUT_W, CI, K, K):
      with T.block("C"):
        vb=T.axis.spatial(N, b)
        vk=T.axis.spatial(CO, k)
        vi=T.axis.spatial(OUT_H, i)
        vj=T.axis.spatial(OUT_W, j)
        vq=T.axis.reduce(CI, q)
        vdi=T.axis.reduce(K, di)
        vdj=T.axis.reduce(K, dj)
        
        # hbsun: need init first
        with T.init():
          C[vb, vk, vi, vj] = T.int64(0)
        C[vb, vk, vi, vj] += A[vb, vq, vi+vdi, vj+vdj] * B[vk, vq, vdi, vdj]    

error: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
 --> /var/folders/bg/vzcr5j8d297c300rzc9gyb2w0000gn/T/ipykernel_16879/3303365678.py:6:5
   |  
 6 |      for b, k, i, j, q, di, dj in T.grid(N, CO, OUT_H, OUT_W, CI, K, K):
   |      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^


DiagnosticError: Traceback (most recent call last):
  [bt] (8) 9   python3.10                          0x00000001007586c4 _PyObject_MakeTpCall + 136
  [bt] (7) 8   _ctypes.cpython-310-darwin.so       0x0000000101422178 PyCFuncPtr_call + 220
  [bt] (6) 7   _ctypes.cpython-310-darwin.so       0x000000010142842c _ctypes_callproc + 936
  [bt] (5) 6   libffi.8.dylib                      0x0000000101445790 ffi_call_int + 1256
  [bt] (4) 5   libffi.8.dylib                      0x000000010144804c ffi_call_SYSV + 76
  [bt] (3) 4   libtvm.dylib                        0x000000013cdc5af4 TVMFuncCall + 60
  [bt] (2) 3   libtvm.dylib                        0x000000013b31b468 tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<void tvm::runtime::TypedPackedFunc<void (tvm::DiagnosticContext)>::AssignTypedLambda<tvm::$_8>(tvm::$_8, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >)::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) + 776
  [bt] (1) 2   libtvm.dylib                        0x000000013b311368 tvm::DiagnosticContext::Render() + 468
  [bt] (0) 1   libtvm.dylib                        0x000000013b006e28 tvm::runtime::detail::LogFatal::Entry::Finalize() + 84
  File "/Users/yd/Documents/tvm/src/ir/diagnostic.cc", line 105
DiagnosticError: one or more error diagnostics were emitted, please check diagnostic render for output.

### 第二节：如何变换 TensorIR

在讲座中，我们了解到 TensorIR 不仅是一种编程语言，而且还是一种程序变换的抽象。在本节中，让我们尝试变换程序。我们在采用了 `bmm_relu` (`batched_matmul_relu`)，这是一种常见于 Transformer 等模型中的操作变体。

#### 并行化、向量化与循环展开

首先，我们介绍一些新的原语：`parallel`、`vectorize` 和 `unroll`。这三个原语被应用于循环上，指示循环应当如何执行。这是示例：

In [None]:
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer[(4, 4), "int64"],
          B: T.Buffer[(4, 4), "int64"],
          C: T.Buffer[(4, 4), "int64"]):
    T.func_attr({"global_symbol": "add"})
    for i, j in T.grid(4, 4):
      with T.block("C"):
        vi = T.axis.spatial(4, i)
        vj = T.axis.spatial(4, j)
        C[vi, vj] = A[vi, vj] + B[vi, vj]

sch = tvm.tir.Schedule(MyAdd)
block = sch.get_block("C", func_name="add")
i, j = sch.get_loops(block)
i0, i1 = sch.split(i, factors=[None, 2])
sch.parallel(i0)
sch.unroll(i1)
sch.vectorize(j)
IPython.display.Code(sch.mod.script(), language="python")

In [None]:
def lnumpy_mm_relu_v2(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((16, 128, 128), dtype="float32")
    for n in range(16):
        for i in range(128):
            for j in range(128):
                for k in range(128):
                    if k == 0:
                        Y[n, i, j] = 0
                    Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
    for n in range(16):
        for i in range(128):
            for j in range(128):
                C[n, i, j] = max(Y[n, i, j], 0)

In [50]:
# @tvm.script.ir_module
# class MyBmmRelu:
#   @T.prim_func
#   def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"], B: T.Buffer[(16, 128, 128), "float32"], C: T.Buffer[(16, 128, 128), "float32"]):
#     T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
#     Y = T.alloc_buffer([16, 128, 128], dtype="float32")
#     for i, j, k, n in T.grid(128, 128, 128, 16):
#       with T.block("Y"):
#         vi, vj, vk, vn = T.axis.remap("SSSR", [n, i, j, k])
#         with T.init():
#           Y[vn, vi, vj] = T.float32(0)
#         Y[vn, vi, vj] += A[vn, vi, vk] * B[vn, vk, vj]
#     for i, j, n in T.grid(128, 128, 16):
#       with T.block("C"):
#         vi, vj, vn = T.axis.remap("SSR", [n, i, j])
#         Y[vn, vi, vj] = T.max(Y[vn, vi, vj], T.float32(0))
@tvm.script.ir_module
class MyBmmRelu:
  @T.prim_func
  def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"],
               B: T.Buffer[(16, 128, 128), "float32"],
               C: T.Buffer[(16, 128, 128), "float32"]):
    T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
    Y = T.alloc_buffer((16, 128, 128), dtype="float32")
    for n, i, j, k in T.grid(16, 128, 128, 128):
        with T.block("Y"):
            vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k])
            with T.init():
                Y[vn, vi, vj] = T.float32(0)
            Y[vn, vi, vj] = Y[vn, vi, vj] + A[vn, vi, vk] * B[vn, vk, vj]

    for n, i, j in T.grid(16, 128, 128):
                with T.block("C"):
                    vn, vi, vj = T.axis.remap("SSS", [n, i, j])
                    C[vn, vi, vj] = T.max(Y[vn, vi, vj], T.float32(0))


sch = tvm.tir.Schedule(MyBmmRelu)
print(sch.mod.script())

# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
    @T.prim_func
    def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"], B: T.Buffer[(16, 128, 128), "float32"], C: T.Buffer[(16, 128, 128), "float32"]):
        # function attr dict
        T.func_attr({"tir.noalias": True, "global_symbol": "bmm_relu"})
        # body
        # with T.block("root")
        Y = T.alloc_buffer([16, 128, 128], dtype="float32")
        for n, i, j, k in T.grid(16, 128, 128, 128):
            with T.block("Y"):
                vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k])
                T.reads(A[vn, vi, vk], B[vn, vk, vj])
                T.writes(Y[vn, vi, vj])
                with T.init():
                    Y[vn, vi, vj] = T.float32(0)
                Y[vn, vi, vj] = Y[vn, vi, vj] + A[vn, vi, vk] * B[vn, vk, vj]
        for n, i, j in T.grid(16, 128, 128):
            with T.block("C"):
                vn, vi, vj = T.axis.remap("SSS", [n, i, j])
                T.reads(Y[vn

In [63]:
@tvm.script.ir_module
class TargetModule:
    @T.prim_func
    def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"], B: T.Buffer[(16, 128, 128), "float32"], C: T.Buffer[(16, 128, 128), "float32"]) -> None:
        T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        Y = T.alloc_buffer([16, 128, 128], dtype="float32")
        for i0 in T.parallel(16):
            for i1, i2_0 in T.grid(128, 16):
                for ax0_init in T.vectorized(8):
                    with T.block("Y_init"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + ax0_init)
                        Y[n, i, j] = T.float32(0)
                for ax1_0 in T.serial(32):
                    for ax1_1 in T.unroll(4):
                        for ax0 in T.serial(8):
                            with T.block("Y_update"):
                                n, i = T.axis.remap("SS", [i0, i1])
                                j = T.axis.spatial(128, i2_0 * 8 + ax0)
                                k = T.axis.reduce(128, ax1_0 * 4 + ax1_1)
                                Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
                for i2_1 in T.vectorized(8):
                    with T.block("C"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + i2_1)
                        C[n, i, j] = T.max(Y[n, i, j], T.float32(0))

In [61]:
sch = tvm.tir.Schedule(MyBmmRelu)
# TODO: transformations
# Hints: you can use
# `IPython.display.Code(sch.mod.script(), language="python")`
# or `print(sch.mod.script())`
# to show the current program at any time during the transformation.

# Step 1. Get blocks
block_Y = sch.get_block("Y", func_name="bmm_relu")

# Step 2. Get loops
b, i, j, k = sch.get_loops(block_Y)

# parallelize the outermost loop
sch.parallel(b)

# Organize the loops
k0, k1 = sch.split(k, factors=[32, 4])
j0, j1 = sch.split(j, factors=[16, 8])
                  
sch.reorder(j0, k0, k1, j1)
print(sch.mod.script())

# reverse compute
block_C = sch.get_block("C", func_name="bmm_relu")
sch.reverse_compute_at(block_C, j0)
print(sch.mod.script())

# decompose reduction
Y_init = sch.decompose_reduction(block_Y, k0)
print(sch.mod.script())

# vectorize /  unroll
n, i, j_0, j_1_init = sch.get_loops(Y_init)
sch.vectorize(j_1_init)
n, i, j_0, i2_1 = sch.get_loops(block_C)
sch.vectorize(i2_1)

block_Y_update = sch.get_block("Y_update", func_name="bmm_relu")
n, i, j_0, k_0, k_1, j_1 = sch.get_loops(block_Y_update)
sch.unroll(k_1)

# n, i, j_0, k_0, k_1, j_1 = sch.get_loops(block_Y_update)
# sch.parallel(n)
...

IPython.display.Code(sch.mod.script(), language="python")

# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
    @T.prim_func
    def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"], B: T.Buffer[(16, 128, 128), "float32"], C: T.Buffer[(16, 128, 128), "float32"]):
        # function attr dict
        T.func_attr({"tir.noalias": True, "global_symbol": "bmm_relu"})
        # body
        # with T.block("root")
        Y = T.alloc_buffer([16, 128, 128], dtype="float32")
        for n in T.parallel(16):
            for i, j_0, k_0, k_1, j_1 in T.grid(128, 16, 32, 4, 8):
                with T.block("Y"):
                    vn, vi = T.axis.remap("SS", [n, i])
                    vj = T.axis.spatial(128, j_0 * 8 + j_1)
                    vk = T.axis.reduce(128, k_0 * 4 + k_1)
                    T.reads(A[vn, vi, vk], B[vn, vk, vj])
                    T.writes(Y[vn, vi, vj])
                    with T.init():
                        Y[vn, vi, vj] = T.float32(0)
                    Y[vn, vi, vj] = Y[vn, vi, vj] + A[vn, vi, vk] *

In [64]:
tvm.ir.assert_structural_equal(sch.mod, TargetModule)
print("Pass")

Pass


In [72]:
before_rt_lib = tvm.build(MyBmmRelu, target="llvm")
after_rt_lib = tvm.build(sch.mod, target="llvm")

a_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
b_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
c_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
before_timer = before_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
f_timer = after_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("Before transformation:")
print(before_timer(a_tvm, b_tvm, c_tvm))

print("After transformation:")
print(f_timer(a_tvm, b_tvm, c_tvm))



Before transformation:
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  42.2089      42.2089      42.2089      42.2089       0.0000   
               
After transformation:
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
   1.8328       1.8328       1.8328       1.8328       0.0000   
               
