# Cross correlation implementation
In this notebook, we will implement the cross correlation of a batch of kernels with a batch of input images.  
Cross correlation (and convolution) are essentially dot products repeated over the input image.  
This is because at each correlation, the kernel and input image values correspond to a fixed "window index".  
This is a 1D index that make the kernel to the input image.  
We just need to compute that index and then repeat it for nb_correlations times with an offset in both the horizontal and vertical direction.  


## Setup

### Imports

In [83]:
from os.path import join

from scipy.signal import correlate2d
import plotly.subplots as sp
import numpy as np
import plotly.express as px
import kagglehub

### Data extraction

In [2]:
dataset_path = kagglehub.dataset_download("hojjatk/mnist-dataset")
train_image_path = join(dataset_path, 'train-images.idx3-ubyte')
train_labels_path = join(dataset_path, 'train-labels.idx1-ubyte')
test_image_path = join(dataset_path, 't10k-images.idx3-ubyte')
test_labels_path = join(dataset_path, 't10k-labels.idx1-ubyte')

def load_images(path) -> np.ndarray:
    with open(path, 'rb') as f:
        return (
            np.frombuffer(f.read(), dtype=np.uint8)
            [16:]
            .reshape(-1, 28**2)
            / 255
        )

def load_labels(path) -> np.ndarray:
    with open(path, 'rb') as f:
        label_idxs = np.frombuffer(f.read(), dtype=np.uint8)[8:]
        labels = np.eye(10)[label_idxs]
        return labels

train_dataset = load_images(train_image_path)
train_labels = load_labels(train_labels_path)
test_dataset = load_images(test_image_path)
test_labels = load_labels(test_labels_path)

In [9]:
px.imshow(train_dataset[0].reshape(28, 28), color_continuous_scale='Rainbow')

In [27]:
INPUT_SHAPE = (28, 28)
input = train_dataset[0]

## Single kernel, single input image, cross correlation implementation

First we declare are kernel, we will store it as a vectore this way we can use vector/matrix dot product to compute the cross correlation.

In [18]:
kernel = np.array([
    [1, 0, -1], 
    [1, 0, -1],
    [1, 0, -1],
])
kernel_weights = kernel.ravel()
kernel_weights

array([ 1,  0, -1,  1,  0, -1,  1,  0, -1])

First we will create the "window" index of the kernel.  
This is the index that will be slid over the input image to compute the cross correlation.

In [30]:
window_index = np.arange(kernel.shape[0])
window_index = np.tile(window_index, kernel.shape[1])
window_index += np.repeat(np.arange(kernel.shape[0]) * INPUT_SHAPE[0], kernel.shape[1])
window_index

array([ 0,  1,  2, 28, 29, 30, 56, 57, 58])

In [49]:
nb_horizontal_correlations = 1 + INPUT_SHAPE[0] - kernel.shape[0]
nb_vertical_correlations = 1 + INPUT_SHAPE[0] - kernel.shape[0]
nb_correlations = nb_horizontal_correlations * nb_vertical_correlations
nb_correlations

676

In [66]:
# Here we repeat the window index for each correlation and then add the horizontal and vertical offsets and store the result into a matrix.
# Each row corresponds to a correlation and each column corresponds to a window index.
correlation_indices = np.tile(window_index, (nb_correlations, 1))
# The offsets are computed the same way as the window index.
# We reshape the offsets by giving them an extra dimension of size 1 to brodcast the addition over the window indices.
horizontal_offsets = np.tile(np.arange(nb_horizontal_correlations), nb_vertical_correlations).reshape(-1, 1)
vertical_offsets = np.repeat(np.arange(nb_vertical_correlations) * INPUT_SHAPE[0], nb_vertical_correlations).reshape(-1, 1)
correlation_indices += horizontal_offsets + vertical_offsets

In [101]:
scipy_correlation = correlate2d(input.reshape(28, 28), kernel, mode="valid")
custom_correlation = (input[correlation_indices] @ kernel_weights).reshape(nb_horizontal_correlations, nb_vertical_correlations)

fig = px.imshow(np.asarray([scipy_correlation, custom_correlation]), facet_col=0)
fig.layout.annotations[0]['text'] = "Scipy Correlation"
fig.layout.annotations[1]['text'] = "Custom Correlation"
fig.show()