In [41]:
import numpy as np
import torch

## Convolution

In [56]:
def conv(A,F,stride=(1,1),padding=(0,0)):
    n1,n2=A.shape
    f1,f2=F.shape
    s1,s2=stride
    p1,p2=padding
    d1=(n1+2*p1-f1)//s1+1
    d2=(n2+2*p2-f2)//s2+1
    result=np.zeros((d1,d2))
    A=np.pad(A, ((p1, p1), (p2, p2)))
    for i in range(d1):
        for j in range(d2):
            A_part=A[i*s1:i*s1+f1,j*s2:j*s2+f2]
            result[i,j]=np.sum(F*A_part)
    return result

def conv_torch(A,F,stride=(1,1),padding=(0,0)):
    conv_torch = torch.nn.Conv2d(1, 1, kernel_size=F.shape, 
                                stride=stride, padding=padding, bias=False)
    conv_torch.weight = torch.nn.Parameter(
        torch.tensor(F).unsqueeze(0).unsqueeze(0).float()
        )
    A_torch = torch.tensor(A).unsqueeze(0).unsqueeze(0).float()
    return conv_torch(A_torch).detach().squeeze().numpy()

In [57]:
A = np.ones((4,4))
F = np.ones((3,3))
print(f"{A=}\n{F=}")

A=array([[1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.]])
F=array([[1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.]])


In [58]:
output_np = conv(A,F,stride=(1,1),padding=(1,1))
print(f"{output_np=}")

output_np=array([[4., 6., 6., 4.],
       [6., 9., 9., 6.],
       [6., 9., 9., 6.],
       [4., 6., 6., 4.]])


In [59]:
output_torch = conv_torch(A,F,stride=(1,1),padding=(1,1))
print(f"{output_torch=}")

output_torch=array([[4., 6., 6., 4.],
       [6., 9., 9., 6.],
       [6., 9., 9., 6.],
       [4., 6., 6., 4.]], dtype=float32)


In [60]:
A=np.array([[2,1],[3,2]],dtype=float)
F=np.array([
[1,2,1],
[2,0,1],
[0,2,1]],dtype=float)
print(f"{A=}\n{F=}")

A=array([[2., 1.],
       [3., 2.]])
F=array([[1., 2., 1.],
       [2., 0., 1.],
       [0., 2., 1.]])


In [61]:
output_np = conv(A,F,stride=(1,1),padding=(1,1))
print(f"{output_np=}")

output_np=array([[ 9.,  8.],
       [ 7., 10.]])


In [62]:
output_torch = conv_torch(A,F,stride=(1,1),padding=(1,1))
print(f"{output_torch=}")

output_torch=array([[ 9.,  8.],
       [ 7., 10.]], dtype=float32)
