# [2.5. Exercises for TensorIR](https://mlc.ai/chapter_tensor_program/tensorir_exercises.html)

Solutions for the proposed exercises:

In [1]:
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T

In [57]:
import torch

In [2]:
# init data
a = np.arange(16).reshape(4, 4)
b = np.arange(16, 0, -1).reshape(4, 4)

In [3]:
# numpy version
c_np = a + b
c_np

array([[16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16]])

In [4]:
# low-level numpy version
def lnumpy_add(a: np.ndarray, b: np.ndarray, c: np.ndarray):
  for i in range(4):
    for j in range(4):
      c[i, j] = a[i, j] + b[i, j]
c_lnumpy = np.empty((4, 4), dtype=np.int64)
lnumpy_add(a, b, c_lnumpy)
c_lnumpy

array([[16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16]])

In [5]:
# TensorIR version
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer((4, 4), "int64"),
          B: T.Buffer((4, 4), "int64"),
          C: T.Buffer((4, 4), "int64")):
    T.func_attr({"global_symbol": "add"})
    for i, j in T.grid(4, 4):
      with T.block("C"):
        vi = T.axis.spatial(4, i)
        vj = T.axis.spatial(4, j)
        C[vi, vj] = A[vi, vj] + B[vi, vj]

rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)

# 2.5.1.2. Exercise 1: Broadcast Add

In [6]:
# init data
a = np.arange(16).reshape(4, 4)
b = np.arange(4, 0, -1).reshape(4)

In [7]:
# numpy version
c_np = a + b
c_np

array([[ 4,  4,  4,  4],
       [ 8,  8,  8,  8],
       [12, 12, 12, 12],
       [16, 16, 16, 16]])

## Low level numpy

In [8]:
# low-level numpy version
def lnumpy_add_bc(a: np.ndarray, b: np.ndarray, c: np.ndarray):
  for i in range(4):
    for j in range(4):
      c[i, j] = a[i, j] + b[j]
c_lnumpy = np.empty((4, 4), dtype=np.int64)
lnumpy_add_bc(a, b, c_lnumpy)
c_lnumpy

array([[ 4,  4,  4,  4],
       [ 8,  8,  8,  8],
       [12, 12, 12, 12],
       [16, 16, 16, 16]])

## TensorIR

In [9]:
# TensorIR version
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer((4, 4), "int64"),
          B: T.Buffer((4), "int64"),
          C: T.Buffer((4, 4), "int64")):
    T.func_attr({"global_symbol": "add"})
    for i, j in T.grid(4, 4):
      with T.block("C"):
        vi = T.axis.spatial(4, i)
        vj = T.axis.spatial(4, j)
        C[vi, vj] = A[vi, vj] + B[vj]

rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)

# 2.5.1.3. Exercise 2: 2D Convolution

In [10]:
N, CI, H, W, CO, K = 1, 1, 8, 8, 2, 3
OUT_H, OUT_W = H - K + 1, W - K + 1
in1 = np.arange(N*CI*H*W).reshape(N, CI, H, W)
in2 = np.arange(CO*CI*K*K).reshape(CO, CI, K, K)

In [11]:
in1[0, 0, 3, 3]

27

In [12]:
in2

array([[[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8]]],


       [[[ 9, 10, 11],
         [12, 13, 14],
         [15, 16, 17]]]])

In [13]:
# torch version
data_torch = torch.Tensor(in1)
weight_torch = torch.Tensor(in2)
conv_torch = torch.nn.functional.conv2d(data_torch, weight_torch)
conv_torch = conv_torch.numpy().astype(np.int64)
conv_torch.shape

(1, 2, 6, 6)

## Low level numpy

In [14]:
# low-level numpy version
def lnumpy_conv2d(data: np.ndarray, weight: np.ndarray, H, W, K, CO):
  C = np.zeros([CO, OUT_W, OUT_H], dtype=int)
  print(data.shape)
  print(weight.shape)
  for co in range(CO):
    for dh in range(H-K+1):
      for dw in range(W-K+1):
        for r in range(K):
          for c in range(K):
            #s = s + data[0, 0, c, r] * weight[0, 0, c, r]
            C[co, dw, dh] = C[co, dw, dh] + data[0, 0, c+dw, r+dh] * weight[co, 0, c, r]
  return C

