In [None]:
import numpy as np
import torch
import faiss

from misc import load_embeddings
from misc import ALL_N_CLUSTERS, ALL_ASSIGNMENTS_PATHS

In [None]:
embs = load_embeddings()
X = torch.from_numpy(embs).float()

In [None]:
def _to_faiss_array(tensor: torch.Tensor) -> np.ndarray:
    array = np.asarray(tensor.cpu())
    array = array.astype('float32')
    array = np.ascontiguousarray(array)
    faiss.normalize_L2(array)
    return array

def fit_spherical_kmeans(
        context: torch.Tensor,
        n_clusters: int=1024,
):
    X = _to_faiss_array(context)
    km = faiss.Kmeans(d=X.shape[1], k=n_clusters, spherical=True, gpu=False, verbose=True)
    km.train(X)
    return km.centroids

In [None]:
for n_clusters, assignments_path in zip(ALL_N_CLUSTERS, ALL_ASSIGNMENTS_PATHS):
    print(f"Fitting spherical k-means with {n_clusters} clusters...")
    cluster_centers = fit_spherical_kmeans(
        X,
        n_clusters=n_clusters,
    )
    assignments = faiss.knn(
        _to_faiss_array(X),
        cluster_centers,
        k=1,
    )[1].squeeze()
    np.save(assignments_path, assignments)
print("Done.")