In [1]:
import IPython
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T


***simple add operation***

In [16]:
# init data
a_np = np.arange(16).reshape(4, 4)
b_np = np.arange(16).reshape(4, 4)

a = np.arange(16).reshape(4, 4)
b = np.arange(16).reshape(4, 4)
# numpy version
c_np = a_np + b_np
c_np

array([[ 0,  2,  4,  6],
       [ 8, 10, 12, 14],
       [16, 18, 20, 22],
       [24, 26, 28, 30]])

In [17]:
# low-level numpy version
def lnumpy_add(a: np.ndarray, b: np.ndarray, c: np.ndarray):
  for i in range(4):
    for j in range(4):
      c[i, j] = a[i, j] + b[i, j]
c_lnumpy = np.empty((4, 4), dtype=np.int64)
lnumpy_add(a, b, c_lnumpy)
c_lnumpy

array([[ 0,  2,  4,  6],
       [ 8, 10, 12, 14],
       [16, 18, 20, 22],
       [24, 26, 28, 30]])

In [20]:
# TensorIR version
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer((4, 4), "int64"),
          B: T.Buffer((4, 4), "int64"),
          C: T.Buffer((4, 4), "int64")):
    T.func_attr({"global_symbol": "add"})
    for i, j in T.grid(4, 4):
      with T.block("C"):
        vi = T.axis.spatial(4, i)
        vj = T.axis.spatial(4, j)
        C[vi, vj] = A[vi, vj] + B[vi, vj]

rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)

In [24]:
# init data
a = np.arange(16).reshape(4, 4)
b = np.arange(4, 0, -1).reshape(4)
a, b

(array([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]]),
 array([4, 3, 2, 1]))

In [25]:
c_np = a + b
c_np

array([[ 4,  4,  4,  4],
       [ 8,  8,  8,  8],
       [12, 12, 12, 12],
       [16, 16, 16, 16]])

In [32]:
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer((4,4), "int64"),
          B: T.Buffer((4, ), "int64"), 
          C: T.Buffer((4,4), "int64")):
    T.func_attr({"global_symbol": "add", "tir.noalias": True})
    for i, j in T.grid(4,4):
      with T.block("C"):
        vi = T.axis.spatial(4, i)
        vj = T.axis.spatial(4, j)
        with T.init():
          C[vi, vj] = 0
        C[vi, vj] = C[vi, vj] + A[vi, vj] + B[vj]
    

    

rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
#print(a_tvm.shape, b_tvm.shape)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)

**2-D Convolution**

In [33]:
N, CI, H, W, CO, K = 1, 1, 8, 8, 2, 3
OUT_H, OUT_W = H - K + 1, W - K + 1
data = np.arange(N*CI*H*W).reshape(N, CI, H, W)
weight = np.arange(CO*CI*K*K).reshape(CO, CI, K, K)

In [35]:
# torch version
import torch

data_torch = torch.Tensor(data)
weight_torch = torch.Tensor(weight)
conv_torch = torch.nn.functional.conv2d(data_torch, weight_torch)
conv_torch = conv_torch.numpy().astype(np.int64)
conv_torch

array([[[[ 474,  510,  546,  582,  618,  654],
         [ 762,  798,  834,  870,  906,  942],
         [1050, 1086, 1122, 1158, 1194, 1230],
         [1338, 1374, 1410, 1446, 1482, 1518],
         [1626, 1662, 1698, 1734, 1770, 1806],
         [1914, 1950, 1986, 2022, 2058, 2094]],

        [[1203, 1320, 1437, 1554, 1671, 1788],
         [2139, 2256, 2373, 2490, 2607, 2724],
         [3075, 3192, 3309, 3426, 3543, 3660],
         [4011, 4128, 4245, 4362, 4479, 4596],
         [4947, 5064, 5181, 5298, 5415, 5532],
         [5883, 6000, 6117, 6234, 6351, 6468]]]])

In [49]:
@tvm.script.ir_module
class MyConv:
  @T.prim_func
  def conv(D: T.buffer((N, CI, H, W), "int64"),
           WEIGHT: T.buffer((CO, CI, K, K), "int64"),
           C: T.buffer((N, CO, OUT_H, OUT_W), "int64")):
    T.func_attr({"global_symbol": "conv", "tir.noalias": True})
    for b, k, i, j in T.grid(N, CO, OUT_H, OUT_W):
      with T.block("C_init"):
        vb = T.axis.spatial(N, b)
        vk = T.axis.spatial(CO, k)
        vi = T.axis.spatial(OUT_H, i)
        vj = T.axis.spatial(OUT_W, j)
        with T.init():
          C[vb, vk, vi, vj] = 0
    for b, k, i, j, di, dj, q  in T.grid(N, CO, OUT_H, OUT_W, K, K, CI):
      with T.block("C"):
        vb = T.axis.spatial(N, b)
        vk = T.axis.spatial(CO, k)
        vi = T.axis.spatial(OUT_H, i)
        vj = T.axis.spatial(OUT_W, j)
        vdi = T.axis.reduce(K, di)
        vdj = T.axis.reduce(K, dj)
        vq = T.axis.reduce(CI, q)
        
        C[vb, vk, vi, vj] = C[vb, vk, vi, vj] + D[vb, vq, vi + vdi, vj + vdj] * WEIGHT[vk, vq, vdi, vdj]
          
              
