In [1]:
import tvm
from tvm import te

In [2]:
M, N, K = 100, 100, 64
NI, KI = 2, 2
NO, KO = N // NI, K // KI
NNZ = int(NO * KO * 0.6)

In [3]:
LHS = te.placeholder((M, K), name='Data')
RHS = te.placeholder((NO, KO, NI, KI), name='Weight')

ko = te.reduce_axis((0, KO), name='ko')
def dense_kernel(m, no, ni, ki):
    return te.sum(LHS[m, ko * KI + ki] * RHS[no, ko, ni, ki], axis=ko)

C1 = te.compute((M, NO, NI, KI), dense_kernel, name='C1')
ki = te.reduce_axis((0, KI), name='ki')
C2 = te.compute((M, N), lambda m, n: te.sum(C1[m, n // NI, n % NI, ki], axis=ki), name='C2')

In [4]:
s = te.create_schedule(C2.op)

m, n = s[C2].op.axis
mo, no, mi, ni = s[C2].tile(m, n, 10, NI)
ki, = s[C2].op.reduce_axis
s[C2].unroll(ki)
s[C2].vectorize(ni)
#s[C2].unroll(mi)

s[C1].compute_at(s[C2], no)
m, no, ni, ki = s[C1].op.axis
ko, = s[C1].op.reduce_axis
s[C1].reorder(m, no, ko, ni, ki)
s[C1].unroll(ki)
s[C1].vectorize(ni)
#s[C1].unroll(m)

print(tvm.lower(s, [LHS, RHS, C2], simple_mode=True))
func = tvm.build(s, [LHS, RHS, C2])

primfn(Data_1: handle, Weight_1: handle, C2_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {Data: Buffer(Data_2: Pointer(float32), float32, [100, 64], []),
             C2: Buffer(C2_2: Pointer(float32), float32, [100, 100], []),
             Weight: Buffer(Weight_2: Pointer(float32), float32, [50, 32, 2, 2], [])}
  buffer_map = {Data_1: Data, Weight_1: Weight, C2_1: C2} {
  attr [C1: Pointer(float32x2)] "storage_scope" = "global";
  allocate(C1, float32x2, [20]);
  for (m.outer: int32, 0, 10) {
    for (n.outer: int32, 0, 50) {
      for (m: int32, 0, 10) {
        C1[ramp((m*4), 2, 2)] = broadcast(0f32, 2)
        C1[ramp(((m*4) + 1), 2, 2)] = broadcast(0f32, 2)
        for (ko: int32, 0, 32) {
          C1[ramp((m*4), 2, 2)] = ((float32x2*)C1[ramp((m*4), 2, 2)] + (broadcast((float32*)Data_2[(((m.outer*640) + (m*64)) + (ko*2))], 2)*(float32x2*)Weight_2[ramp(((n.outer*128) + (ko*4)), 2, 2)]))
          C1[ramp(((m*4) + 1), 2, 2)] = ((float32x2*)C1

In [5]:
RHS2 = te.placeholder((NNZ, NI, KI), name='Weight2')
Indptr = te.placeholder((NO+1,), name='Indptr', dtype='int')
Indices = te.placeholder((NNZ,), name='Indices', dtype='int')

def bsr_kernel(m, no, ni, ki):
    a, b = Indptr[no], Indptr[no + 1]
    k = te.reduce_axis((0, b - a), name='k')
    return te.sum(LHS[m, Indices[k + a] * KI + ki] * RHS2[k + a, ni, ki], axis=k)

C3 = te.compute((M, NO, NI, KI), bsr_kernel, name='C3')
ki = te.reduce_axis((0, KI), name='ki')
C4 = te.compute((M, N), lambda m, n: te.sum(C3[m, n // NI, n % NI, ki], axis=ki), name='C4')

In [6]:
s = te.create_schedule(C4.op)
#print(tvm.lower(s, [LHS, RHS2, Indptr, Indices, C4], simple_mode=True))
func = tvm.build(s, [LHS, RHS2, Indptr, Indices, C4])

In [7]:
A = te.placeholder((10, 10), name='A')
B = te.compute((10, 10, 10), lambda x, y, k: A[x, y] * k, name='B')
C = te.compute((10, 10, 10), lambda x, y, k: B[x, y, k] + 1, name='C')
s = te.create_schedule(C.op)
xo, yo = s[C].op.axis[:2]
s[B].compute_at(s[C], yo)
bx, by, bk = s[B].op.axis
s[B].fuse(bx, by)
print(tvm.lower(s, [A, B, C], simple_mode=True))

primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {C: Buffer(C_2: Pointer(float32), float32, [10, 10, 10], []),
             A: Buffer(A_2: Pointer(float32), float32, [10, 10], []),
             B: Buffer(B_2: Pointer(float32), float32, [10, 10, 10], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C} {
  for (x: int32, 0, 10) {
    for (y: int32, 0, 10) {
      for (k: int32, 0, 10) {
        B_2[(((x*100) + (y*10)) + k)] = ((float32*)A_2[((x*10) + y)]*cast(float32, k))
      }
      for (k_1: int32, 0, 10) {
        C_2[(((x*100) + (y*10)) + k_1)] = ((float32*)B_2[(((x*100) + (y*10)) + k_1)] + 1f32)
      }
    }
  }
}


