Calculating $B = A * W^T$, where 
- $A$ is non-transposed and $W$ is transposed (NT).
- $A$ and $W$ are both fp16, while $B$ is fp32.

Reference: https://zhuanlan.zhihu.com/p/410971069 

In [1]:
import tvm
from tvm import te
n, hd, wd = 8192, 256, 512

A = te.placeholder((n, hd), name="A", dtype="float16")
W = te.placeholder((wd, hd), name="W", dtype="float16")
hdr = te.reduce_axis((0, hd), "hdr")
B = te.compute((n, wd), lambda h, w: te.sum(A[h, hdr].astype("float32") * W[w, hdr].astype("float32"), axis=hdr), name="B")

s = te.create_schedule(B.op)
AS = s.cache_read(A, "shared", [B])
WS = s.cache_read(W, "shared", [B])
AF = s.cache_read(AS, "wmma.matrix_a", [B])
WF = s.cache_read(WS, "wmma.matrix_b", [B])
BF = s.cache_write(B, "wmma.accumulator")

block_x = te.thread_axis("blockIdx.x")
block_y = te.thread_axis("blockIdx.y")
thread_x = te.thread_axis("threadIdx.x")
thread_y = te.thread_axis("threadIdx.y")

In [2]:
h, w = s[B].op.axis
ho, hi = s[B].split(h, factor=256)
wo, wi = s[B].split(w, factor=128)
hio, hii = s[B].split(hi, 64)
wio, wii = s[B].split(wi, 64)
hiio, hiii = s[B].split(hii, 16)
wiio, wiii = s[B].split(wii, 16)
s[B].reorder(ho, wo, hio, wio, hiio, wiio, hiii, wiii)

In [3]:
hwio = s[B].fuse(hio, wio)
s[B].bind(ho, block_x)
s[B].bind(wo, block_y)
s[B].bind(hwio, thread_y)

In [4]:
s[BF].compute_at(s[B], hwio)

In [5]:
hbf, wbf = s[BF].op.axis
hbfo, hbfi = s[BF].split(hbf, 16)
wbfo, wbfi = s[BF].split(wbf, 16)

(rbf,) = s[BF].op.reduce_axis
rbfo, rbfi = s[BF].split(rbf, 16)
s[BF].reorder(rbfo, hbfo, wbfo, hbfi, wbfi, rbfi)

In [6]:
rbfoo, rbfoi = s[BF].split(rbfo, 4)

# s[BF].reorder(rbfoo, rbfoi, hbfo, wbfo, hbfi, wbfi, rbfi)

s[AF].compute_at(s[BF], rbfoi)
s[WF].compute_at(s[BF], rbfoi)
s[AS].compute_at(s[BF], rbfoo)
s[WS].compute_at(s[BF], rbfoo)

In [7]:
haf, waf = s[AF].op.axis
hafo, hafi = s[AF].split(haf, 16)

hwf, wwf = s[WF].op.axis
hwfo, hwfi = s[WF].split(hwf, 16)
s[WF].reorder(hwfo, wwf, hwfi)
# wwfo, wwfi = s[WF].split(wwf, 16)
# s[WF].reorder(wwfo, hwf, wwfi)

In [8]:
has, was = s[AS].op.axis
hwas = s[AS].fuse(has, was)
hwaso, hwasi = s[AS].split(hwas, 256)

hws, wws = s[WS].op.axis
hwws = s[WS].fuse(hws, wws)
hwwso, hwwsi = s[WS].split(hwws, 256)

In [9]:
hwasoo, hwasoi = s[AS].split(hwaso, 8)
hwasio, hwasii = s[AS].split(hwasi, 8)

hwwsoo, hwwsoi = s[WS].split(hwwso, 8)
hwwsio, hwwsii = s[WS].split(hwwsi, 8)

In [10]:
s[AS].bind(hwasoi, thread_y)
s[AS].bind(hwasio, thread_x)
s[AS].vectorize(hwasii)

s[WS].bind(hwwsoi, thread_y)
s[WS].bind(hwwsio, thread_x)
s[WS].vectorize(hwwsii)

