# Cross correlation implementation
In this notebook, we will implement the cross correlation of a batch of kernels with a batch of input images.  
Cross correlation (and convolution) are essentially dot products repeated over the input image.  
This is because at each correlation, the kernel and input image values correspond to a fixed "window index".  
This is a 1D index that make the kernel to the input image.  
We just need to compute that index and then repeat it for nb_correlations times with an offset in both the horizontal and vertical direction.  


## Setup

### Imports

In [2]:
from os.path import join

from scipy.signal import correlate2d
import plotly.subplots as sp
import numpy as np
import plotly.express as px
import kagglehub

from cifar_10_dataset_loading import load_cifar_10

### Data extraction

In [3]:
dataset_path = kagglehub.dataset_download("hojjatk/mnist-dataset")
train_image_path = join(dataset_path, 'train-images.idx3-ubyte')
train_labels_path = join(dataset_path, 'train-labels.idx1-ubyte')
test_image_path = join(dataset_path, 't10k-images.idx3-ubyte')
test_labels_path = join(dataset_path, 't10k-labels.idx1-ubyte')

def load_images(path) -> np.ndarray:
    with open(path, 'rb') as f:
        return (
            np.frombuffer(f.read(), dtype=np.uint8)
            [16:]
            .reshape(-1, 28**2)
            / 255
        )

def load_labels(path) -> np.ndarray:
    with open(path, 'rb') as f:
        label_idxs = np.frombuffer(f.read(), dtype=np.uint8)[8:]
        labels = np.eye(10)[label_idxs]
        return labels

train_dataset = load_images(train_image_path)
train_labels = load_labels(train_labels_path)
test_dataset = load_images(test_image_path)
test_labels = load_labels(test_labels_path)

In [4]:
px.imshow(train_dataset[0].reshape(28, 28), color_continuous_scale='Rainbow')

In [5]:
INPUT_SHAPE = (28, 28)
input = train_dataset[0]

## Single kernel, single input image, cross correlation implementation

First we declare are kernel, we will store it as a vectore this way we can use vector/matrix dot product to compute the cross correlation.

In [6]:
kernel = np.array([
    [1, 0, -1], 
    [1, 0, -1],
    [1, 0, -1],
])
kernel_weights = kernel.ravel()
kernel_weights

array([ 1,  0, -1,  1,  0, -1,  1,  0, -1])

First we will create the "window" index of the kernel.  
This is the index that will be slid over the input image to compute the cross correlation.

In [7]:
window_index = np.arange(kernel.shape[0])
window_index = np.tile(window_index, kernel.shape[1])
window_index += np.repeat(np.arange(kernel.shape[0]) * INPUT_SHAPE[0], kernel.shape[1])
window_index

array([ 0,  1,  2, 28, 29, 30, 56, 57, 58])

In [8]:
nb_horizontal_correlations = 1 + INPUT_SHAPE[0] - kernel.shape[0]
nb_vertical_correlations = 1 + INPUT_SHAPE[0] - kernel.shape[0]
nb_correlations = nb_horizontal_correlations * nb_vertical_correlations
nb_correlations

676

In [9]:
# Here we repeat the window index for each correlation and then add the horizontal and vertical offsets and store the result into a matrix.
# Each row corresponds to a correlation and each column corresponds to a window index.
correlation_indices = np.tile(window_index, (nb_correlations, 1))
# The offsets are computed the same way as the window index.
# We reshape the offsets by giving them an extra dimension of size 1 to brodcast the addition over the window indices.
horizontal_offsets = np.tile(np.arange(nb_horizontal_correlations), nb_vertical_correlations).reshape(-1, 1)
vertical_offsets = np.repeat(np.arange(nb_vertical_correlations) * INPUT_SHAPE[0], nb_vertical_correlations).reshape(-1, 1)
correlation_indices += horizontal_offsets + vertical_offsets

In [10]:
scipy_correlation = correlate2d(input.reshape(28, 28), kernel, mode="valid")
custom_correlation = (input[correlation_indices] @ kernel_weights).reshape(nb_horizontal_correlations, nb_vertical_correlations)

fig = px.imshow(np.asarray([scipy_correlation, custom_correlation]), facet_col=0)
fig.layout.annotations[0]['text'] = "Scipy Correlation"
fig.layout.annotations[1]['text'] = "Custom Correlation"
fig.show()

## Multiple kernels, single input image, cross correlation implementation

