In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler

import torchvision.datasets as dset
import torchvision.transforms as T
import torch.nn.functional as F

import numpy as np

USE_GPU = True
dtype = torch.float32

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print_every = 100
print('using device:', device)

using device: cuda


In [4]:
class Conv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super().__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding

        self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
        self.bias = nn.Parameter(torch.zeros(out_channels))

    def forward(self, x):
        B, C, H, W = x.shape

        H_out = (H + 2 * self.padding - self.kernel_size) // self.stride + 1
        W_out = (W + 2 * self.padding - self.kernel_size) // self.stride + 1

        x = F.pad(x, (self.padding, self.padding, self.padding, self.padding))

        out = torch.zeros(B, self.out_channels, H_out, W_out, device=x.device, dtype=x.dtype)

        for b in range(B):
            for o in range(self.out_channels):
                for i in range(H_out):
                    for j in range(W_out):
                        h_start = i * self.stride
                        w_start = j * self.stride

                        h_end = h_start + self.kernel_size
                        w_end = w_start + self.kernel_size

                        x_patch = x[b, :, h_start:h_end, w_start:w_end]

                        out[b, o, i, j] = torch.sum(x_patch * self.weight[o]) + self.bias[o]

        return out
    
in_channels = 3

conv = Conv2d(in_channels, 2, 4, 1, 1)
x = torch.randn(1, in_channels, 64, 64)
conv(x).shape

torch.Size([1, 2, 63, 63])

In [3]:
class Conv2dFast(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super().__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding

        self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
        self.bias = nn.Parameter(torch.zeros(out_channels))

    def forward(self, x):
        B, C, H, W = x.shape

        H_out = (H + 2 * self.padding - self.kernel_size) // self.stride + 1
        W_out = (W + 2 * self.padding - self.kernel_size) // self.stride + 1

        x = F.unfold(x, self.kernel_size, padding=self.padding, stride=self.stride)

        W = self.weight.reshape(self.out_channels, -1)

        out = W @ x
        out += self.bias[None, :, None]

        out = out.reshape(B, self.out_channels, H_out, W_out)
        return out
    
in_channels = 3

conv = Conv2dFast(in_channels, 2, 4, 1, 1)
x = torch.randn(5, in_channels, 64, 64)
conv(x).shape

torch.Size([5, 2, 63, 63])