# Apache TVM - an in-depth look

This notebook will demonstrate basics of TVM expressions and schedules.

Let's start with importing TVM:

In [1]:
import tvm
from tvm import te

import difflib


def compute_diff(s1: str, s2: str):
    """
    Demonstrates differences between two strings, line by line.

    Parameters
    ----------
    s1: str
        First sequence to compare
    s2: str
        Second sequence to compare
    """
    s1split = s1.split("\n")
    s2split = s2.split("\n")
    delta = difflib.ndiff(s1split, s2split)
    for line in delta:
        print(line)

## Defining schedules

**Schedules** are set of transformations applied to computations.

`tvm.te` provides Tensor Expressions used both by Relay to represent operations in the model functions, as in schedules/optimization strategies to organize operations.

### Creating computation

Let's perform element-wise matrix multiplication.

* `te.var` define single-value variables.
* `te.placeholder` are responsible for creating and managing space for tensors.
* `te.compute` constructs a new tensor by computing data over the shape domain with given function.

In [2]:
n = te.var("n")
m = te.var("m")

A = te.placeholder((m, n), name="A")
B = te.placeholder((m, n), name="B")
C = te.compute((m, n), lambda i, j: A[i, j] * B[i, j], name="C")

schedule = te.create_schedule([C.op])

* `C.op` is an operation for which we define the schedule.
* `schedule` defines what operations need to be computed - it will be subjected to further optimizations

### Lowering computations

`tvm.lower` transforms the computation definition into real callable function.

In [3]:
base_function = str(tvm.lower(schedule, [A, B, C], simple_mode=True))

