This new version brings the stride option by constructing the views with as_strided instead of sliding_window_view.  
It also gets rid of the tensordot which should make it more suitable for cupy like ~25x faster.

In [23]:
import numpy as np
import plotly.express as px
from numpy import ndarray
from numpy.lib.stride_tricks import as_strided
from numpy.lib.stride_tricks import sliding_window_view as sliding_views

from cifar_10_dataset_loading import load_cifar_10

In [24]:
x_train, y_train, x_test, y_test = load_cifar_10()

In [25]:
IMAGES_IDX = [0, 392, 351, 333]
images = x_train[IMAGES_IDX]
px.imshow(images, facet_col=0)

In [47]:
kernels = np.asarray(
[
    [
        # First kernel
        # First channel
            [
            [1, 0, -1, 0, 1], 
            [1, 0, -1, 0, 1],
            [1, 0, -1, 0, 1],
            [1, 0, -1, 0, 1],
            [1, 0, -1, 0, 1],
        ],
        # Second channel
        [
            [1, 0, 1, 1, 1], 
            [1, 0, 1, 1, 1], 
            [1, 0, 1, 1, 1],
            [1, 0, 1, 1, 1],
            [1, 0, 1, 1, 1],
        ],
        # Third channel
        [
            [0, 0, 0, 0, 1],
            [0, 0, 0, 1, 0],
            [0, 0, 1, 0, 0],
            [0, 1, 0, 0, 0],
            [1, 0, 0, 0, 0],
        ],
    ],
    [
        # Second kernel
        # First channel
        [
            [0.5, 0.5, 0.5, 0.5, 0.5], 
            [0.5, 0.5, 0.5, 0.5, 0.5], 
            [0.5, 0.5, 0.5, 0.5, 0.5], 
            [0.5, 0.5, 0.5, 0.5, 0.5],
            [0.5, 0.5, 0.5, 0.5, 0.5],
        ],
        # Second channel
        [
            [1, 2, 3, 4, 5],
            [1, 1, 1, 0, 0],
            [0, 1, 1, 1, 0],
            [0, 0, 1, 1, 1],
            [0, 1, 0, 1, 1],
        ],
        # Thrid channel
        [
            [0, 0, 1, 0, 0],
            [0, 0.5, 1, 0.5, 0],
            [1, 1, 1, 1, 1],
            [0, 0.5, 1, 0.5, 0],
            [0, 0, 1, 0, 0],
        ],
    ]
]).swapaxes(1, 3)
kernels.shape

(2, 5, 5, 3)

In [27]:
print(images.shape)
print(images.strides)

(4, 32, 32, 3)
(3072, 32, 1, 1024)


In [28]:
def sliding_window_2d(inputs:ndarray, window_shape:tuple[int, int], stride:int=1) -> ndarray:
    x_remainder = (inputs.shape[1] - window_shape[0]) % stride
    y_remainder = (inputs.shape[2] - window_shape[1]) % stride
    assert x_remainder == 0 and y_remainder == 0, "Incorrect strides."
    views_shape = (
        inputs.shape[0],
        1 + (inputs.shape[1] - window_shape[0]) // stride,
        1 + (inputs.shape[2] - window_shape[1]) // stride,
        *window_shape,
        inputs.shape[3],
    )
    strides = (
        inputs.strides[0],
        inputs.strides[1] * stride,
        inputs.strides[2] * stride,
        *inputs.strides[1:],
    )
    return as_strided(inputs, views_shape, strides)

views = sliding_window_2d(images, (5, 5))
display(views.shape)
px.imshow(views[0, 5], facet_col=0)

(4, 28, 28, 5, 5, 3)

Check that no copy of the inputs has been made to minimize memory allocation.

In [29]:
flatten_views = views.reshape(*views.shape[:3], -1)
print("flatten is not a copy:", flatten_views.base is not None)

flatten is not a copy: True


In [50]:
def new_valid_correlate_2d(inputs:ndarray, kernels:ndarray, stride:int=1) -> ndarray:
    views = sliding_window_2d(inputs, kernels.shape[1:3], stride)
    flatten_views = views.reshape(*views.shape[:3], -1)
    flatten_kernels = kernels.reshape(kernels.shape[0], -1).T
    return flatten_views @ flatten_kernels

In [51]:
im2col_correlations = new_valid_correlate_2d(images, kernels)
px.imshow(
    im2col_correlations.transpose(0, 3, 1, 2).reshape(8, 28, 28),
    facet_col=0
)

In [48]:
def legacy_valid_correlate(inputs:ndarray, k:ndarray) -> ndarray:
    views = sliding_views(inputs, k.shape[1:3], (1, 2))
    correlations = np.tensordot(views, k, axes=([3, 4, 5], [3, 1, 2]))
    return correlations

legacy_correlations = legacy_valid_correlate(images, kernels)
px.imshow(
    legacy_correlations.transpose(0, 3, 1, 2).reshape(8, 28, 28),
    facet_col=0
)

Verify results.

In [53]:
(im2col_correlations == legacy_correlations).all()

np.True_