In [1]:
import torch

def get_device():

    if torch.cuda.is_available():
        return torch.device('cuda')
    elif torch.backends.mps.is_available():
        return torch.device('mos')
    return torch.device('cpu')


class Conv2d_by_hand(torch.nn.Module):
    def __init__(self, weights):
        super(Conv2d_by_hand, self).__init__()
        self.weights = weights

    def forward(self, x):
        # get dimensions -> set return dimensions
        batch, input_chn, input_height, input_width = x.shape
        ret_chn, _, kernel_height, kernel_width = self.weights.shape

        ret_height = input_height - kernel_height + 1
        ret_width = input_width - kernel_width + 1

        # reshape (for matmul) -> matmul -> reshape (for return)
        x_reshape = x.unfold(2, kernel_height, 1).unfold(3, kernel_width, 1)
        x_reshape = x_reshape.permute(0, 2, 3, 1, 4, 5).contiguous().view(batch, ret_height, ret_width, -1)

        w_reshape = self.weights.view(ret_chn, -1).t()

        ret = torch.matmul(x_reshape, w_reshape).view(batch, ret_height, ret_width, ret_chn)
        ret = ret.permute(0, 3, 1, 2).contiguous()

        return ret

In [10]:
device = get_device()
print(device)

inp = torch.randn(1, 3, 10, 12).to(device)


w = torch.randn(2, 3, 4, 5).to(device)


custom_conv2d_layer = Conv2d_by_hand(weights = w)

out = custom_conv2d_layer(inp)

print ((torch.nn.functional.conv2d(inp,w) - out).abs().max())

cpu
tensor(3.8147e-06)
