In [1]:
from timeit import timeit

import numpy as np
from numpy import ndarray
from scipy.signal import convolve2d, correlate2d
from numpy.lib.stride_tricks import sliding_window_view as sliding_views

from numba import njit, prange
from cifar_10_dataset_loading import load_cifar_10

In [2]:
x, y, x_test, y_test = load_cifar_10()

In [3]:
kernels = np.random.rand(32, 7, 7, 3)

In [4]:
def my_valid_correlate(inputs:ndarray, k:ndarray) -> ndarray:
    views = sliding_views(inputs, k.shape[1:3], (1, 2))
    correlations = np.tensordot(views, k, axes=([3, 4, 5], [3, 1, 2]))
    return correlations

def my_full_convolve(inputs:ndarray, k:ndarray) -> ndarray:
    pad = ((0, 0), (k.shape[1]-1, k.shape[1]-1), (k.shape[2]-1, k.shape[2]-1), (0, 0))
    return my_valid_correlate(np.pad(inputs, pad), np.flip(k, (1, 2)))

In [5]:
inputs = x[:10000]

In [6]:
%timeit my_correlations = my_valid_correlate(inputs, kernels)

4.48 s ± 224 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
# 1st implementation

@njit(parallel=True)
def numba_valid_correlate(inputs: np.ndarray, k: np.ndarray) -> np.ndarray:
    """Optimized correlation using Numba."""
    batch, in_h, in_w, in_c = inputs.shape
    out_c, k_h, k_w, _ = k.shape
    out_h, out_w = in_h - k_h + 1, in_w - k_w + 1
    output = np.zeros((batch, out_h, out_w, out_c))

    for b in prange(batch):  # Parallel over batches
        for i in prange(out_h):  # Parallel over height
            for j in prange(out_w):  # Parallel over width
                for c in prange(out_c):  # Parallel over output channels
                    for di in range(k_h):  # Kernel height
                        for dj in range(k_w):  # Kernel width
                            for dc in range(in_c):  # Kernel depth
                                output[b, i, j, c] += inputs[b, i + di, j + dj, dc] * k[c, di, dj, dc]
    return output


# 2nd implementation
@njit
def sliding_views_numba(inputs, k_h, k_w):
    """
    Create sliding window views of the input.
    inputs: (N, H, W, C)
    Returns: array of shape (N, out_H, out_W, k_h, k_w, C)
    where out_H = H - k_h + 1 and out_W = W - k_w + 1.
    """
    N, H, W, C = inputs.shape
    out_H = H - k_h + 1
    out_W = W - k_w + 1
    views = np.empty((N, out_H, out_W, k_h, k_w, C), dtype=inputs.dtype)
    for n in range(N):
        for i in range(out_H):
            for j in range(out_W):
                for ki in range(k_h):
                    for kj in range(k_w):
                        for c in range(C):
                            views[n, i, j, ki, kj, c] = inputs[n, i + ki, j + kj, c]
    return views

@njit
def numba_valid_correlate_2(inputs, kernels):
    """
    Compute valid correlation (without flipping kernels) using explicit loops.
    inputs: (N, H, W, C)
    kernels: (num_kernels, k_h, k_w, C)
    Returns: output of shape (N, out_H, out_W, num_kernels)
    where out_H = H - k_h + 1 and out_W = W - k_w + 1.
    """
    N, H, W, C = inputs.shape
    num_kernels, k_h, k_w, _ = kernels.shape
    out_H = H - k_h + 1
    out_W = W - k_w + 1
    # Get sliding windows from the input
    views = sliding_views_numba(inputs, k_h, k_w)  # shape (N, out_H, out_W, k_h, k_w, C)
    # Prepare output array
    out = np.empty((N, out_H, out_W, num_kernels), dtype=inputs.dtype)
    # For each sample, spatial location, and kernel, compute the correlation.
    for n in range(N):
        for i in range(out_H):
            for j in range(out_W):
                for k in range(num_kernels):
                    s = 0.0
                    for ki in range(k_h):
                        for kj in range(k_w):
                            for c in range(C):
                                s += views[n, i, j, ki, kj, c] * kernels[k, ki, kj, c]
                    out[n, i, j, k] = s
    return out


warm up

In [8]:
numba_correlations = numba_valid_correlate(inputs, kernels)

In [9]:
%timeit numba_correlations = numba_valid_correlate(inputs, kernels)

8.11 s ± 217 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


warm up

In [None]:
numba_correlations2 = numba_valid_correlate_2(inputs, kernels)

In [None]:
numba_correlations2 = numba_valid_correlate_2(inputs, kernels)