print(base_function)

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle, C: T.handle):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        m, n = T.int32(), T.int32()
        A_1 = T.match_buffer(A, (m, n), strides=("stride", "stride"), buffer_type="auto")
        B_1 = T.match_buffer(B, (m, n), strides=("stride", "stride"), buffer_type="auto")
        C_1 = T.match_buffer(C, (m, n), strides=("stride", "stride"), buffer_type="auto")
        for i, j in T.grid(m, n):
            C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type="auto")
            A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type="auto")
            B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type="auto")
            C_2[i * C_1.strides[0] + j * C_1.strides[1]] = A_2[i * A_1.strides[0] + j * A_1.strides[1]] * B_2[i * B_1.strides[0] + j * B_1.strides[1

### Splitting and tiling computations

https://tvm.apache.org/docs/reference/api/python/te.html#tvm.te.Stage.split

#### Split

`split` splits a given axis by `factor` into outer and inner axis (inner axis has `factor` length), where inner axis has a `factor` length

In [4]:
n = te.var("n")
m = te.var("m")

A = te.placeholder((m, n), name="A")
B = te.placeholder((m, n), name="B")
C = te.compute((m, n), lambda i, j: A[i, j] * B[i, j], name="C")

schedule = te.create_schedule([C.op])

xo, xi = schedule[C].split(C.op.axis[0], factor=32)

split_function = str(tvm.lower(schedule, [A, B, C], simple_mode=True))

print(split_function)

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle, C: T.handle):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        m, n = T.int32(), T.int32()
        A_1 = T.match_buffer(A, (m, n), strides=("stride", "stride"), buffer_type="auto")
        B_1 = T.match_buffer(B, (m, n), strides=("stride", "stride"), buffer_type="auto")
        C_1 = T.match_buffer(C, (m, n), strides=("stride", "stride"), buffer_type="auto")
        for i_outer, i_inner in T.grid((m + 31) // 32, 32):
            if T.likely(i_outer * 32 + i_inner < m):
                for j in range(n):
                    cse_var_1: T.int32 = i_outer * 32 + i_inner
                    C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type="auto")
                    A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type="auto")
                    B_2 = T.Buffer((B_1

In [5]:
compute_diff(base_function, split_function)

  # from tvm.script import ir as I
  # from tvm.script import tir as T
  
  @I.ir_module
  class Module:
      @T.prim_func
      def main(A: T.handle, B: T.handle, C: T.handle):
          T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
          m, n = T.int32(), T.int32()
          A_1 = T.match_buffer(A, (m, n), strides=("stride", "stride"), buffer_type="auto")
          B_1 = T.match_buffer(B, (m, n), strides=("stride", "stride"), buffer_type="auto")
          C_1 = T.match_buffer(C, (m, n), strides=("stride", "stride"), buffer_type="auto")
-         for i, j in T.grid(m, n):
+         for i_outer, i_inner in T.grid((m + 31) // 32, 32):
+             if T.likely(i_outer * 32 + i_inner < m):
+                 for j in range(n):
+                     cse_var_1: T.int32 = i_outer * 32 + i_inner
-             C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type="auto")
+                     C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.d

#### Tile

https://tvm.apache.org/docs/reference/api/python/te.html#tvm.te.Stage.tile

Same as split, but in 2D - tiles the computations along given axes

In [6]:
n = te.var("n")
m = te.var("m")

A = te.placeholder((m, n), name="A")
B = te.placeholder((m, n), name="B")
C = te.compute((m, n), lambda i, j: A[i, j] * B[i, j], name="C")

schedule = te.create_schedule([C.op])

xo, xi, yo, yi = schedule[C].tile(C.op.axis[0], C.op.axis[1], x_factor=16, y_factor=8)

tile_function = str(tvm.lower(schedule, [A, B, C], simple_mode=True))

print(tile_function)

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle, C: T.handle):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        m, n = T.int32(), T.int32()
        A_1 = T.match_buffer(A, (m, n), strides=("stride", "stride"), buffer_type="auto")
        B_1 = T.match_buffer(B, (m, n), strides=("stride", "stride"), buffer_type="auto")
        C_1 = T.match_buffer(C, (m, n), strides=("stride", "stride"), buffer_type="auto")
        for i_outer, j_outer, i_inner in T.grid((m + 15) // 16, (n + 7) // 8, 16):
            if T.likely(i_outer * 16 + i_inner < m):
                for j_inner in range(8):
                    if T.likely(j_outer * 8 + j_inner < n):
                        cse_var_2: T.int32 = j_outer * 8 + j_inner
                        cse_var_1: T.int32 = i_outer * 16 + i_inner
                        C_2 = T.Buffer((C_1.strides[0] * m,), dat

In [7]:
compute_diff(base_function, tile_function)

  # from tvm.script import ir as I
  # from tvm.script import tir as T
  
  @I.ir_module
  class Module:
      @T.prim_func
      def main(A: T.handle, B: T.handle, C: T.handle):
          T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
          m, n = T.int32(), T.int32()
          A_1 = T.match_buffer(A, (m, n), strides=("stride", "stride"), buffer_type="auto")
          B_1 = T.match_buffer(B, (m, n), strides=("stride", "stride"), buffer_type="auto")
          C_1 = T.match_buffer(C, (m, n), strides=("stride", "stride"), buffer_type="auto")
-         for i, j in T.grid(m, n):
+         for i_outer, j_outer, i_inner in T.grid((m + 15) // 16, (n + 7) // 8, 16):
+             if T.likely(i_outer * 16 + i_inner < m):
+                 for j_inner in range(8):
+                     if T.likely(j_outer * 8 + j_inner < n):
+                         cse_var_2: T.int32 = j_outer * 8 + j_inner
+                         cse_var_1: T.int32 = i_outer * 16 + i

### Fusing axes

https://tvm.apache.org/docs/reference/api/python/te.html#tvm.te.Stage.fuse

Fuses two consecutive axes into one

In [8]:
n = te.var("n")
m = te.var("m")

A = te.placeholder((m, n), name="A")
B = te.placeholder((m, n), name="B")
C = te.compute((m, n), lambda i, j: A[i, j] * B[i, j], name="C")

schedule = te.create_schedule([C.op])

fusedaxis = schedule[C].fuse(C.op.axis[0], C.op.axis[1])

fuse_function = str(tvm.lower(schedule, [A, B, C], simple_mode=True))

print(fuse_function)

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle, C: T.handle):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        m, n = T.int32(), T.int32()
        A_1 = T.match_buffer(A, (m, n), strides=("stride", "stride"), buffer_type="auto")
        B_1 = T.match_buffer(B, (m, n), strides=("stride", "stride"), buffer_type="auto")
        C_1 = T.match_buffer(C, (m, n), strides=("stride", "stride"), buffer_type="auto")
        for i_j_fused in range(m * n):
            C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type="auto")
            A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type="auto")
            B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type="auto")
            C_2[i_j_fused // n * C_1.strides[0] + i_j_fused % n * C_1.strides[1]] = A_2[i_j_fused // n * A_1.strides[0] + i_j_fused % n * A_1.s

In [9]:
compute_diff(base_function, fuse_function)

  # from tvm.script import ir as I
  # from tvm.script import tir as T
  
  @I.ir_module
  class Module:
      @T.prim_func
      def main(A: T.handle, B: T.handle, C: T.handle):
          T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
          m, n = T.int32(), T.int32()
          A_1 = T.match_buffer(A, (m, n), strides=("stride", "stride"), buffer_type="auto")
          B_1 = T.match_buffer(B, (m, n), strides=("stride", "stride"), buffer_type="auto")
          C_1 = T.match_buffer(C, (m, n), strides=("stride", "stride"), buffer_type="auto")
-         for i, j in T.grid(m, n):
+         for i_j_fused in range(m * n):
              C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type="auto")
              A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type="auto")
              B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type="auto")
