In [2]:
#
#   https://medium.com/@DataPlayer/scalable-approximate-nearest-neighbour-search-using-googles-scann-and-facebook-s-faiss-3e84df25ba
#
import numpy as np
import tensorflow as tf
import scann
import os

# Generate a synthetic dataset of 1 million vectors
NUM_VECTORS = 1000000
VECTOR_DIM  = 128

dataset = np.random.randn(NUM_VECTORS, VECTOR_DIM).astype(np.float32)

2023-10-18 11:30:10.805269: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-10-18 11:30:10.872308: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-10-18 11:30:10.873178: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [9]:
dataset

array([[ 0.53273886,  0.20877197, -0.13401023, ...,  1.1321068 ,
         0.14936765,  2.0898864 ],
       [-1.1190106 ,  1.1306709 , -0.73434114, ..., -1.3780754 ,
         0.20978712,  0.5911143 ],
       [ 0.71972764,  0.57765245, -0.06945786, ...,  0.34756973,
        -0.02554608,  1.6132392 ],
       ...,
       [ 0.21866706, -0.75853896,  0.8017154 , ..., -1.8284234 ,
        -0.46315098, -1.6407605 ],
       [-0.6338245 , -2.2646642 , -0.13387364, ..., -0.2721559 ,
         0.4283607 , -0.30530217],
       [-0.05822184,  0.10939038, -1.2444814 , ...,  2.2333908 ,
        -0.48894915, -1.8492477 ]], dtype=float32)

In [4]:
len(dataset)

1000000

In [5]:
type(dataset)

numpy.ndarray

In [7]:
dataset[0]

array([ 0.53273886,  0.20877197, -0.13401023, -0.4969735 , -1.460248  ,
       -0.45971847, -0.5590802 , -1.4309138 ,  1.2534372 , -0.2962945 ,
       -1.4032303 ,  0.10832689,  0.7683947 , -0.6894984 ,  0.45465046,
       -1.3447602 ,  0.12526636, -0.5624886 ,  0.89817643, -0.9612799 ,
        1.3090175 , -0.26053992, -0.68672836,  1.982428  ,  0.71417433,
       -1.341766  ,  0.39680555,  0.67695   ,  0.66589975, -1.5169811 ,
        0.2532839 ,  1.4627436 ,  0.65193504, -0.02118386,  1.1397831 ,
       -0.5068141 ,  1.8237536 ,  0.9578784 , -0.7903269 ,  1.6935039 ,
       -0.37883702, -0.710244  , -0.25525364, -0.46150902,  0.67734075,
       -0.80032694, -1.0309495 ,  0.79071707,  0.6638751 ,  1.7183176 ,
       -0.13483477, -0.2653128 ,  1.4478605 , -0.16392525, -1.1000438 ,
        1.146045  ,  0.05218043,  1.033244  , -0.8441918 , -0.2220067 ,
       -0.3302139 ,  0.33130682,  1.522045  , -1.738253  , -1.0906689 ,
        0.15306094, -0.44061014, -0.4192783 ,  1.5171627 ,  0.29

In [8]:
 np.linalg.norm(dataset[0])

10.881542

In [13]:
#
# Build a ScaNN index with L2 distance metric and 10 random projection hash tables
#
# scann_ops_pybind.builder is a method in the ScaNN library that is used to create a ScannBuilder object.
#
# The builder() method returns a ScannBuilder object,
# which can be used to specify additional parameters such as
#  - the number of projection hash tables to use and
#  - the number of leaves in the hierarchical tree used to organize the data. 

# By default, ScaNN builds a k-means tree, where each leaf node represents a
# cluster of vectors.
# The .tree() method can be used to further split or merge these clusters to
# optimize the search performance.

searcher = scann.scann_ops_pybind.builder(
    dataset,                             # numpy array containing the data points to be indexed
    num_neighbors    =  10,              # the number of neighbors to search for
    distance_measure = "squared_l2"      # the distance metric to use (e.g., 'dot_product' or 'squared_l2')
).tree(                                  # .tree() is a method that can be used after(?) calling .build() to further refine the tree structure of the index. .tree() allows us to fine-tune the tree structure of the index to optimize the search performance for a given dataset and use case. It takes several arguments: 
    num_leaves           =   2000,       # num_leaves: the maximum number of leaves in the tree . This can be used to control the balance between search speed and memory usage. Increasing num_leaves will improve the search accuracy but also increase the memory usage.
    num_leaves_to_search =    100,       #num_leaves_to_search: the number of leaves to search during the query phase
    training_sample_size = 250000        # training_sample_size: the number of vectors to use for training the tree.
).score_ah(                              # score_ah() is a method in ScaNN that sets the parameters for the asymmetric hash function used in the indexing process. It is called after .tree() because the tree structure is required to set these parameters.
    2,                                   # num_neighbors sets the number of neighbors to consider when selecting the threshold for asymmetric hashing   
    anisotropic_quantization_threshold = 0.2 # anisotropic_quantization_threshold controls how many different quantization levels to use for each dimension in the hash function (smaller values generally resulting in better accuracy at the cost of higher indexing time and memory usage.)
).reorder(
    100
).build()  # Finally, the ScannBuilder object's build() method is called to create the index.

2023-10-18 12:56:29.896255: I scann/partitioning/partitioner_factory_base.cc:59] Size of sampled dataset for training partition: 249544
2023-10-18 12:56:37.121914: I ./scann/partitioning/kmeans_tree_partitioner_utils.h:84] PartitionerFactory ran in 7.225513822s.


In [14]:
# Create a directory for the serialized index

if not os.path.exists("scann_index"):
    os.makedirs("scann_index")


In [4]:
# Save the index to disk

searcher.serialize("scann_index")

In [15]:
# Load the index from disk
searcher = scann.scann_ops_pybind.load_searcher("scann_index/")

In [16]:
# Generate a query vector
query_vector = np.random.randn(VECTOR_DIM).astype(np.float32)

In [17]:
# Find the top 10 nearest neighbors of the query vector
neighbors, distances = searcher.search(query_vector, final_num_neighbors=10)

In [18]:
len(neighbors)

10

In [9]:
# neighbours is an array of indexes for nearest neighbors. distances is an array with their corresponding squared_L2 distances
neighbors,distances

(array([732276, 981817, 598507, 510868, 730527,   2208, 122018, 973694,
        741261, 231291], dtype=uint32),
 array([150.53412, 154.43344, 156.3534 , 156.4636 , 158.85886, 163.04324,
        163.57863, 163.59833, 163.9052 , 164.5318 ], dtype=float32))