# Cross correlation and convolution implementation from scratch

## Motivation
For our convolution layers, we need to perform:
- a valid cross correlation for the forward pass  
- a full convolutions for the backward pass  

I initially meant to use `scipy.signal`'s `correlated2d` and `convolve2d` but unfortunatly they only work for a single channel, kernel and sample.  
This is unconvinient as I need a mutli samples, kernels and channels operation.  
I could wrap them around with three nested `for` loops but this would be a very slow implementation.  
This is a "from scratch" repo afterall so let's implement a mutli channels, kernels and samples valid cross correlation operation from scratch.  
I will simply implement the full convolution operation as a pad of the input, a transpose of the kernels and then a valid cross correlation.   

## High level implementation  
The valid cross correlation is essentially a vector dot product of the kernel and a subset of the input repeated over the x and y axes.  
The reason why this can be simplfied to a repeated vector dot product is beacause the cross correlation of the view and the kernel have fixed matching incdices:  
the ith element of the kernel matches the ith element of the view.  
Once we have the flatten input and the window index, we can get a single view by indexing the flatten input by the window index.  
With this view we can perform a dot product with the kernel to get a single correlation (a single value of our activation map output).  
To get the next view on the right we simply need to add the input depth to all the indices in the window index.  
To get the next view bellow we simply need to add the input depth times the horizontal length of the input to all the indices in the window index.  
To get all the views we will create matrix wehere each row corresponds to a view.  
The kernels will be in a second matrix where each column corresponds to a kernel.   
To perform the cross correlation we will simply perform a dot product of the two matrices.  
That's for a single input, to perform the operation on a batch of inputs we perform a (broadcasted?) dot product between the kenrel matric and a 3d array.  
In this 3d array, the first dimension is input, the second the view and the third the value in the view.  

## Setup

### Imports

In [1]:
from os.path import join

from scipy.signal import correlate2d, convolve2d
import numpy as np
import plotly.express as px
import kagglehub

from cifar_10_dataset_loading import load_cifar_10

### Data extraction

#### MNIST dataset for single channel implementations

In [2]:
dataset_path = kagglehub.dataset_download("hojjatk/mnist-dataset")
train_image_path = join(dataset_path, 'train-images.idx3-ubyte')
train_labels_path = join(dataset_path, 'train-labels.idx1-ubyte')
test_image_path = join(dataset_path, 't10k-images.idx3-ubyte')
test_labels_path = join(dataset_path, 't10k-labels.idx1-ubyte')

def load_images(path) -> np.ndarray:
    with open(path, 'rb') as f:
        return (
            np.frombuffer(f.read(), dtype=np.uint8)
            [16:]
            .reshape(-1, 28**2)
            / 255
        )

train_dataset = load_images(train_image_path)



In [3]:
px.imshow(train_dataset[0].reshape(28, 28))

In [4]:
INPUT_SHAPE = (28, 28)
input = train_dataset[0]

#### Cifar 10 for multi channels implementations

In [5]:
cifar_10_train_inputs, _, _, _ = load_cifar_10()

## Single kernel, input image and channel cross correlation implementation

First we declare a kernel, we will store it as a vectore this way we can use vector/matrix dot product to compute the cross correlations "all at once".

In [6]:
kernel = np.array([
    [1, 0, -1], 
    [1, 0, -1],
    [1, 0, -1],
])
kernel_weights = kernel.ravel()
kernel_weights

array([ 1,  0, -1,  1,  0, -1,  1,  0, -1])

First we will create the window index of the kernel.  
This is the index that will be slid over the input image to compute the cross correlation.

In [7]:
window_index = np.arange(kernel.shape[0])
window_index = np.tile(window_index, kernel.shape[1])
window_index += np.repeat(np.arange(kernel.shape[0]) * INPUT_SHAPE[0], kernel.shape[1])
window_index

array([ 0,  1,  2, 28, 29, 30, 56, 57, 58])

In [8]:
nb_horizontal_correlations = 1 + INPUT_SHAPE[0] - kernel.shape[0]
nb_vertical_correlations = 1 + INPUT_SHAPE[0] - kernel.shape[0]
nb_correlations = nb_horizontal_correlations * nb_vertical_correlations
nb_correlations

676

