In [2]:
import IPython
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T

In [3]:
# init data
a = np.arange(16).reshape(4, 4)
b = np.arange(16, 0, -1).reshape(4, 4)

In [4]:
# numpy version
c_np = a + b
c_np

array([[16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16]])

In [5]:
# low-level numpy version
def lnumpy_add(a: np.ndarray, b: np.ndarray, c: np.ndarray):
	for i in range(4):
		for j in range (4):
			c[i, j] = a[i, j] + b[i, j]
c_lnumpy = np.empty((4, 4), dtype=np.int64)
lnumpy_add(a, b, c_lnumpy)
c_lnumpy

array([[16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16],
       [16, 16, 16, 16]])

In [None]:
# TensorIR version
@tvm.script.ir_module
class MyAdd:
	@T.prim_func
	def add(A: T.Buffer((4, 4), "int64"), # type: ignore
			B: T.Buffer((4, 4), "int64"), # type: ignore
			C: T.Buffer((4, 4), "int64")): # type: ignore
		T.func_attr({"global_symbol": "add"})
		for i, j in T.grid(4, 4):
			with T.block("C"):
				vi = T.axis.spatial(4, i)
				vj = T.axis.spatial(4, j)
				C[vi, vj] = A[vi, vj] + B[vi, vj]

rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)

In [13]:
# init data
a = np.arange(16).reshape(4, 4)
b = np.arange(4, 0, -1).reshape(4)

(4,)

In [9]:
# numpy version
c_np = a + b
c_np

array([[ 4,  4,  4,  4],
       [ 8,  8,  8,  8],
       [12, 12, 12, 12],
       [16, 16, 16, 16]])

In [None]:
@tvm.script.ir_module
class MyAdd:
	@T.prim_func
	def add(A: T.Buffer((4, 4), "int64"), # type: ignore
			B: T.Buffer((4,), "int64"), # type: ignore
			C: T.Buffer((4, 4), "int64")): # type: ignore
		T.func_attr({"global_symbol": "add", "tri.noalias": True}) # type: ignore
		for i, j in T.grid(4, 4):
			with T.block("C"):
				vi = T.axis.spatial(4, i)
				vj = T.axis.spatial(4, j)
				C[vi, vj] = A[vi, vj] + B[vj]

rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)

In [17]:
N, CI, H, W, CO, K = 1, 1, 8, 8, 2, 3
OUT_H, OUT_W = H - K + 1, W - K + 1
data = np.arange(N*CI*H*W).reshape(N, CI, H, W)
weight = np.arange(CO*CI*K*K).reshape(CO, CI, K, K)

In [40]:
# torch version
import torch

data_torch = torch.Tensor(data)
weight_torch = torch.Tensor(weight)
conv_torch = torch.nn.functional.conv2d(data_torch, weight_torch)
conv_torch = conv_torch.numpy().astype(np.int64)
conv_torch

array([[[[ 474,  510,  546,  582,  618,  654],
         [ 762,  798,  834,  870,  906,  942],
         [1050, 1086, 1122, 1158, 1194, 1230],
         [1338, 1374, 1410, 1446, 1482, 1518],
         [1626, 1662, 1698, 1734, 1770, 1806],
         [1914, 1950, 1986, 2022, 2058, 2094]],

        [[1203, 1320, 1437, 1554, 1671, 1788],
         [2139, 2256, 2373, 2490, 2607, 2724],
         [3075, 3192, 3309, 3426, 3543, 3660],
         [4011, 4128, 4245, 4362, 4479, 4596],
         [4947, 5064, 5181, 5298, 5415, 5532],
         [5883, 6000, 6117, 6234, 6351, 6468]]]])

In [None]:
@tvm.script.ir_module
class MyConv:
	@T.prim_func
	def conv(Input: T.Buffer((N, CI, H, W), "int64"), # type: ignore
			 Weight: T.Buffer((CO, CI, K, K), "int64"), # type: ignore
			 Output: T.Buffer((N, CO, OUT_H, OUT_W), "int64")): # type: ignore
		T.func_attr({"global_symbol": "conv", "tri.noalias": True}) # type: ignore
		for b, k, q, h, w, i, j in T.grid(N, CO, CI, OUT_H, OUT_W, K, K):
			with T.block("Output"):
				vb = T.axis.spatial(N, b)
				vk = T.axis.spatial(CO, k)
				vq = T.axis.reduce(CI, q)
				vh = T.axis.spatial(OUT_H, h)
				vw = T.axis.spatial(OUT_W, w)
				vi = T.axis.reduce(K, i)
				vj = T.axis.reduce(K, j)
				with T.init():
					Output[vb, vk, vh, vw] = T.int64(0)
				Output[vb, vk, vh, vw] += Input[vb, vq, vh + vi, vw + vj] * Weight[vk, vq, vi, vj]

