In [5]:
# [B, C, H, W]的隐式卷积实现
import numpy as np
import torch
from torch import nn
import torch.nn.functional  as F

inputB, intputC, inputH, inputW = 8, 5, 7, 7
kernelO, kernelI, kernelH, kernelW = 3, 5, 3, 3
input = np.arange(0, inputB * intputC * inputH * inputW).reshape(inputB, intputC, inputH, inputW).astype(np.int32)
kernel = np.arange(kernelO * kernelI * kernelH * kernelW, dtype=np.int32)\
           .reshape(kernelO, kernelI, kernelH, kernelW)
# kernel = np.ones([kernelO, kernelI, kernelH, kernelW], dtype=np.int32)
# print(input.shape)
# print(kernel)

In [6]:
# 获取特征图的形状
inputB, inputC, inputH, inputW = input.shape
# 获取filter kernel的形状
kernelO, kernelI, kernelH, kernelW = kernel.shape
# 设置填充和步幅
padding = 0
stride = 1
# 计算卷积操作后输出特征图的维度信息
outputB = inputB    # batch size
outputC = kernelO
outputH = int((inputH - kernelH + padding + stride) / stride)
outputW = int((inputW - kernelW + padding + stride) / stride)
# print(outputC, outputH, outputW)
output = np.arange(0, outputB * outputC * outputH * outputW).reshape(outputB, outputC, outputH, outputW).astype(np.int32)

# 隐式GEMM卷积的中间矩阵的维度信息
GEMM_M = outputB * outputH * outputW
GEMM_N = outputC
GEMM_K = kernelI * kernelH * kernelW

# i,j,k--矩阵乘法的三重遍历
for i in range(GEMM_M):
    for j in range(GEMM_N):
        acc = 0
        for k in range(GEMM_K):
            # 中间矩阵A，B，C和tensor在各个位置上的元素的映射关系：
            #   A(i, k) = input(n, ic, ih, iw); B(k, j) = kernel(oc, ic, kh, kw); C(i, j) = output(n, oc, oh, ow)
            # 其中，有n = ib = ob; ic = ki; oc = ko;
            # 中间矩阵A，B，C的下标i,j,k和tensor的坐标之间的对应关系：
            #   i = n * outputH * outputW + oh * outputW + ow; j = oc; k = ic * kernelH * kernelW + kh * kernelW + kw

            # 通过中间矩阵A，B，C的下标i,j,k和张量input，kernel，output的坐标之间的关系，反推出坐标指标
            oc = j

            # 推算出输出output的坐标
            ob = i // (outputH * outputW)
            i_res = i % (outputH * outputW)
            oh = i_res // outputW
            ow = i_res % outputW

            # 推算出权重kernel的坐标
            ic = k // (kernelH * kernelW)
            k_res = k % (kernelH * kernelW)
            kh = k_res // kernelW
            kw = k_res % kernelW

            # 推算出输入input的坐标
            ih = oh * stride - padding + kh
            iw = ow * stride - padding + kw

            acc += input[ob, ic, ih, iw] * kernel[oc, ic, kh, kw]
        
        output[ob, oc, oh, ow] = acc

print("output.shape: ", output.shape)

# pytorch验证计算结果正确性
inputTensor = input.reshape(inputB, inputC, inputH, inputW)
weightTensor = kernel.reshape(kernelO, kernelI, kernelH, kernelW)
inputTensor = torch.from_numpy(inputTensor)
weightTensor = torch.from_numpy(weightTensor)

pyoutput = F.conv2d(inputTensor, weightTensor, padding=padding, stride=stride)
# print("\ntorch convolution implementation:")
# print(pyoutput)
print("pyoutput.shape: ", pyoutput.shape)

# Convert tensor to ndarray
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
np.testing.assert_allclose(to_numpy(pyoutput), output, rtol=1e-03, atol=1e-05)
print("PASS!")

output.shape:  (8, 3, 5, 5)
pyoutput.shape:  torch.Size([8, 3, 5, 5])
PASS!
