In [1]:
!python3 -m  pip install mlc-ai-nightly -f https://mlc.ai/wheels

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://mlc.ai/wheels
Collecting mlc-ai-nightly
  Downloading https://github.com/mlc-ai/utils/releases/download/v0.9.dev0/mlc_ai_nightly-0.9.dev1664%2Bg1f3985de0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (43.3 MB)
[K     |████████████████████████████████| 43.3 MB 30.7 MB/s 
Collecting synr==0.6.0
  Downloading synr-0.6.0-py3-none-any.whl (18 kB)
Installing collected packages: synr, mlc-ai-nightly
Successfully installed mlc-ai-nightly-0.9.dev1664+g1f3985de0 synr-0.6.0


In [99]:
import IPython
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T

## Section 1: How to Write TensorIR

In this section, let’s try to write TensorIR manually according to high-level instructions (e.g., Numpy or Torch). First, we give an example of element-wise add function, to show what should we do to write a TensorIR function.

### Elementwise add

In [38]:
# init data
a = np.arange(16).reshape(4, 4)
b = np.arange(16, 0, -1).reshape(4, 4)

In [39]:
# numpy version
c_np = a + b
c_np

array([[16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16]])

Before we directly write TensorIR, we should first translate high-level computation abstraction (e.g., ndarray + ndarray) to low-level python implementation (standard for loops with element access and operation)

Notably, the initial value of the o utput array (or buffer) is not always 0. We need to write or initialize it in our implementation, which is important for reduction operator (e.g. matmul and conv)

In [40]:
# low-level numpy version
def lnumpy_add(a: np.ndarray, b: np.ndarray, c: np.ndarray):
  for i in range(4):
    for j in range(4):
      c[i, j] = a[i, j] + b[i, j]
c_lnumpy = np.empty((4, 4), dtype=np.int64)
lnumpy_add(a, b, c_lnumpy)
c_lnumpy

array([[16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16]])

Now, let’s take a further step: translate low-level NumPy implementation into TensorIR. And compare the result with it comes from NumPy.



In [34]:
# TensorIR version
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer[(4, 4), "int64"],
          B: T.Buffer[(4, 4), "int64"],
          C: T.Buffer[(4, 4), "int64"]):
    T.func_attr({"global_symbol": "add"})
    for i, j in T.grid(4, 4):
      with T.block("C"):
        vi = T.axis.spatial(4, i)
        vj = T.axis.spatial(4, j)
        C[vi, vj] = A[vi, vj] + B[vi, vj]

rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)


### Exercise 1: Broadcast Add


In [43]:
# init data
a = np.arange(16).reshape(4, 4)
b = np.arange(4, 0, -1).reshape(4)

In [44]:
# numpy version
c_np = a + b
c_np

array([[ 4,  4,  4,  4],
       [ 8,  8,  8,  8],
       [12, 12, 12, 12],
       [16, 16, 16, 16]])

(4,)

In [66]:
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer[(4, 4), "int64"],
          B: T.Buffer[(4), "int64"],
          C: T.Buffer[(4, 4), "int64"]):
      T.func_attr({"global_symbol": "add", "tir.noalias": True})
      for i, j in T.grid(4, 4):
        with T.block("C"):
          vi=T.axis.spatial(4, i)
          vj=T.axis.spatial(4, j)
          C[vi, vj] = A[vi, vj] + B[vj]

rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)

## Exercise 2: 2D Convolution

Then, let’s try to do something challenging: 2D convolution, which is a common operation in image processing.

Here is the mathematical definition of convolution with NCHW layout:

$Conv[b, k, i, j] = \sum_{di, dj, q} A[b, q, strides * i + di, strides * j + dj] * W[k, q, di, dj]$

, where, A is the input tensor, W is the weight tensor, b is the batch index, k is the out channels, i and j are indices for image hight and width, di and dj are the indices of the weight, q is the input channel, and strides is the stride of the filter window.

In the exercise, we pick a small and simple case with stride=1, padding=0.

In [78]:
N, CI, H, W, CO, K = 1, 1, 8, 8, 2, 3
OUT_H, OUT_W = H - K + 1, W - K + 1
data = np.arange(N*CI*H*W).reshape(N, CI, H, W)
weight = np.arange(CO*CI*K*K).reshape(CO, CI, K, K)

In [50]:
# torch version
import torch

