# Tensor Program Abstraction in Action





In [17]:
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T
import numpy as np

print('tvm versin: %s' % tvm.__version__)

tvm versin: 0.13.dev0


## Constructing Tensor Program

TVMScript 写一个 tensor Program，内容是计算 C = A + B。
MyModule.show() 看到，自动增加了 2 行：

```python
T.reads(A[vi], B[vi])
T.writes(C[vi])
```

In [4]:
@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def main(A: T.Buffer(128, "float32"), 
             B: T.Buffer(128, "float32"), 
             C: T.Buffer(128, "float32")):
        # extra annotations for the function
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        for i in range(128):
            with T.block("C"):
                # declare a data parallel iterator on spatial domain
                vi = T.axis.spatial(128, i)
                C[vi] = A[vi] + B[vi]
type(MyModule)

tvm.ir.module.IRModule

In [5]:
MyModule.show()

### Build and run

运行 IRModule，需要先 build 成执行 module，然后使用 tvm runtime 运行。

In [6]:
rt_mod = tvm.build(MyModule, target="llvm")  # The module for CPU backends.
type(rt_mod)

tvm.driver.build_module.OperatorModule

After build, mod contains a collection of runnable functions. We can retrieve each function by its name.

In [7]:
func = rt_mod["main"]
func

<tvm.runtime.packed_func.PackedFunc at 0x7fbe284366c0>

In [8]:
a = tvm.nd.array(np.arange(128, dtype="float32"))
b = tvm.nd.array(np.ones(128, dtype="float32")) 
c = tvm.nd.empty((128,), dtype="float32")

func(a, b, c)

In [9]:
print(a)
print(b)
print(c)

[  0.   1.   2.   3.   4.   5.   6.   7.   8.   9.  10.  11.  12.  13.
  14.  15.  16.  17.  18.  19.  20.  21.  22.  23.  24.  25.  26.  27.
  28.  29.  30.  31.  32.  33.  34.  35.  36.  37.  38.  39.  40.  41.
  42.  43.  44.  45.  46.  47.  48.  49.  50.  51.  52.  53.  54.  55.
  56.  57.  58.  59.  60.  61.  62.  63.  64.  65.  66.  67.  68.  69.
  70.  71.  72.  73.  74.  75.  76.  77.  78.  79.  80.  81.  82.  83.
  84.  85.  86.  87.  88.  89.  90.  91.  92.  93.  94.  95.  96.  97.
  98.  99. 100. 101. 102. 103. 104. 105. 106. 107. 108. 109. 110. 111.
 112. 113. 114. 115. 116. 117. 118. 119. 120. 121. 122. 123. 124. 125.
 126. 127.]
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.

## Transform the Tensor Program

使用程序 (tvm scheduler) 自动优化程序 (MyModule). schedule 是 TVM 提供的调度源语，通俗的理解，就是优化算法的集合。

In [10]:
sch = tvm.tir.Schedule(MyModule)
type(sch)

tvm.tir.schedule.schedule.Schedule

Let us first try to split the loops

In [11]:
# get block to transform(optimize)
block_c = sch.get_block("C")
# get loop in the block to transform(optimize)
(i,) = sch.get_loops(block_c)

# 3 transformation steps: 
# 1. tile, 2. reorder, 3. parallel
i_0, i_1, i_2 = sch.split(i, factors=[None, 4, 4])
sch.reorder(i_0, i_2, i_1)
sch.parallel(i_0)
sch.mod.show()

In [12]:
8 * 4 * 4

128

Run 优化后的程序，并验证计算结果的正确性不变


In [13]:
c2 = tvm.nd.empty((128,), dtype="float32")

transformed_mod = tvm.build(sch.mod, target="llvm")  # The module for CPU backends.
transformed_mod["main"](a, b, c2)

# check results are the same as before
np.testing.assert_allclose(c.asnumpy(), c2.asnumpy(), rtol=1e-5, atol=1e-5)

## Constructing Tensor Program using Tensor Expression

In the previous example, we directly use TVMScript to construct the tensor program. In practice, it is usually helpful to construct these functions pragmatically from existing definitions. Tensor expression is an API that helps us to build some of the expression-like array computations.

In [18]:
# namespace for tensor expression utility
from tvm import te

# declare the computation using the expression API
A = te.placeholder((128, ), name="A")
B = te.placeholder((128, ), name="B")
C = te.compute((128,), lambda i: A[i] + B[i], name="C")

# create a function with the specified list of arguments. 
func = te.create_prim_func([A, B, C])
# mark that the function name is main
func = func.with_attr("global_symbol", "main")
ir_mod_from_te = IRModule({"main": func})

# ir_mod_from_te.show()

## Transforming a matrix multiplication program

In the above example, we showed how to transform an vector add. Now let us try to apply that to a slightly more complicated program(matrix multiplication). Let us first try to build the initial code using the tensor expression API.


In [22]:
from tvm import te

M = 1024
K = 1024
N = 1024

# The default tensor type in tvm
dtype = "float32"

target = "llvm"
dev = tvm.device(target, 0)

# Algorithm
k = te.reduce_axis((0, K), "k")
A = te.placeholder((M, K), name="A")
B = te.placeholder((K, N), name="B")
C = te.compute((M, N), lambda m, n: te.sum(A[m, k] * B[k, n], axis=k), name="C")

# Default schedule
func = te.create_prim_func([A, B, C])
func = func.with_attr("global_symbol", "main")
ir_module = IRModule({"main": func})
ir_module.show()

# build and run
func = tvm.build(ir_module, target="llvm")  # The module for CPU backends.

a = tvm.nd.array(np.random.rand(M, K).astype(dtype), dev)
b = tvm.nd.array(np.random.rand(K, N).astype(dtype), dev)
c = tvm.nd.array(np.zeros((M, N), dtype=dtype), dev)
func(a, b, c)

evaluator = func.time_evaluator(func.entry_name, dev, number=1)
print("Baseline: %f" % evaluator(a, b, c).mean)

Baseline: 2.407493


We can transform the loop access pattern to make it more cache friendly. Let us use the following schedule.

In [23]:
sch = tvm.tir.Schedule(ir_module)
block_c = sch.get_block("C")
# Get loops surronding the block
(y, x, k) = sch.get_loops(block_c)
# step 1: tile (split)
block_size = 32
yo, yi = sch.split(y, [None, block_size])
xo, xi = sch.split(x, [None, block_size])

# step 2: reorder
sch.reorder(yo, xo, k, yi, xi)
sch.mod.show()

# build and run
func = tvm.build(sch.mod, target="llvm")  # The module for CPU backends.

c = tvm.nd.array(np.zeros((M, N), dtype=dtype), dev)
func(a, b, c)

evaluator = func.time_evaluator(func.entry_name, dev, number=1)
print("after transformation: %f" % evaluator(a, b, c).mean)

after transformation: 0.151193
