Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Requesting Guidance on Extracting User and Item Embeddings #19

Open
niluminous opened this issue Apr 29, 2024 · 1 comment
Open

Requesting Guidance on Extracting User and Item Embeddings #19

niluminous opened this issue Apr 29, 2024 · 1 comment

Comments

@niluminous
Copy link

Hello,

I am currently working on predicting ratings by calculating the dot product of user embeddings and item embeddings. I am considering using the Universal Sequence Representation module for user sequences and the MOE-enhanced adaptor for item embeddings.

However, I am unsure about how to properly extract these embeddings in UniSREC code. Could anyone provide some guidance or sample code on how to access these embeddings from the mentioned modules?

Thank you in advance for your help!

@hyp1231
Copy link
Member

hyp1231 commented Jun 16, 2024

Hi, sorry for the late reply!

1. Loading existing checkpoints

Below is an example of loading UniSRec models (either pretrained or fine-tuned).

UniSRec/finetune.py

Lines 37 to 40 in 05aa5cb

checkpoint = torch.load(pretrained_file)
logger.info(f'Loading from {pretrained_file}')
logger.info(f'Transfer [{checkpoint["config"]["dataset"]}] -> [{dataset}]')
model.load_state_dict(checkpoint['state_dict'], strict=False)

2. Mapping external IDs to internal IDs

Here is an example of mapping the external user/item IDs (str, stored in .inter files) to internal IDs (int, remapped IDs stored in RecBole framework).

https://github.com/RUCAIBox/RecBole/blob/2b6e209372a1a666fe7207e6c2a96c7c3d49b427/run_example/case_study_example.py#L25

3. Extracting representations

Please refer to the following code:

UniSRec/unisrec.py

Lines 174 to 187 in 05aa5cb

def full_sort_predict(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
item_emb_list = self.moe_adaptor(self.plm_embedding(item_seq))
seq_output = self.forward(item_seq, item_emb_list, item_seq_len)
test_items_emb = self.moe_adaptor(self.plm_embedding.weight)
if self.train_stage == 'transductive_ft':
test_items_emb = test_items_emb + self.item_embedding.weight
seq_output = F.normalize(seq_output, dim=-1)
test_items_emb = F.normalize(test_items_emb, dim=-1)
scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B n_items]
return scores

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants