In [2]:
import numpy as np
import plotly.express as px
from numpy import ndarray
from numpy.lib.stride_tricks import as_strided

from cifar_10_dataset_loading import load_cifar_10

In [3]:
x_train, y_train, x_test, y_test = load_cifar_10()

In [4]:
IMAGES_IDX = [0, 392, 351, 333]
images = x_train[IMAGES_IDX]
px.imshow(images, facet_col=0)

In [5]:
kernels = np.asarray(
[
    [
        # First kernel
        # First channel
            [
            [1, 0, -1, 0, 1], 
            [1, 0, -1, 0, 1],
            [1, 0, -1, 0, 1],
            [1, 0, -1, 0, 1],
            [1, 0, -1, 0, 1],
        ],
        # Second channel
        [
            [1, 0, 1, 1, 1], 
            [1, 0, 1, 1, 1], 
            [1, 0, 1, 1, 1],
            [1, 0, 1, 1, 1],
            [1, 0, 1, 1, 1],
        ],
        # Third channel
        [
            [0, 0, 0, 0, 1],
            [0, 0, 0, 1, 0],
            [0, 0, 1, 0, 0],
            [0, 1, 0, 0, 0],
            [1, 0, 0, 0, 0],
        ],
    ],
    [
        # Second kernel
        # First channel
        [
            [0.5, 0.5, 0.5, 0.5, 0.5], 
            [0.5, 0.5, 0.5, 0.5, 0.5], 
            [0.5, 0.5, 0.5, 0.5, 0.5], 
            [0.5, 0.5, 0.5, 0.5, 0.5],
            [0.5, 0.5, 0.5, 0.5, 0.5],
        ],
        # Second channel
        [
            [1, 2, 3, 4, 5],
            [1, 1, 1, 0, 0],
            [0, 1, 1, 1, 0],
            [0, 0, 1, 1, 1],
            [0, 1, 0, 1, 1],
        ],
        # Thrid channel
        [
            [0, 0, 1, 0, 0],
            [0, 0.5, 1, 0.5, 0],
            [1, 1, 1, 1, 1],
            [0, 0.5, 1, 0.5, 0],
            [0, 0, 1, 0, 0],
        ],
    ]
]) #.swapaxes(1, 3)
kernels.shape

(2, 3, 5, 5)

In [13]:
print(images.shape)
print(images.strides)

(4, 32, 32, 3)
(3072, 32, 1, 1024)


In [19]:
def sliding_window_2d(inputs:ndarray, window_shape:tuple[int, int], stride:int=1) -> ndarray:
    x_remainder = (inputs.shape[1] - window_shape[0]) % stride
    y_remainder = (inputs.shape[2] - window_shape[1]) % stride
    assert x_remainder == 0 and y_remainder == 0, "Incorrect strides."
    views_shape = (
        inputs.shape[0],
        1 + (inputs.shape[1] - window_shape[0]) // stride,
        1 + (inputs.shape[2] - window_shape[1]) // stride,
        *window_shape,
        inputs.shape[3],
    )
    strides = (
        inputs.strides[0],
        inputs.strides[1] * stride,
        inputs.strides[2] * stride,
        *inputs.strides[1:],
    )
    return as_strided(inputs, views_shape, strides)

views = sliding_window_2d(images, (8, 8), 4)

(4, 7, 7, 8, 8, 3)
(3072, 128, 4, 32, 1, 1024)


In [22]:
px.imshow(views[0, 1], facet_col=0)

In [31]:
views.shape

(4, 7, 7, 8, 8, 3)

In [29]:
padded_kernels = np.pad(kernels.swapaxes(1, 3), ((0, 0), (0, 3), (0, 3), (0, 0)))
padded_kernels.shape

(2, 8, 8, 3)

In [32]:
padded_kernels = np.pad(kernels.swapaxes(1, 3), ((0, 0), (0, 3), (0, 3), (0, 0)))
correlations = np.tensordot(views, padded_kernels, axes=([5, 3, 4], [3, 1, 2]))

In [33]:
px.imshow(correlations[2], facet_col=2)

In [67]:
def valid_correlate(inputs:np.ndarray, kernels:np.ndarray) -> np.ndarray:
    return np.einsum(
        "bijcxy, kxyc -> bijk",
        sliding_window_view(inputs, (kernels.shape[1], kernels.shape[1]), (1, 2)),
        kernels
    )

def full_convolve(inputs:np.ndarray, k:np.ndarray) -> np.ndarray:
    pad = ((0, 0), (k.shape[1]-1, k.shape[1]-1), (k.shape[2]-1, k.shape[2]-1), (0, 0))
    return valid_correlate(np.pad(inputs, pad, "constant"), np.flip(k, (1, 2)))

convolutions = full_convolve(images, kernels)
px.imshow(convolutions[2], facet_col=2)
