In [68]:
import torch
from typing import Tuple

def compute_output_shape(
        dimension: int, kernel_size: int, padding: int, stride: int
) -> int:
    return (dimension + 2 * padding - kernel_size) // stride + 1


def custom_unfold(
        input: torch.Tensor,
        kernel_size: Tuple[int, int],
        stride: int,
        padding: int
) -> torch.Tensor:
    b, c, h, w = input.shape
    k_h, k_w = kernel_size

    out_h = compute_output_shape(h, k_h, padding, stride)
    out_w = compute_output_shape(w, k_w, padding, stride)

    unfolded_matrix = torch.zeros(b, c * k_h * k_w, out_h * out_w)

    for batch in range(b):
        for channel in range(c):
            for i_h in range(0, h - k_h + 1, stride):
                for i_w in range(0, w - k_w + 1, stride):
                    patch = input[batch, channel, i_h:i_h + k_h, i_w:i_w + k_w]
                    patch_index = (i_h // stride) * out_w + (i_w // stride)
                    
                    start_index = channel * kernel_height * kernel_width
                    end_index = (channel + 1) * kernel_height * kernel_width
                    channel_slice = slice(start_index, end_index)

                    unfolded_matrix[batch, channel_slice, patch_index] = patch.flatten()

    return unfolded_matrix


########## variables ##########
input_height = 3
input_width = 3
input_channels = 5
batches = 2
kernel_height = 3
kernel_width = 3
kernel_size = (3, 3)
stride = 1
padding = 1

X = torch.randn(batches,
    input_channels,
    input_height,
    input_width
)

########## custom implementation ##########
unfolded_tensor: torch.Tensor = custom_unfold(
    input=X,
    kernel_size=(kernel_height, kernel_width),
    stride = stride,
    padding = padding
)

########## torch implementation ##########
torch_unfold = torch.nn.Unfold(
    kernel_size=(kernel_height, kernel_width),
    padding=padding,
    stride=stride
)
torch_unfolded_tensor = torch_unfold(X)

########## test ##########
assert unfolded_tensor.size() == torch_unfolded_tensor.size()