In [2]:
import torch
import torch.nn as nn

def corr2d(X, K):
    h, w = K.shape
    Y = torch.zeros(X.shape[0] - h + 1, X.shape[1] - w + 1)
    
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            Y[i, j] = (X[i:i + h, j:j + w] * K).sum()
    
    return Y

In [23]:
def corr2d_multi_in(X, K):
    return sum(corr2d(x, k) for x, k in zip(X, K))

X = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]], 
                  [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]])
K = torch.tensor([[[0.0, 1.0], [2.0, 3.0]], 
                  [[1.0, 2.0], [3.0, 4.0]]])

print(corr2d_multi_in(X, K))

tensor([[ 56.,  72.],
        [104., 120.]])


In [24]:
def corr2d_multi_in_out(X, K):
    return torch.stack(tuple(corr2d_multi_in(X, k) for k in K), dim=0)

K = torch.stack((K, K + 1, K + 2), dim=0)
print(K.shape)

print(corr2d_multi_in_out(X, K))

torch.Size([3, 2, 2, 2])
tensor([[[ 56.,  72.],
         [104., 120.]],

        [[ 76., 100.],
         [148., 172.]],

        [[ 96., 128.],
         [192., 224.]]])


In [26]:
def corr2d_multi_in_out_1x1(X, K):
    c_i, h, w = X.shape
    c_o = K.shape[0]
    X = X.reshape((c_i, h * w))
    K = K.reshape((c_o, c_i))
    Y = torch.matmul(K, X)
    
    return Y.reshape((c_o, h, w))

X = torch.normal(0, 1, (3, 3, 3))
K = torch.normal(0, 1, (2, 3, 1, 1))

Y1 = corr2d_multi_in_out_1x1(X, K) 
Y2 = corr2d_multi_in_out(X, K) 
assert float(torch.abs(Y1 - Y2).sum()) < 1e-6

print(Y1.shape, Y2.shape)

torch.Size([2, 3, 3]) torch.Size([2, 3, 3])


In [27]:
# unfold可用于将卷积运算转换为矩阵运算以加快计算速度

import torch
import torch.nn.functional as F

# 1 个 batch，1 个通道，4x4 输入
input = torch.arange(1., 17).view(1, 1, 4, 4)
print("Input tensor:")
print(input)

# 使用 3x3 的卷积核展开
output = F.unfold(input, kernel_size=(3, 3))
print("Unfolded tensor:")
print(output)


Input tensor:
tensor([[[[ 1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.],
          [ 9., 10., 11., 12.],
          [13., 14., 15., 16.]]]])
Unfolded tensor:
tensor([[[ 1.,  2.,  5.,  6.],
         [ 2.,  3.,  6.,  7.],
         [ 3.,  4.,  7.,  8.],
         [ 5.,  6.,  9., 10.],
         [ 6.,  7., 10., 11.],
         [ 7.,  8., 11., 12.],
         [ 9., 10., 13., 14.],
         [10., 11., 14., 15.],
         [11., 12., 15., 16.]]])
