In [2]:
import torch
from torch import nn
from d2l import torch as d2l

def corr2d(X, K):
    """correlation2d"""
    h, w = K.shape
    Y = torch.zeros(size=(X.shape[0] - h + 1, X.shape[1] - w + 1))
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            Y[i, j] = (X[i:i+h, j:j+w] * K).sum()
    return Y

In [3]:
X = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]], dtype=torch.float32)
K = torch.tensor([[0, 1], [2, 3]], dtype=torch.float32)
corr2d(X, K)

tensor([[19., 25.],
        [37., 43.]])

# 手写Conv2d层

In [None]:
class Conv2D(nn.Module):
    def __init__(self, kernel_size):
        super().__init__()
        self.weight = nn.Parameter(torch.rand(kernel_size))
        self.bias = nn.Parameter(torch.zeros(1))
    
    def forward(self, X):
        return corr2d(X, self.weight) + self.bias

In [7]:
X = torch.ones((6, 8))
X[:, 2:6] = 0
K = torch.tensor([[1, -1]], dtype=torch.float32)
Y = corr2d(X, K)
X, Y, X.mean()

(tensor([[1., 1., 0., 0., 0., 0., 1., 1.],
         [1., 1., 0., 0., 0., 0., 1., 1.],
         [1., 1., 0., 0., 0., 0., 1., 1.],
         [1., 1., 0., 0., 0., 0., 1., 1.],
         [1., 1., 0., 0., 0., 0., 1., 1.],
         [1., 1., 0., 0., 0., 0., 1., 1.]]),
 tensor([[ 0.,  1.,  0.,  0.,  0., -1.,  0.],
         [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
         [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
         [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
         [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
         [ 0.,  1.,  0.,  0.,  0., -1.,  0.]]),
 tensor(0.5000))

In [16]:
conv2d = nn.Conv2d(1, 1, kernel_size=(1, 2), bias=False)
X = X.reshape((1, 1, 6, 8))
Y = Y.reshape((1, 1, 6, 7))

for i in range(400):
    Y_hat = conv2d(X)
    l = ((Y_hat - Y) ** 2).mean()
    conv2d.zero_grad()
    l.backward()
    conv2d.weight.data -= 0.1 * conv2d.weight.grad
    print(f'batch {i+1} loss: {l:.7f}')

print(conv2d.weight.data)
print(conv2d.weight.grad)

batch 1 loss: 0.4197260
batch 2 loss: 0.3578563
batch 3 loss: 0.3096138
batch 4 loss: 0.2715398
batch 5 loss: 0.2410849
batch 6 loss: 0.2163675
batch 7 loss: 0.1959972
batch 8 loss: 0.1789454
batch 9 loss: 0.1644490
batch 10 loss: 0.1519411
batch 11 loss: 0.1409986
batch 12 loss: 0.1313050
batch 13 loss: 0.1226221
batch 14 loss: 0.1147697
batch 15 loss: 0.1076105
batch 16 loss: 0.1010388
batch 17 loss: 0.0949727
batch 18 loss: 0.0893478
batch 19 loss: 0.0841127
batch 20 loss: 0.0792262
batch 21 loss: 0.0746544
batch 22 loss: 0.0703692
batch 23 loss: 0.0663465
batch 24 loss: 0.0625661
batch 25 loss: 0.0590102
batch 26 loss: 0.0556630
batch 27 loss: 0.0525105
batch 28 loss: 0.0495401
batch 29 loss: 0.0467404
batch 30 loss: 0.0441009
batch 31 loss: 0.0416118
batch 32 loss: 0.0392643
batch 33 loss: 0.0370500
batch 34 loss: 0.0349611
batch 35 loss: 0.0329904
batch 36 loss: 0.0311311
batch 37 loss: 0.0293768
batch 38 loss: 0.0277215
batch 39 loss: 0.0261596
batch 40 loss: 0.0246859
batch 41 