In [1]:
import numpy as np
import faiss

KNN using L2 distance (Euclidean) and dot product.

In [2]:
def search_knn(xq, xb, k, distance_type=faiss.METRIC_L2): 
    """ wrapper around the faiss knn functions without index """
    nq, d = xq.shape
    nb, d2 = xb.shape
    assert d == d2
    
    I = np.empty((nq, k), dtype='int64')
    D = np.empty((nq, k), dtype='float32')
    
    if distance_type == faiss.METRIC_L2: 
        heaps = faiss.float_maxheap_array_t()
        heaps.k = k
        heaps.nh = nq
        heaps.val = faiss.swig_ptr(D)
        heaps.ids = faiss.swig_ptr(I)
        faiss.knn_L2sqr(
            faiss.swig_ptr(xq), faiss.swig_ptr(xb), 
            d, nq, nb, heaps
        )
    elif distance_type == faiss.METRIC_INNER_PRODUCT: 
        heaps = faiss.float_minheap_array_t()
        heaps.k = k
        heaps.nh = nq
        heaps.val = faiss.swig_ptr(D)
        heaps.ids = faiss.swig_ptr(I)
        faiss.knn_inner_product(
            faiss.swig_ptr(xq), faiss.swig_ptr(xb), 
            d, nq, nb, heaps
        )
    return D, I 

IndexFlatL2 or IndexFlatIP are both brute force indexing. As shown by the code below which does KNN manually and asserts that both results are the same. 

In [12]:
# test for function above
        
xb = np.random.rand(200, 32).astype('float32')
xq = np.random.rand(100, 32).astype('float32')

# IndexFlatL2 indexing is brute force
index = faiss.IndexFlatL2(32)
index.add(xb)
Dref, Iref = index.search(xq, 10)

Dnew, Inew = search_knn(xq, xb, 10)

assert np.all(Inew == Iref)
assert np.allclose(Dref, Dnew)


In [13]:
Inew[:5]

array([[ 54,   3,  57,  83, 195,  38,  76, 173,  13, 185],
       [ 54, 190,  23,   8, 158, 100, 114, 165, 107,  69],
       [ 73, 159,  47, 120, 147, 192, 100, 146, 111, 179],
       [108,  83, 126, 168, 138,  33, 160,  26, 173,   3],
       [197,  93, 105,  81, 121,  17,   2,  12, 193, 177]])

In [14]:
# IndexFlatIP indexing is brute force, just a different distance metric
index = faiss.IndexFlatIP(32)
index.add(xb)
Dref, Iref = index.search(xq, 10)

Dnew, Inew = search_knn(xq, xb, 10, distance_type=faiss.METRIC_INNER_PRODUCT)

assert np.all(Inew == Iref)
assert np.allclose(Dref, Dnew)

In [15]:
Inew[:5]

array([[ 43, 173,   3,  17,  69,   4, 195,  83, 142,  54],
       [ 69,  43, 161,   6, 189,  65, 135,  54,   8,   1],
       [141,   4,  43, 128,  41,  98, 173,  65,   7,  73],
       [ 43, 173, 108, 178,  17, 128, 168,  65,  83,   3],
       [ 43, 142,  65, 121,  17, 197, 173,  69, 128,  14]])