In [11]:
def intrin_wmma_store_matrix():
    n = 16
    # Implement the compute pattern with te
    A = te.placeholder((n, n), name="A", dtype="float32")
    BA = tvm.tir.decl_buffer(
        A.shape, 
        A.dtype, 
        scope="wmma.accumulator", 
        data_alignment=32,    # in unit of "Byte" 
        offset_factor=1, 
        strides=[64, 1]
    )
    C = te.compute((n, n), lambda i, j: A[i, j], name="C")
    BC = tvm.tir.decl_buffer(
        C.shape, 
        C.dtype, 
        scope="global", 
        data_alignment=32, 
        offset_factor=1, 
        strides=[wd, 1]
    )

    def intrin_func(ins, outs):
        ib = tvm.tir.ir_builder.create()
        BA = ins[0]
        BC = outs[0]
        ib.emit(
            # source code: tvm/src/target/source/codegen_cuda.cc
            # translated in to cpp code:
            # nvcuda::wmma::store_matrix_sync(
            #     BC_ptr, 
            #     BA_data[BA_element_offset], 
            #     wd, 
            #     nvcuda::wmma::mem_row_major
            # )
            tvm.tir.call_intrin(
                "handle",
                "tir.tvm_store_matrix_sync",
                BA.data,    # op->args[0]
                n, n, n,    # op->args[1,2,3]
                BA.elem_offset//1024 * 4 + (BA.elem_offset//16) % 4,
                BC.access_ptr("w"),    # op->args[5]: w=write
                wd,         # op->args[6]
                "row_major",# op->args[7]
            )
        )
        return ib.get()
    return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})

s[B].tensorize(hiii, intrin_wmma_store_matrix())

In [12]:
print(tvm.lower(s, [A, W, B], simple_mode=True))

primfn(A_1: handle, W_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {B: Buffer(B_2: Pointer(float32), float32, [8192, 512], []),
             A: Buffer(A_2: Pointer(float16), float16, [8192, 256], []),
             W: Buffer(W_2: Pointer(float16), float16, [512, 256], [])}
  buffer_map = {A_1: A, W_1: W, B_1: B} {
  attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 32;
  allocate(B.wmma.accumulator: Pointer(wmma.accumulator float32), float32, [4096]), storage_scope = wmma.accumulator;
  allocate(A.shared: Pointer(shared float16), float16, [16384]), storage_scope = shared;
  allocate(W.shared: Pointer(shared float16), float16, [8192]), storage_scope = shared;
  allocate(A.shared.wmma.matrix_a: Pointer(wmma.matrix_a float16), float16, [1024]), storage_scope = wmma.matrix_a;
  allocate(W.shared.wmma.matrix_b: Pointer(wmma.matrix_b float16), float16, [1024]), stora

In [13]:
def intrin_wmma_gemm():
    n = 16
    # Implement the compute pattern with te
    A = te.placeholder((n, n), name="A", dtype="float16")
    B = te.placeholder((n, n), name="B", dtype="float16")
    k = te.reduce_axis((0, n), name="k")
    C = te.compute(
        (n, n),
        lambda ii, jj: te.sum(A[ii, k].astype("float") * B[jj, k].astype("float"), axis=k), name="C"
    )
    BA = tvm.tir.decl_buffer(
        A.shape, A.dtype, name="BA", scope="wmma.matrix_a", data_alignment=32, offset_factor=1, strides=[16, 1]
    )
    BB = tvm.tir.decl_buffer(
        B.shape, B.dtype, name="BB", scope="wmma.matrix_b", data_alignment=32, offset_factor=1, strides=[16, 1]
    )
    BC = tvm.tir.decl_buffer(
        C.shape, C.dtype, name="BC", scope="wmma.accumulator", data_alignment=32, offset_factor=1, strides=[64, 1]
    )

    def intrin_func(ins, outs):
        BA, BB = ins
        (BC, ) = outs

        def init():
            ib = tvm.tir.ir_builder.create()
            ib.emit(
                tvm.tir.call_intrin(
                    "handle", "tir.tvm_fill_fragment", BC.data, n, n, n, 
                    BC.elem_offset//1024*4 + (BC.elem_offset//16) % 4,
                    0.0
                )
            )
            return ib.get()
        
        def update():
            ib = tvm.tir.ir_builder.create()
            ib.emit(
                tvm.tir.call_intrin(
                    "handle", "tir.tvm_mma_sync",
                    BC.data,
                    BC.elem_offset//1024 * 4 + (BC.elem_offset//16)%4,
                    BA.data,
                    BA.elem_offset//256,
                    BB.data,
                    BB.elem_offset//256,
                    BC.data,
                    BC.elem_offset//1024 * 4 + (BC.elem_offset//16)%4,
                )
            )
            return ib.get()
        
        return update(), init(), update()
    
    return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC})

s[BF].tensorize(hbfi, intrin_wmma_gemm())

In [14]:
print(tvm.lower(s, [A, W, B], simple_mode=True))

