New implementation built on to of the previous ones.   
We are still using `numpy.lib.stride_tricks.sliding_window_view` but we will use a conventional `@` dot product.  
To make this work, we will flatten inputs and kernels to make the thing a LOT faster than a `np.tensordot` call on unflatten arrays.  

> Once we start scaling up to 300 samples the mem alloc caused by the sliding views + swapaxes + reshape take a lot of time AND the `@` and `tensordot` operation scale to the same execution time.  
> At 3000 samples the `tensordot` is actually faster thant the `@` operation.

Conclusion: We'll keep the tensordot implementation... (sad)

In [1]:
from timeit import time, timeit

import numpy as np
import plotly.express as px
from numpy.lib.stride_tricks import sliding_window_view
from scipy.signal import correlate2d

from time_utils import time_to_exec, print_time_dict, reset_time_dict
from cifar_10_dataset_loading import load_cifar_10

In [2]:
x_train, y_train, x_test, y_test = load_cifar_10()
x_train = x_train.transpose(0, 3, 1, 2) # we also want to 

In [3]:
x_train.shape

(50000, 3, 32, 32)

In [4]:
IMAGES_IDX = [0, 360, 351, 333]
images = x_train[IMAGES_IDX]
px.imshow(images.transpose(0, 2, 3, 1), facet_col=0)

In [5]:
kernels = np.asarray(
[
    [
        # First kernel
        # First channel
            [
            [1, 0, -1, 0, 1], 
            [1, 0, -1, 0, 1],
            [1, 0, -1, 0, 1],
            [1, 0, -1, 0, 1],
            [1, 0, -1, 0, 1],
        ],
        # Second channel
        [
            [1, 0, 1, 1, 1], 
            [1, 0, 1, 1, 1], 
            [1, 0, 1, 1, 1],
            [1, 0, 1, 1, 1],
            [1, 0, 1, 1, 1],
        ],
        # Third channel
        [
            [0, 0, 0, 0, 1],
            [0, 0, 0, 1, 0],
            [0, 0, 1, 0, 0],
            [0, 1, 0, 0, 0],
            [1, 0, 0, 0, 0],
        ],
    ],
    [
        # Second kernel
        # First channel
        [
            [0.5, 0.5, 0.5, 0.5, 0.5], 
            [0.5, 0.5, 0.5, 0.5, 0.5], 
            [0.5, 0.5, 0.5, 0.5, 0.5], 
            [0.5, 0.5, 0.5, 0.5, 0.5],
            [0.5, 0.5, 0.5, 0.5, 0.5],
        ],
        # Second channel
        [
            [1, 2, 3, 4, 5],
            [1, 1, 1, 0, 0],
            [0, 1, 1, 1, 0],
            [0, 0, 1, 1, 1],
            [0, 1, 0, 1, 1],
        ],
        # Thrid channel
        [
            [0, 0, 1, 0, 0],
            [0, 0.5, 1, 0.5, 0],
            [1, 1, 1, 1, 1],
            [0, 0.5, 1, 0.5, 0],
            [0, 0, 1, 0, 0],
        ],
    ]
]) 
display(kernels.shape)
# px.imshow(kernels.swapaxes(1, 3) * 255, facet_col=0)

(2, 3, 5, 5)

In [6]:
images.shape

(4, 3, 32, 32)

In [7]:
views = sliding_window_view(images, (kernels.shape[2], kernels.shape[3]), (2, 3)).swapaxes(1, 3)
display(views.shape)
flatten_views = views.reshape(*views.shape[:3], -1)
display(flatten_views.shape)

(4, 28, 28, 3, 5, 5)

(4, 28, 28, 75)

In [8]:
flatten_kernels = kernels.reshape(kernels.shape[0], -1).T
flatten_kernels.shape

(75, 2)

In [9]:
reset_time_dict()
big_x_train_subset = x_train[:3000]

def my_valid_correlate(inputs:np.ndarray, kernels:np.ndarray) -> np.ndarray:
    with time_to_exec("compute flatten views"):
        views = sliding_window_view(inputs, kernels.shape[2:], range(2, inputs.ndim))
        views = views.swapaxes(1, 3)
        flatten_views = views.reshape(*views.shape[:3], -1)
    print("views shape:", flatten_views.shape)
    with time_to_exec("compute flatten kernels"):
        flatten_kernels = kernels.reshape(kernels.shape[0], -1).T
    with time_to_exec("compute correlations"):
        correlations = (flatten_views @ flatten_kernels).swapaxes(1, 3)
    return correlations

def my_tensordot_valid_correlation(inputs:np.ndarray, kernels:np.ndarray) -> np.ndarray:
    with time_to_exec("compute views"):
        views = sliding_window_view(inputs, kernels.shape[2:], range(2, inputs.ndim))
    with time_to_exec("compute tensordot correlations"):
        correlations = np.tensordot(views, kernels, ([1, 4, 5], [1, 2, 3])).transpose(0, 3, 1, 2)
    return correlations

my_correlations = my_valid_correlate(big_x_train_subset, kernels)
my_tensordot_correlations = my_tensordot_valid_correlation(big_x_train_subset, kernels)

print_time_dict()

views shape: (3000, 28, 28, 75)
compute flatten views: 0.5501611232757568s
compute flatten kernels: 1.71661376953125e-05s
compute correlations: 1.9372386932373047s
compute views: 0.0002963542938232422s
compute tensordot correlations: 2.064324140548706s


In [16]:
def scipy_valid_correlate(inputs:np.ndarray, kernels:np.ndarray) -> np.ndarray:
    scipy_correlations = []
    for img in inputs:
        img_corrs = []
        for kernel in kernels:
            img_corrs.append(sum([correlate2d(img_chan, kernels_chan, "valid") for img_chan, kernels_chan in zip(img, kernel)]))
        scipy_correlations.append(img_corrs)

    return np.stack(scipy_correlations)

scipy_correlations = scipy_valid_correlate(big_x_train_subset, kernels)

## Verify output

In [17]:
(my_correlations == scipy_correlations).mean()

np.float64(1.0)