rt_lib = tvm.build(MyConv, target="llvm")
data_tvm = tvm.nd.array(data)
weight_tvm = tvm.nd.array(weight)
conv_tvm = tvm.nd.array(np.empty((N, CO, OUT_H, OUT_W), dtype=np.int64))
rt_lib["conv"](data_tvm, weight_tvm, conv_tvm)
np.testing.assert_allclose(conv_tvm.numpy(), conv_torch, rtol=1e-5)

In [None]:
def code2html(code):
	"""Helper function to use pygments to turn the code string into highlighted html."""
	import pygments
	from pygments.lexers import Python3Lexer
	from pygments.formatters import HtmlFormatter
	formatter = HtmlFormatter(style="github-dark")
	html = pygments.highlight(code, Python3Lexer(), formatter)
	return "<style>%s</style>%s\n" % (formatter.get_style_defs(".highlight"), html)

In [199]:
from pygments.styles import get_all_styles

print(list(get_all_styles()))

['abap', 'algol', 'algol_nu', 'arduino', 'autumn', 'bw', 'borland', 'coffee', 'colorful', 'default', 'dracula', 'emacs', 'friendly_grayscale', 'friendly', 'fruity', 'github-dark', 'gruvbox-dark', 'gruvbox-light', 'igor', 'inkpot', 'lightbulb', 'lilypond', 'lovelace', 'manni', 'material', 'monokai', 'murphy', 'native', 'nord-darker', 'nord', 'one-dark', 'paraiso-dark', 'paraiso-light', 'pastie', 'perldoc', 'rainbow_dash', 'rrt', 'sas', 'solarized-dark', 'solarized-light', 'staroffice', 'stata-dark', 'stata-light', 'tango', 'trac', 'vim', 'vs', 'xcode', 'zenburn']


In [79]:
@tvm.script.ir_module
class MyAdd:
	@T.prim_func
	def add(A: T.Buffer((4, 4), "int64"), # type: ignore
			B: T.Buffer((4, 4), "int64"), # type: ignore
			C: T.Buffer((4, 4), "int64")): # type: ignore
		T.func_attr({"global_symbol": "add"})
		for i, j in T.grid(4, 4):
			with T.block("C"):
				vi = T.axis.spatial(4, i)
				vj = T.axis.spatial(4, j)
				C[vi, vj] = A[vi, vj] + B[vi, vj]

sch = tvm.tir.Schedule(MyAdd)
block = sch.get_block("C", func_name="add")
i, j = sch.get_loops(block)
i0, i1 = sch.split(i, factors=[2, 2])
sch.parallel(i0)
sch.unroll(i1)
sch.vectorize(j)
IPython.display.HTML(code2html(sch.mod.script()))

In [None]:
def lnumpy_mm_relu_v2(A: np.ndarray, B: np.ndarray, C: np.ndarray):
	Y = np.empty((16, 128, 128), dtype="float32")
	for n in range(16):
		for i in range(128):
			for j in range(128):
				for k in range(128):
					if k == 0:
						Y[n, i, j] = 0
					Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
	for n in range(16):
		for i in range(128):
			for j in range(128):
				C[n, i, j] = max(Y[n, i, j], 0)

In [110]:
@tvm.script.ir_module
class TargetModule:
    @T.prim_func
    def bmm_relu(A: T.Buffer((16, 128, 128), "float32"), B: T.Buffer((16, 128, 128), "float32"), C: T.Buffer((16, 128, 128), "float32")) -> None: # type: ignore
        T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        Y = T.alloc_buffer([16, 128, 128], dtype="float32")
        for i0 in T.parallel(16):
            for i1, i2_0 in T.grid(128, 16):
                for ax0_init in T.vectorized(8):
                    with T.block("Y_init"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + ax0_init)
                        Y[n, i, j] = T.float32(0)
                for ax1_0 in T.serial(32):
                    for ax1_1 in T.unroll(4):
                        for ax0 in T.serial(8):
                            with T.block("Y_update"):
                                n, i = T.axis.remap("SS", [i0, i1])
                                j = T.axis.spatial(128, i2_0 * 8 + ax0)
                                k = T.axis.reduce(128, ax1_0 * 4 + ax1_1)
                                Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
                for i2_1 in T.vectorized(8):
                    with T.block("C"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + i2_1)
                        C[n, i, j] = T.max(Y[n, i, j], T.float32(0))

