In [1]:
import torch
from torch import nn
from d2l import torch as d2l

In [2]:
def corr2d(X, k):
    '''It is a correlation computation--similiar to convolutional kernel'''
    h, w = K.shape
    Y = torch.zeros((X.shape[0]-h+1, X.shape[1]-w+1)) # result with step=1, no padding
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            Y[i, j] = (X[i:i+h, j:j+w]*K).sum() # here is bitwise multiplication
    return Y

In [4]:
X = torch.tensor([[0.,1.,2.],[3.,4.,5.],[6.,7.,8.]])
K = torch.tensor([[0.,1.],[2.,3.]])
corr2d(X, K)

tensor([[19., 25.],
        [37., 43.]])

In [5]:
class Con2D(nn.Module):
    def __init__(self, kernel_size):
        super().__init__()
        self.weight = nn.Parameter(torch.rand(kernel_size))
        self.bias = nn.Parameter(torch.zeros(1))
    def forward(self, x):
        return corr2d(x, self.weight) + self.bias

## Detect the edge

In [27]:
X = torch.ones((6,8))
X[1:4,2:6]=0
X

tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [28]:
K = torch.tensor([[1.,-1.]])

In [29]:
Y = corr2d(X, K)
Y

tensor([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.]])

## Learn kernel

In [60]:
conv2d = nn.Conv2d(1,1,kernel_size=(1,2),bias=False) # channel=1

X = X.reshape((1,1,6,8)) # batch_size, channel, weight, height
Y = Y.reshape((1,1,6,7))

for i in range(20):
    Y_hat = conv2d(X)
    loss = (Y_hat-Y)**2
    
    conv2d.zero_grad()
    loss.sum().backward()
    conv2d.weight.data[:] -= 0.012*conv2d.weight.grad
    if (i+1)%2==0:
        print(f'batch {i+1}, loss {loss.sum():.3f}')

batch 2, loss 4.630
batch 4, loss 2.424
batch 6, loss 1.779
batch 8, loss 1.319
batch 10, loss 0.978
batch 12, loss 0.726
batch 14, loss 0.538
batch 16, loss 0.399
batch 18, loss 0.296
batch 20, loss 0.220