data_torch = torch.Tensor(data)
weight_torch = torch.Tensor(weight)
conv_torch = torch.nn.functional.conv2d(data_torch, weight_torch)
conv_torch = conv_torch.numpy().astype(np.int64)
conv_torch

array([[[[ 474,  510,  546,  582,  618,  654],
         [ 762,  798,  834,  870,  906,  942],
         [1050, 1086, 1122, 1158, 1194, 1230],
         [1338, 1374, 1410, 1446, 1482, 1518],
         [1626, 1662, 1698, 1734, 1770, 1806],
         [1914, 1950, 1986, 2022, 2058, 2094]],

        [[1203, 1320, 1437, 1554, 1671, 1788],
         [2139, 2256, 2373, 2490, 2607, 2724],
         [3075, 3192, 3309, 3426, 3543, 3660],
         [4011, 4128, 4245, 4362, 4479, 4596],
         [4947, 5064, 5181, 5298, 5415, 5532],
         [5883, 6000, 6117, 6234, 6351, 6468]]]])

In [101]:
@tvm.script.ir_module
class MyConv:
  @T.prim_func
  def conv(A: T.Buffer[(N,CI,8,8), "int64"],
           W: T.Buffer[(CO,CI,K,K), "int64"],
           C: T.Buffer[(N,CO,OUT_H,OUT_W), "int64"]):
    T.func_attr({"global_symbol": "conv", "tir.noalias": True})
    for b, k, i, j, di, dj, q in T.grid(N, CO, OUT_H, OUT_W, K, K, CI):
      with T.block("C"):
       vb=T.axis.spatial(N, b)
       vk=T.axis.spatial(CO,k)
       vi=T.axis.spatial(OUT_H, i)
       vj=T.axis.spatial(OUT_W, j)
       vdi=T.axis.reduce(K, di)
       vdj=T.axis.reduce(K, dj)
       vq=T.axis.reduce(CI, q)
       with T.init():
         C[vb, vk, vi, vj] = T.int64(0)
       C[vb, vk, vi, vj] = C[vb, vk, vi, vj] + A[vb, vq, vi+vdi, vj+vdj] * W[vk,vq,vdi,vdj]
      

rt_lib = tvm.build(MyConv, target="llvm")
data_tvm = tvm.nd.array(data)
weight_tvm = tvm.nd.array(weight)
conv_tvm = tvm.nd.array(np.empty((N, CO, OUT_H, OUT_W), dtype=np.int64))
rt_lib["conv"](data_tvm, weight_tvm, conv_tvm)
np.testing.assert_allclose(conv_tvm.numpy(), conv_torch, rtol=1e-5)

## Section 2: How to Transform TensorIR

In the lecture, we learned that TensorIR is not only a programming language but also an abstraction for program transformation. In this section, let’s try to transform the program. We take bmm_relu (batched_matmul_relu) in our studies, which is a variant of operations that common appear in models such as transformers.

### Parallel, Vectorize and Unroll


First, we introduce some new primitives, parallel, vectorize and unroll. These three primitives operates on loops to indicate how this loop execute. Here is the example:

In [83]:
import IPython

def code2html(code):
    """Helper function to use pygments to turn the code string into highlighted html."""
    import pygments
    from pygments.lexers import Python3Lexer
    from pygments.formatters import HtmlFormatter
    formatter = HtmlFormatter()
    html = pygments.highlight(code, Python3Lexer(), formatter)
    return "<style>%s</style>%s\n" % (formatter.get_style_defs(".highlight"), html)

In [87]:
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer[(4, 4), "int64"],
          B: T.Buffer[(4, 4), "int64"],
          C: T.Buffer[(4, 4), "int64"]):
    T.func_attr({"global_symbol": "add"})
    for i, j in T.grid(4, 4):
      with T.block("C"):
        vi = T.axis.spatial(4, i)
        vj = T.axis.spatial(4, j)
        C[vi, vj] = A[vi, vj] + B[vi, vj]

sch = tvm.tir.Schedule(MyAdd)
block = sch.get_block("C", func_name="add")
i, j = sch.get_loops(block)
i0, i1 = sch.split(i, factors=[2, 2])
sch.parallel(i0)
sch.unroll(i1)
sch.vectorize(j)


In [88]:
print(sch.mod.script())

