# 二维卷积层

In [1]:
import torch 
from torch import nn

In [2]:
def corr2d(X,k):
    h,w=k.shape
    Y=torch.zeros((X.shape[0]-h+1,X.shape[1]-w+1))
    for i in range(Y.shape[0]):
        for j in range(Y.shape[0]):
            Y[i,j]=(X[i:i+h,j:j+w]*k).sum()
    return Y

In [4]:
#我们可以构造图5.1中的输入数组X、核数组K来验证二维互相关运算的输出。
X=torch.tensor([[0,1,2],[3,4,5],[6,7,8]])
k=torch.tensor([[0,1],[2,3]])
corr2d(X,k)

tensor([[19., 25.],
        [37., 43.]])

# 二卷积层

In [7]:
X = torch.ones(6, 8)
X[:, 2:6] = 0
X

tensor([[1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.]])

In [9]:
K = torch.tensor([[1, -1]])
Y = corr2d(X, K)
Y

tensor([[ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.]])

In [10]:
class Conv2D(nn.Module):
    def __init__(self,kernel_size):
        super(Conv2D,self).__init__()
        self.weight=nn.Parameter(torch.randn(kernel_size))
        self.bias=nn.Parameter(torch.randn(1))
    def forward(self,x):
        return corr2d(x,self.weight)+self.bias

# 通过数据学习核数组

In [11]:
# 构造一个核数组形状是(1, 2)的二维卷积层
conv2d=Conv2D(kernel_size=(1,2))

step=20
lr=0.01
for i in range(step):
    Y_hat=conv2d(X)
    l=((Y_hat-Y)**2).sum()
    l.backward()
    
    #梯度下降
    conv2d.weight.data-=lr*conv2d.weight.grad
    conv2d.bias.data-=lr*conv2d.bias.grad
    
    #梯度清零
    conv2d.weight.grad.fill_(0)
    conv2d.bias.grad.fill_(0)
    if (i+1)%5==0:
        print('Step %d,loss %.3f'%(i+1,l.item()))

Step 5,loss 7.420
Step 10,loss 2.066
Step 15,loss 0.575
Step 20,loss 0.160


In [12]:
#可以看到，20次迭代后误差已经降到了一个比较小的值。现在来看一下学习到的卷积核的参数。
print("weight: ", conv2d.weight.data)
print("bias: ", conv2d.bias.data)

weight:  tensor([[ 0.8984, -0.8982]])
bias:  tensor([-0.0001])
