In [1]:
import faiss
import numpy as np


class FaissKMeans:
    
    # There are several examples of faiss-kmeans.
    # Some utilize only one core, some utilize all cores (like this one).
    # Some should utilize GPU, but i didn't succeed.
    # This one is the fastest from tested by me.
    
    def __init__(self, n_clusters=8, n_init=10, max_iter=30):
        self.n_clusters = n_clusters
        self.n_init = n_init
        self.max_iter = max_iter
        self.kmeans = None
        self.cluster_centers_ = None
        self.inertia_ = None

    def fit(self, X):
        self.kmeans = faiss.Kmeans(d=X.shape[1],
                                   k=self.n_clusters,
                                   niter=self.max_iter,
                                   nredo=self.n_init)
        self.kmeans.train(X.astype(np.float32))
        self.cluster_centers_ = self.kmeans.centroids
        self.inertia_ = self.kmeans.obj[-1]

    def predict(self, X):
        return self.kmeans.index.search(X.astype(np.float32), 1)[1]

In [2]:
x = np.random.rand(1000, 768)  # 1000 vectors with size of 768
x = x.astype(np.float32)
x.shape

(1000, 768)

In [3]:
kmeans = FaissKMeans()
kmeans.fit(x)
kmeans.predict(x)

array([[0],
       [6],
       [4],
       [7],
       [2],
       [6],
       [7],
       [5],
       [7],
       [5],
       [0],
       [3],
       [7],
       [7],
       [6],
       [5],
       [1],
       [4],
       [2],
       [7],
       [1],
       [4],
       [0],
       [1],
       [5],
       [3],
       [0],
       [4],
       [0],
       [2],
       [7],
       [4],
       [5],
       [6],
       [2],
       [0],
       [5],
       [4],
       [0],
       [2],
       [6],
       [2],
       [0],
       [5],
       [0],
       [4],
       [4],
       [0],
       [1],
       [2],
       [0],
       [7],
       [1],
       [2],
       [2],
       [1],
       [0],
       [7],
       [0],
       [6],
       [5],
       [7],
       [7],
       [7],
       [2],
       [6],
       [1],
       [7],
       [5],
       [4],
       [5],
       [4],
       [2],
       [7],
       [4],
       [7],
       [4],
       [4],
       [0],
       [4],
       [6],
       [0],
       [7],
    