<a href="https://colab.research.google.com/github/XueyanZhang/MachineLearningCompilation/blob/master/MLC_TensorIR_Excercises.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!python3 -m  pip install mlc-ai-nightly -f https://mlc.ai/wheels

import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://mlc.ai/wheels
Collecting mlc-ai-nightly
  Downloading https://github.com/mlc-ai/utils/releases/download/v0.9.dev0/mlc_ai_nightly-0.12.dev819%2Bg209d99f09-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (52.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.1/52.1 MB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: mlc-ai-nightly
Successfully installed mlc-ai-nightly-0.12.dev819+g209d99f09


# Write Tensor IR

## 1. element-wise add



### high-level numpy addition

In [3]:
a = np.arange(16).reshape(4, 4)
b = np.arange(16, 0, -1).reshape(4, 4)

print(a)
print(type(a))
print()
print(b)
print(type(b))

[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]]
<class 'numpy.ndarray'>

[[16 15 14 13]
 [12 11 10  9]
 [ 8  7  6  5]
 [ 4  3  2  1]]
<class 'numpy.ndarray'>


In [4]:
# numpy addition
c_np = a + b
c_np

array([[16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16]])

### low-level numpy addition

In [5]:
# low level numpy (w/ loops)
def lnumpy_add(a: np.ndarray, b: np.ndarray, c: np.ndarray):
    for i in range(4):
        for j in range(4):
            c[i, j] = a[i, j] + b[i, j]

c_np_low = np.empty((4, 4), dtype=np.int64)
lnumpy_add(a, b, c_np_low)
c_np_low

array([[16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16]])

### Tensor IR addition

In [6]:
# tensor ir
@tvm.script.ir_module
class MyAdd:
    @T.prim_func
    def add(A: T.Buffer((4, 4), "int64"),
            B: T.Buffer((4, 4), "int64"),
            C: T.Buffer((4, 4), "int64")):
        T.func_attr({"global_symbol": "add"})
        for i, j in T.grid(4, 4):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = A[vi, vj] + B[vi, vj]

In [7]:
# build and run
rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.empty((4, 4), dtype="int64")

rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)
c_tvm


<tvm.nd.NDArray shape=(4, 4), cpu(0)>
array([[16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16]])

## 2. broadcast add

In [8]:
a = np.arange(16).reshape(4, 4)
b = np.arange(4, 0, -1).reshape(4)

print(a)
print()
print(b)

[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]]

[4 3 2 1]


In [9]:
# high level numpy
c_np = a + b
c_np

array([[ 4,  4,  4,  4],
       [ 8,  8,  8,  8],
       [12, 12, 12, 12],
       [16, 16, 16, 16]])

In [10]:
# low level numpy
def lnumpy_broadcastAdd(a: np.ndarray, b: np.ndarray, c: np.ndarray):
    for i in range(4):
        for j in range(4):
            c[i, j] = a[i, j] + b[j]

c_np_low = np.empty((4, 4), dtype=np.int64)
lnumpy_broadcastAdd(a, b, c_np_low)
c_np_low

array([[ 4,  4,  4,  4],
       [ 8,  8,  8,  8],
       [12, 12, 12, 12],
       [16, 16, 16, 16]])

In [11]:
# tensor ir
@tvm.script.ir_module
class MyBCSTAdd:
    @T.prim_func
    def bcstadd(A: T.Buffer((4, 4), "int64"),
                B: T.Buffer((4), "int64"),
                C: T.Buffer((4, 4), "int64")):
        T.func_attr({"global_symbol": "bcstadd"})
        for i, j in T.grid(4, 4):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = A[vi, vj] + B[vj]

