In [1]:
import faiss
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

d = 4

xb = np.array([
    [0.1,0.2,0.3,0.4],
    [0.1,0.2,0.9,1.0],
    [0.5,0.6,0.7,0.8]
], dtype=np.float32)

index = faiss.IndexHNSWFlat(d, 2,faiss.METRIC_L2)   # build the index
print(index.is_trained)
index.add(xb)                  # add vectors to the index
print(index.ntotal)

True
3


In [2]:
k = 2                          # we want to see 4 nearest neighbors
D, I = index.search(xb, k) # sanity check
print(I)
print(D)

[[0 2]
 [1 2]
 [2 1]]
[[0.         0.64000005]
 [0.         0.40000004]
 [0.         0.40000004]]


In [3]:
cosine_similarity(xb, xb)

array([[1.        , 0.9638632 , 0.96886396],
       [0.9638632 , 1.        , 0.88938314],
       [0.96886396, 0.88938314, 1.        ]], dtype=float32)

In [4]:
faiss.normalize_L2(xb)

print(xb)

index = faiss.IndexHNSWFlat(d, 2,faiss.METRIC_INNER_PRODUCT)   # build the index
index.verbose = True
index.hnsw.efConstruction = 2
index.hnsw.efSearch = 2
print(index.is_trained)
index.add(xb)                  # add vectors to the index
print(index.ntotal)

[[0.18257418 0.36514837 0.5477226  0.73029673]
 [0.07332356 0.14664713 0.65991205 0.7332356 ]
 [0.37904903 0.45485887 0.5306686  0.60647845]]
True
hnsw_add_vertices: adding 3 elements on top of 0 (preset_levels=0)
  max_level = 3
Adding 2 elements at level 3
Adding 0 elements at level 2
Adding 0 elements at level 1
Adding 1 elements at level 0
3
Done in 0.082 ms


In [5]:
k = 2                          # we want to see 4 nearest neighbors
D, I = index.search(xb, k) # sanity check
print(I)
print(D)

[[0 2]
 [1 0]
 [2 0]]
[[1.         0.96886396]
 [1.         0.96386325]
 [1.         0.96886396]]


In [6]:
cosine_similarity(xb, xb)

array([[1.        , 0.96386325, 0.96886396],
       [0.96386325, 1.0000001 , 0.8893832 ],
       [0.96886396, 0.8893832 , 1.0000001 ]], dtype=float32)

In [14]:
index = faiss.IndexHNSWFlat(d, 2,faiss.METRIC_INNER_PRODUCT)   # build the index
index.verbose = True
index.hnsw.efConstruction = 2
index.hnsw.efSearch = 2

ids = np.arange(xb.shape[0]) + 10

index2 = faiss.IndexIDMap(index)
index2.add_with_ids(xb, ids) # works, the vectors are stored in the underlying index

hnsw_add_vertices: adding 3 elements on top of 0 (preset_levels=0)
  max_level = 3
Adding 2 elements at level 3
Adding 0 elements at level 2
Adding 0 elements at level 1
Adding 1 elements at level 0
Done in 0.547 ms


In [15]:
k = 2                          # we want to see 4 nearest neighbors
D, I = index2.search(xb, k) # sanity check
print(I)
print(D)

[[10 12]
 [11 10]
 [12 10]]
[[1.         0.96886396]
 [1.         0.96386325]
 [1.         0.96886396]]


In [17]:
from typing import List

class BaseNNIndexer():
    '''
    Base class for our nearest neighbor indexing operations, atm we mainly abstrcat faiss, but it should allow us to swap in other libs fairly easy
    '''

    def __init__(self, config):
        super(BaseNNIndexer, self).__init__()

        self.token_dim = config["token_dim"]
        self.use_gpu = config["faiss_use_gpu"]
        self.use_fp16 = config["token_dtype"] == "float16"

    def prepare(self, data_chunks:List[np.ndarray], subsample=-1):
        '''
        Train an index with (all) or only some vectors, if subsample is set to a value between 0 and 1
        '''
        pass

    def index(self, ids:List[np.ndarray], data_chunks:List[np.ndarray]):
        '''
        ids: need to be int64
        '''
        pass

    def search(self, query_vec:np.ndarray, top_n:int):
        '''
        query_vec: can be 2d (batch search) or 1d (single search) 
        '''
        pass


