<a href="https://colab.research.google.com/github/Sanzo00/mlc-summer22/blob/master/2_5_tensorir_exercises.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!python3 -m  pip install mlc-ai-nightly -f https://mlc.ai/wheels

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://mlc.ai/wheels


In [2]:
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T
import time

In [3]:
import IPython

def code2html(code):
    """Helper function to use pygments to turn the code string into highlighted html."""
    import pygments
    from pygments.lexers import Python3Lexer
    from pygments.formatters import HtmlFormatter
    formatter = HtmlFormatter()
    html = pygments.highlight(code, Python3Lexer(), formatter)
    return "<style>%s</style>%s\n" % (formatter.get_style_defs(".highlight"), html)

# IPython.display.Code(MyModule.script(), language="python")
# IPython.display.HTML(code2html(MyModule.script()))

In [4]:
# high level sum
a = np.arange(16).reshape(4, 4)
b = np.arange(16, 0, -1).reshape(4, 4)
c_np = a + b
print(c_np)

[[16 16 16 16]
 [16 16 16 16]
 [16 16 16 16]
 [16 16 16 16]]


In [5]:
# low level sum
def lnumpy_add(a: np.ndarray, b: np.ndarray, c: np.ndarray):
  for i in range(4):
    for j in range(4):
      c[i, j] = a[i, j] + b[i, j]

In [6]:
c_lnumpy = np.empty([4, 4], dtype=np.int64)
lnumpy_add(a, b, c_lnumpy)
print(c_lnumpy)

[[16 16 16 16]
 [16 16 16 16]
 [16 16 16 16]
 [16 16 16 16]]


In [7]:
# tensorIR sum
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer[(4, 4), "int64"],
          B: T.Buffer[(4, 4), "int64"],
          C: T.Buffer[(4, 4), "int64"]):
    T.func_attr({"global_symbol": "add"})
    for i, j in T.grid(4, 4):
      with T.block("C"):
        vi = T.axis.spatial(4, i)
        vj = T.axis.spatial(4, j)
        C[vi, vj] = A[vi, vj] + B[vi, vj]


In [8]:
rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty([4, 4], dtype=np.int64))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)
print("test sum is good!")

test sum is good!


In [9]:
# 1.2 broadcast
a = np.arange(16).reshape(4, 4)
b = np.arange(4, 0, -1).reshape(4)
c_np = a + b
print(c_np)

[[ 4  4  4  4]
 [ 8  8  8  8]
 [12 12 12 12]
 [16 16 16 16]]


In [10]:
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer[(4, 4), "int64"],
          B: T.Buffer[(4), "int64"],
          C: T.Buffer[(4, 4), "int64"]):
    T.func_attr({"global_symbol": "add", "tir.noalias": True})
    for i, j in T.grid(4, 4):
      with T.block("C"):
        vi = T.axis.spatial(4, i)
        vj = T.axis.spatial(4, j)
        C[vi, vj] = A[vi, vj] + B[vj]


In [11]:
rt_lib = tvm.build(MyAdd, target="llvm")    
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_np, c_tvm.numpy(), rtol=1e-5)
print("test broadcast sum is good!")

test broadcast sum is good!


In [12]:
# 2-d convolution
N, CI, H, W, CO, K = 1, 1, 8, 8, 2, 3
OUT_H, OUT_W = H - K + 1, W - K + 1
data = np.arange(N*CI*H*W).reshape(N, CI, H, W)
weight = np.arange(CO*CI*K*K).reshape(CO, CI, K, K)

In [13]:
import torch
data_torch = torch.Tensor(data)
weight_torch =torch.Tensor(weight)
conv_torch = torch.nn.functional.conv2d(data_torch, weight_torch)
conv_torch = conv_torch.numpy().astype(np.int64)
print(conv_torch)

[[[[ 474  510  546  582  618  654]
   [ 762  798  834  870  906  942]
   [1050 1086 1122 1158 1194 1230]
   [1338 1374 1410 1446 1482 1518]
   [1626 1662 1698 1734 1770 1806]
   [1914 1950 1986 2022 2058 2094]]

  [[1203 1320 1437 1554 1671 1788]
   [2139 2256 2373 2490 2607 2724]
   [3075 3192 3309 3426 3543 3660]
   [4011 4128 4245 4362 4479 4596]
   [4947 5064 5181 5298 5415 5532]
   [5883 6000 6117 6234 6351 6468]]]]


