# Example

In [79]:
import torch
from abc import ABC, abstractmethod


def calc_out_shape(input_matrix_shape, out_channels, kernel_size, stride, padding):
    batch_size, channels_count, input_height, input_width = input_matrix_shape
    output_height = (input_height + 2 * padding - (kernel_size - 1) - 1) // stride + 1
    output_width = (input_width + 2 * padding - (kernel_size - 1) - 1) // stride + 1

    return batch_size, out_channels, output_height, output_width


class ABCConv2d(ABC):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride

    def set_kernel(self, kernel):
        self.kernel = kernel

    @abstractmethod
    def __call__(self, input_tensor):
        pass


def create_and_call_conv2d_layer(conv2d_layer_class, stride, kernel, input_matrix):
    out_channels = kernel.shape[0]
    in_channels = kernel.shape[1]
    kernel_size = kernel.shape[2]

    layer = conv2d_layer_class(in_channels, out_channels, kernel_size, stride)
    layer.set_kernel(kernel)

    return layer(input_matrix)


class Conv2d(ABCConv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size,
                                      stride, padding=0, bias=False)

    def set_kernel(self, kernel):
        self.conv2d.weight.data = kernel

    def __call__(self, input_tensor):
        return self.conv2d(input_tensor)


def test_conv2d_layer(conv2d_layer_class, batch_size=2,
                      input_height=4, input_width=4, stride=2):
    kernel = torch.tensor(
                      [[[[0., 1, 0],
                         [1,  2, 1],
                         [0,  1, 0]],

                        [[1, 2, 1],
                         [0, 3, 3],
                         [0, 1, 10]],

                        [[10, 11, 12],
                         [13, 14, 15],
                         [16, 17, 18]]]])

    in_channels = kernel.shape[1]

    input_tensor = torch.arange(0, batch_size * in_channels *
                                input_height * input_width,
                                out=torch.FloatTensor()) \
        .reshape(batch_size, in_channels, input_height, input_width)
    print('input_tensor.shape:',input_tensor.shape)
    # print(stride)
    # print(kernel)
    custom_conv2d_out = create_and_call_conv2d_layer(
        conv2d_layer_class, stride, kernel, input_tensor)
    
    print('custom_conv2d_out:', custom_conv2d_out)
    conv2d_out = create_and_call_conv2d_layer(
        Conv2d, stride, kernel, input_tensor)


    

    return torch.allclose(custom_conv2d_out, conv2d_out) \
             and (custom_conv2d_out.shape == conv2d_out.shape)


