In [1]:
import tvm
import numpy as np
import timeit

In [2]:
# (M, K) X (K, N)
M = K = N = 2**10
dtype = 'float32'
target = tvm.target.Target(target='llvm -mcpu=alderlake', host='llvm -mcpu=goldmont')
dev = tvm.device(target.kind.name, 0)

In [3]:
a = tvm.nd.array(np.random.rand(M, K).astype(dtype), dev)
b = tvm.nd.array(np.random.rand(K, N).astype(dtype), dev)
np_repeat = 100
np_running_time = timeit.timeit(
    setup='import numpy as np\n'
    'M=K=N=2**10\n'
    'dtype="float32"\n'
    'a=np.random.rand(M, K).astype(dtype)\n'
    'b=np.random.rand(K, N).astype(dtype)\n',
    stmt='answer = np.dot(a, b)',
    number=np_repeat
)

print('Numpy running time: %f' % (np_running_time / np_repeat))

answer = np.dot(a.numpy(), b.numpy())

Numpy running time: 0.002995


In [4]:
# TVM Matmul w/ TE
from tvm import te
import tvm.testing
k = te.reduce_axis((0, K), 'k')
A = te.placeholder((M, K), dtype, 'A')
B = te.placeholder((K, N), dtype, 'B')
C = te.compute((M, N), lambda x, y: te.sum(A[x, k] * B[k, y], axis=k), name='C')

s = te.create_schedule(C.op)
func = tvm.build(s, [A, B, C], target=target, name='matmul')

c = tvm.nd.array(np.zeros((M, N), dtype=dtype), dev)
func(a, b, c)

tvm.testing.assert_allclose(c.numpy(), answer, rtol=1e-5)

In [5]:
def eval_op(s, vars, tgt, name, opt, log) -> None:
    func = tvm.build(s, vars, target=tgt, name=name)
    assert func
    
    dev = tvm.device(tgt.kind.name, 0)
    c = tvm.nd.array(np.zeros((M, N), dtype), dev)
    func(a, b, c)
    tvm.testing.assert_allclose(c.numpy(), answer, rtol=1e-5)
    
    evalor = func.time_evaluator(func.entry_name, dev, number=1)
    mean_time = evalor(a, b, c).mean
    print('%s: %f' % (opt, mean_time))
    log.append((opt, mean_time))
    
log = []
eval_op(s, [A, B, C], target,'matmul', 'none', log)

none: 2.010295


In [6]:
# low level code, C here, unopt. w/ 4 loops
print(tvm.lower(s, [A, B, C], simple_mode=True))

@main = primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [1048576], []),
             B: Buffer(B_2: Pointer(float32), float32, [1048576], []),
             C: Buffer(C_2: Pointer(float32), float32, [1048576], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [1024, 1024], []), B_1: B_3: Buffer(B_2, float32, [1024, 1024], []), C_1: C_3: Buffer(C_2, float32, [1024, 1024], [])} {
  for (x: int32, 0, 1024) {
    for (y: int32, 0, 1024) {
      C[((x*1024) + y)] = 0f32
      for (k: int32, 0, 1024) {
        let cse_var_2: int32 = (x*1024)
        let cse_var_1: int32 = (cse_var_2 + y)
        C[cse_var_1] = (C[cse_var_1] + (A[(cse_var_2 + k)]*B[((k*1024) + y)]))
      }
    }
  }
}




## Opt 1 : Blocking

In [7]:
# Opt 1 : Blocking
bn = 32
xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
(k,) = s[C].op.reduce_axis
ko, ki = s[C].split(k, factor=4)

s[C].reorder(xo, yo, ko, ki, xi, yi)

In [8]:
eval_op(s, [A, B, C], target,'matmul', 'blking', log)

blking: 0.135766


In [9]:
# more loops for mem loc.
print(tvm.lower(s, [A, B, C], simple_mode=True))

