In [1]:
import IPython
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T

In [2]:
def lnumpy_mm_relu_v2(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((16, 128, 128), dtype="float32")
    for n in range(16):
        for i in range(128):
            for j in range(128):
                for k in range(128):
                    if k == 0:
                        Y[i, j] = 0
                    Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
    for n in range(16):
        for i in range(128):
            for j in range(128):
                C[i, j] = max(Y[i, j], 0)

In [3]:
@tvm.script.ir_module
class MyBmmRule:
  @T.prim_func
  def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"],
               B: T.Buffer[(16, 128, 128), "float32"],
               C: T.Buffer[(16, 128, 128), "float32"]):
    T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
    Y = T.alloc_buffer((16, 128, 128), dtype="float32")
    for b, i, j, k in T.grid(16, 128, 128, 128):
        with T.block("Y"):
            vb = T.axis.spatial(16, b)
            vi = T.axis.spatial(128, i)
            vj = T.axis.spatial(128, j)
            vk = T.axis.reduce(128, k)
            with T.init():
                Y[vb, vi, vj] = T.float32(0)
            Y[vb, vi, vj] = Y[vb, vi, vj] + A[vb, vi, vk] * B[vb, vk, vj]
    for b, i, j in T.grid(16, 128, 128):
        with T.block("C"):
            vb = T.axis.spatial(16, b)
            vi = T.axis.spatial(128, i)
            vj = T.axis.spatial(128, j)
            C[vb, vi, vj] = T.max(Y[vb, vi, vj], T.float32(0))

sch = tvm.tir.Schedule(MyBmmRule)
IPython.display.Code(sch.mod.script(), language="python")
# Also please validate your result

In [4]:
@tvm.script.ir_module
class TargetModule:
    @T.prim_func
    def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"], B: T.Buffer[(16, 128, 128), "float32"], C: T.Buffer[(16, 128, 128), "float32"]) -> None:
        T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        Y = T.alloc_buffer([16, 128, 128], dtype="float32")
        for i0 in T.parallel(16):
            for i1, i2_0 in T.grid(128, 16):
                for ax0_init in T.vectorized(8):
                    with T.block("Y_init"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + ax0_init)
                        Y[n, i, j] = T.float32(0)
                for ax1_0 in T.serial(32):
                    for ax1_1 in T.unroll(4):
                        for ax0 in T.serial(8):
                            with T.block("Y_update"):
                                n, i = T.axis.remap("SS", [i0, i1])
                                j = T.axis.spatial(128, i2_0 * 8 + ax0)
                                k = T.axis.reduce(128, ax1_0 * 4 + ax1_1)
                                Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
                for i2_1 in T.vectorized(8):
                    with T.block("C"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + i2_1)
                        C[n, i, j] = T.max(Y[n, i, j], T.float32(0))

In [5]:
sch = tvm.tir.Schedule(MyBmmRule)
# TODO: transformations
# Hints: you can use
# `IPython.display.Code(sch.mod.script(), language="python")`
# or `print(sch.mod.script())`
# to show the current program at any time during the transformation.

# Step 1. Get blocks
Y = sch.get_block("Y", func_name="bmm_relu")


# Step 2. Get loops
b, i, j, k = sch.get_loops(Y)
sch.parallel(b)

# Step 3. Organize the loops
j0, j1 = sch.split(j, factors=[None, 8])
sch.reorder(j0, k, j1)
block_C = sch.get_block("C", func_name="bmm_relu")
sch.reverse_compute_at(block_C, j0)

# Step 4. decompose reduction
Y_init = sch.decompose_reduction(Y, k)
n, i, j_0, j_1_init = sch.get_loops(Y_init)
_, _, _, ax0 = sch.get_loops(block_C)
Y_update_block = sch.get_block("Y_update", func_name="bmm_relu")
_, _, _, k, j_1 = sch.get_loops(Y_update_block)
k0, k1 = sch.split(k, factors=[32, 4])


# # Step 5. vectorize / parallel / unroll
sch.vectorize(j_1_init)
sch.vectorize(ax0)
sch.unroll(k1)

IPython.display.Code(sch.mod.script(), language="python")

In [6]:
tvm.ir.assert_structural_equal(sch.mod, TargetModule)
print("Pass")

Pass


In [7]:
before_rt_lib = tvm.build(MyBmmRule, target="llvm")
after_rt_lib = tvm.build(sch.mod, target="llvm")
a_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
b_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
c_tvm = tvm.nd.array(np.random.rand(16, 128, 128).astype("float32"))
after_rt_lib["bmm_relu"](a_tvm, b_tvm, c_tvm)
before_timer = before_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("Before transformation:")
print(before_timer(a_tvm, b_tvm, c_tvm))

f_timer = after_rt_lib.time_evaluator("bmm_relu", tvm.cpu())
print("After transformation:")
print(f_timer(a_tvm, b_tvm, c_tvm))

Before transformation:
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  33.1833      33.1833      33.1833      33.1833       0.0000   
               
After transformation:
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
   2.3514       2.3514       2.3514       2.3514       0.0000   
               
