In [21]:
import torch
from torch import nn
from torch import functional as F
from d2l import torch as d2l

In [35]:
def conv2d(arr : torch.Tensor, k : torch.Tensor, stride = 1) -> torch.Tensor:
        assert len(arr.shape) == len(k.shape) == 2
        assert arr.shape[0] >= k.shape[0] and arr.shape[1] >= k.shape[1]

        m = int((arr.shape[0] - k.shape[0]) / stride + 1)
        n = int((arr.shape[1] - k.shape[1]) / stride + 1)

        out = torch.zeros((m, n))
        print(out.shape)

        for i in range(m):
                for j in range(n):
                        start_i = i * stride
                        end_i = i * stride + k.shape[0]
                        start_j = j * stride
                        end_j = j * stride + k.shape[0]
                        out[i, j] = (arr[start_i : end_i, start_j : end_j] * k).sum()

        return out

In [23]:
arr = torch.stack([torch.arange(9)] * 9)
k = torch.ones((3, 3))
print(arr, k)

tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8],
        [0, 1, 2, 3, 4, 5, 6, 7, 8],
        [0, 1, 2, 3, 4, 5, 6, 7, 8],
        [0, 1, 2, 3, 4, 5, 6, 7, 8],
        [0, 1, 2, 3, 4, 5, 6, 7, 8],
        [0, 1, 2, 3, 4, 5, 6, 7, 8],
        [0, 1, 2, 3, 4, 5, 6, 7, 8],
        [0, 1, 2, 3, 4, 5, 6, 7, 8],
        [0, 1, 2, 3, 4, 5, 6, 7, 8]]) tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])


In [24]:
conv2d(arr, k, 1)

torch.Size([7, 7])


tensor([[ 9., 18., 27., 36., 45., 54., 63.],
        [ 9., 18., 27., 36., 45., 54., 63.],
        [ 9., 18., 27., 36., 45., 54., 63.],
        [ 9., 18., 27., 36., 45., 54., 63.],
        [ 9., 18., 27., 36., 45., 54., 63.],
        [ 9., 18., 27., 36., 45., 54., 63.],
        [ 9., 18., 27., 36., 45., 54., 63.]])

In [36]:
class Net(nn.Module):
        def __init__(self, size : tuple):
                super().__init__()
                self.weight = nn.Parameter(torch.rand(size))
                self.bias = nn.Parameter(torch.zeros(1))

        def forward(self, X):
                return conv2d(X, self.weight) + self.bias

net = Net((3, 3))

In [37]:
net(arr)

torch.Size([7, 7])


tensor([[ 4.7082,  9.7866, 14.8649, 19.9433, 25.0216, 30.0999, 35.1783],
        [ 4.7082,  9.7866, 14.8649, 19.9433, 25.0216, 30.0999, 35.1783],
        [ 4.7082,  9.7866, 14.8649, 19.9433, 25.0216, 30.0999, 35.1783],
        [ 4.7082,  9.7866, 14.8649, 19.9433, 25.0216, 30.0999, 35.1783],
        [ 4.7082,  9.7866, 14.8649, 19.9433, 25.0216, 30.0999, 35.1783],
        [ 4.7082,  9.7866, 14.8649, 19.9433, 25.0216, 30.0999, 35.1783],
        [ 4.7082,  9.7866, 14.8649, 19.9433, 25.0216, 30.0999, 35.1783]],
       grad_fn=<AddBackward0>)

In [48]:
X = torch.rand((1, 1, 6, 8))
Y = torch.rand((1, 1, 6, 7))
print(X.shape, Y.shape)

torch.Size([1, 1, 6, 8]) torch.Size([1, 1, 6, 7])


In [59]:
conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(1, 2), bias=False)

for i in range(10):
        conv.zero_grad()
        Y_hat = conv(X)
        loss = (Y_hat - Y) ** 2
        
        loss.sum().backward()
        conv.weight.data[:] -= 3e-2 * conv.weight.grad

        if i % 2 == 0:
                print(f"epoch {i} loss {loss.sum()}")

epoch 0 loss 62.585357666015625
epoch 2 loss 24.129060745239258
epoch 4 loss 10.812389373779297
epoch 6 loss 6.201056957244873
epoch 8 loss 4.604224681854248


In [60]:
x = torch.rand(9, 9)

In [61]:
conv2d = nn.Conv2d(1, 1, kernel_size = (3, 5), padding=(0, 1), stride = (3, 4))
conv2d(X).shape

torch.Size([1, 1, 2, 2])