In [15]:
npconf = lnumpy_conv2d(in1, in2, H, W, K, CO)
npconf.shape
#np.testing.assert_allclose(npconv, conv_torch, rtol=1e-5)

(1, 1, 8, 8)
(2, 1, 3, 3)


(2, 6, 6)

## TensorIR

__TODO:__ generalize the loop in that each `0` is replaced by a respective parameter.

In [42]:
@tvm.script.ir_module
class MyConv:
   @T.prim_func
   def conv (A: T.Buffer((N, CI, H, W),      "int64"),
             B: T.Buffer((CO, CI, K, K),     "int64"),
             C: T.Buffer((N, CO, OUT_H, OUT_W), "int64")):
      T.func_attr({"global_symbol": "conv", "tir.noalias": True})
      for a, b, c, d, e, f, g in T.grid(N, CI, CO, OUT_H, OUT_W, K, K):
         with T.block("Z"):
            n  = T.axis.spatial(N, a)
            co = T.axis.spatial(CO, c)
            dw = T.axis.spatial(OUT_W, e)
            dh = T.axis.spatial(OUT_H, d)
            with T.init():
               C[n, co, dw, dh] = T.int64(0)
      for a, b, c, d, e, f, g in T.grid(N, CI, CO, OUT_H, OUT_W, K, K):
         with T.block("C"):
            n  = T.axis.spatial(N, a)
            ci = T.axis.spatial(CI, b)
            co = T.axis.spatial(CO, c)
            dw = T.axis.spatial(OUT_W, e)
            dh = T.axis.spatial(OUT_H, d)
            vr = T.axis.spatial(K, f)
            vc = T.axis.spatial(K, g)
            C[n, co, dw, dh] = C[n, co, dw, dh] + A[n, ci, vc+dw, vr+dh] * B[co, ci, vc, vr]

In [43]:
rt_lib = tvm.build(MyConv, target="llvm")
data_tvm = tvm.nd.array(in1)
weight_tvm = tvm.nd.array(in2)
conv_tvm = tvm.nd.array(np.empty((N, CO, OUT_H, OUT_W), dtype=np.int64))
rt_lib["conv"](data_tvm, weight_tvm, conv_tvm)
print(conv_tvm.numpy())
np.testing.assert_allclose(conv_tvm.numpy(), conv_torch, rtol=1e-5)

[[[[ 474  510  546  582  618  654]
   [ 762  798  834  870  906  942]
   [1050 1086 1122 1158 1194 1230]
   [1338 1374 1410 1446 1482 1518]
   [1626 1662 1698 1734 1770 1806]
   [1914 1950 1986 2022 2058 2094]]

  [[1203 1320 1437 1554 1671 1788]
   [2139 2256 2373 2490 2607 2724]
   [3075 3192 3309 3426 3543 3660]
   [4011 4128 4245 4362 4479 4596]
   [4947 5064 5181 5298 5415 5532]
   [5883 6000 6117 6234 6351 6468]]]]


# 2.5.2. Section 2: How to Transform TensorIR

## 2.5.2.1. Parallel, Vectorize and Unroll

In [46]:
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer((4, 4), "int64"),
          B: T.Buffer((4, 4), "int64"),
          C: T.Buffer((4, 4), "int64")):
    T.func_attr({"global_symbol": "add"})
    for i, j in T.grid(4, 4):
      with T.block("C"):
        vi = T.axis.spatial(4, i)
        vj = T.axis.spatial(4, j)
        C[vi, vj] = A[vi, vj] + B[vi, vj]

In [47]:
sch = tvm.tir.Schedule(MyAdd)
block = sch.get_block("C", func_name="add")
i, j = sch.get_loops(block)
i0, i1 = sch.split(i, factors=[2, 2])
sch.parallel(i0)
sch.unroll(i1)
sch.vectorize(j)

In [48]:
sch.show()

## 2.5.2.2. Exercise 3: Transform a batch matmul program

### Low level numpy

