In [1]:
import torch
import torch.nn as nn

In [8]:
# 实现卷积

def corr2d(X, K):
    h, w = K.shape
    Y = torch.zeros(X.shape[0] - h + 1, X.shape[1] - w + 1)
    
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            Y[i, j] = (X[i:i + h, j:j + w] * K).sum()
    
    return Y

In [9]:
# 实现卷积层

class Conv2D(nn.Module):
    def __init__(self, kernel_size):
        super().__init__()
        self.weight = nn.Parameter(torch.rand(kernel_size)) # rand是均匀分布，randn是高斯分布
        self.bias = nn.Parameter(torch.zeros(1)) # 广播
    
    def forward(self, X):
        return corr2d(X, self.weight) + self.bias

In [10]:
# 手动设置卷积核进行边缘检测

X = torch.ones((6, 8))
X[:, 2:6] = 0
print(X)

K = torch.tensor([[1.0, -1.0]])
Y = corr2d(X, K) # 水平上，1为1->0的边缘，-1为0->1的边缘
print(Y)

Y_t = corr2d(X.t(), K) # 垂直上
print(Y_t)

tensor([[1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.]])
tensor([[ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.]])
tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])


In [15]:
# 学习卷积层

# 构造一个二维卷积层，一个输入通道，一个输出通道，卷积核形状为(1, 2)
conv2d = nn.Conv2d(1, 1, kernel_size=(1, 2), bias=False)

# 使用四维输入和输出格式（批量大小、通道、高度、宽度）
X = X.reshape((1, 1, 6, 8))
Y = Y.reshape((1, 1, 6, 7))
lr = 3e-2

for i in range(20):
    conv2d.zero_grad()
    Y_hat = conv2d(X)
    l = ((Y_hat - Y) ** 2).sum()
    l.backward()
    conv2d.weight.data[:] -= lr * conv2d.weight.grad
    
    print(f'epoch {i + 1: d}, loss {l: .4f}')

epoch  1, loss  3.8791
epoch  2, loss  1.6955
epoch  3, loss  0.7627
epoch  4, loss  0.3561
epoch  5, loss  0.1738
epoch  6, loss  0.0891
epoch  7, loss  0.0479
epoch  8, loss  0.0270
epoch  9, loss  0.0157
epoch  10, loss  0.0094
epoch  11, loss  0.0058
epoch  12, loss  0.0036
epoch  13, loss  0.0023
epoch  14, loss  0.0014
epoch  15, loss  0.0009
epoch  16, loss  0.0006
epoch  17, loss  0.0004
epoch  18, loss  0.0002
epoch  19, loss  0.0002
epoch  20, loss  0.0001


In [16]:
print(conv2d.weight.data)

tensor([[[[ 0.9989, -1.0009]]]])
