In [4]:
import torch
import torch.nn.functional as F
from torch import nn

class Conv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=(3, 3), stride=1, padding=1):
        super(Conv2D, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding

        # Initialize kernel weights
        self.kernel = nn.Parameter(torch.randn(
            out_channels, in_channels, kernel_size[0], kernel_size[1]
        ))

    def forward(self, input_batch):
        """
        Perform convolution operation using unfold for vectorized implementation.
        """
        b, c, h, w = input_batch.size()
        k_h, k_w = self.kernel_size

        # Pad the input
        input_padded = F.pad(
            input_batch, (self.padding, self.padding, self.padding, self.padding)
        )

        # Unfold the input into patches
        patches = F.unfold(
            input_padded, kernel_size=self.kernel_size, stride=self.stride
        )  # Shape: (b, c*k_h*k_w, p)

        # Reshape the kernel for matrix multiplication
        kernel_flat = self.kernel.view(self.out_channels, -1)  # Shape: (out_channels, c*k_h*k_w)

        # Perform matrix multiplication and reshape the output
        output = torch.matmul(kernel_flat, patches)  # Shape: (b, out_channels, p)
        output = output.view(
            b, self.out_channels, (h + 2 * self.padding - k_h) // self.stride + 1,
            (w + 2 * self.padding - k_w) // self.stride + 1,
        )
        return output

In [5]:
class Conv2DFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input_batch, kernel, stride=1, padding=1):
        """
        Forward pass of convolution using unfold.
        """
        ctx.save_for_backward(input_batch, kernel)
        ctx.stride = stride
        ctx.padding = padding

        # Get input dimensions
        b, c, h, w = input_batch.size()
        k_h, k_w = kernel.size(2), kernel.size(3)

        # Pad and unfold the input
        input_padded = F.pad(input_batch, (padding, padding, padding, padding))
        patches = F.unfold(input_padded, kernel_size=(k_h, k_w), stride=stride)  # U

        # Reshape kernel and perform matrix multiplication
        kernel_flat = kernel.view(kernel.size(0), -1)  # (C_out, C_in * K_H * K_W)
        output = torch.matmul(kernel_flat, patches)  # (B, C_out, P)

        # Reshape output to image format
        h_out = (h + 2 * padding - k_h) // stride + 1
        w_out = (w + 2 * padding - k_w) // stride + 1
        output = output.view(b, kernel.size(0), h_out, w_out)

        # Save unfolded patches for backward pass
        ctx.save_for_backward(patches)

        return output

    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward pass for convolution.
        """
        input_batch, kernel = ctx.saved_tensors
        stride = ctx.stride
        padding = ctx.padding

        b, c, h, w = input_batch.size()
        k_h, k_w = kernel.size(2), kernel.size(3)

        # Reshape grad_output for matrix multiplication
        grad_output_reshaped = grad_output.view(grad_output.size(0), grad_output.size(1), -1)  # Y^\triangledown

        # Load saved unfolded patches
        input_padded = F.pad(input_batch, (padding, padding, padding, padding))
        patches = F.unfold(input_padded, kernel_size=(k_h, k_w), stride=stride)  # U

        # Compute gradient w.r.t. kernel
        grad_kernel = torch.matmul(grad_output_reshaped, patches.permute(0, 2, 1))  # (B, C_out, C_in * K_H * K_W)
        grad_kernel = grad_kernel.sum(dim=0)  # Sum over batch
        grad_kernel = grad_kernel.view(kernel.size())  # Reshape to kernel size

        # Compute gradient w.r.t. input
        kernel_flat = kernel.view(kernel.size(0), -1)  # (C_out, C_in * K_H * K_W)
        grad_input_patches = torch.matmul(kernel_flat.t(), grad_output_reshaped)  # (B, C_in * K_H * K_W, P)
        grad_input_patches = grad_input_patches.view(b, -1, patches.size(2))  # (B, C_in * K_H * K_W, P)

        grad_input = F.fold(grad_input_patches, (h, w), kernel_size=(k_h, k_w), stride=stride, padding=padding)

        return grad_input, grad_kernel, None, None


In [7]:
input_batch = torch.randn(16, 3, 32, 32)  # Batch of 16, 3 channels, 32x32 resolution
kernel = torch.randn(64, 3, 3, 3)  # 64 output channels, 3 input channels, 3x3 kernel
stride = 1
padding = 1

# Using Conv2D Module
conv = Conv2D(3, 64, kernel_size=(3, 3), stride=1, padding=1)
output = conv(input_batch)
print("Output shape (Conv2D):", output.shape)

Conv2DFunc.forward(input_batch, kernel, stride=stride, padding=padding)


# Using Conv2DFunc
output_func = Conv2DFunc.apply(input_batch, kernel, stride, padding)
print("Output shape (Conv2DFunc):", output_func.shape)


Output shape (Conv2D): torch.Size([16, 64, 32, 32])
Output shape (Conv2DFunc): torch.Size([16, 64, 32, 32])


In [2]:
def unfold_naive(input, kernel_size, stride, padding):
    b, c, h, w = input.shape
    k_h, k_w = kernel_size
    s_h, s_w = stride

    # Add padding
    padded = torch.zeros((b, c, h + 2 * padding, w + 2 * padding))
    padded[:, :, padding:-padding, padding:-padding] = input

    # Compute output dimensions
    h_out = (h + 2 * padding - k_h) // s_h + 1
    w_out = (w + 2 * padding - k_w) // s_w + 1

    # Extract patches
    output = []
    for i in range(0, h_out):
        for j in range(0, w_out):
            patch = padded[:, :, i*s_h:i*s_h+k_h, j*s_w:j*s_w+k_w]
            output.append(patch.flatten(2))  # Flatten along spatial dimensions

    return torch.stack(output, dim=-1).permute(0, 2, 1)  # Shape: (b, c*k_h*k_w, p)
