In [0]:
import torch

# Создаем входной массив из двух изображений RGB 3*3
input_images = torch.tensor(
      [[[[0,  1,  2],
         [3,  4,  5],
         [6,  7,  8]],

        [[9, 10, 11],
         [12, 13, 14],
         [15, 16, 17]],

        [[18, 19, 20],
         [21, 22, 23],
         [24, 25, 26]]],


       [[[27, 28, 29],
         [30, 31, 32],
         [33, 34, 35]],

        [[36, 37, 38],
         [39, 40, 41],
         [42, 43, 44]],

        [[45, 46, 47],
         [48, 49, 50],
         [51, 52, 53]]]])


correct_padded_images = torch.tensor(
       [[[[0.,  0.,  0.,  0.,  0.],
          [0.,  0.,  1.,  2.,  0.],
          [0.,  3.,  4.,  5.,  0.],
          [0.,  6.,  7.,  8.,  0.],
          [0.,  0.,  0.,  0.,  0.]],

         [[0.,  0.,  0.,  0.,  0.],
          [0.,  9., 10., 11.,  0.],
          [0., 12., 13., 14.,  0.],
          [0., 15., 16., 17.,  0.],
          [0.,  0.,  0.,  0.,  0.]],

         [[0.,  0.,  0.,  0.,  0.],
          [0., 18., 19., 20.,  0.],
          [0., 21., 22., 23.,  0.],
          [0., 24., 25., 26.,  0.],
          [0.,  0.,  0.,  0.,  0.]]],


        [[[0.,  0.,  0.,  0.,  0.],
          [0., 27., 28., 29.,  0.],
          [0., 30., 31., 32.,  0.],
          [0., 33., 34., 35.,  0.],
          [0.,  0.,  0.,  0.,  0.]],

         [[0.,  0.,  0.,  0.,  0.],
          [0., 36., 37., 38.,  0.],
          [0., 39., 40., 41.,  0.],
          [0., 42., 43., 44.,  0.],
          [0.,  0.,  0.,  0.,  0.]],

         [[0.,  0.,  0.,  0.,  0.],
          [0., 45., 46., 47.,  0.],
          [0., 48., 49., 50.,  0.],
          [0., 51., 52., 53.,  0.],
          [0.,  0.,  0.,  0.,  0.]]]])



In [0]:
def get_padding2d(input_images):
    shape = input_images.shape
    zeros_shape = (shape[0], shape[1], shape[2]+2, shape[2]+2)
    padded_images = torch.zeros(*zeros_shape)
    padded_images[:,:,1:-1,1:-1] = input_images
    padded_images
    return padded_images

In [3]:
# Проверка происходит автоматически вызовом следующего кода
# (раскомментируйте для самостоятельной проверки,
#  в коде для сдачи задания должно быть закомментировано):
print(torch.allclose(get_padding2d(input_images), correct_padded_images))

True


In [4]:
import numpy as np


def calc_out_shape(input_matrix_shape, out_channels, kernel_size, stride, padding):
    out_shape = (input_matrix_shape[0], 
                 out_channels,
                 (input_matrix_shape[2] - (kernel_size - 1) + 2 * padding) // stride,
                 (input_matrix_shape[3] - (kernel_size - 1) + 2 * padding) // stride)

    return out_shape

def calc_out_shape(input_matrix_shape, out_channels, kernel_size, stride, padding):
    batch_size, channels_count, input_height, input_width = input_matrix_shape
    output_height = (input_height + 2 * padding - (kernel_size - 1) - 1) // stride + 1
    output_width = (input_width + 2 * padding - (kernel_size - 1) - 1) // stride + 1

    return batch_size, out_channels, output_height, output_width

print(np.array_equal(
    calc_out_shape(input_matrix_shape=[2, 3, 10, 10],
                   out_channels=10,
                   kernel_size=3,
                   stride=1,
                   padding=0),
    [2, 10, 8, 8]))

True


In [0]:
from abc import ABC, abstractmethod

class ABCConv2d(ABC):
  def __init__(self, in_channels, out_channels, kernel_size, stride):
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.kernel_size = kernel_size
    self.stride = stride

  def set_kernel(self, kernel):
    self.kernel = kernel

  @abstractmethod
  def __call__(self, input_tensor):
    pass

In [0]:
class Conv2d(ABCConv2d):
  def __init__(self, in_channels, out_channels, kernel_size, stride):
    self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size,
                                  stride, padding=0, bias=0)
    
  def set_kernel(self, kernel):
    self.conv2d.weight.data = kernel

  def __call__(self, input_tensor):
    return self.conv2d(input_tensor)

In [0]:
def create_and_call_conv2d_layer(conv2d_layer_class, stride, kernel, input_matrix):
  out_channels = kernel.shape[0]
  in_channels = kernel.shape[1]
  kernel_size= kernel.shape[2]

  layer = conv2d_layer_class(in_channels, out_channels, kernel_size, stride)
  layer.set_kernel(kernel)

  return layer(input_matrix)