@tvm.script.ir_module
class Module:
    @tir.prim_func
    def add(A: tir.Buffer[(4, 4), "int64"], B: tir.Buffer[(4, 4), "int64"], C: tir.Buffer[(4, 4), "int64"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "add"})
        # body
        # with tir.block("root")
        for i_0 in tir.parallel(2):
            for i_1 in tir.unroll(2):
                for j in tir.vectorized(4):
                    with tir.block("C"):
                        vi = tir.axis.spatial(4, i_0 * 2 + i_1)
                        vj = tir.axis.spatial(4, j)
                        tir.reads(A[vi, vj], B[vi, vj])
                        tir.writes(C[vi, vj])
                        C[vi, vj] = A[vi, vj] + B[vi, vj]
    


In [90]:
from tvm.script import tir
@tvm.script.ir_module
class Module:
    @tir.prim_func
    def add(A: tir.Buffer[(4, 4), "int64"], B: tir.Buffer[(4, 4), "int64"], C: tir.Buffer[(4, 4), "int64"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "add"})
        # body
        # with tir.block("root")
        for i_0 in tir.parallel(2):
            for i_1 in tir.unroll(2):
                for j in tir.vectorized(4):
                    with tir.block("C"):
                        vi = tir.axis.spatial(4, i_0 * 2 + i_1)
                        vj = tir.axis.spatial(4, j)
                        tir.reads(A[vi, vj], B[vi, vj])
                        tir.writes(C[vi, vj])
                        C[vi, vj] = A[vi, vj] + B[vi, vj]

### Exercise 3: Transform a batch matmul program

Now, let’s go back to the bmm_relu exercise. First, Let’s see the definition of bmm $Y_{n, i, j} = \sum_k A_{n, i, k} \times B_{n, k, j}$, $C_{n, i, j} = \mathbb{relu}(Y_{n,i,j}) = \mathbb{max}(Y_{n, i, j}, 0)$

It’s your time to write the TensorIR for bmm_relu. We provide the lnumpy func as hint:



In [127]:
def lnumpy_mm_relu_v2(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((16, 128, 128), dtype="float32")
    for n in range(16):
        for i in range(128):
            for j in range(128):
                for k in range(128):
                    if k == 0:
                        Y[i, j] = 0
                    Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
    for n in range(16):
        for i in range(128):
            for j in range(128):
                C[i, j] = max(Y[i, j], 0)


In [140]:
@tvm.script.ir_module
class MyBmmRelu:
  @T.prim_func
  def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"],
               B: T.Buffer[(16, 128, 128), "float32"],
               C: T.Buffer[(16, 128, 128), "float32"]):
    T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
    Y = T.alloc_buffer((16, 128, 128), dtype="float32")
    for n, i, j, k in T.grid(16, 128, 128, 128):
        with T.block("Y"):
            vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k])
            with T.init():
                Y[vn, vi, vj] = T.float32(0)
            Y[vn, vi, vj] = Y[vn, vi, vj] + A[vn, vi, vk] * B[vn, vk, vj]
    for n, i, j in T.grid(16, 128, 128):
        with T.block("C"):
            vn, vi, vj = T.axis.remap("SSS", [n, i, j])
            C[vn, vi, vj] = T.max(Y[vn, vi, vj], T.float32(0))

sch = tvm.tir.Schedule(MyBmmRule)
#IPython.display.Code(sch.mod.script(), language="python")
# Also please validate your result

now let us verify the program

In [129]:
print(sch.mod.script())

