In [2]:
%pylab notebook
%autocall
import scipy.sparse

Populating the interactive namespace from numpy and matplotlib
Automatic calling is: Smart


In [3]:
import tvm
from tvm import te, tir
# llc --version
target = "llvm -mcpu=znver2"  # cascadelake
ctx = tvm.context(target, 0)
vec = 8


def idxsplit(idx, dim, *dim2):
    if dim2:
        idx, *lower = idxsplit(idx, *dim2)
    else:
        lower = []
    return (idx // dim, idx % dim, *lower)


class TVMRunner:
    def __init__(self, name, params):
        self.name = name
        self.params = params
    
    def lower(self):
        return tvm.lower(*self.params, simple_mode=True)
    
    def _wrap_args(self, args):
        def _wrap_single(item, item2):
            if isinstance(item, np.ndarray):
                return tvm.nd.array(item, ctx)
            elif isinstance(item, type):
                shape = [it.value for it in item2.shape]
                return tvm.nd.array(np.zeros(shape, dtype=item), ctx)
            else:
                raise Exception('unknown arg', item)

        _, args2 = self.params
        realargs = [_wrap_single(it, it2) for it, it2 in zip(args, args2)]
        return realargs
    
    def __call__(self, *args):
        func = tvm.build(*self.params, target=target, name=self.name)
        realargs = self._wrap_args(args)
        func(*realargs)
        return realargs
    
    def time_eval(self, *args, number=10):
        func = tvm.build(*self.params, target=target, name=self.name)
        realargs = self._wrap_args(args)
        evaluator = func.time_evaluator(func.entry_name, ctx, number=number)
        print(evaluator(*realargs).mean)
        return realargs

In [4]:
weight_oihw = np.random.rand(64, 64, 3, 3).astype('float32')
weight_ohwi = np.moveaxis(weight_oihw, 1, -1)
weight_ohwi_flat = weight_ohwi.reshape((weight_ohwi.shape[0], -1))
nchw_data = np.random.randint(0, 256, (10, 64, 256, 256)).astype('float32')
nhwc_data = np.moveaxis(nchw_data, 1, -1)

del weight_oihw
del weight_ohwi
del nchw_data

In [5]:
def make_bsr_sparse(dense, sprate, blocksize):
    bsrdata = scipy.sparse.bsr_matrix(dense, blocksize=blocksize)
    # find partition value
    summed = bsrdata.data.sum((1, 2))
    idx = int(sprate * len(summed) + 0.5)
    val = np.partition(summed, idx)[idx]
    # filter the data
    data, indices, indptr, bsrWid = [], [], [], bsrdata.indptr[1]
    for idx, (block, indval) in enumerate(zip(bsrdata.data, bsrdata.indices)):
        if idx % bsrWid == 0:
            indptr.append(len(data))
        if block.sum() >= val:
            data.append(block)
            indices.append(indval)
    indptr.append(len(data))
    # convert format
    bsrdata2 = tuple([np.array(i) for i in [data, indices, indptr]])
    return scipy.sparse.bsr_matrix(bsrdata2, shape=dense.shape)


def unpack_bsr(bsrdata):
    return bsrdata.data, bsrdata.indices, bsrdata.indptr

In [6]:
def create_nhwc_im2col(data):
    N, H, W, C = inshape = data.shape
    A = te.placeholder(inshape, name='A')

    def im2col_kernel(row, col):
        jn, jh, jw = idxsplit(row, H, W)
        kh, kw, jc = idxsplit(col, 3, C)
        ih, iw = jh + kh - 1, jw + kw - 1
        return tir.if_then_else(
            tir.all(0 <= ih, ih < H, 0 <= iw, iw < W),
            A[jn, ih, iw, jc], 0)

    outshape = (N*H*W, 9*C)
    B = te.compute(outshape, im2col_kernel, name='B')

    def im2col_schedule(CC):
        s = te.create_schedule(CC.op)
        _, coldim = s[CC].op.axis
        _, chandim = s[CC].split(coldim, factor=C)
        s[CC].vectorize(chandim)
        return s
    
    s = im2col_schedule(B)
    return s, [A, B]


tr = TVMRunner('im2col', create_nhwc_im2col(nhwc_data))
_, ret = tr.time_eval(nhwc_data, np.float32)
nhwkkc_data = ret.asnumpy()
kkcnhw_data = nhwkkc_data.T
del ret

0.18924946


# Dense Trans GEMM

In [7]:
#%timeit nhwkkc_data @ weight_ohwi_flat.T
#%timeit np.tensordot(nhwkkc_data, weight_ohwi_flat, [[1], [1]])

In [73]:
def create_dense_trans_gemm(data, weight):
    M, K, N, _ = *data.shape, *weight.shape
    A = te.placeholder((M, K), name='A')
    B = te.placeholder((N, K), name='B')
    kk = te.reduce_axis((0, K // vec), name='kk')
    CC = te.compute((M, N, vec),
                    lambda m, n, v: te.sum(A[m, kk*vec + v] * B[n, kk*vec + v], axis=kk), name='CC')
    kv = te.reduce_axis((0, vec), name='kv')
    C = te.compute((M, N), lambda m, n: te.sum(CC[m, n, kv], axis=kv), name='C')

    def create_dense_schedule(C, CC):
        s = te.create_schedule(C.op)
        m, n = s[C].op.axis
        (kv,) = s[C].op.reduce_axis
        mo, no, mi, ni = s[C].tile(m, n, 16, 8)
        s[C].unroll(kv)

        s[CC].compute_at(s[C], no)
        mi, ni, v = s[CC].op.axis
        (kk,) = s[CC].op.reduce_axis
        ko, ki = s[CC].split(kk, factor=4)
        s[CC].reorder(ko, mi, ni, ki, v)
        s[CC].vectorize(v)
        s[CC].unroll(ki)
        return s

    s = create_dense_schedule(C, CC)
    return s, [A, B, C]


tr = TVMRunner('dense_trans_gemm', create_dense_trans_gemm(nhwkkc_data, weight_ohwi_flat))
print(tr.lower())
tr.time_eval(nhwkkc_data, weight_ohwi_flat, np.float32)
None

primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {C: Buffer(C_2: Pointer(float32), float32, [655360, 64], []),
             A: Buffer(A_2: Pointer(float32), float32, [655360, 576], []),
             B: Buffer(B_2: Pointer(float32), float32, [64, 576], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C} {
  attr [CC: Pointer(float32)] "storage_scope" = "global";
  allocate(CC, float32, [1024]);
  for (m.outer: int32, 0, 40960) {
    for (n.outer: int32, 0, 8) {
      for (m.init: int32, 0, 16) {
        for (n.init: int32, 0, 8) {
          CC[ramp(((m.init*64) + (n.init*8)), 1, 8)] = broadcast(0f32, 8)
        }
      }
      for (kk.outer: int32, 0, 18) {
        for (m: int32, 0, 16) {
          for (n: int32, 0, 8) {
            CC[ramp(((m*64) + (n*8)), 1, 8)] = ((float32x8*)CC[ramp(((m*64) + (n*8)), 1, 8)] + ((float32x8*)A_2[ramp((((m.outer*9216) + (m*576)) + (kk.outer*32)), 1, 8)]*(float32x8*)B_2[ramp((((n.outer*4608

# Sparse NonTrans GEMM

In [71]:
bsr_2x1 = make_bsr_sparse(weight_ohwi_flat, 0.5, (2, 1))
#%timeit bsr_2x1 * kkcnhw_data

In [72]:
def create_nontrans_gemm(bsr, dense):
    M, K, _, N, bsrR, bsrC = *bsr.shape, *dense.shape, *bsr.blocksize
    bsrdata, bsrindices, bsrindptr = unpack_bsr(bsr)
    Wdat = te.placeholder(bsrdata.shape, name='Wdat')
    Wind = te.placeholder(bsrindices.shape, dtype='int', name='Wind')
    Wptr = te.placeholder(bsrindptr.shape, dtype='int', name='Wptr')
    Data = te.placeholder(dense.shape, name='Data')
    
    def bsr_gemm_kernel(wrow, brow, dcol, bcol):
        row_start, row_end = Wptr[wrow], Wptr[wrow+1]
        elem_idx = te.reduce_axis((0, row_end - row_start), name='elem_idx')
        elem = row_start + elem_idx
        return te.sum(Data[Wind[elem]*bsrC + bcol, dcol] * Wdat[elem, brow, bcol], axis=elem_idx)

    CC = te.compute((M // bsrR, bsrR, N, bsrC), bsr_gemm_kernel, name='CC')
    k = te.reduce_axis((0, bsrC), name='k')
    C = te.compute((M, N), lambda m, n: te.sum(CC[m // bsrR, m % bsrR, n, k], axis=k), name='C')
    
    def create_bsr_gemm_schedule(C, CC):
        s = te.create_schedule(C.op)
        md, nd = s[C].op.axis
        kd = s[C].op.reduce_axis[0]
        md, nd1, rd, nd2 = s[C].tile(md, nd, bsrR, 16*vec)
        s[C].reorder(nd1, md, nd2, rd, kd)
        s[C].unroll(kd)
        s[C].unroll(rd)
        nd2a, nd2b = s[C].split(nd2, nparts=8)
        s[C].vectorize(nd2b)
        
        s[CC].compute_at(s[C], md)
        md, rd, nd2, cd = s[CC].op.axis
        (ed,) = s[CC].op.reduce_axis
        s[CC].reorder(md, ed, nd2, rd, cd)
        s[CC].unroll(cd)
        s[CC].unroll(rd)
        nd2a, nd2b = s[CC].split(nd2, nparts=8)
        s[CC].vectorize(nd2b)
        return s
    
    s = create_bsr_gemm_schedule(C, CC)
    return s, [Wdat, Wind, Wptr, Data, C]


tr = TVMRunner('bsr_nontrans_gemm', create_nontrans_gemm(bsr_2x1, kkcnhw_data))
print(tr.lower())
tr.time_eval(*unpack_bsr(bsr_2x1), kkcnhw_data, np.float32)
None

primfn(Wdat_1: handle, Wind_1: handle, Wptr_1: handle, Data_1: handle, C_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {Wptr: Buffer(Wptr_2: Pointer(int32), int32, [33], []),
             Wind: Buffer(Wind_2: Pointer(int32), int32, [9216], []),
             Wdat: Buffer(Wdat_2: Pointer(float32), float32, [9216, 2, 1], []),
             C: Buffer(C_2: Pointer(float32), float32, [64, 655360], []),
             Data: Buffer(Data_2: Pointer(float32), float32, [576, 655360], [])}
  buffer_map = {Wind_1: Wind, Data_1: Data, C_1: C, Wptr_1: Wptr, Wdat_1: Wdat} {
  attr [CC: Pointer(float32x16)] "storage_scope" = "global";
  allocate(CC, float32x16, [16]);
  for (n.outer: int32, 0, 5120) {
    for (m.outer: int32, 0, 32) {
      for (dcol.outer.init: int32, 0, 8) {
        CC[ramp((dcol.outer.init*16), 1, 16)] = broadcast(0f32, 16)
        CC[ramp(((dcol.outer.init*16) + 128), 1, 16)] = broadcast(0f32, 16)
      }
      for (elem_idx: int32, 0, ((int32*)W

# Sparse Trans GEMM

In [94]:
bsr_1x2 = make_bsr_sparse(weight_ohwi_flat, 0.5, (2, 2))
#%timeit nhwkkc_data * bsr_1x2.T

In [109]:
def create_trans_gemm(dense, bsr):
    M, K, N, _, bsrR, bsrC = *dense.shape, *bsr.shape, *bsr.blocksize
    bsrdata, bsrindices, bsrindptr = unpack_bsr(bsr)
    Data = te.placeholder(dense.shape, name='Data')
    Wdat = te.placeholder(bsrdata.shape, name='Wdat')
    Wind = te.placeholder(bsrindices.shape, dtype='int', name='Wind')
    Wptr = te.placeholder(bsrindptr.shape, dtype='int', name='Wptr')
    
    def bsr_gemm_kernel(drow, wrow, brow, bcol):
        row_start, row_end = Wptr[wrow], Wptr[wrow+1]
        elem_idx = te.reduce_axis((0, row_end - row_start), name='elem_idx')
        elem = row_start + elem_idx
        return te.sum(Data[drow, Wind[elem]*bsrC + bcol] * Wdat[elem, brow, bcol], axis=elem_idx)

    CC = te.compute((M, N // bsrR, bsrR, bsrC), bsr_gemm_kernel, name='CC')
    k = te.reduce_axis((0, bsrC), name='k')
    C = te.compute((M, N), lambda m, n: te.sum(CC[m, n // bsrR, n % bsrR, k], axis=k), name='C')
    
    def create_bsr_gemm_schedule(C, CC):
        s = te.create_schedule(C.op)
        md, nd = s[C].op.axis
        md1, nd, md2, rd = s[C].tile(md, nd, 16*vec, bsrR)
        cd = s[C].op.reduce_axis[0]
        s[C].unroll(cd)
        s[C].unroll(rd)
        md2a, md2b = s[C].split(md2, nparts=8)
        s[C].vectorize(md2b)
        
        s[CC].compute_at(s[C], nd)
        md2, nd0, rd, cd = s[CC].op.axis
        md2a, md2b = s[CC].split(md2, nparts=8)
        ed = s[CC].op.reduce_axis[0]
        s[CC].reorder(nd0, md2a, ed, md2b, rd, cd)
        s[CC].unroll(cd)
        s[CC].unroll(rd)
        s[CC].vectorize(md2b)
        return s
    
    s = create_bsr_gemm_schedule(C, CC)
    return s, [Data, Wdat, Wind, Wptr]


tr = TVMRunner('bsr_trans_gemm', create_trans_gemm(nhwkkc_data, bsr_1x2))
print(tr.lower())
tr.time_eval(nhwkkc_data, *unpack_bsr(bsr_1x2), np.float32)
None

primfn(Data_1: handle, Wdat_1: handle, Wind_1: handle, Wptr_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {Wind: Buffer(Wind_2: Pointer(int32), int32, [4608], []),
             Wptr: Buffer(Wptr_2: Pointer(int32), int32, [33], []),
             Data: Buffer(Data_2: Pointer(float32), float32, [655360, 576], []),
             Wdat: Buffer(Wdat_2: Pointer(float32), float32, [4608, 2, 2], [])}
  buffer_map = {Data_1: Data, Wdat_1: Wdat, Wind_1: Wind, Wptr_1: Wptr} {
  attr [C: Pointer(float32x16)] "storage_scope" = "global";
  allocate(C, float32x16, [2621440]);
  attr [CC: Pointer(float32x16)] "storage_scope" = "global";
  allocate(CC, float32x16, [32]);
  for (m.outer: int32, 0, 5120) {
    for (n.outer: int32, 0, 32) {
      for (drow.outer: int32, 0, 8) {
        CC[ramp((drow.outer*64), 4, 16)] = broadcast(0f32, 16)
        CC[ramp(((drow.outer*64) + 1), 4, 16)] = broadcast(0f32, 16)
        CC[ramp(((drow.outer*64) + 2), 4, 16)] = broadcast(0f32