In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
import os

In [None]:
# Source: https://github.com/snatch59/load-cifar-10/blob/master/load_cifar_10_alt.py

def load_batch(file_path, label_key='labels'):
    """Load a batch of CIFAR data"""
    with open(file_path, 'rb') as f:
        d = pickle.load(f, encoding='bytes')
        # decode utf8
        d_decoded = {}
        for k, v in d.items():
            d_decoded[k.decode('utf8')] = v
        d = d_decoded
    data = d['data']
    labels = d[label_key]

    data = data.reshape(data.shape[0], 3, 32, 32).transpose(0, 2, 3, 1)
    return data, labels


def load_data(path):
    """Load CIFAR10 dataset"""
    num_train_samples = 50000

    x_train_local = np.empty((num_train_samples, 32, 32, 3), dtype='uint8')
    y_train_local = np.empty((num_train_samples,), dtype='uint8')

    for i in range(1, 6):
        batch_file_path = os.path.join(path, 'data_batch_' + str(i))
        (x_train_local[(i - 1) * 10000: i * 10000, :, :, :],
         y_train_local[(i - 1) * 10000: i * 10000]) = load_batch(batch_file_path)
        
    fpath = os.path.join(path, 'test_batch')
    x_test_local, y_test_local = load_batch(fpath)

    y_train_local = np.reshape(y_train_local, (len(y_train_local), 1))
    y_test_local = np.reshape(y_test_local, (len(y_test_local), 1))
    
    # x_train_local = x_train_local.transpose(0, 2, 3, 1)
    # x_test_local = x_test_local.transpose(0, 2, 3, 1)

    return (x_train_local, y_train_local), (x_test_local, y_test_local)

In [None]:
path = 'cifar-10-batches-py'
(x_train, y_train), (x_test, y_test) = load_data(path)

In [None]:
def filter_class(class_id):
    idx = (y_train == class_id).reshape(x_train.shape[0])
    x_train_filter = x_train[idx]
    y_train_filter = y_train[idx]
    
    idx = (y_test == class_id).reshape(x_test.shape[0])
    x_test_filter = x_test[idx]
    y_test_filter = y_test[idx]
    
    return (x_train_filter, y_train_filter), (x_test_filter, y_test_filter)

In [None]:
def rgb2gray(rgb):
    """Utility function for converting RGB numpy array to Grayscale"""
    return np.dot(rgb[...,:3], [0.2989, 0.5870, 0.1140])

In [None]:
from scipy.spatial.distance import euclidean, cityblock, cdist

# Doc: https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html
dist = np.linalg.norm(rgb2gray(x_test[0]).reshape(1, -1) - rgb2gray(x_test[1]).reshape(1, -1), ord=2)
print(dist)

dist = np.linalg.norm(rgb2gray(x_test[0]).reshape(1, -1) - rgb2gray(x_test[1]).reshape(1, -1), ord=1)
print(dist)

dist = euclidean(rgb2gray(x_test[0]).reshape(1, -1), rgb2gray(x_test[1]).reshape(1, -1))
print(dist)

dist = cityblock(rgb2gray(x_test[0]).reshape(1, -1), rgb2gray(x_test[1]).reshape(1, -1))
print(dist)

dist = cdist(rgb2gray(x_test[0]).reshape(1, -1), rgb2gray(x_test[1]).reshape(1, -1), metric='euclidean')
print(dist)

dist = cdist(rgb2gray(x_test[0]).reshape(1, -1), rgb2gray(x_test[1]).reshape(1, -1), metric='cityblock')
print(dist)

In [None]:
diff = rgb2gray(x_test[0]) - rgb2gray(x_test[1])
dist = np.sqrt(np.sum(diff ** 2))
print(dist)

diff = np.abs(rgb2gray(x_test[0]) - rgb2gray(x_test[1]))
dist = np.sum(diff)
print(dist)
