<a href="https://colab.research.google.com/github/XueyanZhang/MachineLearningCompilation/blob/master/MLC_TensorProgram.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!python3 -m  pip install mlc-ai-nightly -f https://mlc.ai/wheels

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://mlc.ai/wheels
Collecting mlc-ai-nightly
  Downloading https://github.com/mlc-ai/utils/releases/download/v0.9.dev0/mlc_ai_nightly-0.12.dev803%2Bg62125abf8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (52.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.0/52.0 MB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: mlc-ai-nightly
Successfully installed mlc-ai-nightly-0.12.dev803+g62125abf8


In [3]:
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T # tensor level IR

## Basic Example (linear layer + relu)

In [4]:
# linear layer with relu (high level)
dtype: str = 'float32'
M = K = N = 128
a_np = np.random.rand(M, K).astype(dtype)
b_np = np.random.rand(K, N).astype(dtype)
c_mm_relu = np.maximum(a_np @ b_np, 0)

In [5]:
# linear layer with relu (low level)
def ll_numpy_mm_relu(C: np.ndarray, A: np.ndarray, B: np.ndarray):
    Y = np.empty((M,N), dtype=dtype)
    for i in range(M):
        for j in range(N):
            for k in range(K):
                if k == 0:
                    Y[i, j] = 0
                Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
    for i in range(M):
        for j in range(N):
            C[i, j] = max(Y[i, j], 0)

In [6]:
# compute result with it
c_np = np.empty((M, N), dtype=dtype)
ll_numpy_mm_relu(c_np, a_np, b_np)

In [7]:
# compare two results
np.testing.assert_allclose(c_mm_relu, c_np, rtol=1e-5)

## TensorIR

In [8]:
@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def mm_relu(A: T.Buffer((128, 128), "float32"),
                B: T.Buffer((128, 128), "float32"),
                C: T.Buffer((128, 128), "float32")):
        T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})

        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                vk = T.axis.reduce(128, k)
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]

        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))

### Axes

- `vi` and `vj` are **Spatial Axes** as they directly corresponds to the spatial region of Y.

- `vk` is **Reduce Axes** as it involves reduction.

- equivalent using `remap`
```
vi, vj, vk = T.axis.remap("SSR", [i, j , k])
```
S: spatial, R: reduce

### Block
- a basic unit of computation in TensorIR

- Self-contained (all details for parallelization)



In [9]:
type(MyModule)

tvm.ir.module.IRModule

In [10]:
type(MyModule["mm_relu"])

tvm.tir.function.PrimFunc

In [11]:
# An IRModule can contain multiple tensor functions
@tvm.script.ir_module
class MyModuleWith2Func:
    @T.prim_func
    def mm(A: T.Buffer((128, 128), "float32"),
           B: T.Buffer((128, 128), "float32"),
           Y: T.Buffer((128, 128), "float32")):
        T.func_attr({"global_symbol": "mm", "tir.noalias": True})

        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                vk = T.axis.reduce(128, k)
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]

    @T.prim_func
    def relu(C: T.Buffer((128, 128), "float32"),
             Y: T.Buffer((128, 128), "float32")):
        T.func_attr({"global_symbol": "relu", "tir.noalias": True})
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))


## Transformation

Obtain different implementation variants.

E.g., split loop j into two loops
```
def lnumpy_mm_relu_v2(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((128, 128), dtype="float32")
    for i in range(128):
        for j0 in range(32):
            for k in range(128):
                for j1 in range(4):
                    j = j0 * 4 + j1
                    if k == 0:
                        Y[i, j] = 0
                    Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
    for i in range(128):
        for j in range(128):
            C[i, j] = max(Y[i, j], 0)
```

In [13]:
import IPython
IPython.display.Code(MyModule.script(), language="python")

In [18]:
from tvm.script.ir_builder.tir.ir import func_name
sch = tvm.tir.Schedule(MyModule)
# mm part
block_Y = sch.get_block("Y", func_name="mm_relu")
i, j, k = sch.get_loops(block_Y)
# split j into ranges 32 (auto compute) and 4
j0, j1 = sch.split(j, factors=[None, 4])

IPython.display.Code(sch.mod.script(), language="python")

In [19]:
sch.reorder(j0, k, j1)

IPython.display.Code(sch.mod.script(), language="python")

In [20]:
# relu part
block_C = sch.get_block("C", "mm_relu")
sch.reverse_compute_at(block_C, j0) # 

IPython.display.Code(sch.mod.script(), language="python")

In [21]:
sch.decompose_reduction(block_Y, k)
IPython.display.Code(sch.mod.script(), language="python")

### Run Variant Programs

In [22]:
# run MyModule
rt_lib = tvm.build(MyModule, target="llvm") # build llvm runtime

a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.array(b_np)
c_nd = tvm.nd.empty((128, 128), dtype="float32")
type(c_nd)

tvm.runtime.ndarray.NDArray

In [24]:
func_mm_relu = rt_lib["mm_relu"]
func_mm_relu(a_nd, b_nd, c_nd)

# c_mm_relu is computed (way) above
np.testing.assert_allclose(c_mm_relu, c_nd.numpy(), rtol=1e-5)

In [25]:
# run variant
rt_lib_mod = tvm.build(sch.mod, target="llvm")
c_nd_mod = tvm.nd.empty((128, 128), dtype="float32")
rt_lib_mod["mm_relu"](a_nd, b_nd, c_nd_mod)

np.testing.assert_allclose(c_mm_relu, c_nd_mod.numpy(), rtol=1e-5)

In [28]:
# compare runtime
f_timer = rt_lib.time_evaluator("mm_relu", tvm.cpu())
f_timer_mod = rt_lib_mod.time_evaluator("mm_relu", tvm.cpu())
print("MyModule takes %g sec" % f_timer(a_nd, b_nd, c_nd).mean)
print("sch.mod takes %g sec" % f_timer_mod(a_nd, b_nd, c_nd_mod).mean)

MyModule takes 0.00361607 sec
sch.mod takes 0.000712287 sec


- The performance difference comes from loop reordering, which affects the way of memory access.

- The varaint program leverages memory locality, and thus runs faster.