### Convolution 2D with vector operations

#### Computation graph

```mermaid
graph LR
  A(input_batch)
  B(U)
  C(U_permuted)
  D(U_reshaped)
  E(kernel)
  F(kernel_reshaped)
  G(kernel_permuted)
  H(output_2D)
  I(output_3D)
  J(output_4D)
  K(output)

  L((unfold))
  M((permute_U))
  N((reshape_U_permuted))
  O((reshape_kernel))
  T((permute_kernel_reshaped))
  P((compute_output_2D))
  Q((reshape_output_to_3D))
  R((reshape_output_to_4D))
  S((permute_output))

  subgraph input_batch
    A --> L --> B
    B --> M --> C
    C --> N --> D
  end

  subgraph kernel
    E --> O --> F
    F --> T --> G
  end

  D --> P
  G --> P
  P --> H
  subgraph output
    H --> Q --> I
    I --> R --> J
    J --> S --> K
  end



#### Custom implementation

In [None]:
import torch

batches = 4
in_channels = 3
i_h = 8
i_w = 8
out_channels = 3
k_h = 3
k_w = 3
stride = 1
padding = 1


def unfold(
    input_batch: torch.Tensor, kernel_size: tuple, stride: int, padding: int
) -> torch.Tensor:
    """
    input:
        input_batch: (b, in_channels, i_h, i_w)
    output:
        (b, patch_size, patches)
    """

    torch_unfold = torch.nn.Unfold(
        kernel_size=kernel_size, padding=padding, stride=stride
    )

    return torch_unfold(input_batch)


def permute_U(U: torch.Tensor) -> torch.Tensor:
    """
    input:
        input_batch: (b, patch_size, patches)
    output:
        (b, patches, patch_size)
    """

    return U.permute(0, 2, 1)


def reshape_U_permuted(U_permuted: torch.Tensor) -> torch.Tensor:
    """
    input:
        input_batch: (b, patches, patch_size)
    output:
        (b * patches, patch_size)
    """
    b, patches, patch_size = U_permuted.shape
    return U_permuted.reshape(b * patches, patch_size)


def reshape_kernel(kernel: torch.Tensor) -> torch.Tensor:
    """
    input:
        input_batch: (out_channels, in_channels, kernel_h, kernel_w)
    output:
        (out_channels, in_channels * kernel_h * kernel_w)
    """
    out_channels, in_channels, kernel_h, kernel_w = kernel.shape
    return kernel.reshape(out_channels, in_channels * kernel_h * kernel_w)


def permute_kernel_reshaped(kernel_reshaped: torch.Tensor) -> torch.Tensor:
    """
    input:
        input_batch: (out_channels, in_channels * kernel_h * kernel_w)
    output:
        (in_channels * kernel_h * kernel_w, out_channels)
    """
    return kernel_reshaped.permute(1, 0)


def compute_output_2D(
    U_reshaped: torch.Tensor, kernel_permuted: torch.Tensor
) -> torch.Tensor:
    """
    input:
        U_reshaped: (b * patches, patch_size)
        kernel_permuted: (in_channels * kernel_h * kernel_w, out_channels)
    output:
        (b * patches, out_channels)
    """
    return U_reshaped.matmul(kernel_permuted)


def reshape_output_to_3D(
    output_2D: torch.Tensor, b: int, patches: int, out_channels: int
) -> torch.Tensor:
    """
    <<<<<:
        output: (b * patches, out_channels)
    >>>>>:
        (b, patches, out_channels)
    """
    return output_2D.reshape(b, patches, out_channels)


def reshape_output_to_4D(
    output_3D: torch.Tensor, b: int, o_h: int, o_w: int, out_channels: int
) -> torch.Tensor:
    """
    <<<<<:
        output: (b, patches, out_channels)

    >>>>>:
        (b, o_h, o_w, out_channels)
    """
    return output_3D.reshape(b, o_h, o_w, out_channels)


def permute_output(output: torch.Tensor) -> torch.Tensor:
    """
    <<<<<:
        output: (b, o_h, o_w, out_channels)

    >>>>>:
        (b, out_channels, o_h, o_w)
    """
    return output.permute(0, 3, 1, 2)


def get_output_dim(input_dim, filter_size, stride, padding):
    return (input_dim - filter_size + 2 * padding) // stride + 1


class Conv2DFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input_batch, kernel, stride=1, padding=1):
        ctx.stride = stride
        ctx.padding = padding

        b, in_channels, i_h, i_w = input_batch.shape
        out_channels, in_channels, kernel_h, kernel_w = kernel.shape

        o_h = get_output_dim(i_h, kernel_h, stride, padding)
        o_w = get_output_dim(i_w, kernel_w, stride, padding)

        U = unfold(
            input_batch=input_batch,
            kernel_size=(kernel_h, kernel_w),
            padding=padding,
            stride=stride,
        )
        b, patch_size, patches = U.shape

        U_permuted = permute_U(U)
        U_reshaped = reshape_U_permuted(U_permuted)

        kernel_reshaped = reshape_kernel(kernel)
        kernel_permuted = permute_kernel_reshaped(kernel_reshaped)

        output_2D = compute_output_2D(U_reshaped, kernel_permuted)
        output_3D = reshape_output_to_3D(output_2D, b, patches, out_channels)
        output_4D = reshape_output_to_4D(output_3D, b, o_h, o_w, out_channels)
        output = permute_output(output_4D)

        ctx.save_for_backward(input_batch, kernel, U_reshaped, kernel_permuted)
        return ctx, output

    @staticmethod
    def backward(ctx, grad_output):
        input_batch, kernel, U_reshaped, kernel_permuted = ctx.saved_tensors
        b, in_channels, i_h, i_w = input_batch.shape
        out_channels, in_channels, kernel_h, kernel_w = kernel.shape
        b, out_channels, o_h, o_w = grad_output.shape
        patches = o_h * o_w

        stride, padding = ctx.stride, ctx.padding

        # backward for `permute_output`
        # (b, out_channels, o_h, o_w) -> (b, o_h, o_w, out_channels)
        output_4D_grad = grad_output.permute(0, 2, 3, 1)

        # backward for `reshape_output_to_4D`
        # (b, o_h, o_w, out_channels) -> (b, patches, out_channels)
        output_3D_grad = output_4D_grad.reshape(b, o_h * o_w, out_channels)

        # backward for `reshape_output_to_3D`
        # (b, patches, out_channels) -> (b * patches, out_channels)
        output_2D_grad = output_3D_grad.reshape(b * patches, out_channels)

        # ----------------- kernel gradient ----------------------------------------------
        # backward of `compute_output_2D` with respect to kernel_permuted
        # U_reshaped of shape (b * patches, patch_size)
        # output_2D_grad of shape (b * patches, out_channels)
        # kernel_permuted of shape (in_channels * kernel_h * kernel_w, out_channels)
        kernel_permuted_grad = torch.matmul(U_reshaped.t(), output_2D_grad)
        assert kernel_permuted_grad.shape == (
            in_channels * kernel_h * kernel_w,
            out_channels,
        )

        # backward of `permute_kernel_reshaped`
        # (in_channels * kernel_h * kernel_w, out_channels)
        #   -> (out_channels, in_channels * kernel_h * kernel_w)
        kernel_reshaped_grad = kernel_permuted_grad.permute(1, 0)

        # backward of `reshape_kernel`
        # (out_channels, in_channels * kernel_h * kernel_w)
        #   -> (out_channels, in_channels, kernel_h, kernel_w)
        kernel_grad = kernel_reshaped_grad.reshape(
            out_channels, in_channels, kernel_h, kernel_w
        )
        assert kernel_grad.shape == (out_channels, in_channels, kernel_h, kernel_w)

        # ----------------- input_batch gradient -----------------------------------------
        # backward of `compute_output_2D` with respect to U_reshaped
        # kernel_permuted of shape (in_channels * kernel_h * kernel_w, out_channels)
        # output_2D_grad of shape (b * patches, out_channels)
        # U_reshaped of shape (b * patches, patch_size)
        patch_size = in_channels * kernel_h * kernel_w
        U_reshaped_grad = torch.matmul(output_2D_grad, kernel_permuted.t())
        assert U_reshaped_grad.shape == (b * patches, patch_size)

        # backward of `reshape_U_permuted`
        # (b * patches, patch_size) -> (b, patches, patch_size)
        U_permuted_grad = U_reshaped_grad.reshape(b, patches, patch_size)
        assert U_permuted_grad.shape == (b, patches, patch_size)

        # backward of `permute_U`
        # (b, patches, patch_size) -> (b, patch_size, patches)
        U_grad = U_permuted_grad.permute(0, 2, 1)
        assert U_grad.shape == (b, patch_size, patches)

        # backward of `unfold`
        torch_fold = torch.nn.Fold(
            output_size=(i_h, i_w),
            kernel_size=(kernel_h, kernel_w),
            padding=padding,
            stride=stride,
        )
        input_batch_grad = torch_fold(U_grad)

        return input_batch_grad, kernel_grad, None, None


# Initialize input tensors --------
torch.manual_seed(42)
input_batch = torch.randn(batches, in_channels, i_h, i_w).requires_grad_(True)
kernel = torch.randn(out_channels, in_channels, k_h, k_w).requires_grad_(True)

# Verify output --------
ctx, custom_conv2d_output = Conv2DFunc.apply(input_batch, kernel, stride, padding)

torch_conv2d_output = torch.nn.functional.conv2d(
    input=input_batch, weight=kernel, stride=stride, padding=padding
)
torch.testing.assert_allclose(custom_conv2d_output, torch_conv2d_output)

# Verify gradients --------
grad_output = torch.randn_like(custom_conv2d_output, requires_grad=True)
torch_conv2d_output.backward(grad_output)

custom_input_batch_grad, custom_kernel_grad, _, _ = Conv2DFunc.backward(
    ctx, grad_output
)

torch.testing.assert_allclose(custom_kernel_grad, kernel.grad)
torch.testing.assert_allclose(custom_input_batch_grad, input_batch.grad)