In [9]:
# Here we repeat the window index for each correlation and then add the horizontal and vertical offsets and store the result into a matrix.
# Each row corresponds to a correlation and each column corresponds to a window index.
# The offsets are computed the same way as the window index.
# We reshape the offsets by giving them an extra dimension of size 1 to brodcast the addition over the window indices.
horizontal_offsets = np.tile(np.arange(nb_horizontal_correlations), nb_vertical_correlations).reshape(-1, 1)
vertical_offsets = np.repeat(np.arange(nb_vertical_correlations) * INPUT_SHAPE[0], nb_vertical_correlations).reshape(-1, 1)
correlation_indices = np.tile(window_index, (nb_correlations, 1)) + horizontal_offsets + vertical_offsets

### result visualization and comparaison with `scipy.signal`'s `correlated2d`

In [10]:
scipy_correlation = correlate2d(input.reshape(28, 28), kernel, mode="valid")
custom_correlation = (input[correlation_indices] @ kernel_weights).reshape(nb_horizontal_correlations, nb_vertical_correlations)

fig = px.imshow(np.asarray([scipy_correlation, custom_correlation]), facet_col=0)
fig.layout.annotations[0]['text'] = "Scipy Correlation"
fig.layout.annotations[1]['text'] = "Custom Correlation"
fig.show()

## Multiple kernels, single input image and channel cross correlation implementation  
The only difference with the previous implementation is that we perform a matrix/matrix dot product instead of a matrix/vector.   
The kernel array has been upgraded to a matrix where each column corresponds to a kernel.   

In [11]:
kernels = np.array([
    # First kernel
    [
        [1, 0, -1], 
        [1, 0, -1],
        [1, 0, -1],
    ],
    # Second kernel
    [
        [1, 0, 1], 
        [0, 0, 0],
        [-1, 0, -1],
    ],
])
kernels_weights = kernels.reshape(2, -1).T
kernels_weights

array([[ 1,  1],
       [ 0,  0],
       [-1,  1],
       [ 1,  0],
       [ 0,  0],
       [-1,  0],
       [ 1, -1],
       [ 0,  0],
       [-1, -1]])

In [12]:
# Here we can use the same correlation indices as the previous implementation which pretty convinient
flatt_cross_correlation = input[correlation_indices] @ kernels_weights
flatt_cross_correlation.shape

(676, 2)

### Results visualization and comparaison with `scipy.signal`'s `correlated2d`

In [13]:
custom_correlations = (input[correlation_indices] @ kernels_weights).reshape(nb_horizontal_correlations, nb_vertical_correlations, kernels.shape[0])
scipy_correlations = np.stack([correlate2d(input.reshape(28, 28), kernels[k], mode="valid") for k in range(kernels.shape[0])], axis=2)

display(np.concatenate([custom_correlations, scipy_correlations], axis=2).shape)

fig = px.imshow(
    np.concatenate([custom_correlations, scipy_correlations], axis=2),
    facet_col=2,
    facet_col_wrap=2,
)
fig.layout.annotations[0]['text'] = "Custom Correlation 0"
fig.layout.annotations[1]['text'] = "Custom Correlation 1"
fig.layout.annotations[2]['text'] = "Scipy Correlation 0"
fig.layout.annotations[3]['text'] = "Scipy Correlation 1"
fig.show()

(26, 26, 4)

## Multiple kernels and channels, single input image, cross correlation implementation
Here we need to update the correlation indices because of the new depth dimension.  