In [11]:
kernels = np.array([
    [
        [1, 0, -1], 
        [1, 0, -1],
        [1, 0, -1],
    ],
    [
        [1, 0, 1], 
        [0, 0, 0],
        [-1, 0, -1],
    ],
])
kernels_weights = kernels.reshape(2, -1).T
kernels_weights

array([[ 1,  1],
       [ 0,  0],
       [-1,  1],
       [ 1,  0],
       [ 0,  0],
       [-1,  0],
       [ 1, -1],
       [ 0,  0],
       [-1, -1]])

In [12]:
flatt_cross_correlation = input[correlation_indices] @ kernels_weights
flatt_cross_correlation.shape

(676, 2)

In [13]:
reshaped_cross_correlation = flatt_cross_correlation.reshape(nb_horizontal_correlations, nb_vertical_correlations, -1)
display(reshaped_cross_correlation.shape)
px.imshow(reshaped_cross_correlation[..., 1])

(26, 26, 2)

In [14]:
yes = np.matmul(input[correlation_indices], kernels_weights)
yes.shape

no = input[correlation_indices] @ kernels_weights

(yes == no).all()

np.True_

In [15]:
kernels_weights.shape


(9, 2)

In [16]:
custom_correlations = (input[correlation_indices] @ kernels_weights).reshape(nb_horizontal_correlations, nb_vertical_correlations, kernels.shape[0])
scipy_correlations = np.stack([correlate2d(input.reshape(28, 28), kernels[k], mode="valid") for k in range(kernels.shape[0])], axis=2)

display(np.concatenate([custom_correlations, scipy_correlations], axis=2).shape)

fig = px.imshow(
    np.concatenate([custom_correlations, scipy_correlations], axis=2),
    facet_col=2,
    facet_col_wrap=2,
)
fig.layout.annotations[0]['text'] = "Custom Correlation 0"
fig.layout.annotations[1]['text'] = "Custom Correlation 1"
fig.layout.annotations[2]['text'] = "Scipy Correlation 0"
fig.layout.annotations[3]['text'] = "Scipy Correlation 1"
fig.show()

(26, 26, 4)

## Multiple kernels, multiple channels, single input image, cross correlation implementation

In [17]:
multi_channels_kernels = np.asarray(
[
    [
        # First kernel
        # First channel
            [
            [1, 0, -1, 0, 1], 
            [1, 0, -1, 0, 1],
            [1, 0, -1, 0, 1],
            [1, 0, -1, 0, 1],
            [1, 0, -1, 0, 1],
        ],
        # Second channel
        [
            [1, 0, 1, 1, 1], 
            [1, 0, 1, 1, 1], 
            [1, 0, 1, 1, 1],
            [1, 0, 1, 1, 1],
            [1, 0, 1, 1, 1],
        ],
        # Third channel
        [
            [0, 0, 0, 0, 1],
            [0, 0, 0, 1, 0],
            [0, 0, 1, 0, 0],
            [0, 1, 0, 0, 0],
            [1, 0, 0, 0, 0],
        ],
    ],
    [
        # Second kernel
        # First channel
        [
            [0.5, 0.5, 0.5, 0.5, 0.5], 
            [0.5, 0.5, 0.5, 0.5, 0.5], 
            [0.5, 0.5, 0.5, 0.5, 0.5], 
            [0.5, 0.5, 0.5, 0.5, 0.5],
            [0.5, 0.5, 0.5, 0.5, 0.5],
        ],
        # Second channel
        [
            [1, 2, 3, 4, 5],
            [1, 1, 1, 0, 0],
            [0, 1, 1, 1, 0],
            [0, 0, 1, 1, 1],
            [0, 1, 0, 1, 1],
        ],
        # Thrid channel
        [
            [0, 0, 1, 0, 0],
            [0, 0.5, 1, 0.5, 0],
            [1, 1, 1, 1, 1],
            [0, 0.5, 1, 0.5, 0],
            [0, 0, 1, 0, 0],
        ],
    ]
]).swapaxes(1, 3)
multi_channels_kernels.shape

(2, 5, 5, 3)

In [27]:
# There is most likely a simpler way of doing this...
multi_channels_kernels_weights = multi_channels_kernels.reshape(2, -1).T
multi_channels_kernels_weights

