In [3]:
import os
import sys
import faiss
sys.path.append(os.path.join(os.path.dirname(os.getcwd()), "vec_db"))
from utilities import compute_recall_at_k
import numpy as np

In [9]:
def read_fvecs(filename):
    """
    Reads a .fvecs file and returns it as a NumPy array.
    
    Parameters:
        filename (str): Path to the .fvecs file.
    
    Returns:
        np.ndarray: Array of vectors with shape (num_vectors, dimension).
    """
    with open(filename, 'rb') as f:
        data = np.fromfile(f, dtype=np.float32)
        # Reshape based on the first value in each vector (dimension)
        dimension = data.view(np.int32)[0]  # The first value is the dimension
        return data.reshape(-1, dimension + 1)[:, 1:]  # Skip the first value (dimension)

# Read the input vectors xb and the query vectors xq
xb = read_fvecs('../Databases/sift_base.fvecs')
xq = read_fvecs('../Databases/sift_query.fvecs')

xq = xq[:50].reshape(50, xq.shape[1])

In [17]:
# Create and train the FAISS index
index = faiss.index_factory(128, "IVF256,PQ32")
index.train(xb)
index.add(xb)

imi = faiss.extract_index_ivf(index)
imi.nprobe = 8

# Perform a search with query vectors
k = 5  # Number of nearest neighbors to retrieve
D, I = index.search(xq, k)

print("Nearest neighbors (indices):")
print(I)
print("Distances:")
print(D)

# Evaluate recall
# Create exact search index for ground truth
ground_truth_index = faiss.IndexFlatL2(128)
ground_truth_index.add(xb)

# Perform an exact search to get ground truth
D_exact, I_exact = ground_truth_index.search(xq, k)

recall = compute_recall_at_k(I_exact, I, k)
print(f"Recall@{k}: {recall:.4f}")

Nearest neighbors (indices):
[[932085 934876 561813 708177 908244]
 [880462 706838 249062 413071 413247]
 [408764 551661 861882 408462 239766]
 [ 48044 970797 125539 190692 191115]
 [340871 748397 748193 175336 716433]
 [220473 669622  27746 187470  67875]
 [652078 880346 982409 982379  60092]
 [207868 618842 107468 906750 807599]
 [323464 724549 323160 325865 724587]
 [178811 177646 821938 716433 292077]
 [320547 316210 992617 753170 297239]
 [  4697 762035 672308 910776 824256]
 [191546 923810 931688 390776 563997]
 [522078 906497 652476 134090 990875]
 [860352 601274 454220 433429 368880]
 [969865 970094 823579 329989 237312]
 [796217 796844 798456 217222 986823]
 [394163 601678 579460 203958 268236]
 [434884 247250 701657 545483  63379]
 [394974 193620 272436 222535 258797]
 [ 82246 587404 448233 666481 929092]
 [  4490   4457 214449 337027 566609]
 [426465  31629 483146 536469 281276]
 [686808 219735 363583 526788 350527]
 [  4630   4554 804007 784790 279489]
 [719083  87978 27963

In [18]:
%timeit index.search(xq, k)

9.57 ms ± 656 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
