In [1]:
# P&G faiss index recall evaluation code

In [1]:
import numpy as np
from PIL import Image
import csv
import os
from os import listdir
from os.path import isfile, join, splitext
import shutil
import random
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow

import faiss

%matplotlib inline


In [2]:
arr = np.load('pg_np_512d_vector_data.npy').astype('float32')

In [3]:
arr.shape

(32217, 512)

In [4]:
# seperate index and test query
xb = arr[0:30000]
xq = arr[30000:32000]

In [5]:
print('xb.shape : ',xb.shape)
print('xq.shape', xq.shape)

xb.shape :  (30000, 512)
xq.shape (2000, 512)


In [6]:
# set L2 index
d = 512
indexL2 = faiss.IndexFlatL2(d)
indexL2.add(xb)

In [7]:
# set IVFFlat index
nlist = 100
quantizer = faiss.IndexFlatL2(d)  # the other index
indexIVFFlat = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2)
indexIVFFlat.train(xb)
indexIVFFlat.add(xb)

In [25]:
# build L2 - ground truth dataset
k = 1
_, I_indexL2 = indexL2.search(xq, k)


In [26]:
# eval function
def evaluate(I_indexL2, I_IndexFlat):
    nq, d = xq.shape
    missing_rate = (I_IndexFlat.flatten() == -1).sum() / float(k * nq)
    recall_at_1 = (I_IndexFlat == I_indexL2[:, :1]).sum() / nq
    print ("R@1 %.4f, missing rate %.4f, k value is %.0f, nprobe is %.0f" % (recall_at_1, missing_rate, k, indexIVFFlat.nprobe))

In [27]:
# build IVFFlat evaluation test dataset
indexIVFFlat.nprobe = 1
_, I_IndexFlat = indexIVFFlat.search(xq, k)
evaluate(I_indexL2, I_IndexFlat)

R@1 0.4175, missing rate 0.0000, k value is 1, nprobe is 1


In [28]:
indexIVFFlat.nprobe = 10
_, I_IndexFlat = indexIVFFlat.search(xq, k)
evaluate(I_indexL2, I_IndexFlat)

R@1 0.9185, missing rate 0.0000, k value is 1, nprobe is 10


In [29]:
indexIVFFlat.nprobe = 20
_, I_IndexFlat = indexIVFFlat.search(xq, k)
evaluate(I_indexL2, I_IndexFlat)

R@1 0.9795, missing rate 0.0000, k value is 1, nprobe is 20


In [30]:
indexIVFFlat.nprobe = 30
_, I_IndexFlat = indexIVFFlat.search(xq, k)
evaluate(I_indexL2, I_IndexFlat)

R@1 0.9940, missing rate 0.0000, k value is 1, nprobe is 30
