In [1]:
import json
import logging

import numpy as np
import torch
from time import time
from torch import optim
from tqdm import tqdm

from torch.utils.data import DataLoader

from embdatasets import EmbDataset
from models.rqvae import RQVAE
from trainer import Trainer
import os

In [5]:

dataset = "All_Beauty"
# ckpt_path = "/mnt/zhengbowen/rqvae_ckpt/BERT/mean-nokm/best_collision_model.pth"
# ckpt_path = "/mnt/zhengbowen/rqvae_ckpt/LLaMA/32d-nosk/best_collision_model.pth"
# ckpt_path = "/mnt/zhengbowen/rqvae_ckpt/LLaMA/arts-nosk/epoch_4749_collision_0.0568_model.pth"
# ckpt_path = "./results/Jul-03-2024_12-44-34/epoch_4999_collision_0.0202_model.pth"
# ckpt_path = "./results/Jul-03-2024_13-10-09/epoch_4899_collision_0.0495_model.pth"
ckpt_path = "./results_sk/Jul-04-2024_15-13-40/epoch_4949_collision_0.0888_model.pth"
output_dir = f"./ID_generation/preprocessing/processed/{dataset}/"
# output_file = "Games.bertindex.json"
output_file = f"{dataset}.index.json"
output_file = os.path.join(output_dir,output_file)
device = torch.device("cuda:1")

ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
args = ckpt["args"]
print(args)
state_dict = ckpt["state_dict"]


data = EmbDataset(args.data_path)

model = RQVAE(in_dim=data.dim,
                  num_emb_list=args.num_emb_list,
                  e_dim=args.e_dim,
                  layers=args.layers,
                  dropout_prob=args.dropout_prob,
                  bn=args.bn,
                  loss_type=args.loss_type,
                  quant_loss_weight=args.quant_loss_weight,
                  kmeans_init=args.kmeans_init,
                  kmeans_iters=args.kmeans_iters,
                  sk_epsilons=args.sk_epsilons,
                  sk_iters=args.sk_iters,
                  )

model.load_state_dict(state_dict)
model = model.to(device)
model.eval()
# print(model)

data_loader = DataLoader(data,num_workers=1,
                             batch_size=1, shuffle=False,
                             pin_memory=True)
# print(data[0])
# print(data_loader[0][0])

indices_count = {}
all_indices = []
prefix = ["<a_{}>","<b_{}>","<c_{}>","<d_{}>","<e_{}>"]

for d in tqdm(data_loader):
    d = d.to(device)
    indices = model.get_indices(d,use_sk=False)
    print(d)
    indices = indices.view(-1, indices.shape[-1]).cpu().numpy()
    for index in indices:
        code = []
        for i, ind in enumerate(index):
            code.append(prefix[i].format(int(ind)))
        code_str = str(code)
        # print(code_str)
        if code_str in indices_count:
            code.append(prefix[-1].format(int(indices_count[code_str])))
            indices_count[code_str] += 1
        else:
            code.append(prefix[-1].format(int(0)))
            indices_count[code_str] = 1

        all_indices.append(code)
    break




Namespace(lr=0.001, epochs=5000, batch_size=1024, num_workers=4, eval_step=50, learner='AdamW', data_path='./ID_generation/preprocessing/processed/All_Beauty/All_Beauty.embeddings.npy', weight_decay=0.0001, dropout_prob=0.0, bn=False, loss_type='mse', kmeans_init=True, kmeans_iters=100, sk_epsilons=[0.0, 0.0, 0.003], sk_iters=50, device='cuda:0', num_emb_list=[256, 256, 256], e_dim=32, quant_loss_weight=1.0, layers=[2048, 1024, 512, 256, 128, 64], ckpt_dir='./results_sk')


  0%|          | 0/43978 [00:11<?, ?it/s]

tensor([[-1.9190e-02, -3.0011e-02,  1.3549e-02,  2.3046e-02, -1.2068e-02,
         -1.2244e-02, -3.1564e-02,  4.9798e-02,  5.1205e-02, -4.9068e-02,
          6.7359e-02, -7.0621e-02,  7.8102e-02,  1.8242e-02,  7.9031e-02,
          1.1171e-02,  5.8734e-02, -1.5851e-02,  3.4447e-03,  3.0281e-02,
          6.6221e-03, -5.8023e-02, -6.2288e-02,  4.4407e-03, -2.5553e-02,
          4.2189e-02,  3.9598e-02,  9.7310e-03, -5.8534e-02, -2.0238e-02,
          4.2026e-02,  4.5803e-03, -2.8158e-02,  1.0660e-01, -2.2199e-02,
          2.5602e-02, -4.3362e-03, -5.4305e-02, -1.1608e-02,  4.5118e-02,
          2.1546e-02, -5.3732e-02, -1.8426e-02,  6.4713e-02, -3.1941e-02,
          1.9826e-02,  9.3003e-03,  3.6995e-02, -4.1563e-02, -3.4768e-02,
         -7.5662e-02, -2.3196e-02, -4.2653e-02, -5.0087e-02, -2.9825e-02,
          2.8402e-03, -3.7418e-02,  2.0221e-02,  3.0557e-02, -2.3578e-02,
         -6.7398e-02,  3.0314e-04, -2.7991e-02, -9.0706e-03, -3.6811e-02,
          3.9954e-04, -2.0114e-02, -5.




In [3]:
# open ./RQ-VAE/ID_generation/preprocessing/processed/All_Beauty/All_Beauty.embeddings.npy 
embeddings = np.load(f"./ID_generation/preprocessing/processed/{dataset}/{dataset}.embeddings.npy")

In [4]:
embeddings[0]

array([-1.91897824e-02, -3.00107710e-02,  1.35494350e-02,  2.30460223e-02,
       -1.20679503e-02, -1.22436415e-02, -3.15639377e-02,  4.97984737e-02,
        5.12053519e-02, -4.90678363e-02,  6.73594326e-02, -7.06207827e-02,
        7.81023577e-02,  1.82422176e-02,  7.90307149e-02,  1.11709423e-02,
        5.87341636e-02, -1.58514753e-02,  3.44469002e-03,  3.02811619e-02,
        6.62207045e-03, -5.80227636e-02, -6.22883253e-02,  4.44067642e-03,
       -2.55533066e-02,  4.21887152e-02,  3.95983271e-02,  9.73097980e-03,
       -5.85339032e-02, -2.02379227e-02,  4.20258790e-02,  4.58025420e-03,
       -2.81580929e-02,  1.06597953e-01, -2.21989173e-02,  2.56021731e-02,
       -4.33616666e-03, -5.43050170e-02, -1.16084376e-02,  4.51182276e-02,
        2.15455145e-02, -5.37323803e-02, -1.84258986e-02,  6.47126883e-02,
       -3.19410972e-02,  1.98259577e-02,  9.30027943e-03,  3.69949639e-02,
       -4.15625460e-02, -3.47676501e-02, -7.56622329e-02, -2.31959280e-02,
       -4.26531695e-02, -

In [None]:
# print(all_indices)
print("All indices number: ",len(all_indices))
print("Max number of conflicts: ", max(indices_count.values()))

all_indices_dict = {}
for item, indices in enumerate(all_indices):
    all_indices_dict[item] = indices


ss =set()
for key in all_indices_dict:
    for t in all_indices_dict[key]:
        ss.add(t)
print(len(list(ss)))


with open(output_file, 'w') as fp:
    json.dump(all_indices_dict,fp)

# print(all_indices_dict[1])
# print(all_indices_dict[2])
# print(all_indices_dict[0])