class FaissBaseIndexer(BaseNNIndexer):
    '''
    Shared faiss code
    '''

    def __init__(self,config):
        super(FaissBaseIndexer, self).__init__(config)
        self.faiss_index:faiss.Index = None # needs to be initialized by the actual faiss classes

    def index(self, ids:List[np.ndarray], data_chunks:List[np.ndarray]):
        # single add needed for multi-gpu index (sharded), and hnsw so just do it for all (might be a memory problem at some point, but we can come back to that)
        i = np.concatenate(ids).astype(np.int64)
        c = np.concatenate(data_chunks).astype(np.float32)
        self.faiss_index.add_with_ids(c,i)

    def search(self, query_vec:np.ndarray, top_n:int):
        # even a single search must be 1xn dims
        if len(query_vec.shape) == 1:
            query_vec = query_vec[np.newaxis,:]
            
        res_scores, indices = self.faiss_index.search(query_vec.astype(np.float32),top_n)

        return res_scores, indices

    def save(self, path:str):
        if self.use_gpu:
            idx = faiss.index_gpu_to_cpu(self.faiss_index)
        else:
            idx = self.faiss_index
        faiss.write_index(idx, path)
    
    def load(self, path:str,config_overwrites=None):
        self.faiss_index = faiss.read_index(path)


class FaissIdIndexer(FaissBaseIndexer):
    '''
    Simple brute force nearest neighbor faiss index with id mappings, with potential gpu usage, support for fp16
    -> if faiss_use_gpu=True use all availbale GPUs in a sharded index 
    '''

    def __init__(self,config):
        super(FaissIdIndexer, self).__init__(config)

        if self.use_gpu:

            cpu_index = faiss.IndexIDMap(faiss.IndexFlatIP(config["token_dim"]))
                        
            co = faiss.GpuMultipleClonerOptions()
            co.shard = True
            co.useFloat16 = self.use_fp16

            self.faiss_index = faiss.index_cpu_to_all_gpus(cpu_index,co)

        else:
            if self.use_fp16:
                self.faiss_index = faiss.IndexIDMap(faiss.IndexScalarQuantizer(config["token_dim"],faiss.ScalarQuantizer.QT_fp16,faiss.METRIC_INNER_PRODUCT))
            else:
                self.faiss_index = faiss.IndexIDMap(faiss.IndexFlatIP(config["token_dim"]))


class FaissHNSWIndexer(FaissBaseIndexer):
    '''
    HNSW - graph based - index, only supports CPU - but gets very low query latency 
    '''

    def __init__(self,config):
        super(FaissHNSWIndexer, self).__init__(config)

        self.use_gpu = False # HNSW does not support GPUs

        if self.use_fp16:
            self.faiss_index = faiss.IndexHNSWSQ(config["token_dim"],faiss.ScalarQuantizer.QT_fp16,
                                                config["faiss_hnsw_graph_neighbors"],faiss.METRIC_INNER_PRODUCT)

        else:
            self.faiss_index = faiss.IndexHNSWFlat(config["token_dim"],config["faiss_hnsw_graph_neighbors"],faiss.METRIC_INNER_PRODUCT)
        
        self.faiss_index.verbose = True
        self.faiss_index.hnsw.efConstruction = config["faiss_hnsw_efConstruction"]
        self.faiss_index.hnsw.efSearch = config["faiss_hnsw_efSearch"]

        self.faiss_index = faiss.IndexIDMap(self.faiss_index)

    def prepare(self, data_chunks:List[np.ndarray], subsample=-1):
        if self.use_fp16:
            # training for the scalar quantizer, according to: https://github.com/facebookresearch/faiss/blob/master/benchs/bench_hnsw.py
            self.faiss_index.train(np.concatenate(data_chunks).astype(np.float32))

In [23]:
index_config = {
    "token_dim": 4,
    "faiss_hnsw_graph_neighbors": 2,
    "faiss_hnsw_efConstruction": 2,
    "faiss_hnsw_efSearch": 2,
    "faiss_use_gpu": False,
    "token_dtype": "float32"
}

indexer = FaissHNSWIndexer(index_config)

indexer.index(ids=[ids], data_chunks=[xb])
indexer.index(ids=[ids+10], data_chunks=[xb])

hnsw_add_vertices: adding 3 elements on top of 0 (preset_levels=0)
  max_level = 3
Adding 2 elements at level 3
Adding 0 elements at level 2
Adding 0 elements at level 1
Adding 1 elements at level 0
Done in 0.388 ms
hnsw_add_vertices: adding 3 elements on top of 3 (preset_levels=0)
  max_level = 0
Adding 3 elements at level 0
Done in 0.062 ms


In [27]:
indexer.search(xb, top_n=6)

(array([[ 1.0000000e+00,  1.0000000e+00,  9.6886396e-01,  9.6386325e-01,
         -3.4028235e+38, -3.4028235e+38],
        [ 1.0000000e+00,  1.0000000e+00,  9.6386325e-01,  8.8938320e-01,
         -3.4028235e+38, -3.4028235e+38],
        [ 1.0000000e+00,  1.0000000e+00,  9.6886396e-01,  9.6886396e-01,
          8.8938320e-01,  8.8938320e-01]], dtype=float32),
 array([[10, 20, 12, 21, -1, -1],
        [11, 21, 10, 12, -1, -1],
        [22, 12, 20, 10, 11, 21]]))