In [197]:
@tvm.script.ir_module
class MyBmmRelu:
	@T.prim_func
	def bmm_relu(A: T.Buffer((16, 128, 128), "float32"), # type: ignore
				 B: T.Buffer((16, 128, 128), "float32"), # type: ignore
				 C: T.Buffer((16, 128, 128), "float32")): # type: ignore
		T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True}) # type: ignore
		Y = T.alloc_buffer([16, 128, 128], dtype="float32")
		for i0, i1, i2, i3 in T.grid(16, 128, 128, 128):
			with T.block("Y"):
				n, i, j, k = T.axis.remap("SSSR", [i0, i1, i2, i3])
				with T.init():
					Y[n, i, j] = T.float32(0)
				Y[n, i, j] += A[n, i, k] * B[n, k, j]
		for i0, i1, i2 in T.grid(16, 128, 128):
			with T.block("C"):
				n, i, j = T.axis.remap("SSS", [i0, i1, i2])
				C[n, i, j] = T.max(Y[n, i, j], T.float32(0))

sch = tvm.tir.Schedule(MyBmmRelu)
block_C = sch.get_block("C", func_name="bmm_relu")
block_Y = sch.get_block("Y", func_name="bmm_relu")

bc, ic, jc = sch.get_loops(block_C)

j0, j1 = sch.split(jc, factors=[None, 8])
sch.compute_at(block_Y, j0)

sch.parallel(bc)
sch.vectorize(j1)

b, i, j0, j1, k = sch.get_loops(block_Y)
sch.reorder(k, j1)
k0, k1 = sch.split(k, factors=[None, 4])
sch.decompose_reduction(block_Y, k0)

block_Y_init = sch.get_block("Y_init", func_name="bmm_relu")
block_Y_update = sch.get_block("Y_update", func_name="bmm_relu")

bi, ii, ji0, ji1 = sch.get_loops(block_Y_init)
sch.vectorize(ji1)

bu, iu, ju0, ju1, k0, k1 = sch.get_loops(block_Y_update)
sch.unroll(k0)

tvm.ir.assert_structural_equal(sch.mod, TargetModule)
print("Pass")
IPython.display.HTML(code2html(sch.mod.script()))

Pass


In [198]:
before_rt_lib = tvm.build(MyBmmRelu, target="llvm")
after_rt_lib = tvm.build(sch.mod, target="llvm")
a_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
b_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
c_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
after_rt_lib["bmm_relu"](a_tvm, b_tvm, c_tvm)
before_timer = before_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("Before transformation:")
print(before_timer(a_tvm, b_tvm, c_tvm))

f_timer = after_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("After transformation:")
print(f_timer(a_tvm, b_tvm, c_tvm))

Before transformation:
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  12.1840      12.1840      12.1840      12.1840       0.0000                  
After transformation:
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
   0.6562       0.6562       0.6562       0.6562       0.0000                  


In [202]:
@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def mm_relu(A: T.Buffer((128, 128), "float32"), # type: ignore
                B: T.Buffer((128, 128), "float32"), # type: ignore
                C: T.Buffer((128, 128), "float32")): # type: ignore
        T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                vk = T.axis.reduce(128, k)
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))

In [236]:
def transform(mod, jfactor):
    sch = tvm.tir.Schedule(mod)
    block_Y = sch.get_block("Y", func_name="mm_relu")
    i, j, k = sch.get_loops(block_Y)
    j0, j1 = sch.split(j, factors=[None, jfactor])
    sch.reorder(j0, k, j1)
    block_C = sch.get_block("C", "mm_relu")
    sch.reverse_compute_at(block_C, j0)
    return sch.mod

def testj(jfactor, a_nd, b_nd, c_nd):
    mod_transformed = transform(MyModule, jfactor=jfactor)

    rt_lib_transformed = tvm.build(mod_transformed, "llvm")
    f_timer_transformed = rt_lib_transformed.time_evaluator("mm_relu", tvm.cpu())
    print("Time cost of transformed mod_transformed for jfactor %i %g sec" % (jfactor, f_timer_transformed(a_nd, b_nd, c_nd).mean))
    # display the code below
    IPython.display.HTML(code2html(mod_transformed.script()))

In [242]:
dtype = "float32"
a_np = np.random.rand(128, 128).astype(dtype)
b_np = np.random.rand(128, 128).astype(dtype)
a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.array(b_np)
c_nd = tvm.nd.empty((128, 128), dtype="float32")
type(c_nd)

for i in range(7):
    testj(2**(i+1), a_nd, b_nd, c_nd)

Time cost of transformed mod_transformed for jfactor 2 0.000454409 sec
Time cost of transformed mod_transformed for jfactor 4 0.000213332 sec
Time cost of transformed mod_transformed for jfactor 8 0.000139546 sec
Time cost of transformed mod_transformed for jfactor 16 0.000113309 sec
Time cost of transformed mod_transformed for jfactor 32 9.04959e-05 sec
Time cost of transformed mod_transformed for jfactor 64 0.000106782 sec
Time cost of transformed mod_transformed for jfactor 128 0.000112312 sec
