In [23]:
import numpy as np
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
from mnist import MNIST
from tqdm import tqdm

In [2]:
mnist_loader = MNIST("../data/mnist/")
mnist_loader.gz = True

In [6]:
images, labels = mnist_loader.load_training()
timages, tlabels = mnist_loader.load_testing()

In [8]:
images = np.array(images)
labels = np.array(labels)
timages = np.array(timages)
tlabels = np.array(tlabels)

In [9]:
images.shape, timages.shape

((60000, 784), (10000, 784))

## cos sim

In [10]:
n_rows = images.shape[0]
t_n_rows = timages.shape[0]
batch_size = 2500
path = '../data/mnist/cos_sim.memmap'

In [12]:
fp = np.memmap(path, dtype='float32', mode='w+', shape=(n_rows + t_n_rows, n_rows))

In [13]:
for batch_index in tqdm(range(n_rows // batch_size)):
    start = batch_size * batch_index
    end = batch_size * (batch_index + 1)
    fp[start:end, :] = cosine_similarity(X=images[start:end], Y=images)

100%|██████████| 24/24 [04:11<00:00, 10.46s/it]


In [15]:
for batch_index in tqdm(range(t_n_rows // batch_size)):
    start = batch_size * batch_index
    end = batch_size * (batch_index + 1)
    fp_start = start + n_rows
    fp_end = end + n_rows
    fp[fp_start:fp_end, :] = cosine_similarity(X=timages[start:end], Y=images)

100%|██████████| 4/4 [00:35<00:00,  8.93s/it]


In [21]:
del fp

## euc dist

In [22]:
path = '../data/mnist/eud_dist.memmap'
fp = np.memmap(path, dtype='float32', mode='w+', shape=(n_rows + t_n_rows, n_rows))

In [24]:
for batch_index in tqdm(range(n_rows // batch_size)):
    start = batch_size * batch_index
    end = batch_size * (batch_index + 1)
    fp[start:end, :] = euclidean_distances(X=images[start:end], Y=images)

100%|██████████| 24/24 [06:59<00:00, 17.50s/it]


In [25]:
for batch_index in tqdm(range(t_n_rows // batch_size)):
    start = batch_size * batch_index
    end = batch_size * (batch_index + 1)
    fp_start = start + n_rows
    fp_end = end + n_rows
    fp[fp_start:fp_end, :] = euclidean_distances(X=timages[start:end], Y=images)

100%|██████████| 4/4 [01:19<00:00, 19.95s/it]


In [26]:
del fp