In [None]:
from collections import defaultdict

import numpy as np
import json
import pickle

from sklearn.cluster import KMeans

In [None]:
embeddings_input_path = '../data/Beauty/content_embeddings.pkl'
semantic_index_output_path = '../data/Beauty/index_rqkmeans.json'

In [None]:
with open(embeddings_input_path, 'rb') as f:
    data = pickle.load(f)

item_ids = np.array(data['item_id'], dtype=np.int64)
X = np.array(data['embedding'], dtype=np.float32)

In [None]:
class RQKMeans:
    def __init__(
            self, 
            num_clusters, 
            num_codebooks, 
            init='k-means++', 
            max_iter=300, 
            tol=1e-4, 
            verbose=0, 
            random_state=42
    ):
        self.models = [
            KMeans(
                n_clusters=num_clusters,
                init=init,
                max_iter=max_iter,
                tol=tol,
                verbose=verbose,
                random_state=random_state + i,
            ) for i in range(num_codebooks)
        ]

    def fit(self, X, y=None):
        for model in self.models:
            y = model.fit_predict(X)
            X = X - model.cluster_centers_[y]
        return self

    def predict(self, X):
        result = []
        centroids = []
        for model in self.models:
            result.append(model.predict(X))
            centroids.append(model.cluster_centers_[result[-1]])
            X = X - centroids[-1]
        return np.stack(result, axis=-1)

In [None]:
rq_kmeans = RQKMeans(num_clusters=256, num_codebooks=3, max_iter=1000)

In [None]:
rq_kmeans.fit(X)

In [None]:
clusters = rq_kmeans.predict(X)

In [None]:
# Create semantics mapping
inter = {}
sem_2_ids = defaultdict(list)
for idx, clusters in zip(item_ids, clusters):
    inter[int(idx)] = clusters.tolist()
    sem_2_ids[tuple(clusters.tolist())].append(int(idx))

# Solve collistions
for semantics, item_ids in sem_2_ids.items():
    assert len(item_ids) <= 256
    collision_solvers = np.random.permutation(256)[:len(item_ids)].tolist()
    for item_id, collision_solver in zip(item_ids, collision_solvers):
        inter[item_id].append(collision_solver)
    
# Save semantics
with open(semantic_index_output_path, 'w') as f:
    json.dump(inter, f)