2.5. Exercises for TensorIR¶

In [43]:
import IPython
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T

2.5.1.1. Example: Element-wise Add

In [44]:
# init data
a = np.arange(16).reshape(4, 4)
b = np.arange(16, 0, -1).reshape(4, 4)

In [45]:
# numpy version
c_np = a + b
c_np

array([[16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16]])

In [46]:
# low-level numpy version
def lnumpy_add(a: np.ndarray, b: np.ndarray, c: np.ndarray):
  for i in range(4):
    for j in range(4):
      c[i, j] = a[i, j] + b[i, j]
c_lnumpy = np.empty((4, 4), dtype=np.int64)
lnumpy_add(a, b, c_lnumpy)
c_lnumpy

array([[16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16]])

In [47]:
# TensorIR version
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(A: T.Buffer[(4, 4), "int64"],
          B: T.Buffer[(4, 4), "int64"],
          C: T.Buffer[(4, 4), "int64"]):
    T.func_attr({"global_symbol": "add"})
    for i, j in T.grid(4, 4):
      with T.block("C"):
        vi = T.axis.spatial(4, i)
        vj = T.axis.spatial(4, j)
        C[vi, vj] = A[vi, vj] + B[vi, vj]

rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)

2.5.1.2. Exercise 1: Broadcast Add

In [48]:
# init data
a = np.arange(16).reshape(4, 4)
b = np.arange(4, 0, -1).reshape(4)

In [49]:
# numpy version
c_np = a + b
c_np

array([[ 4,  4,  4,  4],
       [ 8,  8,  8,  8],
       [12, 12, 12, 12],
       [16, 16, 16, 16]])

In [50]:
@tvm.script.ir_module
class MyAdd:
  @T.prim_func
  def add(
      A: T.Buffer[(4, 4), "int64"], 
      B: T.Buffer[(4), "int64"],
      C: T.Buffer[(4, 4), "int64"]
    ):
        T.func_attr({"global_symbol": "add", "tir.noalias": True})
        for i, j in T.grid(4, 4):
            with T.block("C"):
                vi = T.axis.spatial(4, i)
                vj = T.axis.spatial(4, j)
                C[vi, vj] = A[vi, vj] + B[vj]

rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)

2.5.1.3. Exercise 2: 2D Convolution¶


In [156]:
N, CI, H, W, CO, K = 1, 1, 8, 8, 2, 3
OUT_H, OUT_W = H - K + 1, W - K + 1
data = np.arange(N*CI*H*W).reshape(N, CI, H, W)
weight = np.arange(CO*CI*K*K).reshape(CO, CI, K, K)

In [157]:
# torch version
import torch

data_torch = torch.Tensor(data)
weight_torch = torch.Tensor(weight)
conv_torch = torch.nn.functional.conv2d(data_torch, weight_torch)
conv_torch = conv_torch.numpy().astype(np.int64)
conv_torch

array([[[[ 474,  510,  546,  582,  618,  654],
         [ 762,  798,  834,  870,  906,  942],
         [1050, 1086, 1122, 1158, 1194, 1230],
         [1338, 1374, 1410, 1446, 1482, 1518],
         [1626, 1662, 1698, 1734, 1770, 1806],
         [1914, 1950, 1986, 2022, 2058, 2094]],

        [[1203, 1320, 1437, 1554, 1671, 1788],
         [2139, 2256, 2373, 2490, 2607, 2724],
         [3075, 3192, 3309, 3426, 3543, 3660],
         [4011, 4128, 4245, 4362, 4479, 4596],
         [4947, 5064, 5181, 5298, 5415, 5532],
         [5883, 6000, 6117, 6234, 6351, 6468]]]])

