In [2]:
import torch
import numpy as np
import matplotlib

In [3]:
def corr2d(input_tensor: torch.Tensor, kernel: torch.Tensor):
    m, n = input_tensor.shape
    a, b = kernel.shape
    y = torch.zeros(m - a + 1, n - b + 1)
    for i in range(y.shape[0]):
        for j in range(y.shape[1]):
            y[i, j] = torch.sum(input_tensor[i:i + a, j:j + b] * kernel)
    return y

In [4]:
def corr2d_multi_in(input_tensor: torch.Tensor, kernel: torch.Tensor):

    # 直接遍历即可。
    return sum(corr2d(x, k) for x, k in zip(input_tensor, kernel))

In [5]:
x = np.array([np.arange(0, 9), np.arange(1, 10)]).reshape([2, 3, 3])
x = torch.from_numpy(x)
k = np.array([np.arange(0, 4), np.arange(1, 5)]).reshape([2, 2, 2])
k = torch.from_numpy(k)
corr2d_multi_in(x, k)

tensor([[ 56.,  72.],
        [104., 120.]])

In [17]:
np.array([np.arange(0, 9), np.arange(1, 10)])

array([[0, 1, 2, 3, 4, 5, 6, 7, 8],
       [1, 2, 3, 4, 5, 6, 7, 8, 9]])

### 多层输出通道

这里每个输出通道都对应一个卷积核，于是卷积核成为一个四维张量 $c_i \times c_o \times k_h \times k_w$

In [6]:
def corr2d_multi_out(input_tensor: torch.Tensor, kernel: torch.Tensor):
    return torch.stack([corr2d_multi_in(input_tensor, k) for k in kernel], dim=0)

In [11]:
k2 = torch.stack([k, k + 1, k + 2])
k2.shape

torch.Size([3, 2, 2, 2])

In [12]:
corr2d_multi_out(x, k2)

tensor([[[ 56.,  72.],
         [104., 120.]],

        [[ 76., 100.],
         [148., 172.]],

        [[ 96., 128.],
         [192., 224.]]])

### 1 x 1 卷积核

只改变通道数，不进行卷积操作

卷积核的尺寸为 $c_i \times c_o$

In [13]:
def corr2d_1x1(x: torch.Tensor, k: torch.Tensor):
    c_i, h, w = x.shape
    c_o = k.shape[0]
    x = x.reshape(c_i, h * w)
    k = k.reshape(c_o, c_i)
    y = torch.matmul(k, x)
    return y.reshape(c_o, h, w)

In [15]:
X = torch.normal(0, 1, (3, 3, 3))
K = torch.normal(0, 1, (2, 3, 1, 1))

Y1 = corr2d_1x1(X, K)
Y2 = corr2d_multi_out(X, K)

Y1, Y2

(tensor([[[ 4.5196,  5.6132, -4.0586],
          [ 1.9676,  1.9140,  0.7529],
          [-4.1370,  0.8033,  1.8129]],
 
         [[ 0.6334, -2.3787,  1.4956],
          [-2.3898, -1.9273, -0.5111],
          [ 4.5911,  0.2217, -2.0058]]]),
 tensor([[[ 4.5196,  5.6132, -4.0586],
          [ 1.9676,  1.9140,  0.7529],
          [-4.1370,  0.8033,  1.8129]],
 
         [[ 0.6334, -2.3787,  1.4956],
          [-2.3898, -1.9273, -0.5111],
          [ 4.5911,  0.2217, -2.0058]]]))