In [1]:
%pylab notebook
import pandas as pd
import scipy.sparse

Populating the interactive namespace from numpy and matplotlib


In [2]:
import sys
import os
TVMPATH = os.path.expanduser('~/tvm/python')
if TVMPATH not in sys.path:
    sys.path.append(TVMPATH)

In [3]:
def get_weights():
    tot = np.fromfile('../dat.bin', dtype=np.float32)
    channs, idx = [], [0]
    with open('../fmt.txt') as f:
        for case in f:
            if ' ' not in case: continue
            ch, *_ = [int(i) for i in case.split()]
            channs.append(ch)
            idx.append(idx[-1] + ch**2 * 9)
    split = np.split(tot, idx[1:-1])
    return [np.moveaxis(it.reshape((c, c, 3, 3)), 1, -1)
             for it, c in zip(split, channs)]

weight_list = get_weights()
wei = weight_list[0]

In [4]:
import tvm
from tvm import te, tir
target = "llvm -mcpu=cascadelake"
ctx = tvm.context(target, 0)
vec = 16


def idxsplit(idx, dim, *dim2):
    if dim2:
        idx, *lower = idxsplit(idx, *dim2)
    else:
        lower = []
    return (idx // dim, idx % dim, *lower)

In [5]:
chann, *_ = wei.shape
hw = 64 * 256 // chann
nbatch = 10
data = np.random.randint(0, 256, (nbatch, hw, hw, chann)).astype('float32')

# create compute

A = te.placeholder((nbatch, hw, hw, chann), name='A')

def im2col_kernel(row, col):
    jn, jh, jw = idxsplit(row, hw, hw)
    kh, kw, jc = idxsplit(col, 3, chann)
    ih, iw = jh + kh - 1, jw + kw - 1
    return tir.if_then_else(
        tir.all(0 <= ih, ih < hw, 0 <= iw, iw < hw),
        A[jn, ih, iw, jc], 0)

outshape = (nbatch*hw*hw, chann*9)
B = te.compute(outshape, im2col_kernel, name='B')

# create schedule & func

def im2col_schedule(C):
    s = te.create_schedule(C.op)
    _, bcdim = s[C].op.axis
    _, bcd2 = s[C].split(bcdim, factor=chann)
    s[C].vectorize(bcd2)
    return s

s = im2col_schedule(B)
func = tvm.build(s, [A, B], target=target, name='im2col')

# call func
a = tvm.nd.array(data, ctx)
b = tvm.nd.array(np.zeros(outshape, dtype='float32'), ctx)
func(a, b)
im2col = b.asnumpy()

# evaluate func
evaluator = func.time_evaluator(func.entry_name, ctx, number=10)
evaluator(a, b).mean

0.1727696005

In [6]:
# dense
wei2 = wei.reshape((chann, -1))
M, K = im2col.shape
N, _ = wei2.shape
A = te.placeholder((M, K), name='A')
B = te.placeholder((N, K), name='B')
k = te.reduce_axis((0, K // vec), name='k')
CC = te.compute((M, N, vec),
                lambda m, n, z: te.sum(A[m, k*vec + z] * B[n, k*vec + z], axis=k), name='CC')
kk = te.reduce_axis((0, vec), name='kk')
C = te.compute((M, N), lambda m, n: te.sum(CC[m, n, kk], axis=kk), name='C')

def create_dense_schedule(C):
    s = te.create_schedule(C.op)
    x, y = s[C].op.axis
    (kk,) = s[C].op.reduce_axis
    xo, yo, xi, yi = s[C].tile(x, y, 16, 8)
    s[C].unroll(kk)
    
    (CC,) = s[C].op.input_tensors
    s[CC].compute_at(s[C], yo)
    x, y, z = s[CC].op.axis
    (k,) = s[CC].op.reduce_axis
    ko, ki = s[CC].split(k, factor=4)
    s[CC].reorder(ko, x, y, ki, z)
    s[CC].vectorize(z)
    s[CC].unroll(ki)
    return s

s = create_dense_schedule(C)
func = tvm.build(s, [A, B, C], target=target, name='dense')

a = tvm.nd.array(im2col, ctx)
b = tvm.nd.array(wei2, ctx)
c = tvm.nd.array(np.zeros((M, N), dtype='float32'), ctx)
func(a, b, c)

evaluator = func.time_evaluator(func.entry_name, ctx, number=1, repeat=20)
#evaluator(a, b, c).mean
pd.Series(evaluator(a, b, c).results).median()

0.6164905199999999

尴尬，NHWC是$data \times weight$，CSR还真没法利用向量化。整体转置吧。

In [29]:
# CSR
wvals = np.abs(wei2.reshape((-1,)))
wvals.sort()
thres = wvals[int(0.5 * len(wvals))]
spwei = scipy.sparse.csr_matrix(np.where(np.abs(wei2) < thres, 0, wei2))
cntvals = spwei.data.size
im2row = im2col.T

In [30]:
Data = te.placeholder((K, M), name='Data')
Wdat = te.placeholder((cntvals,), name='Wdat')
Wind = te.placeholder((cntvals,), name='Wind', dtype='int')
Wptr = te.placeholder((N+1,), name='Wptr', dtype='int')

def csr_dense_kernel(wrow, drow):
    row_start, row_end = Wptr[wrow], Wptr[wrow+1]
    elem_idx = te.reduce_axis((0, row_end - row_start), name='elem_idx')
    elem = row_start + elem_idx
    return te.sum(Data[Wind[elem], drow] * Wdat[elem], axis=elem_idx)

C = te.compute((N, M), csr_dense_kernel, name='C')

def csr_dense_schedule(C):
    s = te.create_schedule(C.op)
    ckk, nhw = s[C].op.axis
    (ei,) = s[C].op.reduce_axis
    no, ni = s[C].split(nhw, factor=32)
    s[C].reorder(ckk, no, ei, ni)
    s[C].vectorize(ni)
    return s

s = csr_dense_schedule(C)
func = tvm.build(s, [Data, Wdat, Wind, Wptr, C], target=target, name='csr_dense')

data = tvm.nd.array(im2row, ctx)
wdat = tvm.nd.array(spwei.data, ctx)
wind = tvm.nd.array(spwei.indices, ctx)
wptr = tvm.nd.array(spwei.indptr, ctx)
ret = tvm.nd.array(np.zeros((N, M), dtype='float32'), ctx)
func(data, wdat, wind, wptr, ret)

evaluator = func.time_evaluator(func.entry_name, ctx, number=1, repeat=2)
pd.Series(evaluator(data, wdat, wind, wptr, ret).results).median()

6.3932800525

看起来需要一种新的稀疏储存格式，若干行为一段储存，每段内部按列储存。这样A的连续若干行在迭代时只需要做一轮B的行遍历。

一种分段的CSC储存方式？

啊哈哈好像不行。TVM做不了随机写入，CSC类的方案都有问题。只能像之前的方案那样做全对齐。

In [31]:
print(tvm.lower(s, [Data, Wdat, Wind, Wptr, C]))

primfn(Data_1: handle, Wdat_1: handle, Wind_1: handle, Wptr_1: handle, C_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {Data: Buffer(Data_2: Pointer(float32), float32, [576, 655360], []),
             Wptr: Buffer(Wptr_2: Pointer(int32), int32, [65], []),
             Wind: Buffer(Wind_2: Pointer(int32), int32, [18432], []),
             C: Buffer(C_2: Pointer(float32), float32, [64, 655360], []),
             Wdat: Buffer(Wdat_2: Pointer(float32), float32, [18432], [])}
  buffer_map = {Wptr_1: Wptr, Wdat_1: Wdat, C_1: C, Data_1: Data, Wind_1: Wind} {
  for (wrow: int32, 0, 64) {
    for (drow.outer: int32, 0, 20480) {
      C_2[ramp(((wrow*655360) + (drow.outer*32)), 1, 32)] = broadcast(0f32, 32)
      for (elem_idx: int32, 0, ((int32*)Wptr_2[(wrow + 1)] - (int32*)Wptr_2[wrow])) {
        C_2[ramp(((wrow*655360) + (drow.outer*32)), 1, 32)] = ((float32x32*)C_2[ramp(((wrow*655360) + (drow.outer*32)), 1, 32)] + ((float32x32*)Data_2[ramp((((int32*)Wi