In [209]:
@tvm.script.ir_module
class MyConv:
  N, CI, H, W, CO, K = 1, 1, 8, 8, 2, 3
  OUT_H, OUT_W = H - K + 1, W - K + 1
  @T.prim_func
  def conv(
    in_img: T.Buffer[(1, 1, 8, 8), "int64"],
    filters: T.Buffer[(2, 1, 3, 3), "int64"],
    out_img: T.Buffer[(1, 2, 6, 6), "int64"]
  ):
    T.func_attr({"global_symbol": "conv", "tir.noalias": True})
    for n, co, i, j, kh, kw in T.grid(1, 2, 6, 6, 3, 3):
      with T.block("conv"):
        vn, vco, vi, vj, vkh, vkw = T.axis.remap("SSSSRR", [n, co, i, j, kh, kw]) 
        with T.init():
          out_img[vn, vco, vi, vj] = T.int64(0)
        out_img[vn, vco, vi, vj] += in_img[0, 0, vi + vkh, vj + vkw] * filters[vco, 0, vkh, vkw]

print(MyConv)

@conv = primfn(in_img_handle: handle, filters_handle: handle, out_img_handle: handle) -> ()
  attr = {"global_symbol": "conv", "tir.noalias": True}
  buffers = {in_img: Buffer(in_img_1: Pointer(global int64), int64, [1, 1, 8, 8], []),
             filters: Buffer(filters_1: Pointer(global int64), int64, [2, 1, 3, 3], []),
             out_img: Buffer(out_img_1: Pointer(global int64), int64, [1, 2, 6, 6], [])}
  buffer_map = {in_img_handle: in_img, filters_handle: filters, out_img_handle: out_img} {
  block([], "root") {
    tir.reads([])
    tir.writes([])
    for (n: int32, 0, 1) {
      for (co: int32, 0, 2) {
        for (i: int32, 0, 6) {
          for (j: int32, 0, 6) {
            for (kh: int32, 0, 3) {
              for (kw: int32, 0, 3) {
                block([1, 2, 6, 6, tir.reduce_axis(0, 3), tir.reduce_axis(0, 3)], "conv") as [vn, vco, vi, vj, vkh, vkw] {
                  bind(vn, n)
                  bind(vco, co)
                  bind(vi, i)
                  bind(vj, 

In [207]:
rt_lib = tvm.build(MyConv, target="llvm")
data_tvm = tvm.nd.array(data)
weight_tvm = tvm.nd.array(weight)
conv_tvm = tvm.nd.array(np.empty((N, CO, OUT_H, OUT_W), dtype=np.int64))
rt_lib["conv"](data_tvm, weight_tvm, conv_tvm)
np.testing.assert_allclose(conv_tvm.numpy(), conv_torch, rtol=1e-5)

result = conv_tvm.numpy()
#for index in np.ndindex(result.shape):
#    print(index, result[index])
print(data)
print(result)
print('*' * 20)
print(conv_torch)

[[[[ 0  1  2  3  4  5  6  7]
   [ 8  9 10 11 12 13 14 15]
   [16 17 18 19 20 21 22 23]
   [24 25 26 27 28 29 30 31]
   [32 33 34 35 36 37 38 39]
   [40 41 42 43 44 45 46 47]
   [48 49 50 51 52 53 54 55]
   [56 57 58 59 60 61 62 63]]]]
[[[[ 474  510  546  582  618  654]
   [ 762  798  834  870  906  942]
   [1050 1086 1122 1158 1194 1230]
   [1338 1374 1410 1446 1482 1518]
   [1626 1662 1698 1734 1770 1806]
   [1914 1950 1986 2022 2058 2094]]

  [[1203 1320 1437 1554 1671 1788]
   [2139 2256 2373 2490 2607 2724]
   [3075 3192 3309 3426 3543 3660]
   [4011 4128 4245 4362 4479 4596]
   [4947 5064 5181 5298 5415 5532]
   [5883 6000 6117 6234 6351 6468]]]]
********************
[[[[ 474  510  546  582  618  654]
   [ 762  798  834  870  906  942]
   [1050 1086 1122 1158 1194 1230]
   [1338 1374 1410 1446 1482 1518]
   [1626 1662 1698 1734 1770 1806]
   [1914 1950 1986 2022 2058 2094]]

  [[1203 1320 1437 1554 1671 1788]
   [2139 2256 2373 2490 2607 2724]
   [3075 3192 3309 3426 3543 3660]
  