In [1]:
import math
import itertools

import numpy as np
import torch 


In [2]:
SCHEMES = {"SAME", "VALID"}

def _apply_conv2d(section, filter):
    output_conv2d = torch.conv2d(section, filter)
    output_matmul = (section @ filter.transpose(0, 1)).unsqueeze(-1).unsqueeze(-1)
    
    assert torch.allclose(output_conv2d, output_matmul), (
        output_conv2d.shape, output_matmul.shape)
    
    return section @ filter
    

def conv_2d(input, filter, stride, overflow_scheme="SAME"):
    """
    """
    assert input.ndim == 4, input.shape
    assert filter.ndim == 4, filter.shape
    assert stride.ndim == 1, stride.shape
    assert len(stride) == 2, stride.shape
    assert isinstance(overflow_scheme, str), type(overflow_scheme).mro()
    assert input.shape[3] == filter.shape[0], (filter.shape[3], filter.shape[0])
    assert overflow_scheme in SCHEMES, (overflow_scheme, SCHEMES)

    batch_size = input.shape[0]
    x_size = input.shape[1]
    y_size = input.shape[2]
    x_stride = stride[0]
    y_stride = stride[1]
    x_filter_size = filter.shape[1]
    y_filter_size = filter.shape[2]
    num_channels = input.shape[3]
    output_channels = filter.shape[4]

    if overflow_scheme == "SAME":
        output = np.empty((
            math.ceil(x_size / x_stride), 
            math.ceil(y_size / y_stride), 
            output_channels,
        ))
    elif overflow_scheme == "VALID":
        output = np.empty((
            math.ceil(x_size / x_stride), 
            math.ceil(y_size / y_stride), 
            output_channels,
        ))
    else:
        raise


    for b_idx in range(batch_size):
        for x_idx in range(0, x_size, x_stride):
            for y_idx in range(0, y_size, y_stride):
                if overflow_scheme == "SAME":
                    feature_left_boundary_x = x_idx - x_filter_size // 2
                    feature_right_boundary_x = x_idx + x_filter_size // 2
                    feature_left_boundary_y = y_idx - y_filter_size // 2
                    feature_right_boundary_y = y_idx - y_filter_size // 2
                
                elif overflow_scheme == "VALID":
                    feature_left_boundary_x = x_idx
                    feature_right_boundary_x = x_idx + x_filter_size
                    feature_left_boundary_y = y_idx
                    feature_right_boundary_y = y_idx + y_filter_size
                
                else:
                    raise ValueError(f"`overflow_scheme` must be one of {SCHEMES}, got `{overflow_scheme}`")


                output[b_idx, x_idx:x_idx + x_filter_size, y_idx + y_filter_size] = _apply_conv2d(
                    input[:, feature_left_boundary_x: feature_right_boundary_x, feature_left_boundary_y:feature_right_boundary_y], 
                    filter,
                )

In [None]:
h = 3
w = 5
b = 1
i = 4
o = 8

image = torch.rand(b, i, h, w)
filter = torch.rand(o, i, h, w)

output_conv2d = torch.conv2d(image, filter)

image = image.reshape(b, -1)
filter = filter.reshape(o, -1)

###############################################################################
output_einsum = torch.einsum("bz,oz->bo", image, filter).unsqueeze(-1).unsqueeze(-1)
print(f"{torch.allclose(output_conv2d, output_einsum) = }")

###############################################################################
output_matmul = (image @ filter.transpose(0, 1)).unsqueeze(-1).unsqueeze(-1)
print(f"{torch.allclose(output_conv2d, output_matmul) = }")

###############################################################################
output_iter = torch.empty((b, o, 1, 1))
for b_idx, o_idx in itertools.product(range(b), range(o)):
    output_iter[b_idx, o_idx] = torch.dot(
        image[b_idx].reshape(-1), 
        filter[o_idx].reshape(-1),
    )

print(f"{torch.allclose(output_conv2d, output_iter) = }")