In [14]:
multi_channels_kernels = np.asarray(
[
    [
        # First kernel
        # First channel
            [
            [1, 0, -1, 0, 1], 
            [1, 0, -1, 0, 1],
            [1, 0, -1, 0, 1],
            [1, 0, -1, 0, 1],
            [1, 0, -1, 0, 1],
        ],
        # Second channel
        [
            [1, 0, 1, 1, 1], 
            [1, 0, 1, 1, 1], 
            [1, 0, 1, 1, 1],
            [1, 0, 1, 1, 1],
            [1, 0, 1, 1, 1],
        ],
        # Third channel
        [
            [0, 0, 0, 0, 1],
            [0, 0, 0, 1, 0],
            [0, 0, 1, 0, 0],
            [0, 1, 0, 0, 0],
            [1, 0, 0, 0, 0],
        ],
    ],
    [
        # Second kernel
        # First channel
        [
            [0.5, 0.5, 0.5, 0.5, 0.5], 
            [0.5, 0.5, 0.5, 0.5, 0.5], 
            [0.5, 0.5, 0.5, 0.5, 0.5], 
            [0.5, 0.5, 0.5, 0.5, 0.5],
            [0.5, 0.5, 0.5, 0.5, 0.5],
        ],
        # Second channel
        [
            [1, 2, 3, 4, 5],
            [1, 1, 1, 0, 0],
            [0, 1, 1, 1, 0],
            [0, 0, 1, 1, 1],
            [0, 1, 0, 1, 1],
        ],
        # Thrid channel
        [
            [0, 0, 1, 0, 0],
            [0, 0.5, 1, 0.5, 0],
            [1, 1, 1, 1, 1],
            [0, 0.5, 1, 0.5, 0],
            [0, 0, 1, 0, 0],
        ],
    ]
])
# Above we have declared the kernels with a shape (nb_kernels, nb_channels, width, height) but the input are usually of the shape (nb_inputs, width, height, nb_channels)
# To account for this we simply need to swap the width  and depth axes of the kernels
multi_channels_kernels = multi_channels_kernels.swapaxes(1, 3)
multi_channels_kernels.shape

(2, 5, 5, 3)

In [15]:
# There is most likely a simpler way of doing this...
multi_channels_kernels_weights = multi_channels_kernels.reshape(2, -1).T
multi_channels_kernels_weights.shape

(75, 2)

For testing I have chosen this mestirious red and purple eyes frog(however decided to add frogs to the dataset is a genious).  

In [16]:
display(cifar_10_train_inputs.shape)
multi_channels_input = cifar_10_train_inputs[0]
flatten_multi_channels_input = multi_channels_input.ravel()
px.imshow(multi_channels_input[:, ])

(50000, 32, 32, 3)

In [19]:
window_idx = np.arange(multi_channels_kernels.shape[3])
window_idx = np.tile(window_idx, multi_channels_kernels.shape[2])
window_idx += np.repeat(np.arange(multi_channels_kernels.shape[2]) * cifar_10_train_inputs.shape[3], multi_channels_kernels.shape[3])
window_idx = np.tile(window_idx, multi_channels_kernels.shape[1])
window_idx += np.repeat(np.arange(multi_channels_kernels.shape[1]) * cifar_10_train_inputs.shape[2] * cifar_10_train_inputs.shape[3], multi_channels_kernels.shape[2] * multi_channels_kernels.shape[3])
display(window_idx.reshape(5, 5, 3))

array([[[  0,   1,   2],
        [  3,   4,   5],
        [  6,   7,   8],
        [  9,  10,  11],
        [ 12,  13,  14]],

       [[ 96,  97,  98],
        [ 99, 100, 101],
        [102, 103, 104],
        [105, 106, 107],
        [108, 109, 110]],

       [[192, 193, 194],
        [195, 196, 197],
        [198, 199, 200],
        [201, 202, 203],
        [204, 205, 206]],

       [[288, 289, 290],
        [291, 292, 293],
        [294, 295, 296],
        [297, 298, 299],
        [300, 301, 302]],

       [[384, 385, 386],
        [387, 388, 389],
        [390, 391, 392],
        [393, 394, 395],
        [396, 397, 398]]])

### Viewing a the purple eye by adding ofssets to the window indices
By hovering on the right eye of the frog in the above plot we can see that the eye is roughly at x:20 and y:5.  
Below I demonstrate that by adding the correct offsets we can move the view to that eye.  

In [38]:
x_offset = cifar_10_train_inputs.shape[3] * 20
y_offset = cifar_10_train_inputs.shape[3] * cifar_10_train_inputs.shape[2] * 5
px.imshow(flatten_multi_channels_input[window_idx + x_offset + y_offset].reshape(5, 5, 3))

