In [1]:
import torch

In [2]:
input_images = torch.tensor(
      [[[[0,  1,  2],
         [3,  4,  5],
         [6,  7,  8]],

        [[9, 10, 11],
         [12, 13, 14],
         [15, 16, 17]],

        [[18, 19, 20],
         [21, 22, 23],
         [24, 25, 26]]],


       [[[27, 28, 29],
         [30, 31, 32],
         [33, 34, 35]],

        [[36, 37, 38],
         [39, 40, 41],
         [42, 43, 44]],

        [[45, 46, 47],
         [48, 49, 50],
         [51, 52, 53]]]])

In [3]:
def get_padding2d(input_images):
  images = []
  for batch in input_images:
    images.append([])
    for image in batch:
      images[-1].append([])
      for row in image:
        images[-1][-1].append([0.])
        images[-1][-1][-1].extend(row.float().tolist())
        images[-1][-1][-1].extend([0.])
      images[-1][-1].insert(0, [0.]*len(images[-1][-1][-1]))
      images[-1][-1].append([0.]*len(images[-1][-1][-1]))
  return torch.tensor(images)

In [4]:
get_padding2d(input_images)

tensor([[[[ 0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  1.,  2.,  0.],
          [ 0.,  3.,  4.,  5.,  0.],
          [ 0.,  6.,  7.,  8.,  0.],
          [ 0.,  0.,  0.,  0.,  0.]],

         [[ 0.,  0.,  0.,  0.,  0.],
          [ 0.,  9., 10., 11.,  0.],
          [ 0., 12., 13., 14.,  0.],
          [ 0., 15., 16., 17.,  0.],
          [ 0.,  0.,  0.,  0.,  0.]],

         [[ 0.,  0.,  0.,  0.,  0.],
          [ 0., 18., 19., 20.,  0.],
          [ 0., 21., 22., 23.,  0.],
          [ 0., 24., 25., 26.,  0.],
          [ 0.,  0.,  0.,  0.,  0.]]],


        [[[ 0.,  0.,  0.,  0.,  0.],
          [ 0., 27., 28., 29.,  0.],
          [ 0., 30., 31., 32.,  0.],
          [ 0., 33., 34., 35.,  0.],
          [ 0.,  0.,  0.,  0.,  0.]],

         [[ 0.,  0.,  0.,  0.,  0.],
          [ 0., 36., 37., 38.,  0.],
          [ 0., 39., 40., 41.,  0.],
          [ 0., 42., 43., 44.,  0.],
          [ 0.,  0.,  0.,  0.,  0.]],

         [[ 0.,  0.,  0.,  0.,  0.],
          [ 0., 45., 46., 

In [5]:
correct_padded_images = torch.tensor(
       [[[[0.,  0.,  0.,  0.,  0.],
          [0.,  0.,  1.,  2.,  0.],
          [0.,  3.,  4.,  5.,  0.],
          [0.,  6.,  7.,  8.,  0.],
          [0.,  0.,  0.,  0.,  0.]],

         [[0.,  0.,  0.,  0.,  0.],
          [0.,  9., 10., 11.,  0.],
          [0., 12., 13., 14.,  0.],
          [0., 15., 16., 17.,  0.],
          [0.,  0.,  0.,  0.,  0.]],

         [[0.,  0.,  0.,  0.,  0.],
          [0., 18., 19., 20.,  0.],
          [0., 21., 22., 23.,  0.],
          [0., 24., 25., 26.,  0.],
          [0.,  0.,  0.,  0.,  0.]]],


        [[[0.,  0.,  0.,  0.,  0.],
          [0., 27., 28., 29.,  0.],
          [0., 30., 31., 32.,  0.],
          [0., 33., 34., 35.,  0.],
          [0.,  0.,  0.,  0.,  0.]],

         [[0.,  0.,  0.,  0.,  0.],
          [0., 36., 37., 38.,  0.],
          [0., 39., 40., 41.,  0.],
          [0., 42., 43., 44.,  0.],
          [0.,  0.,  0.,  0.,  0.]],

         [[0.,  0.,  0.,  0.,  0.],
          [0., 45., 46., 47.,  0.],
          [0., 48., 49., 50.,  0.],
          [0., 51., 52., 53.,  0.],
          [0.,  0.,  0.,  0.,  0.]]]])

In [6]:
print(torch.allclose(get_padding2d(input_images), correct_padded_images))

True


In [7]:
import numpy as np
import math

def calc_out_shape(input_matrix_shape, out_channels, kernel_size, stride, padding):
    out_shape = input_matrix_shape.copy()
    out_shape[1] = out_channels
    out_shape[2] = math.floor(((out_shape[2] + padding * 2) - (kernel_size - 1) - 1) // stride + 1)
    out_shape[3] = math.floor(((out_shape[3] + padding * 2) - (kernel_size - 1) - 1) // stride + 1)

    return out_shape

print(np.array_equal(
    calc_out_shape(input_matrix_shape=[2, 3, 10, 10],
                   out_channels=10,
                   kernel_size=3,
                   stride=1,
                   padding=0),
    [2, 10, 8, 8]))

True


In [8]:
from abc import ABC, abstractmethod

def calc_out_shape(input_matrix_shape, out_channels, kernel_size, stride, padding):
  batch_size, channels_count, input_height, input_width = input_matrix_shape
  output_height = (input_height + 2 * padding - (kernel_size - 1) - 1) // stride + 1
  output_width = (input_width + 2 * padding - (kernel_size - 1) - 1) // stride + 1

  return batch_size, out_channels, output_height, output_width


class ABCConv2d(ABC):
  def __init__(self, in_channels, out_channels, kernel_size, stride):
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.kernel_size = kernel_size
    self.stride = stride

  def set_kernel(self, kernel):
    self.kernel = kernel

  @abstractmethod
  def __call__(self, input_tensor):
    pass


class Conv2d(ABCConv2d):
  def __init__(self, in_channels, out_channels, kernel_size, stride):
    self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size,
                                  stride, padding=0, bias=False)

  def set_kernel(self, kernel):
    self.conv2d.weight.data = kernel

  def __call__(self, input_tensor):
    return self.conv2d(input_tensor)


def create_and_call_conv2d_layer(conv2d_layer_class, stride, kernel, input_matrix):
  out_channels = kernel.shape[0]
  in_channels = kernel.shape[1]
  kernel_size = kernel.shape[2]

  layer = conv2d_layer_class(in_channels, out_channels, kernel_size, stride)
  layer.set_kernel(kernel)

  return layer(input_matrix)


def test_conv2d_layer(conv2d_layer_class, batch_size=2,
                      input_height=4, input_width=4, stride=2):
  kernel = torch.tensor(
                      [[[[0., 1, 0],
                         [1,  2, 1],
                         [0,  1, 0]],

                        [[1, 2, 1],
                         [0, 3, 3],
                         [0, 1, 10]],

                        [[10, 11, 12],
                         [13, 14, 15],
                         [16, 17, 18]]]])

  in_channels = kernel.shape[1]

  input_tensor = torch.arange(0, batch_size * in_channels *
                                input_height * input_width,
                                out=torch.FloatTensor()) \
        .reshape(batch_size, in_channels, input_height, input_width)

  custom_conv2d_out = create_and_call_conv2d_layer(
        conv2d_layer_class, stride, kernel, input_tensor)
  conv2d_out = create_and_call_conv2d_layer(
        Conv2d, stride, kernel, input_tensor)

  return torch.allclose(custom_conv2d_out, conv2d_out) \
             and (custom_conv2d_out.shape == conv2d_out.shape)

In [9]:
class Conv2dLoop(ABCConv2d):
  def __call__(self, input_tensor):
    output_tensor = []

    n,c,h,w = input_tensor.shape

    for batch in range(0, n):
      output_tensor.append([])
      for ker in self.kernel:
        output_tensor[-1].append([])
        for hi in range(0, h - self.kernel_size):
          output_tensor[-1][-1].append([])
          for wi in range(0, w - self.kernel_size):
            sum = 0
            for channel in range(0, c):
              for i in range(0, self.kernel_size):
                for j in range(0, self.kernel_size):
                  sum = sum + input_tensor[batch][channel][hi + i][wi + j] * ker[channel][i][j]
            output_tensor[-1][-1][-1].append(sum)

    return torch.tensor(output_tensor)

In [10]:
print(test_conv2d_layer(Conv2dLoop))

True


In [11]:
class Conv2dMatrix(ABCConv2d):
  def _unsqueeze_kernel(self, torch_input, output_height, output_width):
    kernel_unsqueezed = []

    for ker in self.kernel:
      kernel_unsqueezed.append([])
      for h in range(0, output_height * output_width):
        start = (h // output_width) * torch_input.shape[3] + (h % output_width)
        for channel in range(0, ker.shape[0]):
          gap = torch_input.shape[3] - ker[channel].shape[1]
          linear_kernel = []
          for h in ker[channel]:
            for w in h:
              linear_kernel.append(w.item())
            linear_kernel.extend([0.] * gap)
          segment = [0.] * start
          segment.extend(linear_kernel)
          kernel_unsqueezed[-1].extend(segment)
          kernel_unsqueezed[-1].extend([0.] * (torch_input.shape[2] * torch_input.shape[3] - len(segment)))

    return torch.tensor(kernel_unsqueezed)

  def __call__(self, torch_input):
    batch_size, out_channels, output_height, output_width = calc_out_shape(
        input_matrix_shape=torch_input.shape,
        out_channels=self.kernel.shape[0],
        kernel_size=self.kernel.shape[2],
        stride=self.stride,
        padding=0)

    kernel_unsqueezed = self._unsqueeze_kernel(torch_input, output_height, output_width)
    result = kernel_unsqueezed @ torch_input.view((batch_size, -1)).permute(1, 0)
    return result.permute(1, 0).view((batch_size, self.out_channels,
                                      output_height, output_width))

In [12]:
print(test_conv2d_layer(Conv2dMatrix))

True


In [76]:
class Conv2dMatrixV2(ABCConv2d):
  def _convert_kernel(self):
    converted_kernel = []
    for ker in self.kernel:
      converted_kernel.append([])
      for channel in ker:
        converted_kernel[-1].extend(channel.reshape([-1]).tolist())

    return torch.tensor(converted_kernel)

  def _convert_input(self, torch_input, output_height, output_width):
      n,c,h,w = torch_input.shape
      output_tensor = [ [] for _ in range(n) ]

      for batch in range(0, n):
        for channel in range(0, c):
          for hi in range(0, h - self.kernel_size):
            for wi in range(0, w - self.kernel_size):
              for i in range(0, self.kernel_size):
                for j in range(0, self.kernel_size):
                  output_tensor[batch].append(torch_input[batch][channel][hi + i][wi + j].item())

      return torch.tensor(output_tensor).transpose(0, 1)

  def __call__(self, torch_input):
      batch_size, out_channels, output_height, output_width = calc_out_shape(
          input_matrix_shape=torch_input.shape,
          out_channels=self.kernel.shape[0],
          kernel_size=self.kernel.shape[2],
          stride=self.stride,
          padding=0)

      converted_kernel = self._convert_kernel()
      converted_input = self._convert_input(torch_input, output_height, output_width)

      conv2d_out_alternative_matrix_v2 = converted_kernel @ converted_input
      return conv2d_out_alternative_matrix_v2.transpose(1,0).view(torch_input.shape[0],
                                                                  self.out_channels,
                                                                  output_height,
                                                                  output_width)

In [77]:
print(test_conv2d_layer(Conv2dMatrixV2))

True
