In [None]:
import sys

sys.path.append("..")
import argparse
import warnings
from logging import getLogger

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torch.utils.tensorboard import SummaryWriter
from mpl_toolkits.mplot3d import Axes3D
from recbole.utils import init_seed, set_color
from sklearn.decomposition import PCA

from config.configuration import Config
from data.dataset import GeneralDataset, GeneralGraphDataset
from data.utils import data_reparation
from models.embedding import (EmbeddingHelper, EmbeddingModel, EmbeddingType,
                              TemplateType)
from trainer import Trainer
from utils.logger import init_logger
from utils.utils import get_flops, get_model


In [None]:
writer = SummaryWriter("./embeddings")

In [None]:

parser = argparse.ArgumentParser()

parser.add_argument(
    "--dataset", "-d", type=str, default="wsdream-tp", help="name of datasets"
)

parser.add_argument(
    "--model", "-m", type=str, default="XXX", help="name of models"
)

args, _ = parser.parse_known_args()

config = Config(model=args.model, dataset=args.dataset)

dataset = GeneralGraphDataset(config)
train_data, test_data = data_reparation(config, dataset)


In [None]:
dataset.user_feat["country"]

In [None]:
def get_pretrained_embedding(dataset, template_type:TemplateType):
    eh = EmbeddingHelper()
    user_invocations = {}
    item_invocations = {}
    for uid in dataset.uids_in_inter_feat:
        user_invocations[uid] = dataset.inter_data_by_type("user", uid)
    for iid in dataset.iids_in_inter_feat:
        item_invocations[iid] = dataset.inter_data_by_type("item", iid)
    user_embedding = torch.Tensor(eh.fit(EmbeddingType.USER, template_type,
                                    EmbeddingModel.INSTRUCTOR_BGE_SMALL, invocations=user_invocations, auto_save=False))
    item_embedding = torch.Tensor(eh.fit(EmbeddingType.ITEM, template_type,
                                    EmbeddingModel.INSTRUCTOR_BGE_SMALL, invocations=item_invocations, auto_save=False))
    
    # user_embedding = torch.nn.Embedding(
    #         num_embeddings=339, embedding_dim=384).weight
    # item_embedding = torch.nn.Embedding(
    #         num_embeddings=5825, embedding_dim=384).weight
    return user_embedding,item_embedding

u_embedding, i_embedding = get_pretrained_embedding(train_data.dataset, TemplateType.IMPROVED)

uids = list(range(len(u_embedding)))
iids = list(range(len(i_embedding)))

writer.add_embedding(u_embedding, metadata=uids, tag="User Embeddings - IMPROVED")
writer.add_embedding(i_embedding, metadata=iids, tag="Item Embeddings - IMPROVED")


In [None]:
def get_pretrained_embedding(dataset, template_type:TemplateType):
    eh = EmbeddingHelper()
    user_invocations = {}
    item_invocations = {}
    for uid in dataset.uids_in_inter_feat:
        user_invocations[uid] = dataset.inter_data_by_type("user", uid)
    for iid in dataset.iids_in_inter_feat:
        item_invocations[iid] = dataset.inter_data_by_type("item", iid)
    user_embedding = torch.Tensor(eh.fit(EmbeddingType.USER, template_type,
                                    EmbeddingModel.INSTRUCTOR_BGE_SMALL, invocations=user_invocations, auto_save=False))
    item_embedding = torch.Tensor(eh.fit(EmbeddingType.ITEM, template_type,
                                    EmbeddingModel.INSTRUCTOR_BGE_SMALL, invocations=item_invocations, auto_save=False))
    
    return user_embedding, item_embedding

u_embedding, i_embedding = get_pretrained_embedding(train_data.dataset, TemplateType.IMPROVED)

uids = list(dataset.user_feat["country"])
iids = list(dataset.item_feat["country"])

writer.add_embedding(u_embedding, metadata=uids, tag="User Embeddings - IMPROVED_country")
writer.add_embedding(i_embedding, metadata=iids, tag="Item Embeddings - IMPROVED_country")
