In [1]:
import IPython
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T

In [2]:
N, CI, H, W, CO, K = 1, 1, 8, 8, 2, 3
OUT_H, OUT_W = H - K + 1, W - K + 1
data = np.arange(N*CI*H*W).reshape(N, CI, H, W)
weight = np.arange(CO*CI*K*K).reshape(CO, CI, K, K)

In [3]:
# torch version
import torch

data_torch = torch.Tensor(data)
weight_torch = torch.Tensor(weight)
conv_torch = torch.nn.functional.conv2d(data_torch, weight_torch)
conv_torch = conv_torch.numpy().astype(np.int64)
conv_torch

array([[[[ 474,  510,  546,  582,  618,  654],
         [ 762,  798,  834,  870,  906,  942],
         [1050, 1086, 1122, 1158, 1194, 1230],
         [1338, 1374, 1410, 1446, 1482, 1518],
         [1626, 1662, 1698, 1734, 1770, 1806],
         [1914, 1950, 1986, 2022, 2058, 2094]],

        [[1203, 1320, 1437, 1554, 1671, 1788],
         [2139, 2256, 2373, 2490, 2607, 2724],
         [3075, 3192, 3309, 3426, 3543, 3660],
         [4011, 4128, 4245, 4362, 4479, 4596],
         [4947, 5064, 5181, 5298, 5415, 5532],
         [5883, 6000, 6117, 6234, 6351, 6468]]]])

In [4]:
@tvm.script.ir_module
class MyConv:
  @T.prim_func
  def conv(A: T.Buffer[(1, 1, 8, 8), "int64"],
           W: T.Buffer[(2, 1, 3, 3), "int64"],
           C: T.Buffer[(1, 2, 6, 6), "int64"]):
    T.func_attr({"global_symbol": "conv", "tir.noalias": True})
    for b, k, i, j, di, dj, q in T.grid(N, CO, OUT_H, OUT_W, K, K, CI):
        with T.block("Y"):
            b_b = T.axis.spatial(1, b)
            b_k = T.axis.spatial(2, k)
            b_i = T.axis.spatial(6, i)
            b_j = T.axis.spatial(6, j)
            b_di = T.axis.reduce(3, di)
            b_dj = T.axis.reduce(3, dj)
            b_q = T.axis.reduce(1, q)
            with T.init():
                C[b_b, b_k, b_i, b_j] = T.int64(0)
            C[b_b, b_k, b_i, b_j] = C[b_b, b_k, b_i, b_j] + A[b_b, b_q, b_i + b_di, b_j + b_dj] * W[b_k, b_q, b_di, b_dj]
        

rt_lib = tvm.build(MyConv, target="llvm")
data_tvm = tvm.nd.array(data)
weight_tvm = tvm.nd.array(weight)
conv_tvm = tvm.nd.array(np.empty((N, CO, OUT_H, OUT_W), dtype=np.int64))
rt_lib["conv"](data_tvm, weight_tvm, conv_tvm)
np.testing.assert_allclose(conv_tvm.numpy(), conv_torch, rtol=1e-5)