-             C_2[i * C_1.strides[0] + j * C_1.strides[1]] = A_2[i * A_1.st

### Binding thread axis

Threading is a common concept in GEMM and linear algebra computations. It is possible to bind a specified axis to threads, e.g. CUDA thread blocks and threads.

In [10]:
n = te.var("n")
m = te.var("m")

A = te.placeholder((m, n), name="A")
B = te.placeholder((m, n), name="B")
C = te.compute((m, n), lambda i, j: A[i, j] * B[i, j], name="C")

schedule = te.create_schedule([C.op])

co, ci = schedule[C].split(C.op.axis[0], factor=64)

schedule[C].bind(co, te.thread_axis("blockIdx.x"))
schedule[C].bind(ci, te.thread_axis("threadIdx.x"))

bind_function = str(tvm.lower(schedule, [A, B, C], simple_mode=True))

print(bind_function)

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle, C: T.handle):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        m, n = T.int32(), T.int32()
        A_1 = T.match_buffer(A, (m, n), strides=("stride", "stride"), buffer_type="auto")
        B_1 = T.match_buffer(B, (m, n), strides=("stride", "stride"), buffer_type="auto")
        C_1 = T.match_buffer(C, (m, n), strides=("stride", "stride"), buffer_type="auto")
        blockIdx_x = T.launch_thread("blockIdx.x", (m + 63) // 64)
        threadIdx_x = T.launch_thread("threadIdx.x", 64)
        for j in range(n):
            if T.likely(blockIdx_x * 64 + threadIdx_x < m):
                C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type="auto")
                A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type="auto")
                B_2 = T.Buffer((B_1.strides[0] *

In [11]:
compute_diff(base_function, bind_function)

  # from tvm.script import ir as I
  # from tvm.script import tir as T
  
  @I.ir_module
  class Module:
      @T.prim_func
      def main(A: T.handle, B: T.handle, C: T.handle):
          T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
          m, n = T.int32(), T.int32()
          A_1 = T.match_buffer(A, (m, n), strides=("stride", "stride"), buffer_type="auto")
          B_1 = T.match_buffer(B, (m, n), strides=("stride", "stride"), buffer_type="auto")
          C_1 = T.match_buffer(C, (m, n), strides=("stride", "stride"), buffer_type="auto")
-         for i, j in T.grid(m, n):
+         blockIdx_x = T.launch_thread("blockIdx.x", (m + 63) // 64)
+         threadIdx_x = T.launch_thread("threadIdx.x", 64)
+         for j in range(n):
+             if T.likely(blockIdx_x * 64 + threadIdx_x < m):
-             C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type="auto")
+                 C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, 

### Reordering computation of axes

https://tvm.apache.org/docs/reference/api/python/te.html#tvm.te.Stage.reorder

Reorders computation of axes - let's test it on tiled example

In [12]:
n = te.var("n")
m = te.var("m")

A = te.placeholder((m, n), name="A")
B = te.placeholder((m, n), name="B")
C = te.compute((m, n), lambda i, j: A[i, j] * B[i, j], name="C")

schedule = te.create_schedule([C.op])

xo, yo, xi, yi = schedule[C].tile(C.op.axis[0], C.op.axis[1], x_factor=16, y_factor=8)

schedule[C].reorder(yo, yi, xo, xi)

reordered_function_1 = str(tvm.lower(schedule, [A, B, C], simple_mode=True))

print(reordered_function_1)

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle, C: T.handle):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        m, n = T.int32(), T.int32()
        A_1 = T.match_buffer(A, (m, n), strides=("stride", "stride"), buffer_type="auto")
        B_1 = T.match_buffer(B, (m, n), strides=("stride", "stride"), buffer_type="auto")
        C_1 = T.match_buffer(C, (m, n), strides=("stride", "stride"), buffer_type="auto")
        for j_outer, j_inner in T.grid((n + 7) // 8, 8):
            if T.likely(j_outer * 8 + j_inner < n):
                for i_outer, i_inner in T.grid((m + 15) // 16, 16):
                    if T.likely(i_outer * 16 + i_inner < m):
                        cse_var_2: T.int32 = j_outer * 8 + j_inner
                        cse_var_1: T.int32 = i_outer * 16 + i_inner
                        C_2 = T.Buffer((C_1.strides[0] * m,), da

In [14]:
compute_diff(tile_function, reordered_function_1)

  # from tvm.script import ir as I
  # from tvm.script import tir as T
  
  @I.ir_module
  class Module:
      @T.prim_func
      def main(A: T.handle, B: T.handle, C: T.handle):
          T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
          m, n = T.int32(), T.int32()
          A_1 = T.match_buffer(A, (m, n), strides=("stride", "stride"), buffer_type="auto")
          B_1 = T.match_buffer(B, (m, n), strides=("stride", "stride"), buffer_type="auto")
          C_1 = T.match_buffer(C, (m, n), strides=("stride", "stride"), buffer_type="auto")
+         for j_outer, j_inner in T.grid((n + 7) // 8, 8):
+             if T.likely(j_outer * 8 + j_inner < n):
-         for i_outer, j_outer, i_inner in T.grid((m + 15) // 16, (n + 7) // 8, 16):
?                      ---------                                 --------------

+                 for i_outer, i_inner in T.grid((m + 15) // 16, 16):
? ++++++++

-             if T.likely(i_outer * 16 + i_inner < m

In [16]:
n = te.var("n")
m = te.var("m")

A = te.placeholder((m, n), name="A")
B = te.placeholder((m, n), name="B")
C = te.compute((m, n), lambda i, j: A[i, j] * B[i, j], name="C")

schedule = te.create_schedule([C.op])

xo, yo, xi, yi = schedule[C].tile(C.op.axis[0], C.op.axis[1], x_factor=16, y_factor=8)

schedule[C].reorder(xo, yo, xi, yi)

reordered_function_2 = str(tvm.lower(schedule, [A, B, C], simple_mode=True))

print(reordered_function_2)

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle, C: T.handle):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        m, n = T.int32(), T.int32()
        A_1 = T.match_buffer(A, (m, n), strides=("stride", "stride"), buffer_type="auto")
        B_1 = T.match_buffer(B, (m, n), strides=("stride", "stride"), buffer_type="auto")
        C_1 = T.match_buffer(C, (m, n), strides=("stride", "stride"), buffer_type="auto")
        for i_outer, j_outer, i_inner in T.grid((m + 15) // 16, (n + 7) // 8, 16):
            if T.likely(i_outer * 16 + i_inner < m):
                for j_inner in range(8):
                    if T.likely(j_outer * 8 + j_inner < n):
                        cse_var_2: T.int32 = j_outer * 8 + j_inner
                        cse_var_1: T.int32 = i_outer * 16 + i_inner
                        C_2 = T.Buffer((C_1.strides[0] * m,), dat

In [17]:
compute_diff(reordered_function_1, reordered_function_2)

  # from tvm.script import ir as I
  # from tvm.script import tir as T
  
  @I.ir_module
  class Module:
      @T.prim_func
      def main(A: T.handle, B: T.handle, C: T.handle):
          T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
          m, n = T.int32(), T.int32()
          A_1 = T.match_buffer(A, (m, n), strides=("stride", "stride"), buffer_type="auto")
          B_1 = T.match_buffer(B, (m, n), strides=("stride", "stride"), buffer_type="auto")
          C_1 = T.match_buffer(C, (m, n), strides=("stride", "stride"), buffer_type="auto")
-         for j_outer, j_inner in T.grid((n + 7) // 8, 8):
-             if T.likely(j_outer * 8 + j_inner < n):
-                 for i_outer, i_inner in T.grid((m + 15) // 16, 16):
? --------

+         for i_outer, j_outer, i_inner in T.grid((m + 15) // 16, (n + 7) // 8, 16):
?                      +++++++++                                 ++++++++++++++

-                     if T.likely(i_outer * 16 + i_i

### Shifting computations

Let's define a schedule with multiple operations in it

In [18]:
m = te.var("m")

A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i] + 1, name="B")
C = te.compute((m,), lambda i: B[i] * 2, name="C")

schedule = te.create_schedule(C.op)
base_op_chain = str(tvm.lower(schedule, [A, B, C], simple_mode=True))
print(base_op_chain)

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle, C: T.handle):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        m = T.int32()
        A_1 = T.match_buffer(A, (m,), strides=("stride",), buffer_type="auto")
        B_1 = T.match_buffer(B, (m,), strides=("stride",), buffer_type="auto")
        C_1 = T.match_buffer(C, (m,), strides=("stride",), buffer_type="auto")
        B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type="auto")
        for i in range(m):
            A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type="auto")
            B_2[i * B_1.strides[0]] = A_2[i * A_1.strides[0]] + T.float32(1)
        for i in range(m):
            C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type="auto")
            C_2[i * C_1.strides[0]] = B_2[i * B_1.strides[0]] * T.float32(2)


Each computation is handled separately.
However, it is possible to move computations so they can share the same loop.
For this, we can use `compute_at`.

https://tvm.apache.org/docs/reference/api/python/te.html#tvm.te.Stage.compute_at

In [19]:
m = te.var("m")

A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i] + 1, name="B")
C = te.compute((m,), lambda i: B[i] * 2, name="C")

schedule = te.create_schedule(C.op)

schedule[B].compute_at(schedule[C], C.op.axis[0])

computeshift_op_chain = str(tvm.lower(schedule, [A, B, C], simple_mode=True))
print(computeshift_op_chain)

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle, C: T.handle):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        m = T.int32()
        A_1 = T.match_buffer(A, (m,), strides=("stride",), buffer_type="auto")
        B_1 = T.match_buffer(B, (m,), strides=("stride",), buffer_type="auto")
        C_1 = T.match_buffer(C, (m,), strides=("stride",), buffer_type="auto")
        for i in range(m):
            B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type="auto")
            A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type="auto")
            B_2[i * B_1.strides[0]] = A_2[i * A_1.strides[0]] + T.float32(1)
            C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type="auto")
            C_2[i * C_1.strides[0]] = B_2[i * B_1.strides[0]] * T.float32(2)


In [20]:
compute_diff(base_op_chain, computeshift_op_chain)

  # from tvm.script import ir as I
  # from tvm.script import tir as T
  
  @I.ir_module
  class Module:
      @T.prim_func
      def main(A: T.handle, B: T.handle, C: T.handle):
          T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
          m = T.int32()
          A_1 = T.match_buffer(A, (m,), strides=("stride",), buffer_type="auto")
          B_1 = T.match_buffer(B, (m,), strides=("stride",), buffer_type="auto")
          C_1 = T.match_buffer(C, (m,), strides=("stride",), buffer_type="auto")
-         B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type="auto")
          for i in range(m):
