In [1]:
import numpy as np
import pickle as pkl

In [2]:
path = "./final_data.pkl"

file = open(path, 'rb')
data = pkl.load(file)
file.close()

In [3]:
item_ids = np.array(data['item_id'], dtype=np.int64)
X = np.array(data['embedding'], dtype=np.float32)[item_ids]

In [4]:
item_ids[item_ids][:20]

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 15, 12, 11, 14, 13, 16,
       17, 18, 19])

In [5]:
import numpy as np
from sklearn.cluster import KMeans


class RQKMeans:
    def __init__(self, num_clusters, num_ids, max_iter=300, random_state=42):
        self.models = [
            KMeans(
                n_clusters=num_clusters,
                max_iter=max_iter,
                random_state=(random_state + i),
                init='k-means++',
                verbose=0,
            ) for i in range(num_ids)
        ]

    @staticmethod
    def check(X):
        if isinstance(X, np.ndarray) and (X.dtype == np.float32):
            return X
        raise RuntimeError(f"Wrong input type ({type(X)}) or dtype!")

    def fit(self, X, y=None):
        self.check(X)
        for model in self.models:
            y = model.fit_predict(X)
            X = X - model.cluster_centers_[y]
        return self

    def predict(self, X, return_residuals=False):
        self.check(X)
        result = []
        for model in self.models:
            result.append(model.predict(X))
            X = X - model.cluster_centers_[result[-1]]
        
        residuals = X
        if return_residuals:
            return np.stack(result, axis=-1), residuals
        return np.stack(result, axis=-1)

In [6]:
algo = RQKMeans(num_clusters=256, num_ids=3, max_iter=1000)

In [7]:
algo.fit(X)

<__main__.RQKMeans at 0x7fbe030bb800>

In [8]:
cluster, residuals = algo.predict(X, return_residuals=True)

In [9]:
cluster.shape, residuals.shape

((12101, 3), (12101, 768))

In [10]:
(np.linalg.norm(residuals, axis=-1) / np.linalg.norm(X, axis=-1)).mean()

np.float32(0.23402064)

In [11]:
cluster[0]

array([165, 114,  10], dtype=int32)

In [12]:
inder = {}
from collections import defaultdict
sem_2_ids = defaultdict(list)

for idx, (fst, snd, trd) in enumerate(cluster):
    inder[str(idx)] = [f'<a_{fst}>', f'<b_{snd}>', f'<c_{trd}>']
    sem_2_ids[(int(fst), int(snd), int(trd))].append(str(idx))

In [15]:
inder['12100']

['<a_205>', '<b_8>', '<c_27>']

In [16]:
for semantics, item_ids in sem_2_ids.items():
    assert len(item_ids) <= 256
    collision_solvers = np.random.permutation(256)[:len(item_ids)].tolist()
    for item_id, collision_solver in zip(item_ids, collision_solvers):
        inder[item_id].append(f'<d_{collision_solver}>')

In [17]:
import json

with open('../data/Beauty/index.json', 'w') as f:
    json.dump(inder, f)