primfn(A_1: handle, W_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {B: Buffer(B_2: Pointer(float32), float32, [8192, 512], []),
             A: Buffer(A_2: Pointer(float16), float16, [8192, 256], []),
             W: Buffer(W_2: Pointer(float16), float16, [512, 256], [])}
  buffer_map = {A_1: A, W_1: W, B_1: B} {
  attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 32;
  allocate(B.wmma.accumulator: Pointer(wmma.accumulator float32), float32, [4096]), storage_scope = wmma.accumulator;
  allocate(A.shared: Pointer(shared float16), float16, [16384]), storage_scope = shared;
  allocate(W.shared: Pointer(shared float16), float16, [8192]), storage_scope = shared;
  allocate(A.shared.wmma.matrix_a: Pointer(wmma.matrix_a float16), float16, [1024]), storage_scope = wmma.matrix_a;
  allocate(W.shared.wmma.matrix_b: Pointer(wmma.matrix_b float16), float16, [1024]), stora

In [15]:
def intrin_wmma_load_matrix_a(scope):
    n = 16
    A = te.placeholder((n, n), name="A", dtype="float16")
    BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope="shared", data_alignment=32, offset_factor=1, strides=[64, 1])
    C = te.compute((n, n), lambda i, j: A[i, j], name="C")
    BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=1, strides=[16, 1])

    def intrin_func(ins, outs):
        ib = tvm.tir.ir_builder.create()
        BA = ins[0]
        BC = outs[0]
        ib.emit(
            tvm.tir.call_intrin(
                "handle", "tir.tvm_load_matrix_sync",
                BC.data, n, n, n,
                BC.elem_offset // 256,
                BA.access_ptr("r"),
                64,
                "row_major"
            )
        )
        return ib.get()
    return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})

def intrin_wmma_load_matrix_b(scope):
    n = 16
    A = te.placeholder((n, n), name="A", dtype="float16")
    BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope="shared", data_alignment=32, offset_factor=1, strides=[64, 1])
    C = te.compute((n, n), lambda i, j: A[i, j], name="C")
    BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=1, strides=[16, 1])

    def intrin_func(ins, outs):
        ib = tvm.tir.ir_builder.create()
        BA = ins[0]
        BC = outs[0]
        ib.emit(
            tvm.tir.call_intrin(
                "handle", "tir.tvm_load_matrix_sync",
                BC.data, n, n, n,
                BC.elem_offset // 256,
                BA.access_ptr("r"),
                64,
                "col_major"
            )
        )
        return ib.get()
    return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})

s[AF].tensorize(hafi, intrin_wmma_load_matrix_a("wmma.matrix_a"))
s[WF].tensorize(wwf, intrin_wmma_load_matrix_b("wmma.matrix_b"))

In [16]:
print(tvm.lower(s, [A, W, B], simple_mode=True))

primfn(A_1: handle, W_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {B: Buffer(B_2: Pointer(float32), float32, [8192, 512], []),
             A: Buffer(A_2: Pointer(float16), float16, [8192, 256], []),
             W: Buffer(W_2: Pointer(float16), float16, [512, 256], [])}
  buffer_map = {A_1: A, W_1: W, B_1: B} {
  attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 32;
  allocate(B.wmma.accumulator: Pointer(wmma.accumulator float32), float32, [4096]), storage_scope = wmma.accumulator;
  allocate(A.shared: Pointer(shared float16), float16, [16384]), storage_scope = shared;
  allocate(W.shared: Pointer(shared float16), float16, [8192]), storage_scope = shared;
  allocate(A.shared.wmma.matrix_a: Pointer(wmma.matrix_a float16), float16, [1024]), storage_scope = wmma.matrix_a;
  allocate(W.shared.wmma.matrix_b: Pointer(wmma.matrix_b float16), float16, [1024]), stora

In [20]:
import numpy as np
import tvm.testing
dev = tvm.cuda(0)
func = tvm.build(s, [A, W, B], "cuda")

a_np = np.random.uniform(size=(n, hd)).astype(A.dtype)
w_np = np.random.uniform(size=(wd, hd)).astype(W.dtype)

a = tvm.nd.array(a_np, dev)
w = tvm.nd.array(w_np, dev)

b = tvm.nd.array(np.zeros((n, wd), dtype=B.dtype), dev)

func(a, w, b)

evaluator = func.time_evaluator(func.entry_name, dev, repeat=20, number=300)
print(evaluator(a, w, b))
print("Matmul: %f ms" % (evaluator(a, w, b).mean * 1e3))

tvm.testing.assert_allclose(b.numpy(), np.matmul(a.numpy(), np.transpose(w.numpy())), rtol = 0.001)


Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
   0.1358       0.1332       0.1447       0.1331       0.0045   
               
Matmul: 0.133138 ms
