**<h1>1. Реализовать функцию свертки. Придумать тест для проверки.**

In [72]:
import torch
import torch.nn.functional as F

def convolution(input_tensor, filter_tensor, stride=(1, 1), padding=(0, 0)):
    N, C_in, H_in, W_in = input_tensor.size()
    M, C_fl, R, S = filter_tensor.size()

    if C_in == C_fl:
      C = C_in
    else:
      print("Error!")

    H_out = (H_in + 2 * padding[0] - R) // stride[0] + 1
    W_out = (W_in + 2 * padding[1] - S) // stride[1] + 1

    I = F.pad(input_tensor, (padding[1], padding[1], padding[0], padding[0]))

    output_tensor = torch.zeros(N, M, H_out, W_out)

    for n in range(N):
      for m in range(M):
        for x in range(W_out):
          for y in range(H_out):
            for i in range(R):
              for j in range(S):
                for k in range(C):
                  output_tensor[n, m, x, y] += I[n, k, x * stride[0] + i, y * stride[1] + j] * filter_tensor[m, k, i, j]

    return output_tensor


input_tensor = torch.tensor([[[[1.2, 2.1, 3.1],
                               [4.8, 5.4, 6.2],
                               [7.5, 8.5, 9.3]]]], dtype=torch.float32)
filter_tensor = torch.tensor([[[[1, 0],
                                [0, 1]]]], dtype=torch.float32)

output_tensor_custom = convolution(input_tensor, filter_tensor, stride=(1, 1), padding=(0, 0))
output_tensor_library = F.conv2d(input_tensor, filter_tensor, stride=(1, 1), padding=(0, 0))

print(output_tensor_custom)
print(output_tensor_library)

if torch.allclose(output_tensor_custom, output_tensor_library):
    print("Outputs match")
else:
    print("Outputs don't match")


tensor([[[[ 6.6000,  8.3000],
          [13.3000, 14.7000]]]])
tensor([[[[ 6.6000,  8.3000],
          [13.3000, 14.7000]]]])
Outputs match


**<h1>2. Написать функцию реализующую сверточный слой через im2col. Сделать проверку результата с помощью прямой реализации свертки.**

In [73]:
def convolution_im2col(input_tensor, filter_tensor, stride=(1, 1), padding=(0, 0)):
    N, C_in, H_in, W_in = input_tensor.size()
    M, C_fl, R, S = filter_tensor.size()

    if C_in == C_fl:
      C = C_in
    else:
      print("Error!")

    input_padded = F.pad(input_tensor, (padding[1], padding[1], padding[0], padding[0]))

    H_out = (H_in + 2 * padding[0] - R) // stride[0] + 1
    W_out = (W_in + 2 * padding[1] - S) // stride[1] + 1

    image_matrix = [[[0 for i in range(H_out*W_out)] for j in range(R*R*C)] for k in range(C)]
    for n in range(N):
      for c in range(C):
        start = [0,0]
        matrix_column = 0
        while start[0] <= H_in - R:
          matrix_row = 0
          for row in range(start[0],start[0]+R):
            for column in range(start[1],start[1]+R):
              image_matrix[n][matrix_row][matrix_column] = input_tensor[n,c,row,column].float()
              matrix_row += 1
          matrix_column += 1
          start[1] += 1
          if start[1] == W_in-1:
            start[0] += 1
            start[1] = 0
    input_im2col = torch.tensor(image_matrix)

    filters_matrix = [[0 for i in range(R*R*C)] for j in range(M)]
    for n in range(M):
      for c in range(C_fl):
        el_num = 0
        for row in range(R):
          for column in range(S):
            filters_matrix[n][el_num] = filter_tensor[n,c,row,column].float()
            el_num += 1

    filter_reshaped = torch.tensor(filters_matrix)

    output_im2col = torch.matmul(filter_reshaped, input_im2col)
    output_tensor = output_im2col.view(N, M, H_out, W_out)

    return output_tensor

In [74]:
input_tensor = torch.tensor([[[[1.2, 2.1, 3.1],
                               [4.8, 5.4, 6.2],
                               [7.5, 8.5, 9.3]]]], dtype=torch.float32)
filter1 = torch.tensor([[[1, 0],
                         [0, 1]]], dtype=torch.float32)
filter2 = torch.tensor([[[0, 1],
                         [1, 0]]], dtype=torch.float32)

filter_tensor = torch.stack([filter1, filter2], dim=0)
output_im2col = convolution_im2col(input_tensor, filter_tensor)
output_convolution = convolution(input_tensor, filter_tensor)

if torch.allclose(output_im2col, output_convolution):
    print("Outputs match")
else:
    print("Outputs don't match")

Outputs match