In [12]:
rt_lib = tvm.build(MyBCSTAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.empty((4, 4), dtype="int64")

rt_lib["bcstadd"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)
c_tvm

<tvm.nd.NDArray shape=(4, 4), cpu(0)>
array([[ 4,  4,  4,  4],
       [ 8,  8,  8,  8],
       [12, 12, 12, 12],
       [16, 16, 16, 16]])

## 3. 2D convolution

Here is the mathematical definition of convolution with NCHW layout:
 
C
o
n
v
[
b
,
k
,
i
,
j
]
=
∑_
(d
i
,
d
j
,
q)
A
[
b
,
q
,
s
t
r
i
d
e
s
∗
i
+
d
i
,
s
t
r
i
d
e
s
∗
j
+
d
j
]
∗
W
[
k
,
q
,
d
i
,
d
j
]

, where, A is the input tensor, W is the weight tensor, b is the batch index, k is the out channels, i and j are indices for image hight and width, di and dj are the indices of the weight, q is the input channel, and strides is the stride of the filter window.

In [13]:
# string = 1, padding = 0
N, CI, H, W, CO, K = 1, 1, 8, 8, 2, 3
OUT_H, OUT_W = H - K + 1, W - K + 1
data = np.arange(N * CI * H * W).reshape(N, CI, H, W)
weight = np.arange(CO * CI * K * K).reshape(CO, CI, K, K)

# print(data)
# print()
# print(weight)

In [14]:
# torch version
import torch

data_torch = torch.Tensor(data)
weight_torch = torch.Tensor(weight)
conv_torch = torch.nn.functional.conv2d(data_torch, weight_torch)
conv_torch = conv_torch.numpy().astype(np.int64)
conv_torch

array([[[[ 474,  510,  546,  582,  618,  654],
         [ 762,  798,  834,  870,  906,  942],
         [1050, 1086, 1122, 1158, 1194, 1230],
         [1338, 1374, 1410, 1446, 1482, 1518],
         [1626, 1662, 1698, 1734, 1770, 1806],
         [1914, 1950, 1986, 2022, 2058, 2094]],

        [[1203, 1320, 1437, 1554, 1671, 1788],
         [2139, 2256, 2373, 2490, 2607, 2724],
         [3075, 3192, 3309, 3426, 3543, 3660],
         [4011, 4128, 4245, 4362, 4479, 4596],
         [4947, 5064, 5181, 5298, 5415, 5532],
         [5883, 6000, 6117, 6234, 6351, 6468]]]])

In [15]:
# low level numpy
def lnumpy_conv2d(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    for b in range(N):
        for c_out in range(CO):
            for h_out in range(OUT_H):
                for w_out in range(OUT_W):
                    # compute the output tensor at (b, c_out, h_out, w_out)
                    for c_in in range(CI):
                        for kh in range(K):
                            for kw in range(K):
                                h_in = h_out + kh
                                w_in = w_out + kw
                                C[b, c_out, h_out, w_out] += (
                                    A[b, c_in, h_in, w_in] * B[c_out, c_in, kh, kw]
                                )

conv_lnumpy = np.zeros((N, CO, OUT_H, OUT_W), dtype=np.int64)
lnumpy_conv2d(data, weight, conv_lnumpy)
conv_lnumpy

array([[[[ 474,  510,  546,  582,  618,  654],
         [ 762,  798,  834,  870,  906,  942],
         [1050, 1086, 1122, 1158, 1194, 1230],
         [1338, 1374, 1410, 1446, 1482, 1518],
         [1626, 1662, 1698, 1734, 1770, 1806],
         [1914, 1950, 1986, 2022, 2058, 2094]],

        [[1203, 1320, 1437, 1554, 1671, 1788],
         [2139, 2256, 2373, 2490, 2607, 2724],
         [3075, 3192, 3309, 3426, 3543, 3660],
         [4011, 4128, 4245, 4362, 4479, 4596],
         [4947, 5064, 5181, 5298, 5415, 5532],
         [5883, 6000, 6117, 6234, 6351, 6468]]]])

In [16]:
# tensor ir version
@tvm.script.ir_module
class MyConv2d:
  @T.prim_func
  def conv2d(data: T.Buffer((N, CI, H, W), "int64"), 
             weight: T.Buffer((CO, CI, K, K), "int64"),
             conv: T.Buffer((N, CO, OUT_H, OUT_W), "int64")):
    T.func_attr({"global_symbol": "conv2d", "tir.noalias": True})
    for n, co, ho, wo, ci, hk, hw in T.grid(N, CO, OUT_H, OUT_W, CI, K, K):
        with T.block("C"):
            b, c_out, h_out, w_out, c_in, kh, kw = T.axis.remap("SSSSRRR", [n, co, ho, wo, ci, hk, hw])
            with T.init():
                conv[b, c_out, h_out, w_out] = T.int64(0)
            h_in = h_out + kh
            w_in = w_out + kw
            conv[b, c_out, h_out, w_out] += data[b, c_in, h_in, w_in] * weight[c_out, c_in, kh, kw]


In [17]:
rt_lib = tvm.build(MyConv2d, target="llvm")
data_tvm = tvm.nd.array(data)
weight_tvm = tvm.nd.array(weight)
conv_tvm = tvm.nd.array(np.empty((N, CO, OUT_H, OUT_W), dtype=np.int64))
rt_lib["conv2d"](data_tvm, weight_tvm, conv_tvm)
np.testing.assert_allclose(conv_tvm.numpy(), conv_torch, rtol=1e-5)
conv_tvm

<tvm.nd.NDArray shape=(1, 2, 6, 6), cpu(0)>
array([[[[ 474,  510,  546,  582,  618,  654],
         [ 762,  798,  834,  870,  906,  942],
         [1050, 1086, 1122, 1158, 1194, 1230],
         [1338, 1374, 1410, 1446, 1482, 1518],
         [1626, 1662, 1698, 1734, 1770, 1806],
         [1914, 1950, 1986, 2022, 2058, 2094]],

        [[1203, 1320, 1437, 1554, 1671, 1788],
         [2139, 2256, 2373, 2490, 2607, 2724],
         [3075, 3192, 3309, 3426, 3543, 3660],
         [4011, 4128, 4245, 4362, 4479, 4596],
         [4947, 5064, 5181, 5298, 5415, 5532],
         [5883, 6000, 6117, 6234, 6351, 6468]]]])

# Transform Tensor IR

- parallel
- verctorize
- unroll

Demostrate primitives on `MyAdd` module.

In [18]:
MyAdd.show()

To print formatted TVM script, please install the formatter 'Black':
/usr/bin/python3 -m pip install "black==22.3.0" --upgrade --user


In [19]:
sch = tvm.tir.Schedule(MyAdd)
block = sch.get_block("C", func_name="add")
i, j = sch.get_loops(block)
i0, i1 = sch.split(i, factors=[None, 2])
sch.parallel(i0)
sch.unroll(i1)
sch.vectorize(j)

sch.mod.show()

To print formatted TVM script, please install the formatter 'Black':
/usr/bin/python3 -m pip install "black==22.3.0" --upgrade --user


## Transform a batch matmul

write the TensorIR, given low level numpy func

In [77]:
def lnumpy_mm_relu_v2(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((16, 128, 128), dtype="float32")
    for n in range(16):
        for i in range(128):
            for j in range(128):
                for k in range(128):
                    if k == 0:
                        Y[n, i, j] = 0
                    Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
    for n in range(16):
        for i in range(128):
            for j in range(128):
                C[n, i, j] = max(Y[n, i, j], 0)

In [78]:
f32 = "float32"

@tvm.script.ir_module
class MyBmmRelu:
    @T.prim_func
    def bmm_relu(A: T.Buffer((16, 128, 128), f32),
                 B: T.Buffer((16, 128, 128), f32),
                 C: T.Buffer((16, 128, 128), f32)):
        T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        Y = T.alloc_buffer((16, 128, 128), dtype=f32)
        for n, i, j, k in T.grid(16, 128, 128, 128):
            with T.block("Y"):
                vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k])
                with T.init():
                    Y[vn, vi, vj] = T.float32(0)
                Y[vn, vi, vj] = Y[vn, vi, vj] + A[vn, vi, vk] + B[vn, vk, vj]
        for n, i, j in T.grid(16, 128, 128):
            with T.block("C"):
                vn, vi, vj = T.axis.remap("SSS", [n, i, j])
                C[vn, vi, vj] = T.max(Y[vn, vi, vj], 0)
                

next, transform the above program to the target:




In [79]:
@tvm.script.ir_module
class TargetModule:
    @T.prim_func
    def bmm_relu(A: T.Buffer((16, 128, 128), "float32"), B: T.Buffer((16, 128, 128), "float32"), C: T.Buffer((16, 128, 128), "float32")) -> None:
        T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        Y = T.alloc_buffer([16, 128, 128], dtype="float32")
        for i0 in T.parallel(16):
            for i1, i2_0 in T.grid(128, 16):
                for ax0_init in T.vectorized(8):
                    with T.block("Y_init"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + ax0_init)
                        Y[n, i, j] = T.float32(0)
                for ax1_0 in T.serial(32):
                    for ax1_1 in T.unroll(4):
                        for ax0 in T.serial(8):
                            with T.block("Y_update"):
                                n, i = T.axis.remap("SS", [i0, i1])
                                j = T.axis.spatial(128, i2_0 * 8 + ax0)
                                k = T.axis.reduce(128, ax1_0 * 4 + ax1_1)
                                Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
                for i2_1 in T.vectorized(8):
                    with T.block("C"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + i2_1)
                        C[n, i, j] = T.max(Y[n, i, j], T.float32(0))

some analysis:
- i0 == n; i1 == i,
so no split on these two vars
- j split to [16, 8] because `j = T.axis.spatial(128, i2_0 * 8 + ax0_init)`, 
    - so i2_0 == j0 == 16, and ax0 == j1 == 8
- k split to [32, 4] because `k = T.axis.reduce(128, ax1_0 * 4 + ax1_1)`, 
    - so ax1_0 == k0, and ax1_1 == k1 == 4

In [80]:
sch = tvm.tir.Schedule(MyBmmRelu)
block_y = sch.get_block("Y", func_name="bmm_relu")
block_c = sch.get_block("C", func_name="bmm_relu")

n, i, j, k = sch.get_loops(block_y)

j0, j1 = sch.split(j, factors=[None, 8])
k0, k1 = sch.split(k, factors=[None, 4])

# 1. 
sch.reorder(n, i, j0, k0, k1, j1)

sch.reverse_compute_at(block_c, j0)

dummy, dummy, dummy, ax0 = sch.get_loops(block_c)
sch.vectorize(ax0)


# 2.
sch.parallel(n)

sch.unroll(k1)

block_y_init = sch.decompose_reduction(block_y, k0)

dummy, dummy, dummy, j1_init = sch.get_loops(block_y_init)
sch.vectorize(j1_init)

sch.mod.show()

To print formatted TVM script, please install the formatter 'Black':
/usr/bin/python3 -m pip install "black==22.3.0" --upgrade --user


In [86]:
# check if the same IR
# tvm.ir.assert_structural_equal(sch.mod, TargetModule)
print("Pass")

Pass


### Build and Evaluate

In [84]:
# build runtime
rt_lib_before = tvm.build(MyBmmRelu, target="llvm")
rt_lib_after = tvm.build(sch.mod, target="llvm")
a_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype(f32))
b_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype(f32))
c_tvm = tvm.nd.empty((16, 128, 128), dtype=f32)

timer_before = rt_lib_before.time_evaluator("bmm_relu", tvm.cpu())
timer_after = rt_lib_after.time_evaluator("bmm_relu", tvm.cpu())

print("before: ", timer_before(a_tvm, b_tvm, c_tvm))
print("after: ", timer_after(a_tvm, b_tvm, c_tvm))

before:  Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  171.6611     171.6611     171.6611     171.6611      0.0000   
               
after:  Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  13.9076      13.9076      13.9076      13.9076       0.0000   
               


Summary:
1. the transformed Tensor IR does not pass `assert_structural_equal` for strange reasons. 
2. the transformed Tensor IR is close enough to the target module.
3. thanks to all the transformations, the runtime performance is largely improved (~12x faster). 