In [56]:
import numpy as np
from numpy.typing import NDArray
from typing import Tuple, Any

In [57]:
def check_type(obj: Any, name: str):
    if isinstance(obj, int):
        return "int"
    elif not isinstance(obj, Tuple):
        raise ValueError(f"{name} should be tuple[int, int] or int, give {type(obj)}")
    if len(obj) == 1:
        return "int"
    if len(obj) == 0:
        raise ValueError(f"{name} should have len = 2, give {len(obj)}")

    if len(obj) > 2:
        raise RuntimeError(f"{name} should have len is 2, give {len(obj)}")
    
    if not isinstance(obj[0], int) or not isinstance(obj[1], int):
        raise ValueError(f"{name} should be tuple[int, int], give {type(obj[0]), type(obj[1])}")

In [63]:
class Conv2D:
    def __init__(
        self,
        channel_in: int, 
        channel_out: int, 
        kernel_size: int | Tuple[int, int],
        stride: int | Tuple[int, int],
        padding: int | Tuple[int, int]
    ) -> None:
        if not isinstance(channel_in, int): 
            raise ValueError(f"channel_in should be int, give {type(channel_in)}")
        if not isinstance(channel_out, int):
            raise ValueError(f"channel_out should be int, give {type(channel_out)}")
        if check_type(kernel_size, "kernel_size") == "int":
            kernel_size = (kernel_size, kernel_size)
        if check_type(stride, "stride") == "int":
            stride = (stride, stride)
        if check_type(padding, "padding") == "int":
            padding = (padding, padding)
        print(f"kernel_size = {kernel_size}, stride = {stride}, padding = {padding}")
        self._channel_in = channel_in
        self._channel_out = channel_out
        self._kernel_size = kernel_size
        self._stride = stride
        self._padding = padding
        np.random.seed(42)

    def __call__(self, img: NDArray[np.uint8]):
        # check that kernal_size is good 
        # create list of filters
        self._filters = []
        for _ in range(self._channel_out):
            self._filters.append(np.random.rand(self._channel_in, self._kernel_size[0], self._kernel_size[1]))
        return self.forward(img)
    
    def forward(self, img: NDArray[np.uint8]):
        n_filters = len(self._filters)
        n_channels = len(self._channel_in)

        out_height = ((img.shape[1] + 2 * self._padding[0] - self._kernel_size[0]) // self._stride[0]) + 1
        out_width = ((img.shape[2] + 2 * self._padding[1] - self._kernel_size[1]) // self._stride[1]) + 1
        out_channels = n_filters
        out = np.zeros((out_channels, out_height, out_width))
        # add padding
        for n in range(n_filters):
            for c in range(n_channels):
                for x in range(0, img.shape[1], step=self._stride[0]):
                    for y in range(0, img.shape[2], step=self._stride[1]):
                        for i in range(self._kernel_size[0]):
                            for j in range(self._kernel_size[1]):
                                out[n][x][y] += img[c][x + i][y + j] * n_filters[n][c][i][j]


        return out