In [1]:
import torch

from rqvae_data import get_data

df = get_data()

  df = torch.load("../data/df_with_embs.pt")


In [2]:
embs = torch.stack(df["embeddings"].tolist())

In [3]:
embs.shape

torch.Size([12101, 512])

In [4]:
import random

from rqvae import RQVAE

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


rqvae = RQVAE(
    input_dim=embs.shape[1],
    hidden_dim=128,
    beta=0.25,
    codebook_sizes=[256] * 4,
    should_init_codebooks=True,
    should_reinit_unused_clusters=False,
).to(device)


embs_dict = {"embedding": embs.to(device)}

rqvae.forward(embs_dict)

100%|██████████| 4/4 [00:08<00:00,  2.07s/it]


{'loss': tensor(0.0057, device='cuda:0', grad_fn=<MeanBackward0>),
 'recon_loss': tensor(0.0052, device='cuda:0'),
 'rqvae_loss': tensor(0.0005, device='cuda:0'),
 'unique/0': 256,
 'unique/1': 256,
 'unique/2': 256,
 'unique/3': 256}

In [5]:
def get_cb_tuples(embeddings):
    ind_lists = []
    for cb in rqvae.codebooks:
        dist = torch.cdist(rqvae.encoder(embeddings), cb)
        ind_lists.append(dist.argmin(dim=-1).cpu().numpy())

    return zip(*ind_lists)


def search_similar_items(items_with_tuples, clust2search, max_cnt=5):
    random.shuffle(items_with_tuples)
    cnt = 0
    similars = []
    for item, clust_tuple in items_with_tuples:
        if clust_tuple[: len(clust2search)] == clust2search:
            similars.append((item, clust_tuple))
            cnt += 1
        if cnt >= max_cnt:
            return similars
    return similars

In [6]:
cb_tuples = get_cb_tuples(embs_dict["embedding"])
items_with_tuples = list(zip(df["title"], cb_tuples))

In [10]:
for i in range(220, 230):
    sim = search_similar_items(items_with_tuples, (i,), 10)
    if len(sim) == 0:
        continue
    print(i)
    for item, clust_tuple in sim:
        print(f"{item=} {clust_tuple=}")

# TODO fix collisisons (remainder = last embedding, auto-increment 4th id)

220
item='Fairy Dust by Paris Hilton for Women - 3.4 Ounce EDP Spray' clust_tuple=(220, 212, 67, 88)
item='D &amp; G Light Blue By Dolce &amp; Gabbana For Men Eau De Toilette Spray, 4.2-Ounces' clust_tuple=(220, 212, 25, 49)
item='Halle Pure Orchid by Halle Berry Eau De Parfum Spray for Women, 1 Ounce' clust_tuple=(220, 212, 25, 88)
item='Armani Code By Giorgio Armani For Men. Eau De Toilette Spray 1.7 Ounces' clust_tuple=(220, 212, 25, 88)
item='Taj Sunset by Escada for Women, Eau de Toilette Spray, 3.4 Ounce' clust_tuple=(220, 212, 25, 200)
item='Tom Ford Black Orchid By Tom Ford For Women. Eau De Parfum Spray 3.4-Ounces' clust_tuple=(220, 212, 67, 88)
item='Beyonce Heat Rush by Beyonce, 3.4 Ounce' clust_tuple=(220, 62, 67, 200)
item='In Control Curious by Britney Spears for Women, Eau De Parfum Spray, 1.7 Ounce' clust_tuple=(220, 212, 25, 49)
item='Lacoste Style In Play By Lacoste For Men. Eau De Toilette Spray 1.6 Ounces' clust_tuple=(220, 212, 25, 88)
item='Very Irresistible Sensu

In [None]:
# 1 2 3 0
# 1 2 3 1
# 4 5 6 0/2
# 4 5 6 1/3

# Research last index aggregation

# 1) last index = KMeans(last residuals, n=|last codebook|) - collision
# 2) auto increment last index (check paper)
# 3) decoder
# 4) [(1 2 3), (1 2 3)] single item -> ok
# 4.1) several -> get embeddings -> score. softmax(collisions), torch.logsoftmax(logits) -> score -> argmax

In [None]:
# pos emb for item & codebook (000 111 222) - item
# codebook (012 012 012)
# splitting item ?

In [50]:
torch.save(df, "../data/df_with_embs.pt")

In [None]:
!ls -lh ../data