In [49]:
def lnumpy_mm_relu_v2(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((16, 128, 128), dtype="float32")
    for n in range(16):
        for i in range(128):
            for j in range(128):
                for k in range(128):
                    if k == 0:
                        Y[n, i, j] = 0
                    Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
    for n in range(16):
        for i in range(128):
            for j in range(128):
                C[n, i, j] = max(Y[n, i, j], 0)

In [53]:
@tvm.script.ir_module
class MyBmmRelu:
   @T.prim_func
   def bmm_relu(
      A: T.Buffer((16, 128, 128), "float32"), 
      B: T.Buffer((16, 128, 128), "float32"), 
      C: T.Buffer((16, 128, 128), "float32")
   ) -> None:
      T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
      Y = T.alloc_buffer((16, 128, 128), dtype="float32")
      for n, i, j, k in T.grid(16, 128, 128, 128):
         with T.block("Y"):
            vn = T.axis.spatial(16,  n)
            vi = T.axis.spatial(128, i)
            vj = T.axis.spatial(128, j)
            vk = T.axis.reduce(128,  k)
            with T.init():
               Y[vn, vi, vj] = T.float32(0)
            Y[vn, vi, vj] = Y[vn, vi, vj] + A[vn, vi, vk] * B[vn, vk, vj]
      for n, i, j in T.grid(16, 128, 128):
            with T.block("C"):
               vn = T.axis.spatial(16,  n)
               vi = T.axis.spatial(128, i)
               vj = T.axis.spatial(128, j)
               C[vn, vi, vj] = T.max(Y[vn, vi, vj], T.float32(0))

In [54]:
sch = tvm.tir.Schedule(MyBmmRelu)
sch.show()
# Also please validate your result

### Tests

In [56]:
n = 16
i = 128
k = 128
j = 128
input = np.arange(n*i*k).reshape(n, k, i)
mat2  = np.arange(n*k*j).reshape(n, k, j)

In [64]:
input_t = torch.Tensor(input)
mat2_t  = torch.Tensor(mat2)
bmm_torch = torch.bmm(input_t, mat2_t)
bmmr_torch = torch.relu(bmm_torch)
bmmr_torch.size()

torch.Size([16, 128, 128])

In [65]:
def test_it(klasse, in1, in2):
   rt_lib = tvm.build(klasse, target="llvm")
   input_tvm = tvm.nd.array(in1)
   mat2_tvm = tvm.nd.array(in2)
   bmmr_tvm = tvm.nd.array(np.empty((n, i, j), dtype=np.float32))
   rt_lib["bmm_relu"](input_tvm, mat2_tvm, bmmr_tvm)
   print(bmmr_tvm.numpy())
   np.testing.assert_allclose(bmmr_tvm.numpy(), bmmr_torch, rtol=1e-5)

In [66]:
test_it(MyBmmRelu, input_t, mat2_t)

[[[8.84326400e+07 8.84407520e+07 8.84488640e+07 ... 8.94486080e+07
   8.94568000e+07 8.94649120e+07]
  [2.21601792e+08 2.21626304e+08 2.21650752e+08 ... 2.24665744e+08
   2.24690368e+08 2.24714816e+08]
  [3.54770944e+08 3.54811936e+08 3.54852672e+08 ... 3.59882880e+08
   3.59923904e+08 3.59964640e+08]
  ...
  [1.67345684e+10 1.67366349e+10 1.67386941e+10 ... 1.69915945e+10
   1.69936527e+10 1.69957089e+10]
  [1.68677376e+10 1.68698194e+10 1.68718950e+10 ... 1.71268086e+10
   1.71288883e+10 1.71309548e+10]
  [1.70009068e+10 1.70029978e+10 1.70050888e+10 ... 1.72620308e+10
   1.72641198e+10 1.72662088e+10]]

 [[5.16269834e+10 5.16291011e+10 5.16311982e+10 ... 5.18901514e+10
   5.18922609e+10 5.18943580e+10]
  [5.20285880e+10 5.20307139e+10 5.20328151e+10 ... 5.22938040e+10
   5.22959258e+10 5.22980557e+10]
  [5.24301926e+10 5.24323226e+10 5.24344607e+10 ... 5.26974566e+10
   5.26995988e+10 5.27017247e+10]
  ...
  [1.01827576e+11 1.01831705e+11 1.01835866e+11 ... 1.02346736e+11
   1.02350