array([[ 1. ,  0.5],
       [ 1. ,  1. ],
       [ 0. ,  0. ],
       [ 1. ,  0.5],
       [ 1. ,  1. ],
       [ 0. ,  0. ],
       [ 1. ,  0.5],
       [ 1. ,  0. ],
       [ 0. ,  1. ],
       [ 1. ,  0.5],
       [ 1. ,  0. ],
       [ 0. ,  0. ],
       [ 1. ,  0.5],
       [ 1. ,  0. ],
       [ 1. ,  0. ],
       [ 0. ,  0.5],
       [ 0. ,  2. ],
       [ 0. ,  0. ],
       [ 0. ,  0.5],
       [ 0. ,  1. ],
       [ 0. ,  0.5],
       [ 0. ,  0.5],
       [ 0. ,  1. ],
       [ 0. ,  1. ],
       [ 0. ,  0.5],
       [ 0. ,  0. ],
       [ 1. ,  0.5],
       [ 0. ,  0.5],
       [ 0. ,  1. ],
       [ 0. ,  0. ],
       [-1. ,  0.5],
       [ 1. ,  3. ],
       [ 0. ,  1. ],
       [-1. ,  0.5],
       [ 1. ,  1. ],
       [ 0. ,  1. ],
       [-1. ,  0.5],
       [ 1. ,  1. ],
       [ 1. ,  1. ],
       [-1. ,  0.5],
       [ 1. ,  1. ],
       [ 0. ,  1. ],
       [-1. ,  0.5],
       [ 1. ,  0. ],
       [ 0. ,  1. ],
       [ 0. ,  0.5],
       [ 1. ,  4. ],
       [ 0. ,

In [19]:
cifar_10_train_inputs, cifar_10_train_labels, cifar_10_test_inputs, cifar_10_test_labels = load_cifar_10()
display(cifar_10_train_inputs.shape)
multi_channels_input = cifar_10_train_inputs[0]
flatten_multi_channels_input = multi_channels_input.ravel()
px.imshow(multi_channels_input[:, ])

(50000, 32, 32, 3)

In [20]:
multi_channels_window_index = np.arange(multi_channels_kernels.shape[3])
multi_channels_window_index = np.tile(multi_channels_window_index, multi_channels_kernels.shape[2])
multi_channels_window_index += np.repeat(np.arange(multi_channels_kernels.shape[2]) * cifar_10_train_inputs.shape[3], multi_channels_kernels.shape[3])
multi_channels_window_index = np.tile(multi_channels_window_index, multi_channels_kernels.shape[1])
multi_channels_window_index += np.repeat(np.arange(multi_channels_kernels.shape[1]) * cifar_10_train_inputs.shape[2] * cifar_10_train_inputs.shape[3], multi_channels_kernels.shape[2] * multi_channels_kernels.shape[3])
display(multi_channels_window_index.reshape(5, 5, 3))

array([[[  0,   1,   2],
        [  3,   4,   5],
        [  6,   7,   8],
        [  9,  10,  11],
        [ 12,  13,  14]],

       [[ 96,  97,  98],
        [ 99, 100, 101],
        [102, 103, 104],
        [105, 106, 107],
        [108, 109, 110]],

       [[192, 193, 194],
        [195, 196, 197],
        [198, 199, 200],
        [201, 202, 203],
        [204, 205, 206]],

       [[288, 289, 290],
        [291, 292, 293],
        [294, 295, 296],
        [297, 298, 299],
        [300, 301, 302]],

       [[384, 385, 386],
        [387, 388, 389],
        [390, 391, 392],
        [393, 394, 395],
        [396, 397, 398]]])

In [21]:
px.imshow(flatten_multi_channels_input[multi_channels_window_index + (cifar_10_train_inputs.shape[3] * 20) + (cifar_10_train_inputs.shape[3] * cifar_10_train_inputs.shape[2] * 5)].reshape(5, 5, 3))

In [22]:
multi_channels_nb_horizontal_correlations = 1 + cifar_10_train_inputs.shape[1] - multi_channels_kernels.shape[1]
multi_channels_nb_vertical_correlations = 1 + cifar_10_train_inputs.shape[2] - multi_channels_kernels.shape[2]
mutli_channels_nb_correlations = multi_channels_nb_horizontal_correlations * multi_channels_nb_vertical_correlations
mutli_channels_nb_correlations

784

In [59]:
mutli_channels_x_offset_multiplicator = cifar_10_train_inputs.shape[3]
mutli_channels_y_offset_multiplicator = cifar_10_train_inputs.shape[3] * cifar_10_train_inputs.shape[2]
mutli_channels_horizontal_offsets = np.tile(np.arange(multi_channels_nb_horizontal_correlations) * mutli_channels_x_offset_multiplicator, multi_channels_nb_vertical_correlations).reshape(-1, 1)
multi_channels_vertical_offsets = np.repeat(np.arange(multi_channels_nb_vertical_correlations) * mutli_channels_y_offset_multiplicator, multi_channels_nb_vertical_correlations).reshape(-1, 1)
multi_channels_correlatin_indieces = np.tile(multi_channels_window_index, (mutli_channels_nb_correlations, 1)) + mutli_channels_horizontal_offsets + multi_channels_vertical_offsets

multi_channels_correlatin_indieces

array([[   0,    1,    2, ...,  396,  397,  398],
       [   3,    4,    5, ...,  399,  400,  401],
       [   6,    7,    8, ...,  402,  403,  404],
       ...,
       [2667, 2668, 2669, ..., 3063, 3064, 3065],
       [2670, 2671, 2672, ..., 3066, 3067, 3068],
       [2673, 2674, 2675, ..., 3069, 3070, 3071]], shape=(784, 75))

In [55]:
multi_channels_correlatin_indieces.reshape(-1, 5, 5, 3)[:4, ...].shape

(4, 5, 5, 3)

In [64]:
px.imshow(
    flatten_multi_channels_input[multi_channels_correlatin_indieces].reshape(-1, 5, 5, 3)[:56, ...],
    facet_col=0,
    facet_col_wrap=28,
    height=1000,
    facet_col_spacing=0.,
    labels={"facet_col": ""}
)

In [65]:
multi_channels_flatten_correlation = flatten_multi_channels_input[multi_channels_correlatin_indieces] @ multi_channels_kernels_weights
multi_channels_flatten_correlation.shape

(784, 2)

In [76]:
multi_channels_kernels.shape

(2, 5, 5, 3)

In [77]:
for kernel in multi_channels_kernels:
    print(kernel)
    for kernel_channel in kernel:
        print(kernel_channel)
        print("==========")

[[[ 1.  1.  0.]
  [ 1.  1.  0.]
  [ 1.  1.  0.]
  [ 1.  1.  0.]
  [ 1.  1.  1.]]

 [[ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  1.]
  [ 0.  0.  0.]]

 [[-1.  1.  0.]
  [-1.  1.  0.]
  [-1.  1.  1.]
  [-1.  1.  0.]
  [-1.  1.  0.]]

 [[ 0.  1.  0.]
  [ 0.  1.  1.]
  [ 0.  1.  0.]
  [ 0.  1.  0.]
  [ 0.  1.  0.]]

 [[ 1.  1.  1.]
  [ 1.  1.  0.]
  [ 1.  1.  0.]
  [ 1.  1.  0.]
  [ 1.  1.  0.]]]
[[1. 1. 0.]
 [1. 1. 0.]
 [1. 1. 0.]
 [1. 1. 0.]
 [1. 1. 1.]]
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 1.]
 [0. 0. 0.]]
[[-1.  1.  0.]
 [-1.  1.  0.]
 [-1.  1.  1.]
 [-1.  1.  0.]
 [-1.  1.  0.]]
[[0. 1. 0.]
 [0. 1. 1.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]]
[[1. 1. 1.]
 [1. 1. 0.]
 [1. 1. 0.]
 [1. 1. 0.]
 [1. 1. 0.]]
[[[0.5 1.  0. ]
  [0.5 1.  0. ]
  [0.5 0.  1. ]
  [0.5 0.  0. ]
  [0.5 0.  0. ]]

 [[0.5 2.  0. ]
  [0.5 1.  0.5]
  [0.5 1.  1. ]
  [0.5 0.  0.5]
  [0.5 1.  0. ]]

 [[0.5 3.  1. ]
  [0.5 1.  1. ]
  [0.5 1.  1. ]
  [0.5 1.  1. ]
  [0.5 0.  1. ]]

 [[0.5 4.  0. ]
  [0.5 0

In [72]:
reshaped_mutli_channels_correlation = multi_channels_flatten_correlation.reshape(multi_channels_nb_horizontal_correlations, multi_channels_nb_vertical_correlations, -1, )
scipy_mutli_channels_cross_correlatin = np.stack([np.stack([correlate2d(multi_channels_input, kernel_channel, mode="valid") for kernel_channel in kernel]) for kernel in multi_channels_kernels])
px.imshow(
    reshaped_mutli_channels_correlation,
    facet_col=2
)

ValueError: correlate2d inputs must both be 2-D arrays