In [3]:
import numpy as np
import tvm
from tvm import te

# The sizes of inputs and filters
batch = 8
in_channel = 256
out_channel = 512
in_size = 14
kernel = 3
pad = 1
stride = 1

# Algorithm
A = te.placeholder((in_size, in_size, in_channel, batch), name="A")
W = te.placeholder((kernel, kernel, in_channel, out_channel), name="W")
out_size = (in_size - kernel + 2 * pad) // stride + 1
# Pad input
Apad = te.compute(
    (in_size + 2 * pad, in_size + 2 * pad, in_channel, batch),
    lambda yy, xx, cc, nn: tvm.tir.if_then_else(
        tvm.tir.all(yy >= pad, yy - pad < in_size, xx >= pad, xx - pad < in_size),
        A[yy - pad, xx - pad, cc, nn],
        tvm.tir.const(0.0, "float32"),
    ),
    name="Apad",
)
# Create reduction variables
rc = te.reduce_axis((0, in_channel), name="rc")
ry = te.reduce_axis((0, kernel), name="ry")
rx = te.reduce_axis((0, kernel), name="rx")
# Compute the convolution
B = te.compute(
    (out_size, out_size, out_channel, batch),
    lambda yy, xx, ff, nn: te.sum(
        Apad[yy * stride + ry, xx * stride + rx, rc, nn] * W[ry, rx, rc, ff], axis=[ry, rx, rc]
    ),
    name="B",
)
s = te.create_schedule(B.op)

In [4]:
func = tvm.build(s, [A, W, B], "llvm")
dev = tvm.cpu()
a_np = np.random.uniform(size=(in_size, in_size, in_channel, batch)).astype(A.dtype)
w_np = np.random.uniform(size=(kernel, kernel, in_channel, out_channel)).astype(W.dtype)
a = tvm.nd.array(a_np, dev)
w = tvm.nd.array(w_np, dev)
b = tvm.nd.array(np.zeros((out_size, out_size, out_channel, batch), dtype=B.dtype), dev)
func(a, w, b)
evaluator = func.time_evaluator(func.entry_name, dev, number=1)
print("Convolution: %f ms" % (evaluator(a, w, b).mean * 1e3))

Convolution: 3560.534293 ms


In [6]:
print(tvm.lower(s, [A, W, B], simple_mode=True))

@main = primfn(A_1: handle, W_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [401408], []),
             W: Buffer(W_2: Pointer(float32), float32, [1179648], []),
             B: Buffer(B_2: Pointer(float32), float32, [802816], [])}
  buffer_map = {A_1: A, W_1: W, B_1: B}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [14, 14, 256, 8], []), W_1: W_3: Buffer(W_2, float32, [3, 3, 256, 512], []), B_1: B_3: Buffer(B_2, float32, [14, 14, 512, 8], [])} {
  allocate(Apad: Pointer(global float32), float32, [524288]), storage_scope = global {
    for (yy: int32, 0, 16) {
      for (xx: int32, 0, 16) {
        for (cc: int32, 0, 256) {
          for (nn: int32, 0, 8) {
            let cse_var_2: int32 = (xx*2048)
            let cse_var_1: int32 = (cc*8)
            Apad_1: Buffer(Apad, float32, [524288], [])[((((yy*32768) + cse_var_2) + cse_var_1) + nn)] = 