Calculating $B = A * W$, where both $A$ and $W$ are non-transposed (NN).

Refer the [Conv2d Cuda code](https://tvm.apache.org/docs/how_to/optimize_operators/opt_conv_cuda.html#sphx-glr-how-to-optimize-operators-opt-conv-cuda-py) to implement GEMM, i.e. GEMM is equivalent to Conv2d when $Ix=Iy=Kx=Ky=1$.

In [2]:
import torch
import torch.nn as nn
import math, time
import sys

torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import tvm
import tvm.testing
from tvm import te
import numpy
import timeit

In [28]:
m, n, k = 8192, 8192, 8192
A = te.placeholder((m, k), name="A", dtype="float32")
W = te.placeholder((k, n), name="W", dtype="float32")
rk = te.reduce_axis((0, k), name="rk")
B = te.compute(
    (m, n), 
    lambda i, j: te.sum(A[i, rk].astype("float32")*W[rk, j].astype("float32"), axis=[rk]), 
    name="B",
)

s = te.create_schedule(B.op)
AA = s.cache_read(A, "shared", [B])
WW = s.cache_read(W, "shared", [B])
AL = s.cache_read(AA, "local", [B])
WL = s.cache_read(WW, "local", [B])
BL = s.cache_write(B, "local")

tile = 8
num_thread = 8
block_factor = tile * num_thread
step = 8
vthread = 2
block_x = te.thread_axis("blockIdx.x")
block_y = te.thread_axis("blockIdx.y")
thread_x = te.thread_axis((0, num_thread), "threadIdx.x")
thread_y = te.thread_axis((0, num_thread), "threadIdx.y")
thread_xz = te.thread_axis((0, vthread), "vthread", name="vx")
thread_yz = te.thread_axis((0, vthread), "vthread", name="vy")

mi, ni = s[B].op.axis
by, mi = s[B].split(mi, factor=block_factor)
bx, ni = s[B].split(ni, factor=block_factor)
s[B].bind(bx, block_x)
s[B].bind(by, block_y)

tyz, mi = s[B].split(mi, nparts=vthread)
txz, ni = s[B].split(ni, nparts=vthread)
ty, mi = s[B].split(mi, nparts=num_thread)
tx, ni = s[B].split(ni, nparts=num_thread)

# s[B].reorder(by, bx, tyz, txz, ty, tx, ni, mi)
s[B].reorder(bx, by, txz, tyz, tx, ty, mi, ni)

s[B].bind(txz, thread_xz)
s[B].bind(tyz, thread_yz)
s[B].bind(tx, thread_x)
s[B].bind(ty, thread_y)
print(tvm.lower(s, [A, W, B], simple_mode=True))

# s[BL].compute_at(s[B], tx)
s[BL].compute_at(s[B], ty)
mi, ni = s[BL].op.axis
rk, = s[BL].op.reduce_axis

rko, rki = s[BL].split(rk, factor=step)
s[BL].reorder(rko, rki, mi, ni)
s[AA].compute_at(s[BL], rko)
s[WW].compute_at(s[BL], rko)
s[AL].compute_at(s[BL], rki)
s[WL].compute_at(s[BL], rki)

# s[BL].reorder(rk, mi, ni)
# s[AA].compute_at(s[BL], rk)
# s[WW].compute_at(s[BL], rk)
# s[AL].compute_at(s[BL], rk)
# s[WL].compute_at(s[BL], rk)

mi, ki = s[AA].op.axis
tx, mi = s[AA].split(mi, nparts=num_thread)
ty, ki = s[AA].split(ki, nparts=num_thread)
_, mi = s[AA].split(mi, factor=4)
s[AA].reorder(tx, ty, ki, mi)
s[AA].bind(tx, thread_x)
s[AA].bind(ty, thread_y)
s[AA].vectorize(mi)

ni, ki = s[WW].op.axis
tx, ni = s[WW].split(ni, nparts=num_thread)
ty, ki = s[WW].split(ki, nparts=num_thread)
_, ni = s[WW].split(ni, factor=4)
s[WW].reorder(tx, ty, ki, ni)
s[WW].bind(tx, thread_x)
s[WW].bind(ty, thread_y)
s[WW].vectorize(ni)
print(tvm.lower(s, [A, W, B], simple_mode=True))


import numpy as np
func = tvm.build(s, [A, W, B], "cuda")
dev = tvm.cuda(0)
a_np = np.random.uniform(size=(m,k)).astype(A.dtype)
w_np = np.random.uniform(size=(k,n)).astype(W.dtype)
b_np = np.zeros((m,n)).astype(B.dtype) 

a = tvm.nd.array(a_np, dev)
w = tvm.nd.array(w_np, dev)
b = tvm.nd.array(b_np, dev)
func(a, w, b)
evaluator = func.time_evaluator(func.entry_name, dev, number=1)
latency = evaluator(a, w, b).mean
tput = (m*n*k)*2/latency
print("MatMul: %f ms, %.2f TFLOPS" % (evaluator(a, w, b).mean * 1e3, tput/1e12))
tvm.testing.assert_allclose(b.numpy(), np.matmul(a.numpy(), w.numpy()), rtol = 1e-2)

primfn(A_1: handle, W_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [8192, 8192], []),
             B: Buffer(B_2: Pointer(float32), float32, [8192, 8192], []),
             W: Buffer(W_2: Pointer(float32), float32, [8192, 8192], [])}
  buffer_map = {A_1: A, W_1: W, B_1: B} {
  allocate(A.shared: Pointer(shared float32), float32, [32768]), storage_scope = shared;
  allocate(A.shared.local: Pointer(local float32), float32, [32768]), storage_scope = local;
  allocate(W.shared.local: Pointer(local float32), float32, [32768]), storage_scope = local;
  allocate(B.local: Pointer(local float32), float32, [16]), storage_scope = local {
    for (ax0: int32, 0, 4) {
      for (ax1: int32, 0, 8192) {
        A.shared[((ax0*8192) + ax1)] = (float32*)A_2[(((((blockIdx.y: int32*524288) + (vy: int32*262144)) + (threadIdx.y: int32*32768)) + (ax0*8192)) + ax1)]
      }
    }
 