class Conv2dMatrixV2(ABCConv2d):
    # Функция преобразования кернела в нужный формат.
    def _convert_kernel(self):
        print('self.kernel.shape:',self.kernel.shape)
        converted_kernel = torch.zeros(self.kernel.shape[0], self.kernel.shape[1]*self.kernel.shape[2]*self.kernel.shape[3]) # Реализуйте преобразование кернела.
        for filter in range (self.kernel.shape[0]):
            for channel in range(self.kernel.shape[1]):
                converted_kernel[filter, channel*self.kernel.shape[2]*self.kernel.shape[3]:(channel+1)*self.kernel.shape[2]*self.kernel.shape[3]] =  \
                self.kernel[filter,channel,:].view(self.kernel.shape[2]*self.kernel.shape[3])
        print('! converted_kernel.shape:',converted_kernel.shape)
        # print('converted_kernel:\n',converted_kernel)
        return converted_kernel

    # Функция преобразования входа в нужный формат.
    def _convert_input(self, torch_input, output_height, output_width):
        print('torch_input.shape:', torch_input.shape)
        print('batch size:', torch_input.shape[0])
        print('self.kernel_size:',self.kernel_size)
        print('filters num:', self.kernel.shape[0])
        print('channels num:', self.kernel.shape[1])
        print('stride:', self.stride)
        kernels_in_frame_x = (torch_input.shape[2] - (self.kernel_size - 1) - 1) // self.stride + 1
        kernels_in_frame_y = (torch_input.shape[3] - (self.kernel_size - 1) - 1) // self.stride + 1
        print('kernels in the input frame (number of columns):', kernels_in_frame_x*kernels_in_frame_y)
        print('column height:', (self.kernel.shape[1] * self.kernel_size * self.kernel_size))

        # Инициализация массива правильного размера
        converted_input = torch.zeros((self.kernel.shape[1] * self.kernel_size * self.kernel_size),
                                      torch_input.shape[0]*((torch_input.shape[2] - (self.kernel_size - 1) - 1) // self.stride + 1)) 
        
        # Replace zeros with correct values from input tenzor
        # print('torch_input:\n',torch_input)
        for batch in range (torch_input.shape[0]):
            for filter in range(torch_input.shape[1]):
                for i in range(kernels_in_frame_x):
                    for j in range(kernels_in_frame_y):
                        kernel_cut = torch_input[batch, filter, i*self.kernel_size:(i+1)*self.kernel_size,
                                                 j*self.kernel_size:(j+1)*self.kernel_size].reshape(-1,1)
                        print('kernel_cut.shape:',kernel_cut.shape)
                        print('kernel_cut:\n',kernel_cut)
                        # put correct values
                        converted_input[filter*self.kernel_size*self.kernel_size:(filter+1)*self.kernel_size*self.kernel_size,batch] = kernel_cut[:,0]


        print('! converted_input shape:', converted_input.shape)
        return converted_input

    def __call__(self, torch_input):
        batch_size, out_channels, output_height, output_width\
            = calc_out_shape(
                input_matrix_shape=torch_input.shape,
                out_channels=self.kernel.shape[0],
                kernel_size=self.kernel.shape[2],
                stride=self.stride,
                padding=0)

        converted_kernel = self._convert_kernel()
        converted_input = self._convert_input(torch_input, output_height, output_width)
        print('output_height:', output_height)
        print('output_width:', output_width)
        conv2d_out_alternative_matrix_v2 = converted_kernel @ converted_input
        print('conv2d_out_alternative_matrix_v2 shape:', conv2d_out_alternative_matrix_v2.shape)
        print(torch_input.shape[0])
        print(self.out_channels)
        print(output_height)
        print(output_width)
        
        return conv2d_out_alternative_matrix_v2.transpose(1,0).view(torch_input.shape[0],
                                                     self.out_channels, 
                                                     output_height,
                                                     output_width)

# Проверка происходит автоматически вызовом следующего кода
# (раскомментируйте для самостоятельной проверки,
#  в коде для сдачи задания должно быть закомментировано):
print(test_conv2d_layer(Conv2dMatrixV2))

input_tensor.shape: torch.Size([2, 3, 4, 4])
self.kernel.shape: torch.Size([1, 3, 3, 3])
! converted_kernel.shape: torch.Size([1, 27])
torch_input.shape: torch.Size([2, 3, 4, 4])
batch size: 2
self.kernel_size: 3
filters num: 1
channels num: 3
stride: 2
kernels in the input frame (number of columns): 1
column height: 27
kernel_cut.shape: torch.Size([9, 1])
kernel_cut:
 tensor([[ 0.],
        [ 1.],
        [ 2.],
        [ 4.],
        [ 5.],
        [ 6.],
        [ 8.],
        [ 9.],
        [10.]])
kernel_cut.shape: torch.Size([9, 1])
kernel_cut:
 tensor([[16.],
        [17.],
        [18.],
        [20.],
        [21.],
        [22.],
        [24.],
        [25.],
        [26.]])
kernel_cut.shape: torch.Size([9, 1])
kernel_cut:
 tensor([[32.],
        [33.],
        [34.],
        [36.],
        [37.],
        [38.],
        [40.],
        [41.],
        [42.]])
kernel_cut.shape: torch.Size([9, 1])
kernel_cut:
 tensor([[48.],
        [49.],
        [50.],
        [52.],
        [5