In [1]:
import os

import numpy as np
import tvm
from tvm import te, auto_scheduler

In [2]:
@auto_scheduler.register_workload  # Note the auto_scheduler decorator
def matmul_add(N, L, M, dtype):
    A = te.placeholder((N, L), name="A", dtype=dtype)
    B = te.placeholder((L, M), name="B", dtype=dtype)
    C = te.placeholder((N, M), name="C", dtype=dtype)

    k = te.reduce_axis((0, L), name="k")
    matmul = te.compute(
        (N, M),
        lambda i, j: te.sum(A[i, k] * B[k, j], axis=k),
        name="matmul",
        attrs={"layout_free_placeholders": [B]},  # enable automatic layout transform for tensor B
    )
    out = te.compute((N, M), lambda i, j: matmul[i, j] + C[i, j], name="out")

    return [A, B, C, out]

In [3]:
target = tvm.target.Target("llvm -mcpu=core-avx2")
N = L = M = 1024
task = tvm.auto_scheduler.SearchTask(func=matmul_add, args=(N, L, M, "float32"), target=target)

# Inspect the computational graph
print("Computational DAG:")
print(task.compute_dag)

Computational DAG:
A = PLACEHOLDER [1024, 1024]
B = PLACEHOLDER [1024, 1024]
matmul(i, j) += (A[i, k]*B[k, j])
C = PLACEHOLDER [1024, 1024]
out(i, j) = (matmul[i, j] + C[i, j])



In [4]:
log_file = "matmul.json"
tune_option = auto_scheduler.TuningOptions(
    num_measure_trials=10,
    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    verbose=2,
)

In [5]:
# Run auto-tuning (search)
task.tune(tune_option)
# Apply the best schedule
sch, args = task.apply_best(log_file)

----------------------------------------------------------------------
------------------------------  [ Search ]
----------------------------------------------------------------------
Generate Sketches		#s: 3
Sample Initial Population	#s: 2020	fail_ct: 1	Time elapsed: 0.28
GA Iter: 0	Max score: 0.9997	Min score: 0.9464	#Pop: 128	#M+: 0	#M-: 0
GA Iter: 4	Max score: 0.9999	Min score: 0.9868	#Pop: 128	#M+: 1369	#M-: 76
EvolutionarySearch		#s: 128	Time elapsed: 1.17
----------------------------------------------------------------------
------------------------------  [ Measure ]
----------------------------------------------------------------------
Get 10 programs to measure:
..........



**********
No: 1	GFLOPS: 64.05 / 64.05	results: MeasureResult(cost:[0.0335], error_no:0, all_cost:0.42, Tstamp:1658202492.87)
Placeholder: A, B, C
parallel i.0@j.0@i.1@j.1@ (0,64)
  matmul auto_unroll: 16
  for k.0 (0,32)
    for i.2 (0,4)
      for j.2 (0,256)
        for k.1 (0,32)
          for i.3 (0,8)
            vectorize j.3 (0,2)
              matmul = ...
  for i.2 (0,32)
    for j.2 (0,512)
      out = ...

No: 2	GFLOPS: 110.46 / 110.46	results: MeasureResult(cost:[0.0195], error_no:0, all_cost:0.42, Tstamp:1658202493.14)
Placeholder: A, B, C
parallel i.0@j.0@ (0,256)
  for j.1 (0,2)
    matmul auto_unroll: 16
    for k.0 (0,16)
      for i.2 (0,8)
        for j.2 (0,32)
          for k.1 (0,64)
            vectorize j.3 (0,8)
              matmul = ...
    for i.2 (0,8)
      for j.2 (0,256)
        out = ...

No: 3	GFLOPS: 101.34 / 110.46	results: MeasureResult(cost:[0.0212], error_no:0, all_cost:0.41, Tstamp:1658202493.41)
Placeholder: A, B, C
parallel i.0 (0,4)
  for j.0

In [6]:
print("Lowered TIR:")
print(tvm.lower(sch, args, simple_mode=True))

Lowered TIR:
@main = primfn(A_1: handle, B_1: handle, C_1: handle, out_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [1048576], []),
             B: Buffer(B_2: Pointer(float32), float32, [1048576], []),
             C: Buffer(C_2: Pointer(float32), float32, [1048576], []),
             out: Buffer(out_2: Pointer(float32), float32, [1048576], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C, out_1: out}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [1024, 1024], []), B_1: B_3: Buffer(B_2, float32, [1024, 1024], []), C_1: C_3: Buffer(C_2, float32, [1024, 1024], []), out_1: out_3: Buffer(out_2, float32, [1024, 1024], [])} {
  allocate(auto_scheduler_layout_transform: Pointer(global float32), float32, [1048576]), storage_scope = global;
  allocate(matmul: Pointer(global float32), float32, [1048576]), storage_scope = global {
    for (ax0.ax1.fused.ax2.fused: int32, 0

In [7]:
func = tvm.build(sch, args, target)
a_np = np.random.uniform(size=(N, L)).astype(np.float32)
b_np = np.random.uniform(size=(L, M)).astype(np.float32)
c_np = np.random.uniform(size=(N, M)).astype(np.float32)
out_np = a_np.dot(b_np) + c_np

dev = tvm.cpu()
a_tvm = tvm.nd.array(a_np, device=dev)
b_tvm = tvm.nd.array(b_np, device=dev)
c_tvm = tvm.nd.array(c_np, device=dev)
out_tvm = tvm.nd.empty(out_np.shape, device=dev)
func(a_tvm, b_tvm, c_tvm, out_tvm)

# Check results
np.testing.assert_allclose(out_np, out_tvm.numpy(), rtol=1e-3)

# Evaluate execution time.
evaluator = func.time_evaluator(func.entry_name, dev, min_repeat_ms=500)
print(
    "Execution time of this operator: %f s"
    % (np.median(evaluator(a_tvm, b_tvm, c_tvm, out_tvm).results))
)



Execution time of this operator: 0.007376 s


In [8]:
print("Equivalent python schedule:")
print(task.print_best(log_file))

Equivalent python schedule:
matmul_i, matmul_j, matmul_k = tuple(matmul.op.axis) + tuple(matmul.op.reduce_axis)
out_i, out_j = tuple(out.op.axis) + tuple(out.op.reduce_axis)
matmul_i_o_i, matmul_i_i = s[matmul].split(matmul_i, factor=2)
matmul_i_o_o_i, matmul_i_o_i = s[matmul].split(matmul_i_o_i, factor=1)
matmul_i_o_o_o, matmul_i_o_o_i = s[matmul].split(matmul_i_o_o_i, factor=512)
matmul_j_o_i, matmul_j_i = s[matmul].split(matmul_j, factor=8)
matmul_j_o_o_i, matmul_j_o_i = s[matmul].split(matmul_j_o_i, factor=4)
matmul_j_o_o_o, matmul_j_o_o_i = s[matmul].split(matmul_j_o_o_i, factor=1)
matmul_k_o, matmul_k_i = s[matmul].split(matmul_k, factor=2)
s[matmul].reorder(matmul_i_o_o_o, matmul_j_o_o_o, matmul_i_o_o_i, matmul_j_o_o_i, matmul_k_o, matmul_i_o_i, matmul_j_o_i, matmul_k_i, matmul_i_i, matmul_j_i)
s[out].compute_root()
matmul_i_o_o_o_j_o_o_o_fused_i_o_o_i_fused = s[matmul].fuse(matmul_i_o_o_o, matmul_j_o_o_o, matmul_i_o_o_i)
s[matmul].parallel(matmul_i_o_o_o_j_o_o_o_fused_i_o_o_i_f