@main = primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [1048576], []),
             B: Buffer(B_2: Pointer(float32), float32, [1048576], []),
             C: Buffer(C_2: Pointer(float32), float32, [1048576], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [1024, 1024], []), B_1: B_3: Buffer(B_2, float32, [1024, 1024], []), C_1: C_3: Buffer(C_2, float32, [1024, 1024], [])} {
  for (x.outer: int32, 0, 32) {
    for (y.outer: int32, 0, 32) {
      for (x.inner.init: int32, 0, 32) {
        for (y.inner.init: int32, 0, 32) {
          C[((((x.outer*32768) + (x.inner.init*1024)) + (y.outer*32)) + y.inner.init)] = 0f32
        }
      }
      for (k.outer: int32, 0, 256) {
        for (k.inner: int32, 0, 4) {
          for (x.inner: int32, 0, 32) {
            for (y.inner: int32, 0, 32) {

## Opt 2 : Vectorization

In [10]:
# only favored when 1st SIMD by passing multi data to mem.
def vecotorize(s, vars):
    s = blocking(vars)
    (A, B, C) = vars
    s[C].vectorize(yi)
    return s

In [11]:
eval_op(s, [A, B, C], target,'matmul', 'vectorize', log)

vectorize: 0.137350


In [12]:
print(tvm.lower(s, [A, B, C], simple_mode=True))

@main = primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [1048576], []),
             B: Buffer(B_2: Pointer(float32), float32, [1048576], []),
             C: Buffer(C_2: Pointer(float32), float32, [1048576], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [1024, 1024], []), B_1: B_3: Buffer(B_2, float32, [1024, 1024], []), C_1: C_3: Buffer(C_2, float32, [1024, 1024], [])} {
  for (x.outer: int32, 0, 32) {
    for (y.outer: int32, 0, 32) {
      for (x.inner.init: int32, 0, 32) {
        for (y.inner.init: int32, 0, 32) {
          C[((((x.outer*32768) + (x.inner.init*1024)) + (y.outer*32)) + y.inner.init)] = 0f32
        }
      }
      for (k.outer: int32, 0, 256) {
        for (k.inner: int32, 0, 4) {
          for (x.inner: int32, 0, 32) {
            for (y.inner: int32, 0, 32) {

## Opt 3 : Loop Permu.

In [13]:
s = te.create_schedule(C.op)
xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
(k, ) = s[C].op.reduce_axis
ko, ki = s[C].split(k, factor=4)

# re-ordering, only diff. here by changing permu.
s[C].reorder(xo, yo, ko, xi, ki, yi)
s[C].vectorize(yi)

eval_op(s, [A, B, C], target,'matmul', 'Loop permu', log)

Loop permu: 0.048924


## Opt 4 : Array Packing

In [14]:
packedB = te.compute((N / bn, K, bn), lambda x, y, z: B[y, x * bn + z], name='packedB')
C = te.compute(
    (M, N),
    lambda x, y: te.sum(A[x, k] * packedB[y // bn, k, tvm.tir.indexmod(y, bn)], axis=k),
    name='C'
)

s = te.create_schedule(C.op)
xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
(k, ) = s[C].op.reduce_axis
ko, ki = s[C].split(k, factor=4)
s[C].reorder(xo, yo, ko, xi, ki, yi)
s[C].vectorize(yi)

x, y, z = s[packedB].op.axis
s[packedB].vectorize(z)
s[packedB].parallel(x)

eval_op(s, [A, B, C], target,'matmul', 'Array packing', log)
print(tvm.lower(s, [A, B, C], simple_mode=True))

Array packing: 0.054949
@main = primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [1048576], []),
             B: Buffer(B_2: Pointer(float32), float32, [1048576], []),
             C: Buffer(C_2: Pointer(float32), float32, [1048576], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [1024, 1024], []), B_1: B_3: Buffer(B_2, float32, [1024, 1024], []), C_1: C_3: Buffer(C_2, float32, [1024, 1024], [])} {
  allocate(packedB: Pointer(global float32x32), float32x32, [32768]), storage_scope = global {
    for (x: int32, 0, 32) "parallel" {
      for (y: int32, 0, 1024) {
        packedB_1: Buffer(packedB, float32x32, [32768], [])[((x*1024) + y)] = B[ramp(((y*1024) + (x*32)), 1, 32)]
      }
    }
    for (x.outer: int32, 0, 32) {
      for (y.outer: int32, 0, 32) {
        for (x.inner.init: in



## Opt 5 : Write Thru

In [15]:
s = te.create_schedule(C.op)

# Allocate write cache
CC = s.cache_write(C, "global")

xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)

# Write cache is computed at yo
s[CC].compute_at(s[C], yo)

# New inner axes
xc, yc = s[CC].op.axis

(k,) = s[CC].op.reduce_axis
ko, ki = s[CC].split(k, factor=4)
s[CC].reorder(ko, xc, ki, yc)
s[CC].unroll(ki)
s[CC].vectorize(yc)

x, y, z = s[packedB].op.axis
s[packedB].vectorize(z)
s[packedB].parallel(x)

eval_op(s, [A, B, C], target,'matmul', 'Write Thru', log)

# Here is the generated IR after write cache blocking.
print(tvm.lower(s, [A, B, C], simple_mode=True))

Write Thru: 0.052422
@main = primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [1048576], []),
             B: Buffer(B_2: Pointer(float32), float32, [1048576], []),
             C: Buffer(C_2: Pointer(float32), float32, [1048576], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [1024, 1024], []), B_1: B_3: Buffer(B_2, float32, [1024, 1024], []), C_1: C_3: Buffer(C_2, float32, [1024, 1024], [])} {
  allocate(packedB: Pointer(global float32x32), float32x32, [32768]), storage_scope = global;
  allocate(C.global: Pointer(global float32), float32, [1024]), storage_scope = global {
    for (x: int32, 0, 32) "parallel" {
      for (y: int32, 0, 1024) {
        packedB_1: Buffer(packedB, float32x32, [32768], [])[((x*1024) + y)] = B[ramp(((y*1024) + (x*32)), 1, 32)]
      }
    }
    for (x.out

In [16]:
# parallel
s[C].parallel(xo)

x, y, z = s[packedB].op.axis
s[packedB].vectorize(z)
s[packedB].parallel(x)

eval_op(s, [A, B, C], target,'matmul', 'Parallel', log)

# Here is the generated IR after parallelization.
print(tvm.lower(s, [A, B, C], simple_mode=True))

Parallel: 0.020728
@main = primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [1048576], []),
             B: Buffer(B_2: Pointer(float32), float32, [1048576], []),
             C: Buffer(C_2: Pointer(float32), float32, [1048576], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [1024, 1024], []), B_1: B_3: Buffer(B_2, float32, [1024, 1024], []), C_1: C_3: Buffer(C_2, float32, [1024, 1024], [])} {
  allocate(packedB: Pointer(global float32x32), float32x32, [32768]), storage_scope = global {
    for (x: int32, 0, 32) "parallel" {
      for (y: int32, 0, 1024) {
        packedB_1: Buffer(packedB, float32x32, [32768], [])[((x*1024) + y)] = B[ramp(((y*1024) + (x*32)), 1, 32)]
      }
    }
    for (x.outer: int32, 0, 32) "parallel" {
      allocate(C.global: Pointer(global float32), float32, 

In [17]:
log

[('none', 2.010295384),
 ('blking', 0.135766278),
 ('vectorize', 0.137349809),
 ('Loop permu', 0.048923863),
 ('Array packing', 0.054949382),
 ('Write Thru', 0.052422102000000005),
 ('Parallel', 0.020728181)]