In [1]:
import tvm
import numpy as np
import timeit

In [2]:
# (M, K) X (K, N)
M = K = N = 2**10
dtype = 'float32'
target = tvm.target.Target(target='llvm -mcpu=alderlake', host='llvm -mcpu=goldmont')
dev = tvm.device(target.kind.name, 0)

In [3]:
a = tvm.nd.array(np.random.rand(M, K).astype(dtype), dev)
b = tvm.nd.array(np.random.rand(K, N).astype(dtype), dev)
np_repeat = 100
np_running_time = timeit.timeit(
    setup='import numpy as np\n'
    'M=K=N=2**10\n'
    'dtype="float32"\n'
    'a=np.random.rand(M, K).astype(dtype)\n'
    'b=np.random.rand(K, N).astype(dtype)\n',
    stmt='answer = np.dot(a, b)',
    number=np_repeat
)

print('Numpy running time: %f' % (np_running_time / np_repeat))

answer = np.dot(a.numpy(), b.numpy())

Numpy running time: 0.003225


In [4]:
# TVM Matmul w/ TE
from tvm import te
import tvm.testing
k = te.reduce_axis((0, K), 'k')
A = te.placeholder((M, K), dtype, 'A')
B = te.placeholder((K, N), dtype, 'B')
C = te.compute((M, N), lambda x, y: te.sum(A[x, k] * B[k, y], axis=k), name='C')

s = te.create_schedule(C.op)
func = tvm.build(s, [A, B, C], target=target, name='matmul')

c = tvm.nd.array(np.zeros((M, N), dtype=dtype), dev)
func(a, b, c)

tvm.testing.assert_allclose(c.numpy(), answer, rtol=1e-5)

In [5]:
def eval_op(s, vars, tgt, name, opt, log) -> None:
    func = tvm.build(s, vars, target=tgt, name=name)
    assert func
    
    dev = tvm.device(tgt.kind.name, 0)
    c = tvm.nd.array(np.zeros((M, N), dtype), dev)
    func(a, b, c)
    tvm.testing.assert_allclose(c.numpy(), answer, rtol=1e-5)
    
    evalor = func.time_evaluator(func.entry_name, dev, number=1)
    mean_time = evalor(a, b, c).mean
    print('%s: %f' % (opt, mean_time))
    log.append((opt, mean_time))
    
log = []
# eval_op(s, [A, B, C], target,'matmul', 'none', log)

In [6]:
# low level code, C here, unopt. w/ 4 loops
print(tvm.lower(s, [A, B, C], simple_mode=True))

@main = primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [1048576], []),
             B: Buffer(B_2: Pointer(float32), float32, [1048576], []),
             C: Buffer(C_2: Pointer(float32), float32, [1048576], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [1024, 1024], []), B_1: B_3: Buffer(B_2, float32, [1024, 1024], []), C_1: C_3: Buffer(C_2, float32, [1024, 1024], [])} {
  for (x: int32, 0, 1024) {
    for (y: int32, 0, 1024) {
      C[((x*1024) + y)] = 0f32
      for (k: int32, 0, 1024) {
        let cse_var_2: int32 = (x*1024)
        let cse_var_1: int32 = (cse_var_2 + y)
        C[cse_var_1] = (C[cse_var_1] + (A[(cse_var_2 + k)]*B[((k*1024) + y)]))
      }
    }
  }
}




## Opt 1 : Blocking

In [7]:
bn = 256
xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
(k, ) = s[C].op.reduce_axis
ko, ki = s[C].split(k, factor=4)
s[C].reorder(xo, yo, ko, ki, xi, yi)

In [8]:
eval_op(s, [A, B, C], target,'matmul', 'blking', log)

blking: 0.115955


In [9]:
# more loops for mem loc.
print(tvm.lower(s, [A, B, C], simple_mode=True))

