实现互相关运算（即所谓的卷积操作）

In [1]:
import torch
from torch import nn
from d2l import torch as d2l

In [7]:
def corr2d(X, K): # X：输入矩阵 K：核矩阵
    """计算二维互相关运算"""
    h, w = K.shape # 行数对应着高度
    Y = torch.zeros(X.shape[0] - h + 1, X.shape[1] - w + 1)
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            Y[i, j] = (X[i:i + h, j:j + w] * K).sum()  # 矩阵*矩阵 （亮点在于取大矩阵的一部分出来） 
    return Y

实现二维卷积层

In [None]:
class Conv2D(nn.Module):
    def __init__(self, kernel_size) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.rand(kernel_size))
        self.bias = nn.Parameter(torch.zeros(1))
    
    def forward(self, X):
        return corr2d(X, self.weight) + self.bias # self.weight 即为 kernel

## 示例：实现简单的边缘检测

一、自定义Kernel  从X和Kernel->Y

In [14]:
X = torch.ones((6, 8))
X[:, 2:6] = 0
X

tensor([[1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.]])

In [15]:
K = torch.tensor([[1.0, -1.0]]) 

In [16]:
Y = corr2d(X, K)
Y # 可见1 和 -1 就是边缘（01分界处）

tensor([[ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.]])

二、学习卷积核  从X和Y->Kernel

In [25]:
conv2d = nn.Conv2d(1, 1, kernel_size=(1, 2), bias=False)
X = X.reshape(1,1,6,8) # 
Y = Y.reshape(1,1,6,7)
for i in range(10): 
    # forward
    Y_hat = conv2d(X)
    loss = (Y_hat - Y) ** 2
    # backward
    conv2d.zero_grad()
    loss.sum().backward()
    conv2d.weight.data[:] -= 3e-2 * conv2d.weight.grad # 我们没用optimizer=SGD()..  直接手写一个简单的梯度下降
    print(f'batch {i+1}, loss {loss.sum():.3f}')    

batch 1, loss 18.152
batch 2, loss 8.908
batch 3, loss 4.591
batch 4, loss 2.484
batch 5, loss 1.404
batch 6, loss 0.822
batch 7, loss 0.495
batch 8, loss 0.304
batch 9, loss 0.189
batch 10, loss 0.119


In [26]:
conv2d.weight.data # 可见还不错

tensor([[[[ 0.9535, -1.0236]]]])