### Convolution 2D with for loops

In [1]:
import torch
import numpy as np


def convolve(X, W, stride, padding):
    bs, inp_c, inp_rows, inp_cols = X.shape
    out_c, filt_c, filt_rows, filt_cols = W.shape

    assert filt_rows == filt_cols
    k_size = filt_rows

    assert inp_c == filt_c

    out_rows = (inp_rows + 2 * padding - k_size) // stride + 1
    out_cols = (inp_cols + 2 * padding - k_size) // stride + 1

    output = np.zeros((bs, out_c, out_rows, out_cols))

    for b in range(bs):
        for o_c in range(out_c):
            for o_row in range(out_rows):
                for o_col in range(out_cols):
                    start_row = max(o_row * stride - padding, 0)
                    start_col = max(o_col * stride - padding, 0)

                    output[b, o_c, o_row, o_col] = convolve_input(
                        X,
                        W,
                        b,
                        o_c,
                        start_row,
                        start_col,
                        k_size,
                        inp_c,
                        inp_rows,
                        inp_cols,
                    )

    return output


def convolve_input(
    X,
    W,
    b,
    o_c,
    start_row,
    start_col,
    k_size,
    inp_c,
    inp_rows,
    inp_cols,
):
    result = 0.0

    for inp_c_idx in range(inp_c):
        for f_row in range(k_size):
            for f_col in range(k_size):
                inp_row_idx = start_row + f_row
                inp_col_idx = start_col + f_col

                if 0 <= inp_row_idx < inp_rows and 0 <= inp_col_idx < inp_cols:
                    result += (
                        X[b, inp_c_idx, inp_row_idx, inp_col_idx]
                        * W[o_c, inp_c_idx, f_row, f_col]
                    )

    return result


X = np.random.rand(1, 1, 5, 5)
W = np.random.rand(1, 1, 3, 3)
stride = 1
padding = 0

assert np.allclose(
    a=convolve(X=X, W=W, stride=stride, padding=padding),
    b=torch.nn.functional.conv2d(
        input=torch.tensor(X), weight=torch.tensor(W), stride=stride, padding=padding
    ).numpy(),
)