In [39]:
def compute_correlation_inidces(input_shape:np.ndarray, kernel_shape:np.ndarray) -> np.ndarray:
    window_idx = np.arange(kernel_shape[2])
    window_idx = np.tile(window_idx, kernel_shape[1])
    window_idx += np.repeat(np.arange(kernel_shape[1]) * input_shape[2], kernel_shape[2])
    window_idx = np.tile(window_idx, kernel_shape[0])
    window_idx += np.repeat(np.arange(kernel_shape[0]) * input_shape[1] * input_shape[2], kernel_shape[1] * kernel_shape[2])

    nb_x_correlations = 1 + input_shape[0] - kernel_shape[0]
    nb_y_correlations = 1 + input_shape[1] - kernel_shape[1]
    total_nb_correlations = nb_x_correlations * nb_y_correlations

    x_offset_multiplicator = input_shape[2]
    y_offset_multiplicator = input_shape[2] * input_shape[1]
    x_offsets = np.tile(np.arange(nb_x_correlations) * x_offset_multiplicator, nb_y_correlations).reshape(-1, 1)
    y_offsets = np.repeat(np.arange(nb_y_correlations) * y_offset_multiplicator, nb_y_correlations).reshape(-1, 1)
    correlation_indieces = np.tile(window_idx, (total_nb_correlations, 1)) + x_offsets + y_offsets

    return nb_x_correlations, nb_y_correlations, correlation_indieces

nb_x_correlations, nb_y_correlations, multi_channels_correlatin_indieces = compute_correlation_inidces(multi_channels_input.shape, multi_channels_kernels.shape[1:])

In [40]:
multi_channels_correlatin_indieces.reshape(-1, 5, 5, 3)[:4, ...].shape

(4, 5, 5, 3)

### Visualizing the window "moving" over the first two rows of correlation of the image
Feel free to compare them to the full image plot.

In [41]:
px.imshow(
    flatten_multi_channels_input[multi_channels_correlatin_indieces].reshape(-1, 5, 5, 3)[:nb_horizontal_correlations * 2, ...],
    facet_col=0,
    facet_col_wrap=28,
    height=1000,
    facet_col_spacing=0.,
)

### Computing the correlation

In [42]:
multi_channels_flatten_correlation = flatten_multi_channels_input[multi_channels_correlatin_indieces] @ multi_channels_kernels_weights
multi_channels_flatten_correlation.shape
reshaped_mutli_channels_correlation = (
    multi_channels_flatten_correlation
    .reshape(
        nb_x_correlations,
        nb_y_correlations,
        multi_channels_kernels.shape[0],
    )
)

### Results visualization and comparaison with `scipy.signal`'s `correlate2d`

In [43]:
def cross_correlate_multi_channel_2d(multi_channels_input:np.ndarray, multi_channels_kernels:np.ndarray) -> np.ndarray:
    activation_maps = []
    for kernel in multi_channels_kernels.swapaxes(3, 1):
        activation_channel_maps = [correlate2d(input_channel, kernel_channel, mode="valid") for input_channel, kernel_channel in zip(multi_channels_input.swapaxes(2, 0), kernel)]
        activation_channel_maps = np.stack(activation_channel_maps)
        multi_channel_scipy_cross_correlation = np.sum(activation_channel_maps, axis=0)
        activation_maps.append(multi_channel_scipy_cross_correlation.T)
    return np.stack(activation_maps)
activation_maps = cross_correlate_multi_channel_2d(multi_channels_input, multi_channels_kernels)
display(activation_maps.shape)
display("scipy")
px.imshow(activation_maps, facet_col=0).show()
display("custom")
px.imshow(
    reshaped_mutli_channels_correlation,
    facet_col=2
)

(2, 28, 28)

'scipy'

'custom'

## Multi kernels, channels, inputs cross correlation
Here we simply need to modify the flatten input by upgrading it to a 3D array.

Let's add another forg to the input.  
However added frogs to the dataset is a genius.  

In [44]:
inputs = cifar_10_train_inputs[[0, 351]]
px.imshow(inputs, facet_col=0)

In [45]:
flatten_inputs = inputs.reshape(2, -1)
multi_inputs_views = flatten_inputs[:, multi_channels_correlatin_indieces]
multi_inputs_views.shape

(2, 784, 75)

In [46]:
cross_correlations = (multi_inputs_views @ multi_channels_kernels_weights).reshape(2, 28, 28, 2)
cross_correlations.shape

(2, 28, 28, 2)

In [47]:
display("custom")
px.imshow(cross_correlations[1], facet_col=2).show()
display("scipy")
scipy_cross_correlations = np.stack([cross_correlate_multi_channel_2d(input, multi_channels_kernels) for input in inputs])
px.imshow(scipy_cross_correlations[1], facet_col=0)

