In [67]:
import numpy as np
import torch
import torch.nn.functional as F

In [68]:
class Conv2D:
    def __init__(self, input_data, kernel_size: tuple | int, bias: float | None = None,
                 stride: int = 1, padding: tuple[int, int] | int | str = (0, 0), dilation: int = 1):
        self.input_data_numpy = input_data[0, 0].numpy()
        self.input_data_torch = input_data
        self.bias = bias

        self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
        self.stride, self.dilation = stride, dilation

        if isinstance(padding, tuple):
            self.padding = padding
        elif padding == 'same' and stride == 1:
            self.padding = (self.kernel_size[0] - 1, self.kernel_size[1] - 1)
        else:
            self.padding = (0, 0)

        self.weight_tensor_torch = torch.randn(1, 1, *self.kernel_size)
        self.weight_tensor_numpy = self.weight_tensor_torch[0, 0].detach().numpy()

    def conv2d(self):
        image_height, image_width = self.input_data_numpy.shape
        H_out = (image_height + 2 * self.padding[0] - self.dilation * (self.kernel_size[0] - 1) - 1) // self.stride + 1
        W_out = (image_width + 2 * self.padding[1] - self.dilation * (self.kernel_size[1] - 1) - 1) // self.stride + 1

        padded_input = np.pad(self.input_data_numpy, self.padding, mode='constant')
        result = np.array([
            [np.sum(padded_input[y * self.stride:y * self.stride + self.kernel_size[0],
                 x * self.stride:x * self.stride + self.kernel_size[1]] * self.weight_tensor_numpy)
             for x in range(W_out)] for y in range(H_out)
        ])

        if self.bias is not None:
            result += self.bias

        return result

    def torch_conv2d(self):
        return F.conv2d(self.input_data_torch, self.weight_tensor_torch, self.bias, self.stride, self.padding,
                        self.dilation)

    def test(self, print_flag=False):
        my_conv2d = self.conv2d()
        torch_out = np.array(self.torch_conv2d())
        if print_flag:
            print("MyConv2d:\n", my_conv2d)
            print("PyTConv2d:\n", torch_out[0, 0])
        print(np.allclose(my_conv2d, torch_out[0, 0]))


### Далее проверям работоспособность фукнции на тесте со сверткой из торча, совпадают ли выходы и работает ли:

#### ТЕСТЫ

In [69]:
image = torch.randn(1, 1, 5, 5)
test1_out = Conv2D(image, kernel_size=1)
test1_out.test()

image = torch.randn(1, 1, 5, 5)
test2_out = Conv2D(image, kernel_size=1, padding=(0, 0))
test2_out.test()

image = torch.randn(1, 1, 5, 5)
test3_out = Conv2D(image, kernel_size=1, padding='same')
test3_out.test()

image = torch.randn(1, 1, 5, 5)
test4_out = Conv2D(image, kernel_size=4, padding='same')
test4_out.test()

image = torch.randn(1, 1, 5, 5)
test5_out = Conv2D(image, kernel_size=1, dilation=3)
test5_out.test()

image = torch.randn(1, 1, 5, 5)
test6_out = Conv2D(image, kernel_size=1, stride=4)
test6_out.test()


True
True
True
True
True
True


### Результат функции свертки работает, результат совпадает с оригинальной 