In [None]:
import numpy as np
import torch
import faiss

from misc import load_embeddings
from misc import ALL_N_CLUSTERS, ALL_ASSIGNMENTS_PATHS

In [2]:
embs = load_embeddings()
X = torch.from_numpy(embs).float()

In [3]:
def _to_faiss_array(tensor: torch.Tensor) -> np.ndarray:
    array = np.asarray(tensor.cpu())
    array = array.astype('float32')
    array = np.ascontiguousarray(array)
    faiss.normalize_L2(array)
    return array

def fit_spherical_kmeans(
        context: torch.Tensor,
        n_clusters: int=1024,
):
    X = _to_faiss_array(context)
    km = faiss.Kmeans(d=X.shape[1], k=n_clusters, spherical=True, gpu=False, verbose=True)
    km.train(X)
    return km.centroids

In [4]:
for n_clusters, assignments_path in zip(ALL_N_CLUSTERS, ALL_ASSIGNMENTS_PATHS):
    print(f"Fitting spherical k-means with {n_clusters} clusters...")
    cluster_centers = fit_spherical_kmeans(
        X,
        n_clusters=n_clusters,
    )
    assignments = faiss.knn(
        _to_faiss_array(X),
        cluster_centers,
        k=1,
    )[1].squeeze()
    np.save(assignments_path, assignments)
print("Done.")

Fitting spherical k-means with 1 clusters...
Sampling a subset of 256 / 98348 for training
Clustering 256 points in 1536D to 1 clusters, redo 1 times, 25 iterations
  Preprocessing in 0.04 s
  Iteration 24 (0.01 s, search 0.01 s): objective=79.5851 imbalance=1.000 nsplit=0       
Fitting spherical k-means with 2 clusters...
Sampling a subset of 512 / 98348 for training
Clustering 512 points in 1536D to 2 clusters, redo 1 times, 25 iterations
  Preprocessing in 0.04 s
  Iteration 24 (0.01 s, search 0.01 s): objective=172.977 imbalance=1.132 nsplit=0       
Fitting spherical k-means with 4 clusters...
Sampling a subset of 1024 / 98348 for training
Clustering 1024 points in 1536D to 4 clusters, redo 1 times, 25 iterations
  Preprocessing in 0.04 s
  Iteration 24 (0.02 s, search 0.02 s): objective=374.798 imbalance=1.025 nsplit=0       
Fitting spherical k-means with 8 clusters...
Sampling a subset of 2048 / 98348 for training
Clustering 2048 points in 1536D to 8 clusters, redo 1 times, 25

KeyboardInterrupt: 