'custom'

'scipy'

## Multi kernels, channels and inputs full convolution
Here we just need to pad the input and transpose the kernels, then we can perform a valid dcross correlation between the padded input and the transposed kernels.

In [48]:
pad_width = (
    (0, 0), # We don't want to pad along the "sample dimension" as this would add samples to the batch
    (multi_channels_kernels.shape[1] - 1, multi_channels_kernels.shape[1] - 1), # pad along the x axis
    (multi_channels_kernels.shape[2] - 1, multi_channels_kernels.shape[2] - 1), # pad along the y axis
    (0, 0), # Don't pad along the 
)
padded_inputs = np.pad(inputs, pad_width, mode="constant")
display(padded_inputs.shape)
px.imshow(padded_inputs, facet_col=0)

(2, 40, 40, 3)

In [50]:
padded_inputs.shape

(2, 40, 40, 3)

In [51]:
def compute_correlation_inidces(input_shape:np.ndarray, kernel_shape:np.ndarray) -> np.ndarray:
    window_idx = np.arange(kernel_shape[2])
    window_idx = np.tile(window_idx, kernel_shape[1])
    window_idx += np.repeat(np.arange(kernel_shape[1]) * input_shape[2], kernel_shape[2])
    window_idx = np.tile(window_idx, kernel_shape[0])
    window_idx += np.repeat(np.arange(kernel_shape[0]) * input_shape[1] * input_shape[2], kernel_shape[1] * kernel_shape[2])

    nb_x_correlations = 1 + input_shape[0] - kernel_shape[0]
    nb_y_correlations = 1 + input_shape[1] - kernel_shape[1]
    total_nb_correlations = nb_x_correlations * nb_y_correlations

    x_offset_multiplicator = input_shape[2]
    y_offset_multiplicator = input_shape[2] * input_shape[1]
    x_offsets = np.tile(np.arange(nb_x_correlations) * x_offset_multiplicator, nb_y_correlations).reshape(-1, 1)
    y_offsets = np.repeat(np.arange(nb_y_correlations) * y_offset_multiplicator, nb_y_correlations).reshape(-1, 1)
    correlation_indieces = np.tile(window_idx, (total_nb_correlations, 1)) + x_offsets + y_offsets

    return nb_x_correlations, nb_y_correlations, correlation_indieces

nb_x_correlations, nb_y_correlations, convolution_indices = compute_correlation_inidces(padded_inputs.shape[1:], transposed_kernels.shape[1:])

In [52]:
flatten_padded_inputs = padded_inputs.reshape(2, -1)
padded_inputs_views = flatten_padded_inputs[:, convolution_indices]

(2, 5, 5, 3)

In [63]:
line_idx = 8
px.imshow(
    flatten_padded_inputs[0, convolution_indices].reshape(-1, multi_channels_kernels.shape[1], multi_channels_kernels.shape[2], 3)[nb_x_correlations * line_idx:nb_x_correlations * (line_idx + 2)],
    facet_col=0,
    facet_col_wrap=nb_x_correlations
)

In [54]:
def scipy_convolution(inputs:np.ndarray, kernels:np.ndarray) -> np.ndarray:
    activation_maps = []
    kernels = kernels.swapaxes(1, 3)
    inputs = inputs.swapaxes(1, 3)
    for kernel in kernels:
        for input in inputs:
            channel_activation_maps = [convolve2d(input_channel, kernel_channel, mode="full") for input_channel, kernel_channel in zip(kernel, input)]
            activation_map = np.sum(channel_activation_maps, axis=0).T
            activation_maps.append(activation_map)
    return np.stack(activation_maps)

scipy_convolutions = scipy_convolution(inputs, multi_channels_kernels)
px.imshow(
    scipy_convolutions[[1, 3]],
    facet_col=0,
    facet_col_wrap=2,
)

In [67]:
transposed_kernels_weights = np.flip(multi_channels_kernels, axis=(1, 2)).reshape(2, -1).T
padded_views = flatten_padded_inputs[:, convolution_indices]
flatten_convolutions = padded_views @ transposed_kernels_weights
convolutions = flatten_convolutions.reshape(2, nb_x_correlations, nb_y_correlations, 2)

px.imshow(convolutions[1], facet_col=2)