In [14]:
# 2-d convolution (python)
c_np = np.empty(conv_torch.shape, dtype=np.int64)

for i in range(N):
  for j in range(CO):
    for h in range(OUT_H):
      for w in range(OUT_W):
        c_np[i, j, h, w] = 0
        for ci in range(CI):
          for k1 in range(K):
            for k2 in range(K):
              c_np[i, j, h, w] += data[i, ci, h+k1, w+k2] * weight[j, ci, k1, k2]

In [15]:
np.testing.assert_allclose(c_np, conv_torch, rtol=1e-5)
print("test 2d-conv(python) is good!")

test 2d-conv(python) is good!


In [16]:
# 2-d convolution (TensorIR)
@tvm.script.ir_module
class MyConv:
  @T.prim_func
  # def conv(A: T.Buffer[(N, CI, H, W), "int64"],
  #          B: T.Buffer[(CO, CI, K, K), "int64"],
  #          C: T.Buffer[(N, CO, OUT_H, OUT_W), "int64"]):
  def conv(A: T.Buffer[(N, CI, 8, 8), "int64"],
           B: T.Buffer[(CO, CI, K, K), "int64"],
           C: T.Buffer[(N, CO, OUT_H, OUT_W), "int64"]):           
    T.func_attr({"global_symbol": "conv", "tir.noalias": True})
    for i, j, h, w, ci, k1, k2, in T.grid(N, CO, OUT_H, OUT_W, CI, K, K):
      with T.block("C"):
        vi, vj, vh, vw = T.axis.remap("SSSS", [i, j, h, w])
        vci, vk1, vk2 = T.axis.remap("SRR", [ci, k1, k2])
        with T.init():
          C[vi, vj, vh, vw] = T.int64(0)
        C[vi, vj, vh, vw] += A[vi, vci, vh+vk1, vw+vk2] * B[vj, vci, vk1, vk2]

In [17]:
rt_lib = tvm.build(MyConv, target="llvm")
data_tvm = tvm.nd.array(data)
weight_tvm = tvm.nd.array(weight)
conv_tvm = tvm.nd.array(np.empty((N, CO, OUT_H, OUT_W), dtype=np.int64))
rt_lib["conv"](data_tvm, weight_tvm, conv_tvm)
np.testing.assert_allclose(conv_tvm.numpy(), conv_torch, rtol=1e-5)
print('test 2d-conv(tensorIR) is good!')

test 2d-conv(tensorIR) is good!


In [18]:
# 2.1 parallel, unroll, vectorize
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer[(4, 4), "int64"],
          B: T.Buffer[(4, 4), "int64"],
          C: T.Buffer[(4, 4), "int64"]):
    T.func_attr({"global_symbol": "add"})
    for i, j in T.grid(4, 4):
      with T.block("C"):
        vi = T.axis.spatial(4, i)
        vj = T.axis.spatial(4, j)
        C[vi, vj] = A[vi, vj] + B[vi, vj]

In [19]:
sch = tvm.tir.Schedule(MyAdd)
block = sch.get_block("C", func_name="add")
i, j = sch.get_loops(block)
i0, i1 = sch.split(i, factors=[2, 2])
sch.parallel(i0)
sch.unroll(i1)
sch.vectorize(j)
IPython.display.HTML(code2html(sch.mod.script()))

In [20]:
# 2.2 transform matmul
def lnumpy_mm_relu_v2(A: np.ndarray, B: np.ndarray, C: np.ndarray):
  Y = np.empty((16, 128, 128), dtype="float32")
  for n in range(16):
    for i in range(128):
      for j in range(128):
        for k in range(128):
          if k == 0:
              Y[n, i, j] = 0
          Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
  
  for n in range(16):
    for i in range(128):
      for j in range(128):
        C[n, i, j] = max(Y[n, i, j], 0)

