In [2]:
import numpy as np
d = 64                           # dimension
nb = 100000                      # database size
nq = 10000                       # nb of queries
np.random.seed(1234)             # make reproducible

xb = np.random.random((nb, d)).astype('float32')
xb[:, 0] += np.arange(nb) / 1000.

xq = np.random.random((nq, d)).astype('float32')
xq[:, 0] += np.arange(nq) / 1000.


In [16]:
np.shape(xb)

(100000, 64)

In [5]:
!pip install faiss-gpu

Collecting faiss-gpu
  Downloading faiss_gpu-1.7.1.post2-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (89.7 MB)
[K     |████████████████████████████████| 89.7 MB 77.4 MB/s 
[?25hInstalling collected packages: faiss-gpu
Successfully installed faiss-gpu-1.7.1.post2


In [9]:
import faiss                   # make faiss available
index = faiss.IndexFlatL2(d)   # build the index
print(index.is_trained)
index.add(xb)                  # add vectors to the index
print(index.ntotal)

True
100000


In [10]:
k = 4                          # we want to see 4 nearest neighbors
D, I = index.search(xb[:5], k) # sanity check
print(I)
print(D)

D, I = index.search(xq, k)     # actual search
print(I[:5])                   # neighbors of the 5 first queries
print(I[-5:])                  # neighbors of the 5 last queries

[[  0 393 363  78]
 [  1 555 277 364]
 [  2 304 101  13]
 [  3 173  18 182]
 [  4 288 370 531]]
[[0.        7.175174  7.2076287 7.251163 ]
 [0.        6.323565  6.684582  6.799944 ]
 [0.        5.7964087 6.3917365 7.2815127]
 [0.        7.277905  7.5279875 7.6628447]
 [0.        6.763804  7.295122  7.368814 ]]
[[ 381  207  210  477]
 [ 526  911  142   72]
 [ 838  527 1290  425]
 [ 196  184  164  359]
 [ 526  377  120  425]]
[[ 9900 10500  9309  9831]
 [11055 10895 10812 11321]
 [11353 11103 10164  9787]
 [10571 10664 10632  9638]
 [ 9628  9554 10036  9582]]


In [None]:
def pairwise_dist(xyz1, xyz2):
    r_xyz1 = torch.sum(xyz1 * xyz1, dim=2, keepdim=True)  # (B,N,1)
    r_xyz2 = torch.sum(xyz2 * xyz2, dim=2, keepdim=True)  # (B,M,1)
    mul = torch.matmul(xyz2, xyz1.permute(0,2,1))         # (B,M,N)
    dist = r_xyz2 - 2 * mul + r_xyz1.permute(0,2,1)       # (B,M,N)
    return dist