In [108]:
import faiss
exec(open("../database.py").read())
exec(open("../helperFunctions.py").read())
import numpy as np
from keras.datasets import mnist
from time import perf_counter
from ast import literal_eval

In [92]:
name = 'mnist-784'
runs = 1
metric = 'euclidean'
queries = 1000
nameFull = name + '-' + metric + '-true-labels.xlsx'
datasetTrainImages, datasetTestImages, _ = get_ann_benchmark_data(name)

trainDataset :  (60000, 784)
testDataset :  (10000, 784)


***Create faiss***

In [93]:
def createIndex(indexMethod, datasetImages):
    d = datasetImages.shape[1] # dimension
    M = 28  # Number of subquantizers (bits per vector)
    nbits = 7 # Total number of bits for the PQ code
    time_start = perf_counter()
    index = indexMethod(d, M, nbits, faiss.METRIC_L2)
    index.train(datasetImages)
    index.add(datasetImages) 
    time_end = perf_counter()
    totalTime = (time_end - time_start)
    print(f'Took {totalTime:.3f} seconds')
    return (index, totalTime)

In [94]:
(minBuildTime, maxBuildTime, indexedStruct) = createIndexNumerous(createIndex, faiss.IndexPQ, datasetTrainImages, runs)
print('minBuildTime : ', minBuildTime)
print('maxBuildTime : ', maxBuildTime)

Took 11.549 seconds
index  1  created
minBuildTime :  11.549
maxBuildTime :  11.549


In [95]:
indexes = []
distances = []
def measureTime(par, indexes, distances, datasetImages):
    k=100
    totalTime = 0
    for i in range(par) : 
        xq = datasetImages[i:i+1].astype('float32') # Use the first image as the query vector
        time_start = perf_counter()
        distance, index = indexedStruct.search(xq, k) 
        time_end = perf_counter()
        totalTime += (time_end - time_start)
        distances.append(np.sqrt(distance[0]))
        indexes.append(index[0])
    return np.round(totalTime, 3)

In [96]:
(minSearchTime, maxSearchTime, indexes, distances) = measureTimeNumerous(measureTime, runs, queries, datasetTestImages)
print('minSearchTime : ', minSearchTime)
print('maxSearchTime : ', maxSearchTime)

search  1  done
minSearchTime :  7.705
maxSearchTime :  7.705


In [97]:
indexes = np.array(indexes)
distances = np.round(np.array(distances).astype(float), 4)

In [98]:
print('indexes : ', indexes.shape)
print('distances : ', distances.shape)

indexes :  (1000, 100)
distances :  (1000, 100)


In [99]:
path = '../datasets/'+nameFull
(trueIndexes, trueDistances) = readDB(path)

trueIndexes :  (1000, 100)
trueDistances :  (1000, 100)


In [100]:
amount = 10
compareElems(amount, indexes, distances, trueIndexes, trueDistances)

53843 || 53843
665.4332 || 676.584
38620 || 38620
797.8483 || 793.9868
44566 || 16186
878.341 || 862.6766
16186 || 27059
894.4985 || 864.5039
21518 || 47003
913.5148 || 894.7
14563 || 14563
927.2893 || 909.7043
40368 || 44566
933.6413 || 917.6323
15260 || 15260
955.4202 || 921.6241
47003 || 40368
955.7325 || 922.147
27059 || 36395
970.2803 || 943.4972


In [109]:
calculateRecallAverage(indexes, distances, trueIndexes, trueDistances)
calculateRecallAverage(indexes, distances, trueIndexes, trueDistances, 1.01)
calculateRecallAverage(indexes, distances, trueIndexes, trueDistances, 1.1)
calculateNormRecall(indexes, trueIndexes)

Recall@1: 0.819
Recall@1.01: 0.8902
Recall@1.1: 0.9961
77
91
87
78
78
85
81
76
65
76
75
77
76
76
83
82
81
87
79
79
83
77
79
87
65
76
83
76
72
87
84
85
89
59
77
77
81
87
66
85
87
82
82
70
82
76
93
77
79
81
75
74
82
77
80
80
85
85
85
71
81
73
74
73
86
79
74
76
87
86
84
74
74
84
80
77
83
74
84
85
71
72
81
79
72
80
85
80
81
89
79
78
83
81
83
74
91
80
75
78
76
74
86
78
75
78
77
83
83
80
78
83
85
75
84
81
77
77
77
77
73
74
86
78
75
81
79
83
73
80
81
78
80
75
78
85
78
86
75
75
77
79
77
77
79
76
86
74
85
46
78
72
78
76
89
80
84
83
78
81
76
78
87
84
75
80
78
79
84
75
79
85
83
74
84
89
89
69
85
75
77
81
89
76
64
72
82
79
76
91
87
90
83
70
78
75
89
83
78
74
80
77
83
76
87
83
80
82
73
74
82
70
78
82
82
80
81
79
78
77
77
80
69
75
84
78
84
76
79
82
91
87
76
66
83
74
80
80
79
78
80
71
77
77
80
55
78
73
78
77
84
81
79
81
84
89
81
70
83
67
81
73
88
85
77
88
74
82
68
81
81
78
75
83
80
75
84
84
82
90
86
85
75
88
86
80
86
77
85
75
67
84
78
79
81
72
73
76
80
77
89
84
73
58
74
80
83
77
80
80
76
81
68
79
84


0.7897