输入和卷积核个有$c_{i}$个通道，可以在各个通道上对输入的二维数组核卷积核做互相关运算，再将$c_{i}$个互相关运算的二维输出按通道相加。

In [1]:
import torch
import torch.nn as nn
import d2lzh_pytorch as d2l

# 多输入通道

In [2]:
def corr2d_multi_in(X, K):
    res = d2l.corr2d(X[0, :, :], K[0, :, :])
    for i in range(1, X.shape[0]):
        res += d2l.corr2d(X[i,:,:], K[i,:,:])
    return res

In [3]:
X = torch.tensor([[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
                  [[1, 2, 3], [4, 5, 6], [7, 8, 9]]])
K = torch.tensor([[[0, 1], [2, 3]], [[1, 2], [3, 4]]])
corr2d_multi_in(X, K)

tensor([[ 56.,  72.],
        [104., 120.]])

# 多输出通道

如果希望含有多个通道的输出，我们可以为每个输出通道分别创建形状为$c_i \times k_h \times k_w$的核数组，将它们在输出通道维上连结，卷积核的形状为$c_{o} \times c_{i} \times k_{h} \times k_{w}$

In [4]:
def corr2d_multi_in_out(X, K):
    return torch.stack([corr2d_multi_in(X, k) for k in K])

In [5]:
K = torch.stack([K, K + 1, K + 2])
K, K.shape

(tensor([[[[0, 1],
           [2, 3]],
 
          [[1, 2],
           [3, 4]]],
 
 
         [[[1, 2],
           [3, 4]],
 
          [[2, 3],
           [4, 5]]],
 
 
         [[[2, 3],
           [4, 5]],
 
          [[3, 4],
           [5, 6]]]]), torch.Size([3, 2, 2, 2]))

In [6]:
corr2d_multi_in_out(X, K)

tensor([[[ 56.,  72.],
         [104., 120.]],

        [[ 76., 100.],
         [148., 172.]],

        [[ 96., 128.],
         [192., 224.]]])

# 1 * 1卷积层

1 * 1卷积失去了卷积层可以识别高和宽度上相邻元素构成的模式的功能。
实际上，其主要计算发生在通道维上（输入->1*1卷积->输出），通道维度增加or减少。

控制模型复杂度。

In [10]:
def corr2d_multi_in_out_1x1(X, K):
    # c_i 通道数
    c_i, h, w = X.shape
    c_o = K.shape[0]
    X = X.view(c_i, h * w)
    K = K.view(c_o, c_i)
    Y = torch.mm(K, X)
    return Y.view(c_o, h, w)

In [11]:
X = torch.rand(3, 3, 3)
K = torch.rand(2, 3, 1, 1)

print(X, X.view(3, 3 * 3))
print(K, K.view(2, 3))

Y1 = corr2d_multi_in_out_1x1(X, K)
Y2 = corr2d_multi_in_out(X, K)
print(Y1, Y2)
print((Y1 - Y2).norm().item())

tensor([[[0.5865, 0.6318, 0.8863],
         [0.9830, 0.4780, 0.2027],
         [0.6481, 0.5290, 0.6192]],

        [[0.4910, 0.3826, 0.2333],
         [0.6710, 0.0421, 0.7154],
         [0.2667, 0.7100, 0.4201]],

        [[0.7426, 0.1567, 0.6274],
         [0.9437, 0.2318, 0.2500],
         [0.5883, 0.9732, 0.0462]]]) tensor([[0.5865, 0.6318, 0.8863, 0.9830, 0.4780, 0.2027, 0.6481, 0.5290, 0.6192],
        [0.4910, 0.3826, 0.2333, 0.6710, 0.0421, 0.7154, 0.2667, 0.7100, 0.4201],
        [0.7426, 0.1567, 0.6274, 0.9437, 0.2318, 0.2500, 0.5883, 0.9732, 0.0462]])
tensor([[[[0.0917]],

         [[0.7603]],

         [[0.2748]]],


        [[[0.5891]],

         [[0.9025]],

         [[0.2920]]]]) tensor([[0.0917, 0.7603, 0.2748],
        [0.5891, 0.9025, 0.2920]])
tensor([[[0.6311, 0.3919, 0.4310],
         [0.8596, 0.1395, 0.6312],
         [0.4238, 0.8557, 0.3889]],

        [[1.0054, 0.7633, 0.9158],
         [1.4603, 0.3873, 0.8381],
         [0.7943, 1.2366, 0.7574]]]) tensor([[[0.63