In [1]:
import numpy as np

import faiss

import matplotlib.pyplot as plt

plt.style.use("bmh")
plt.rcParams["figure.figsize"] = (8, 6)
plt.rcParams["image.cmap"] = "Blues"

In [2]:
embedding = np.load("../../data/vast/model_50000_1/mean_doc_vec_matrix.npy")

print(embedding.shape)
embedding

(50000, 300)


array([[ 0.03320654, -0.06433736,  0.10389966, ..., -0.05451681,
        -0.04246792,  0.05446664],
       [ 0.02599476,  0.01575095,  0.10604253, ..., -0.06802855,
        -0.00031488, -0.00361113],
       [-0.05472623,  0.05246948,  0.09704033, ...,  0.03371023,
         0.01643197,  0.03046615],
       ...,
       [ 0.02957482, -0.02850348, -0.01619003, ..., -0.05450831,
        -0.0277775 ,  0.01814242],
       [ 0.02957482, -0.02850348, -0.01619003, ..., -0.05450831,
        -0.0277775 ,  0.01814242],
       [ 0.03422622,  0.0448886 , -0.01871247, ...,  0.03861113,
        -0.04534175,  0.01759292]])

In [3]:
k = 100

embedding = embedding.copy(order="C").astype(np.float32)
embedding -= embedding.mean(axis=0)

index = faiss.IndexFlatIP(embedding.shape[1])
faiss.normalize_L2(embedding)

index.add(embedding)

similarities, nearest_neighbors = index.search(embedding, k + 1)

In [4]:
similarities

array([[1.0000001 , 0.5019864 , 0.47723287, ..., 0.39278483, 0.3927035 ,
        0.39244723],
       [1.        , 0.55589545, 0.5531961 , ..., 0.4060866 , 0.4060866 ,
        0.40562934],
       [1.0000001 , 1.0000001 , 1.0000001 , ..., 0.75792146, 0.7563087 ,
        0.75582284],
       ...,
       [1.        , 1.        , 0.7898824 , ..., 0.64848334, 0.6473886 ,
        0.64727515],
       [1.        , 1.        , 0.7898824 , ..., 0.64848334, 0.6473886 ,
        0.64727515],
       [1.0000002 , 0.83116806, 0.83116806, ..., 0.33846855, 0.33783728,
        0.33608305]], dtype=float32)

In [5]:
nearest_neighbors

array([[    0, 11286, 15191, ..., 30806, 12675, 32630],
       [    1, 15107, 47843, ..., 40567, 40566, 45659],
       [48620, 10867, 16873, ..., 16606, 39231, 36205],
       ...,
       [49997, 49998, 44074, ..., 47201,  4872,  8863],
       [49997, 49998, 44074, ..., 47201,  4872,  8863],
       [49999,  7131, 29709, ..., 14802, 21093, 34563]])

In [6]:
labels = np.load("labels.npy")
labels

array([1, 1, 1, ..., 1, 1, 1])

In [7]:
from scipy.io import savemat

savemat(
    "vast_nearest_neighbors_100.mat",
    {
        "labels": labels,
        "nearest_neighbors": nearest_neighbors[:, 1:] + 1,
        "similarities": similarities[:, 1:]
    },
)