In [2]:
import numpy as np
import torch
import scipy.signal

In [17]:
def conv2d(X, K, padding=(0, 0), stride=(1, 1)):
    h, w = K.shape
    X = torch.nn.functional.pad(X, (padding[0], padding[1], padding[0], padding[1]), 'constant', 0)
    Y = torch.zeros((X.shape[0] - h + 1, X.shape[1] - w + 1))
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            Y[i, j] = (X[i:i + h, j:j + w] * K).sum()
    return Y

def conv2d_raschka(X, W, p=(0, 0), s=(1, 1)):
    W_rot = torch.rot90(W, 2, (0, 1))
    X_orig = X.clone()
    n1 = X_orig.shape[0] + 2 * p[0]
    n2 = X_orig.shape[1] + 2 * p[1]
    X_padded = torch.zeros((n1, n2))
    X_padded[p[0]:p[0] + X_orig.shape[0], p[1]:p[1] + X_orig.shape[1]] = X_orig

    res = []
    for i in range(0,
            int((X_padded.shape[0] - W_rot.shape[0]) / s[0]) + 1,
            s[0]):
        res.append([])
        for j in range(0,
                int((X_padded.shape[1] - W_rot.shape[1]) / s[1]) + 1,
                s[1]):
            X_sub = X_padded[i:i + W_rot.shape[0], j:j + W_rot.shape[1]]
            res[-1].append(torch.sum(X_sub * W_rot))   
    return torch.tensor(res)


X = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
K = torch.tensor([[0, 1], [2, 3]])

print(conv2d(X, K))
print(conv2d_raschka(X, K))
print(scipy.signal.convolve2d(X, K, mode='valid'))
print(torch.conv2d(X.unsqueeze(0).unsqueeze(0), K.unsqueeze(0).unsqueeze(0)))

tensor([[19., 25.],
        [37., 43.]])
tensor([[ 5., 11.],
        [23., 29.]])
[[ 5 11]
 [23 29]]
tensor([[[[19, 25],
          [37, 43]]]])


In [19]:
# Testing autodiff

X = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]], dtype=torch.float32, requires_grad=True)
K = torch.tensor([[0, 1], [2, 3]], dtype=torch.float32, requires_grad=True)

Y = conv2d(X, K)
Y.sum().backward()
print(X.grad)

# Testing autodiff with PyTorch

X = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]], dtype=torch.float32, requires_grad=True)
K = torch.tensor([[0, 1], [2, 3]], dtype=torch.float32, requires_grad=True)

Y = torch.conv2d(X.unsqueeze(0).unsqueeze(0), K.unsqueeze(0).unsqueeze(0))
Y.sum().backward()
print(X.grad)

tensor([[0., 1., 1.],
        [2., 6., 4.],
        [2., 5., 3.]])
tensor([[0., 1., 1.],
        [2., 6., 4.],
        [2., 5., 3.]])