@main = primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [1048576], []),
             B: Buffer(B_2: Pointer(float32), float32, [1048576], []),
             C: Buffer(C_2: Pointer(float32), float32, [1048576], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [1024, 1024], []), B_1: B_3: Buffer(B_2, float32, [1024, 1024], []), C_1: C_3: Buffer(C_2, float32, [1024, 1024], [])} {
  for (x.outer: int32, 0, 4) {
    for (y.outer: int32, 0, 4) {
      for (x.inner.init: int32, 0, 256) {
        for (y.inner.init: int32, 0, 256) {
          C[((((x.outer*262144) + (x.inner.init*1024)) + (y.outer*256)) + y.inner.init)] = 0f32
        }
      }
      for (k.outer: int32, 0, 256) {
        for (k.inner: int32, 0, 4) {
          for (x.inner: int32, 0, 256) {
            for (y.inner: int32, 0, 25

## Opt 2 : Vectorization

In [10]:
# only favored when 1st SIMD by passing multi data to mem.
s[C].vectorize(yi)

In [11]:
eval_op(s, [A, B, C], target,'matmul', 'vectorize', log)

vectorize: 0.134336


In [12]:
print(tvm.lower(s, [A, B, C], simple_mode=True))

@main = primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [1048576], []),
             B: Buffer(B_2: Pointer(float32), float32, [1048576], []),
             C: Buffer(C_2: Pointer(float32), float32, [1048576], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [1024, 1024], []), B_1: B_3: Buffer(B_2, float32, [1024, 1024], []), C_1: C_3: Buffer(C_2, float32, [1024, 1024], [])} {
  for (x.outer: int32, 0, 4) {
    for (y.outer: int32, 0, 4) {
      for (x.inner.init: int32, 0, 256) {
        C[ramp((((x.outer*262144) + (x.inner.init*1024)) + (y.outer*256)), 1, 256)] = broadcast(0f32, 256)
      }
      for (k.outer: int32, 0, 256) {
        for (k.inner: int32, 0, 4) {
          for (x.inner: int32, 0, 256) {
            let cse_var_3: int32 = (y.outer*256)
            let cse_var_2: int32

## Opt 3 : Loop Permu.

In [13]:
s = te.create_schedule(C.op)
xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
(k, ) = s[C].op.reduce_axis
ko, ki = s[C].split(k, factor=4)

# re-ordering, only diff. here by changing permu.
s[C].reorder(xo, yo, ko, xi, ki, yi)
s[C].vectorize(yi)

eval_op(s, [A, B, C], target,'matmul', 'Loop permu', log)

Loop permu: 0.073387


## Opt 4 : Array Packing

In [14]:
packedB = te.compute((N / bn, K, bn), lambda x, y, z: B[y, x * bn + z], name='packedB')
C = te.compute(
    (M, N),
    lambda x, y: te.sum(A[x, k] * packedB[y // bn, k, tvm.tir.indexmod(y, bn)], axis=k),
    name='C'
)

s = te.create_schedule(C.op)
xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
(k, ) = s[C].op.reduce_axis
ko, ki = s[C].split(k, factor=4)
s[C].reorder(xo, yo, ko, xi, ki, yi)
s[C].vectorize(yi)

x, y, z = s[packedB].op.axis
s[packedB].vectorize(z)
s[packedB].parallel(x)

eval_op(s, [A, B, C], target,'matmul', 'Array packing', log)
print(tvm.lower(s, [A, B, C], simple_mode=True))

TVMError: Traceback (most recent call last):
  12: TVMFuncCall
  11: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)>::AssignTypedLambda<tvm::{lambda(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)#6}>(tvm::{lambda(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)#6}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  10: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
  9: tvm::SplitMixedModule(tvm::IRModule, tvm::Target const&, tvm::Target const&)
  8: tvm::ApplyPasses(tvm::IRModule, tvm::transform::Sequential)
  7: tvm::transform::Pass::operator()(tvm::IRModule) const
  6: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  5: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  4: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  3: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  2: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_3tir9transform13MakePackedAPIEiEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SF_SJ_
  1: tvm::tir::transform::MakePackedAPI(int)::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}::operator()(tvm::IRModule, tvm::transform::PassContext) const [clone .isra.0]
  0: tvm::tir::MakePackedAPI(tvm::tir::PrimFunc&&, int)
  File "/home/wendell/Desktop/tvm/src/tir/transforms/make_packed_api.cc", line 329
TVMError: Not all Vars are passed in api_args:  'k'  is not bound to any variables