In [1]:
import torch

from rqvae_data import get_data

df = get_data()

In [10]:
embs = torch.stack(df["embeddings"].tolist())

In [None]:
embs.shape

In [None]:
from rqvae import RQVAE

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


rqvae = RQVAE(
    input_dim=embs.shape[1],
    hidden_dim=128,
    beta=0.25,
    codebook_sizes=[256] * 4,
    should_init_codebooks=True,
    should_reinit_unused_clusters=False,
).to(device)


embs_dict = {"embedding": embs.to(device)}

rqvae.forward(embs_dict)

In [4]:
from collisions import dedup
from rqvae_data import get_cb_tuples


cb_tuples = list(get_cb_tuples(rqvae, embs_dict["embedding"]))
items_with_tuples = list(zip(df["asin"], df["title"].fillna("unknown"), cb_tuples))
items_with_tuples = dedup(items_with_tuples)

In [None]:
from rqvae_data import search_similar_items


for i in range(230, 240):
    sim = search_similar_items(items_with_tuples, (i,), 10)
    if len(sim) == 0:
        continue
    print(i)
    for asin, item, clust_tuple in sim:
        if 'nail' in item.lower():
            print(f"{item=} {clust_tuple=}")

In [None]:
from collections import Counter
import matplotlib.pyplot as plt


plt.hist(Counter(item[-1][:-1] for item in items_with_tuples).values(), bins=100)
plt.show()

In [None]:
len(set(item[-1] for item in items_with_tuples))

In [2]:
from sklearn import preprocessing

labels = df['asin']

le = preprocessing.LabelEncoder()
targets = le.fit_transform(labels)

df['asin_numeric'] = targets

In [5]:
torch.save(df, './all_data.pt')