In [21]:
@tvm.script.ir_module
class MyBmmRelu:
  @T.prim_func
  def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"],
               B: T.Buffer[(16, 128, 128), "float32"],
               C: T.Buffer[(16, 128, 128), "float32"]):
    T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
    Y = T.alloc_buffer((16, 128, 128), dtype="float32")
    # for n, i, j, k in T.grid(16, 128, 128, 128):
    for i0, i1, i2, i3 in T.grid(16, 128, 128, 128):
      with T.block("Y"):
        # vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k])
        n, i, j, k = T.axis.remap("SSSR", [i0, i1, i2, i3])
        with T.init():
          Y[n, i, j] = T.float32(0)
        Y[n, i, j] += A[n, i, k] * B[n, k, j]

    for i0, i1, i2 in T.grid(16, 128, 128):
      with T.block("C"):
        n, i, j = T.axis.remap("SSS", [i0, i1, i2])
        C[n, i, j] = T.max(Y[n, i, j], T.float32(0))

In [22]:
sch = tvm.tir.Schedule(MyBmmRelu)
IPython.display.HTML(code2html(sch.mod.script()))

In [23]:
# target tensorIR
@tvm.script.ir_module
class TargetModule:
    @T.prim_func
    def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"], B: T.Buffer[(16, 128, 128), "float32"], C: T.Buffer[(16, 128, 128), "float32"]) -> None:
        T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        Y = T.alloc_buffer([16, 128, 128], dtype="float32")
        for i0 in T.parallel(16):
            for i1, i2_0 in T.grid(128, 16):
                for ax0_init in T.vectorized(8):
                    with T.block("Y_init"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + ax0_init)
                        Y[n, i, j] = T.float32(0)
                for ax1_0 in T.serial(32):
                    for ax1_1 in T.unroll(4):
                        for ax0 in T.serial(8):
                            with T.block("Y_update"):
                                n, i = T.axis.remap("SS", [i0, i1])
                                j = T.axis.spatial(128, i2_0 * 8 + ax0)
                                k = T.axis.reduce(128, ax1_0 * 4 + ax1_1)
                                Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
                for i2_1 in T.vectorized(8):
                    with T.block("C"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + i2_1)
                        C[n, i, j] = T.max(Y[n, i, j], T.float32(0))

In [24]:
sch = tvm.tir.Schedule(MyBmmRelu)
# TODO: transformations
# Hints: you can use
# `IPython.display.Code(sch.mod.script(), language="python")`
# or `print(sch.mod.script())`
# to show the current program at any time during the transformation.

# Step 1. Get blocks
Y = sch.get_block("Y", func_name="bmm_relu")

# Step 2. Get loops
b, i, j, k = sch.get_loops(Y)

In [25]:
# Step 3. Organize the loops
k0, k1 = sch.split(k, [32, 4])
j0, j1 = sch.split(j, [16, 8])
sch.reorder(j0, k0, k1, j1)
IPython.display.HTML(code2html(sch.mod.script()))

In [26]:
# sch.compute_at/reverse_compute_at(...)
C = sch.get_block("C", func_name="bmm_relu")
sch.reverse_compute_at(C, j0)
IPython.display.HTML(code2html(sch.mod.script()))

In [27]:
# Step 4. decompose reduction
sch.parallel(b) # decompose_reduction will break block, so ahead run it
Y_init = sch.decompose_reduction(Y, k0)
IPython.display.HTML(code2html(sch.mod.script()))

In [28]:
# Step 5. vectorize / parallel / unroll
Yn, Yb, Yj0, Yj1 = sch.get_loops(Y_init)
_, _, _, Cj1, = sch.get_loops(C)
sch.vectorize(Yj1)
sch.vectorize(Cj1)
# sch.parallel(b) # must before decompose_reduction 
sch.unroll(k1)
IPython.display.HTML(code2html(sch.mod.script()))

In [29]:
tvm.ir.assert_structural_equal(sch.mod, TargetModule)
print("Pass: transform equal TargetModule")

Pass: transform equal TargetModule


In [30]:
# eval performance
before_rt_lib = tvm.build(MyBmmRelu, target="llvm")
after_rt_lib = tvm.build(sch.mod, target="llvm")
a_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
b_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
c_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
after_rt_lib["bmm_relu"](a_tvm, b_tvm, c_tvm)
before_timer = before_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("Before transformation:")
print(before_timer(a_tvm, b_tvm, c_tvm))

Before transformation:
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  57.3902      57.3902      57.3902      57.3902       0.0000   
               


In [31]:
f_timer = after_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("After transformation:")
print(f_timer(a_tvm, b_tvm, c_tvm))

After transformation:
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  15.3903      15.3903      15.3903      15.3903       0.0000   
               