In [0]:
def test_conv2d_layer(conv2d_layer_class, batch_size=2, 
                      input_height=4, input_width=4, stride=2):
  kernel = torch.tensor(
                      [[[[0., 1, 0],
                         [1,  2, 1],
                         [0,  1, 0]],

                        [[1, 2, 1],
                         [0, 3, 3],
                         [0, 1, 10]],

                        [[10, 11, 12],
                         [13, 14, 15],
                         [16, 17, 18]]]])
  
  in_channels = kernel.shape[1]

  input_tensor = torch.arange(0, 
                              batch_size * in_channels * input_height * input_width,
                              out=torch.FloatTensor()).reshape(batch_size, 
                                                               in_channels, 
                                                               input_height, 
                                                               input_width)
                              
  custom_conv2d_out = create_and_call_conv2d_layer(conv2d_layer_class,
                                                   stride, kernel, 
                                                   input_tensor)
  conv2d_out = create_and_call_conv2d_layer(Conv2d, stride, 
                                            kernel, input_tensor)
  
  return torch.allclose(custom_conv2d_out, conv2d_out)

In [9]:
print(test_conv2d_layer(Conv2d))

True


In [0]:
class Conv2dLoop(ABCConv2d):

    def __call__(self, input_tensor):
        out_shape = calc_out_shape(input_tensor.shape, self.out_channels,
                                   self.kernel_size, self.stride, 0)
        output_tensor = torch.zeros(out_shape)

        for b in range(input_tensor.shape[0]):
            for n_filter in range(self.out_channels):
                for i in range((input_tensor.shape[2] - self.kernel.shape[1]) // self.stride + 1):
                    for j in range((input_tensor.shape[3] - self.kernel.shape[2]) // self.stride + 1):
                        output_tensor[b][n_filter][i][j] = (input_tensor[b, : ,
                                                                         i*self.stride: i*self.stride+self.kernel.shape[2], 
                                                                         j*self.stride: j*self.stride+self.kernel.shape[3]] 
                                                            * self.kernel).sum()
        return output_tensor

In [11]:
print(test_conv2d_layer(Conv2dLoop))

True


In [0]:
class Conv2dMatrix(ABCConv2d):
  
    def _unsqueeze_kernel(self, torch_input, output_height, output_width):
        kernel_unsqueezed = torch.zeros((output_height * output_width, 
                                        torch_input.shape[1] * \
                                        torch_input.shape[2] * \
                                        torch_input.shape[3]))

        k = 0
        for i in range(output_height // self.stride + 1):
            for j in range(output_width // self.stride + 1):
                out = torch.zeros(torch_input[0].shape)
                out[:,
                    i*self.stride: i*self.stride+self.kernel[0][0].shape[0],
                    j*self.stride: j*self.stride+self.kernel[0][0].shape[1]] = self.kernel[0]
                kernel_unsqueezed[k] = out.view(-1)
                k += 1

        return kernel_unsqueezed

    def __call__(self, torch_input):
        batch_size, out_channels, output_height, output_width\
              = calc_out_shape(
                  input_matrix_shape=torch_input.shape,
                  out_channels=self.kernel.shape[0],
                  kernel_size=self.kernel.shape[2],
                  stride=self.stride,
                  padding=0)
              
        kernel_unsqueezed = self._unsqueeze_kernel(torch_input,
                                                    output_height, output_width)
        result = kernel_unsqueezed @ torch_input.view((batch_size, -1)).permute(1, 0)

        return result.permute(1, 0).view((batch_size, self.out_channels,
                                            output_height, output_width))

In [13]:
print(test_conv2d_layer(Conv2dMatrix))

True


In [0]:
class Conv2dMatrixV2(ABCConv2d):

    def _convert_kernel(self):
        converted_kernel = self.kernel.view(self.out_channels, -1)

        return converted_kernel

    def _convert_input(self, torch_input, output_height, output_width):
        converted_input = torch.zeros(torch_input.shape[0], 
                          output_height*output_width,
                          self.in_channels*self.kernel.shape[2]*self.kernel.shape[3])

        for b in range(torch_input.shape[0]):
            k = 0
            for i in range(output_height // self.stride + 1):
                slice_i = slice(i*self.stride, i*self.stride+self.kernel.shape[2])

                for j in range(output_width // self.stride + 1):
                    slice_j = slice(j*self.stride, j*self.stride+self.kernel.shape[3])

                    converted_input[b, k] = torch_input[b, :, slice_i, slice_j].reshape(-1)
                    k += 1

        return converted_input

    def __call__(self, torch_input):
        batch_size, out_channels, output_height, output_width \
        = calc_out_shape(
            input_matrix_shape=torch_input.shape,
            out_channels=self.kernel.shape[0],
            kernel_size=self.kernel.shape[2],
            stride=self.stride,
            padding=0
        )

        converted_kernel = self._convert_kernel()
        converted_input = self._convert_input(torch_input, output_height, output_width)
        conv2d_out_alternative_matrix_v2 = converted_input.reshape(batch_size, -1) @ converted_kernel.permute(1,0)

        return conv2d_out_alternative_matrix_v2.view(torch_input.shape[0], self.out_channels, 
                                                     output_height, output_width)

In [24]:
print(test_conv2d_layer(Conv2dMatrixV2))

True