rt_lib = tvm.build(MyConv, target="llvm")
data_tvm = tvm.nd.array(data)
weight_tvm = tvm.nd.array(weight)
conv_tvm = tvm.nd.array(np.empty((N, CO, OUT_H, OUT_W), dtype=np.int64))
rt_lib["conv"](data_tvm, weight_tvm, conv_tvm)
print(conv_tvm.numpy())

np.testing.assert_allclose(conv_tvm.numpy(), conv_torch, rtol=1e-5)

[[[[ 474  510  546  582  618  654]
   [ 762  798  834  870  906  942]
   [1050 1086 1122 1158 1194 1230]
   [1338 1374 1410 1446 1482 1518]
   [1626 1662 1698 1734 1770 1806]
   [1914 1950 1986 2022 2058 2094]]

  [[1203 1320 1437 1554 1671 1788]
   [2139 2256 2373 2490 2607 2724]
   [3075 3192 3309 3426 3543 3660]
   [4011 4128 4245 4362 4479 4596]
   [4947 5064 5181 5298 5415 5532]
   [5883 6000 6117 6234 6351 6468]]]]


**batch matmul with parallel, vectorize, unroll**

In [50]:
def lnumpy_mm_relu_v2(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((16, 128, 128), dtype="float32")
    for n in range(16):
        for i in range(128):
            for j in range(128):
                for k in range(128):
                    if k == 0:
                        Y[n, i, j] = 0
                    Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
    for n in range(16):
        for i in range(128):
            for j in range(128):
                C[n, i, j] = max(Y[n, i, j], 0)

In [114]:
BS, AH, AW, BH, BW = 16, 128, 128, 128, 128

@tvm.script.ir_module
class MyBmmRelu:
  @T.prim_func
  def bmm_relu(A: T.buffer((BS, AH, AW), "float32"), 
               B: T.buffer((BS, BH, BW), "float32"),
               C: T.buffer((BS, AH, BW), "float32")):
    T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
    Y = T.alloc_buffer((BS, AH, BW), dtype="float32")
    
    for n, i, j, k in T.grid(BS, AH, BW, AW):
        with T.block("Y"):
            vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k])
            with T.init():
               Y[vn, vi, vj] = 0
            Y[vn, vi, vj] += A[vn, vi, vk] * B[vn, vk, vj]
    for n, i, j in T.grid(BS, AH, BW):
        with T.block("C"):
            vn, vi, vj = T.axis.remap("SSS", [n, i, j])
            with T.init():
                C[vn, vi, vj] = 0
            C[vn, vi, vj] = T.max(0, Y[vn, vi, vj])
            

           
#parallelize across n for each of the 16 iterations
#split the j loop into 16 and 8 



    
sch = tvm.tir.Schedule(MyBmmRelu)

block_y = sch.get_block("Y", func_name="bmm_relu")
block_c = sch.get_block("C", func_name="bmm_relu")

# Get loops
i0, i1, i2, k = sch.get_loops(block_y)

# Split loops
i2_0, i2_1 = sch.split(i2, factors=[16, 8])
k_0, k_1 = sch.split(k, factors=[32, 4])

# Reorder loops
sch.reorder(i0, i1, i2_0, k_0, k_1, i2_1)

# Parallelize outer loop
sch.parallel(i0)

# Vectorize inner loops
sch.vectorize(i2_1)

# Decompose reduction
init = sch.decompose_reduction(block_y, k_0)

# Vectorize initialization
sch.vectorize(sch.get_loops(init)[-1])

# Unroll k_1 loop
sch.unroll(k_1)

# Compute C at appropriate level
sch.reverse_compute_at(block_c, i2_0)

# Vectorize C computation
sch.vectorize(sch.get_loops(block_c)[-1])




IPython.display.Code(sch.mod.script(), language="python")
# Also please validate your result

In [115]:
before_rt_lib = tvm.build(MyBmmRelu, target="llvm")
after_rt_lib = tvm.build(sch.mod, target="llvm")
a_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
b_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
c_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
after_rt_lib["bmm_relu"](a_tvm, b_tvm, c_tvm)
before_timer = before_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("Before transformation:")
print(before_timer(a_tvm, b_tvm, c_tvm))

f_timer = after_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("After transformation:")
print(f_timer(a_tvm, b_tvm, c_tvm))

Before transformation:
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  26.1790      26.1790      26.1790      26.1790       0.0000                  
After transformation:
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
   1.1526       1.1526       1.1526       1.1526       0.0000                  
