In [None]:
import numpy as np
import pickle as pkl

In [None]:
path = "./final_data.pkl"

file = open(path, 'rb')
data = pkl.load(file)
file.close()

In [None]:
item_ids = np.array(data['item_id'], dtype=np.int64)
X = np.array(data['embedding'], dtype=np.float32)

In [None]:
X.shape

In [None]:
item_ids[:20]

In [None]:
import numpy as np
from sklearn.cluster import KMeans


class RQKMeans:
    def __init__(self, num_clusters, num_ids, max_iter=300, random_state=42):
        self.models = [
            KMeans(
                n_clusters=num_clusters,
                max_iter=max_iter,
                random_state=(random_state + i),
                init='k-means++',
                verbose=0,
            ) for i in range(num_ids)
        ]

    @staticmethod
    def check(X):
        if isinstance(X, np.ndarray) and (X.dtype == np.float32):
            return X
        raise RuntimeError(f"Wrong input type ({type(X)}) or dtype!")

    def fit(self, X, y=None):
        self.check(X)
        for model in self.models:
            y = model.fit_predict(X)
            X = X - model.cluster_centers_[y]
        return self

    def predict(self, X, return_residuals=False):
        self.check(X)
        result = []
        for model in self.models:
            result.append(model.predict(X))
            X = X - model.cluster_centers_[result[-1]]
        
        residuals = X
        if return_residuals:
            return np.stack(result, axis=-1), residuals
        return np.stack(result, axis=-1)

In [None]:
algo = RQKMeans(num_clusters=256, num_ids=3, max_iter=1000)

In [None]:
algo.fit(X)

In [None]:
clusters = algo.predict(X, return_residuals=False)

In [None]:
clusters.shape

In [None]:
clusters[0]

In [None]:
inder = {}
from collections import defaultdict
sem_2_ids = defaultdict(list)

for idx, (fst, snd, trd) in zip(item_ids, clusters):
    inder[str(idx)] = [f'<a_{fst}>', f'<b_{snd}>', f'<c_{trd}>']
    sem_2_ids[(int(fst), int(snd), int(trd))].append(str(idx))

In [None]:
inder['12100']

In [None]:
for semantics, item_ids in sem_2_ids.items():
    assert len(item_ids) <= 256
    collision_solvers = np.random.permutation(256)[:len(item_ids)].tolist()
    for item_id, collision_solver in zip(item_ids, collision_solvers):
        inder[item_id].append(f'<d_{collision_solver}>')

In [None]:
import json

with open('../data/Beauty/index.json', 'w') as f:
    json.dump(inder, f)