+             B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type="auto")
              A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type="auto")
              B_2[i * B_1.strides[0]] = A_2[i * A_1.strides[0]] + T.float32(1)
-         for i in range(m):
              C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data,

## Axis reduction

In neural network models, one of the most popular scenarios is reduction along given axis using such functions as +, -, *

In TVM, axis along which reduction occurs are created using `tvm.te.reduce_axis` constructors and stored in `tvm.te.Tensor.op.reduce_axis` (regular axes are stored in `tvm.te.Tensor.op.axis`.

In [21]:
n = te.var("n")
m = te.var("m")

A = te.placeholder((n, m), name="A")

k = te.reduce_axis((0, m), "k")

B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B")

schedule = te.create_schedule(B.op)
reduced_axis = str(tvm.lower(schedule, [A, B], simple_mode=True))

print(reduced_axis)

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        n, m = T.int32(), T.int32()
        A_1 = T.match_buffer(A, (n, m), strides=("stride", "stride"), buffer_type="auto")
        B_1 = T.match_buffer(B, (n,), strides=("stride",), buffer_type="auto")
        for i in range(n):
            B_2 = T.Buffer((B_1.strides[0] * n,), data=B_1.data, buffer_type="auto")
            B_2[i * B_1.strides[0]] = T.float32(0)
            for k in range(m):
                A_2 = T.Buffer((A_1.strides[0] * n,), data=A_1.data, buffer_type="auto")
                B_2[i * B_1.strides[0]] = B_2[i * B_1.strides[0]] + A_2[i * A_1.strides[0] + k * A_1.strides[1]]


It is also possible to perform `split` and `bind` on reduce axis.

## Lowering of operations in TVM

`tvm.te` module provides all kinds of typical functions occuring in linear algebra and neural networks.

See [`tvm.te` documentation](https://tvm.apache.org/docs/reference/api/python/te.html) for more details.

When building the model, those functions (so called **Unified intrinsic calls**) are replaced with target-specific functions and/or implementations.

### Sample implementation of operation

Let's create a schedule computing sigmoid function and check it's OpenCL implementation.

*Note: `blockIdx.x` and `threadIdx.x` are used in OpenCL to represent GPU workgroups and their individual threads - they are accessed via `get_group_id` and `get_local_id`*

In [22]:
n = te.var("n")
A = te.placeholder((n,), name="A")
B = te.compute(A.shape, lambda i: te.sigmoid(A[i]), name="B")
schedule = te.create_schedule(B.op)
num_thread = 64
bx, tx = schedule[B].split(B.op.axis[0], factor=num_thread)
schedule[B].bind(bx, te.thread_axis("blockIdx.x"))
schedule[B].bind(tx, te.thread_axis("threadIdx.x"))

print(tvm.lower(schedule, [A, B], simple_mode=True))

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        n = T.int32()
        A_1 = T.match_buffer(A, (n,), strides=("stride",), buffer_type="auto")
        B_1 = T.match_buffer(B, (n,), strides=("stride",), buffer_type="auto")
        blockIdx_x = T.launch_thread("blockIdx.x", (n + 63) // 64)
        threadIdx_x = T.launch_thread("threadIdx.x", 64)
        if T.likely(blockIdx_x * 64 + threadIdx_x < n):
            B_2 = T.Buffer((B_1.strides[0] * n,), data=B_1.data, buffer_type="auto")
            A_2 = T.Buffer((A_1.strides[0] * n,), data=A_1.data, buffer_type="auto")
            B_2[(blockIdx_x * 64 + threadIdx_x) * B_1.strides[0]] = T.sigmoid(A_2[(blockIdx_x * 64 + threadIdx_x) * A_1.strides[0]])


As it can be observed, sigmoid is represented here as tir function `tir.sigmoid` - let's see how it is implemented in OpenCL

In [23]:
fopencl = tvm.build(schedule, [A, B], "opencl", name="mysigm")
print(fopencl.imported_modules[0].get_source())

// Function: mysigm_kernel
__kernel void mysigm_kernel(__global float* restrict A, __global float* restrict B, int n, int stride, int stride_1);
__kernel void mysigm_kernel(__global float* restrict A, __global float* restrict B, int n, int stride, int stride_1) {
  if ((convert_int(get_group_id(0))) < (n >> 6)) {
    B[((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) * stride)] = (1.000000e+00f / (1.000000e+00f + exp((0.000000e+00f - A[((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) * stride_1)]))));
  } else {
    if ((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) < n) {
      B[((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) * stride)] = (1.000000e+00f / (1.000000e+00f + exp((0.000000e+00f - A[((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) * stride_1)]))));
    }
  }
}




### Creating custom implementation of the operation

Adding new operation/function from the Python level is relatively easy, as long as necessary computation blocks are provided (lower-level implementations of kernels need to be handled in C++).

For a new operation/function, we need to create and register it and provide a lowering rule converting the operation to its implementation in supported targets.

*Note: demonstrated lowering of rules can be also used for existing operations to use our custom implementation of a certain function - its selection can be controlled with `level` parameter determining priority*

Let's add our custom `log` implementation:

In [24]:
def mylog(x):
    """customized log intrinsic function"""
    return tvm.tir.call_intrin(x.dtype, "tir.mylog", x)


def opencl_mylog_rule(op):
    """OpenCL lowering rule for log"""
    if op.dtype == "float32":
        return tvm.tir.call_pure_extern("float32", "log", op.args[0])
    else:
        return op


tvm.ir.register_op_attr("tir.mylog", "TCallEffectKind", tvm.tir.CallEffectKind.Pure)
tvm.ir.register_intrin_lowering(
    "tir.mylog", target="opencl", f=opencl_mylog_rule, level=99
)

<function __main__.opencl_mylog_rule(op)>

In [25]:
n = te.var("n")
A = te.placeholder((n,), name="A")
B = te.compute(A.shape, lambda i: mylog(A[i]), name="B")
schedule = te.create_schedule(B.op)
num_thread = 64
bx, tx = schedule[B].split(B.op.axis[0], factor=num_thread)
schedule[B].bind(bx, te.thread_axis("blockIdx.x"))
schedule[B].bind(tx, te.thread_axis("threadIdx.x"))

In [26]:
fopencl = tvm.build(schedule, [A, B], "opencl", name="mykernel")
print(fopencl.imported_modules[0].get_source())

// Function: mykernel_kernel
__kernel void mykernel_kernel(__global float* restrict A, __global float* restrict B, int n, int stride, int stride_1);
__kernel void mykernel_kernel(__global float* restrict A, __global float* restrict B, int n, int stride, int stride_1) {
  if ((convert_int(get_group_id(0))) < (n >> 6)) {
    B[((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) * stride)] = log(A[((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) * stride_1)]);
  } else {
    if ((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) < n) {
      B[((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) * stride)] = log(A[((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) * stride_1)]);
    }
  }
}




## Analyzing model's code

For building whole models from frontends, we use `relay.build`

In [27]:
import onnx
import tvm.relay as relay

onnxmodel = onnx.load("../models/test-delegate-one-input.onnx")
mod, params = relay.frontend.from_onnx(onnxmodel, freeze_params=True, dtype="float32")

with tvm.transform.PassContext(opt_level=3):
    graph, lib, params = relay.build(mod["main"], target="c")

    print(lib.get_source())

One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.


// tvm target: c -keys=cpu 
#define TVM_EXPORTS
#include "tvm/runtime/c_runtime_api.h"
#include "tvm/runtime/c_backend_api.h"
#include <math.h>
#include <stdbool.h>
#ifdef __cplusplus
extern "C"
#endif
TVM_DLL int32_t tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu(void* args, int32_t* arg_type_ids, int32_t num_args, void* out_ret_value, int32_t* out_ret_tcode, void* resource_handle);
#ifdef __cplusplus
extern "C"
#endif
TVM_DLL int32_t tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_1(void* args, int32_t* arg_type_ids, int32_t num_args, void* out_ret_value, int32_t* out_ret_tcode, void* resource_handle);
#ifdef __cplusplus
extern "C"
#endif
TVM_DLL int32_t tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_2(void* args, int32_t* arg_type_ids, int32_t num_args, void* out_ret_value, int32_t* out_ret_tcode, void* resource_handle);
#ifdef __cplusplus
extern "C"
#endif
TVM_DLL int32_t tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_add(void* args, int32_t* arg_type

  graph, lib, params = relay.build(


## Model compilation, evaluation and fine-tuning

The above aspects of TVM are covered in homework tasks.

## References

* [Schedule primitives in TVM](https://tvm.apache.org/docs/how_to/work_with_schedules/schedule_primitives.html#sphx-glr-how-to-work-with-schedules-schedule-primitives-py)
* [Reduction in TVM](https://tvm.apache.org/docs/how_to/work_with_schedules/reduction.html#sphx-glr-how-to-work-with-schedules-reduction-py)
* [TVM intrinsics and math functions](https://tvm.apache.org/docs/how_to/work_with_schedules/intrin_math.html#sphx-glr-how-to-work-with-schedules-intrin-math-py)

## Useful additional resources

* [TVM User how-to guides](https://tvm.apache.org/docs/how_to/index.html)