In [1]:
import torch

from rqvae_data import get_data

df = get_data()

In [2]:
embs = torch.stack(df["embeddings"].tolist())

In [3]:
embs.shape

torch.Size([12101, 512])

In [4]:
from rqvae import RQVAE

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


rqvae = RQVAE(
    input_dim=embs.shape[1],
    hidden_dim=128,
    beta=0.25,
    codebook_sizes=[256] * 4,
    should_init_codebooks=True,
    should_reinit_unused_clusters=False,
).to(device)


embs_dict = {"embedding": embs.to(device)}

rqvae.forward(embs_dict)

100%|██████████| 4/4 [00:07<00:00,  2.00s/it]


{'loss': tensor(0.0057, device='cuda:0', grad_fn=<MeanBackward0>),
 'recon_loss': tensor(0.0052, device='cuda:0'),
 'rqvae_loss': tensor(0.0005, device='cuda:0'),
 'unique/0': 256,
 'unique/1': 256,
 'unique/2': 256,
 'unique/3': 256}

In [5]:
from rqvae_data import get_cb_tuples


cb_tuples = get_cb_tuples(rqvae, embs_dict["embedding"])
items_with_tuples = list(zip(df["title"].fillna("unknown"), cb_tuples))

In [6]:
from rqvae_data import search_similar_items


for i in range(100, 110):
    sim = search_similar_items(items_with_tuples, (i,), 10)
    if len(sim) == 0:
        continue
    print(i)
    for item, clust_tuple in sim:
        print(f"{item=} {clust_tuple=}")

100
item='Axe Anti-dandruff Styling Cream, 3.2 Ounce' clust_tuple=(100, 224, 126, 160)
item='Enjoy Texture Cream, 8.8-Ounce (Packaging may Vary)' clust_tuple=(100, 199, 72, 78)
item='Dove Hair Styling Oxygen Moisture Leave In Foam, 5.1 Ounce' clust_tuple=(100, 224, 126, 109)
item='Suave Professionals Natural Infusion Seaweed and Lotus Blossom Leave-in Foam, 5 Ounce' clust_tuple=(100, 224, 126, 109)
item='Motions Naturally You, Deep Conditioning Masque, 8 Ounce' clust_tuple=(100, 224, 126, 160)
item='Axe Styling Spiked Up Look Gel, 6 Ounce' clust_tuple=(100, 224, 126, 160)
item='Just For Me Texture Softener' clust_tuple=(100, 190, 126, 160)
item='Axe Styling Messy Look Matte Gel, 6 Ounce' clust_tuple=(100, 199, 126, 160)
item='Motions At Home Oil Moisturizer Hair Lotion, 12-Ounce Bottles (Pack of 6)' clust_tuple=(100, 199, 126, 160)
item='Redken Hair Cleansing Cream Shampoo for All Hair Types, 10.1-Ounces' clust_tuple=(100, 199, 72, 195)
101
item='Alba Botanica Even Advanced Eye Makeup 