@tvm.script.ir_module
class Module:
    @tir.prim_func
    def bmm_relu(A: tir.Buffer[(16, 128, 128), "float32"], B: tir.Buffer[(16, 128, 128), "float32"], C: tir.Buffer[(16, 128, 128), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        # body
        # with tir.block("root")
        Y = tir.alloc_buffer([16, 128, 128], dtype="float32")
        for n, i, j, k in tir.grid(16, 128, 128, 128):
            with tir.block("Y"):
                vn, vi, vj, vk = tir.axis.remap("SSSR", [n, i, j, k])
                tir.reads(A[vn, vi, vk], B[vn, vk, vj])
                tir.writes(Y[vn, vi, vj])
                with tir.init():
                    Y[vn, vi, vj] = tir.float32(0)
                Y[vn, vi, vj] = Y[vn, vi, vj] + A[vn, vi, vk] * B[vn, vk, vj]
        for n, i, j in tir.grid(16, 128, 128):
            with tir.block("C"):
                vn, vi, vj = tir.axis.remap("SSS", [n, i, j])
                tir.

In this exercise, let’s focus on transform the original program to a specific target. Note that the target program may not be the best one due to different hardware. But this exercise aims to let students understand how to transform the program to a wanted one. Here is the target program:



In [130]:
@tvm.script.ir_module
class TargetModule:
    @T.prim_func
    def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"], B: T.Buffer[(16, 128, 128), "float32"], C: T.Buffer[(16, 128, 128), "float32"]) -> None:
        T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        Y = T.alloc_buffer([16, 128, 128], dtype="float32")
        for i0 in T.parallel(16):
            for i1, i2_0 in T.grid(128, 16):
                for ax0_init in T.vectorized(8):
                    with T.block("Y_init"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + ax0_init)
                        Y[n, i, j] = T.float32(0)
                for ax1_0 in T.serial(32):
                    for ax1_1 in T.unroll(4):
                        for ax0 in T.serial(8):
                            with T.block("Y_update"):
                                n, i = T.axis.remap("SS", [i0, i1])
                                j = T.axis.spatial(128, i2_0 * 8 + ax0)
                                k = T.axis.reduce(128, ax1_0 * 4 + ax1_1)
                                Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
                for i2_1 in T.vectorized(8):
                    with T.block("C"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + i2_1)
                        C[n, i, j] = T.max(Y[n, i, j], T.float32(0))

In [152]:
sch = tvm.tir.Schedule(MyBmmRelu)
# TODO: transformations
# Hints: you can use
# `IPython.display.Code(sch.mod.script(), language="python")`
# or `print(sch.mod.script())`
# to show the current program at any time during the transformation.

# Step 1. Get blocks
Y = sch.get_block("Y", func_name="bmm_relu")
C = sch.get_block("C", func_name = "bmm_relu")

# Step 2. Get loops
i0, i, j, k = sch.get_loops(Y)
sch.parallel(i0)



# Step 3. Organize the loops
j0, j1 = sch.split(j, factors = [16, 8])
sch.reorder(j0, k, j1)
sch.reverse_compute_at(C, j0)


# Step 4. decompose reduction
Y_init = sch.decompose_reduction(Y, k)
n, i, j_0, j_1_init = sch.get_loops(Y_init)
_, _, _, ax0 = sch.get_loops(C)
Y_update = sch.get_block("Y_update", func_name="bmm_relu")
_, _, _, k, j_1 = sch.get_loops(Y_update)
k0, k1 = sch.split(k, factors=[32, 4])


# Step 5. vectorize / parallel / unroll
sch.vectorize(j_1_init)
sch.vectorize(ax0)
sch.unroll(k1)

print(sch.mod.script())

@tvm.script.ir_module
class Module:
    @tir.prim_func
    def bmm_relu(A: tir.Buffer[(16, 128, 128), "float32"], B: tir.Buffer[(16, 128, 128), "float32"], C: tir.Buffer[(16, 128, 128), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        # body
        # with tir.block("root")
        Y = tir.alloc_buffer([16, 128, 128], dtype="float32")
        for n in tir.parallel(16):
            for i, j_0 in tir.grid(128, 16):
                for j_1_init in tir.vectorized(8):
                    with tir.block("Y_init"):
                        vn, vi = tir.axis.remap("SS", [n, i])
                        vj = tir.axis.spatial(128, j_0 * 8 + j_1_init)
                        tir.reads()
                        tir.writes(Y[vn, vi, vj])
                        Y[vn, vi, vj] = tir.float32(0)
                for k_0 in tir.serial(32):
                    for k_1 in tir.unroll(4):
                        for j_1 in tir.s

OPTIONAL If we want to make sure the transformed program is exactly the same as the given target, we can use assert_structural_equal. Note that this step is an optional step in this exercise. It’s good enough if you transformed the program towards the target and get performance improvement.

In [149]:
tvm.ir.assert_structural_equal(sch.mod, TargetModule)
print("Pass")

Pass


### Build and Evaluate¶


Finally we can evaluate the performance of the transformed program.



In [150]:
before_rt_lib = tvm.build(MyBmmRule, target="llvm")
after_rt_lib = tvm.build(sch.mod, target="llvm")
a_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
b_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
c_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
after_rt_lib["bmm_relu"](a_tvm, b_tvm, c_tvm)
before_timer = before_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("Before transformation:")
print(before_timer(a_tvm, b_tvm, c_tvm))

f_timer = after_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("After transformation:")
print(f_timer(a_tvm, b_tvm, c_tvm))

Before transformation:
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  49.5689      49.5689      49.5689      49.5689       0.0000   
               
After transformation:
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  14.0756      14.0756      14.0756      14.0756       0.0000   
               
