In [1]:
%pylab notebook
%autocall
import scipy.sparse

Populating the interactive namespace from numpy and matplotlib
Automatic calling is: Smart


In [2]:
import tvm
from tvm import te, tir
# llc --version
target = "llvm -mcpu=znver2"  # cascadelake
ctx = tvm.context(target, 0)
vec = 8


def idxsplit(idx, dim, *dim2):
    if dim2:
        idx, *lower = idxsplit(idx, *dim2)
    else:
        lower = []
    return (idx // dim, idx % dim, *lower)


class TVMRunner:
    def __init__(self, name, params):
        self.name = name
        self.params = params
    
    def lower(self):
        return tvm.lower(*self.params, simple_mode=True)
    
    def _wrap_args(self, args):
        def _wrap_single(item, item2):
            if isinstance(item, np.ndarray):
                return tvm.nd.array(item, ctx)
            elif isinstance(item, type):
                shape = [it.value for it in item2.shape]
                return tvm.nd.array(np.zeros(shape, dtype=item), ctx)
            else:
                raise Exception('unknown arg', item)

        _, args2 = self.params
        realargs = [_wrap_single(it, it2) for it, it2 in zip(args, args2)]
        return realargs
    
    def __call__(self, *args):
        func = tvm.build(*self.params, target=target, name=self.name)
        realargs = self._wrap_args(args)
        func(*realargs)
        return realargs
    
    def time_eval(self, *args, number=10):
        func = tvm.build(*self.params, target=target, name=self.name)
        realargs = self._wrap_args(args)
        evaluator = func.time_evaluator(func.entry_name, ctx, number=number)
        print(evaluator(*realargs).mean)
        return realargs

In [3]:
weight_oihw = np.random.rand(64, 64, 3, 3).astype('float32')
weight_ohwi = np.moveaxis(weight_oihw, 1, -1)
weight_ohwi_flat = weight_ohwi.reshape((weight_ohwi.shape[0], -1))
nchw_data = np.random.randint(0, 256, (10, 64, 256, 256)).astype('float32')
nhwc_data = np.moveaxis(nchw_data, 1, -1)

del weight_oihw
del weight_ohwi
del nchw_data

In [4]:
def make_bsr_sparse(dense, sprate, blocksize):
    bsrdata = scipy.sparse.bsr_matrix(dense, blocksize=blocksize)
    # find partition value
    summed = bsrdata.data.sum((1, 2))
    idx = int(sprate * len(summed) + 0.5)
    val = np.partition(summed, idx)[idx]
    # filter the data
    data, indices, indptr, bsrWid = [], [], [], bsrdata.indptr[1]
    for idx, (block, indval) in enumerate(zip(bsrdata.data, bsrdata.indices)):
        if idx % bsrWid == 0:
            indptr.append(len(data))
        if block.sum() <= val:
            data.append(block)
            indices.append(indval)
    indptr.append(len(data))
    # convert format
    bsrdata2 = tuple([np.array(i) for i in [data, indices, indptr]])
    return scipy.sparse.bsr_matrix(bsrdata2, shape=dense.shape)


def unpack_bsr(bsrdata):
    return bsrdata.data, bsrdata.indices, bsrdata.indptr

In [5]:
def create_nhwc_im2col(data):
    N, H, W, C = inshape = data.shape
    A = te.placeholder(inshape, name='A')

    def im2col_kernel(row, col):
        jn, jh, jw = idxsplit(row, H, W)
        kh, kw, jc = idxsplit(col, 3, C)
        ih, iw = jh + kh - 1, jw + kw - 1
        return tir.if_then_else(
            tir.all(0 <= ih, ih < H, 0 <= iw, iw < W),
            A[jn, ih, iw, jc], 0)

    outshape = (N*H*W, 9*C)
    B = te.compute(outshape, im2col_kernel, name='B')

    def im2col_schedule(CC):
        s = te.create_schedule(CC.op)
        _, coldim = s[CC].op.axis
        _, chandim = s[CC].split(coldim, factor=C)
        s[CC].vectorize(chandim)
        return s
    
    s = im2col_schedule(B)
    return s, [A, B]


tr = TVMRunner('im2col', create_nhwc_im2col(nhwc_data))
_, ret = tr.time_eval(nhwc_data, np.float32)
nhwkkc_data = ret.asnumpy()
kkcnhw_data = nhwkkc_data.T
del ret

0.18997756999999998


# Dense Trans GEMM

In [6]:
%timeit nhwkkc_data @ weight_ohwi_flat.T
%timeit np.tensordot(nhwkkc_data, weight_ohwi_flat, [[1], [1]])

2.15 s ± 44.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.16 s ± 53.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [13]:
def create_dense_trans_gemm(data, weight):
    M, K = data.shape
    N, _ = weight.shape
    A = te.placeholder((M, K), name='A')
    B = te.placeholder((N, K), name='B')
    k = te.reduce_axis((0, K // vec), name='k')
    CC = te.compute((M, N, vec),
                    lambda m, n, z: te.sum(A[m, k*vec + z] * B[n, k*vec + z], axis=k), name='CC')
    kk = te.reduce_axis((0, vec), name='kk')
    C = te.compute((M, N), lambda m, n: te.sum(CC[m, n, kk], axis=kk), name='C')

    def create_dense_schedule(C):
        s = te.create_schedule(C.op)
        x, y = s[C].op.axis
        (kk,) = s[C].op.reduce_axis
        xo, yo, xi, yi = s[C].tile(x, y, 16, 8)
        s[C].unroll(kk)

        (CC,) = s[C].op.input_tensors
        s[CC].compute_at(s[C], yo)
        x, y, z = s[CC].op.axis
        (k,) = s[CC].op.reduce_axis
        ko, ki = s[CC].split(k, factor=4)
        s[CC].reorder(ko, x, y, ki, z)
        s[CC].vectorize(z)
        s[CC].unroll(ki)
        return s

    s = create_dense_schedule(C)
    return s, [A, B, C]


tr = TVMRunner('dense_trans_gemm', create_dense_trans_gemm(nhwkkc_data, weight_ohwi_flat))
tr.time_eval(nhwkkc_data, weight_ohwi_flat, np.float32)
None

0.76986363


# Sparse NonTrans GEMM

In [44]:
bsr_2x1 = make_bsr_sparse(weight_ohwi_flat, 0.25, (1, 1))
#%timeit bsr_2x1 * kkcnhw_data

In [45]:
def create_nontrans_gemm(bsr, dense):
    M, K, _, N, bsrR, bsrC = *bsr.shape, *dense.shape, *bsr.blocksize
    bsrdata, bsrindices, bsrindptr = unpack_bsr(bsr)
    Wdat = te.placeholder(bsrdata.shape, name='Wdat')
    Wind = te.placeholder(bsrindices.shape, dtype='int', name='Wind')
    Wptr = te.placeholder(bsrindptr.shape, dtype='int', name='Wptr')
    Data = te.placeholder(dense.shape, name='Data')
    
    def bsr_gemm_kernel(wrow, brow, dcol, bcol):
        row_start, row_end = Wptr[wrow], Wptr[wrow+1]
        elem_idx = te.reduce_axis((0, row_end - row_start), name='elem_idx')
        elem = row_start + elem_idx
        return te.sum(Data[Wind[elem]*bsrC + bcol, dcol] * Wdat[elem, brow, bcol], axis=elem_idx)

    CC = te.compute((M // bsrR, bsrR, N, bsrC), bsr_gemm_kernel, name='CC')
    k = te.reduce_axis((0, bsrC), name='k')
    C = te.compute((M, N), lambda m, n: te.sum(CC[m // bsrR, m % bsrR, n, k], axis=k), name='C')
    
    def create_bsr_gemm_schedule(C, CC):
        s = te.create_schedule(C.op)
        md, nd = s[C].op.axis
        (kd,) = s[C].op.reduce_axis
        s[C].unroll(kd)
        
        md, rd, nd, cd = s[CC].op.axis
        (ed,) = s[CC].op.reduce_axis
        nd1, nd2 = s[CC].split(nd, factor=16*vec)
        s[CC].reorder(nd1, md, ed, rd, nd2, cd)
        s[CC].unroll(cd)
        s[CC].vectorize(nd2)
        s[CC].unroll(rd)
        return s
    
    s = create_bsr_gemm_schedule(C, CC)
    return s, [Wdat, Wind, Wptr, Data, C]


tr = TVMRunner('bsr_nontrans_gemm', create_nontrans_gemm(bsr_2x1, kkcnhw_data))
print(tr.lower())
tr.time_eval(*unpack_bsr(bsr_2x1), kkcnhw_data, np.float32)
None

primfn(Wdat_1: handle, Wind_1: handle, Wptr_1: handle, Data_1: handle, C_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {Wptr: Buffer(Wptr_2: Pointer(int32), int32, [65], []),
             Data: Buffer(Data_2: Pointer(float32), float32, [576, 655360], []),
             C: Buffer(C_2: Pointer(float32), float32, [64, 655360], []),
             Wdat: Buffer(Wdat_2: Pointer(float32), float32, [9217, 1, 1], []),
             Wind: Buffer(Wind_2: Pointer(int32), int32, [9217], [])}
  buffer_map = {Wdat_1: Wdat, Wind_1: Wind, C_1: C, Data_1: Data, Wptr_1: Wptr} {
  attr [CC: Pointer(float32)] "storage_scope" = "global";
  allocate(CC, float32, [41943040]) {
    for (dcol.outer: int32, 0, 5120) {
      for (wrow: int32, 0, 64) {
        CC[ramp(((wrow*655360) + (dcol.outer*128)), 1, 128)] = broadcast(0f32, 128)
        for (elem_idx: int32, 0, ((int32*)Wptr_2[(wrow + 1)] - (int32*)Wptr_2[wrow])) {
          CC[ramp(((wrow*655360) + (dcol.outer*128)), 1, 12

In [37]:
bsr_2x1.data.shape

(18433, 1, 1)