From 08b12a0675b363aa94c1d566c042450b14b5f1e7 Mon Sep 17 00:00:00 2001 From: xzjin Date: Mon, 30 Oct 2023 01:06:20 +0800 Subject: [PATCH 1/2] Fix: fix bug in ngcf of drop_adj --- recbole_gnn/model/general_recommender/ngcf.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/recbole_gnn/model/general_recommender/ngcf.py b/recbole_gnn/model/general_recommender/ngcf.py index 7ede56e..37930c5 100644 --- a/recbole_gnn/model/general_recommender/ngcf.py +++ b/recbole_gnn/model/general_recommender/ngcf.py @@ -85,6 +85,9 @@ def forward(self): sparse_sizes=(self.n_users + self.n_items, self.n_users + self.n_items)) edge_index = edge_index.t() edge_weight = None + else: + edge_index, edge_weight = dropout_adj(edge_index=edge_index, edge_attr=edge_weight, + p=self.node_dropout, training=self.training) all_embeddings = self.get_ego_embeddings() embeddings_list = [all_embeddings] From 576c7580cb3d6e19773f00c4e4cc548b543036b9 Mon Sep 17 00:00:00 2001 From: xzjin Date: Mon, 30 Oct 2023 01:07:16 +0800 Subject: [PATCH 2/2] FEA: add SSL4REC --- .../model/general_recommender/__init__.py | 1 + .../model/general_recommender/ssl4rec.py | 163 ++++++++++++++++++ recbole_gnn/properties/model/SSL4REC.yaml | 6 + tests/test_model.py | 6 + 4 files changed, 176 insertions(+) create mode 100644 recbole_gnn/model/general_recommender/ssl4rec.py create mode 100644 recbole_gnn/properties/model/SSL4REC.yaml diff --git a/recbole_gnn/model/general_recommender/__init__.py b/recbole_gnn/model/general_recommender/__init__.py index d687d32..8f8553c 100644 --- a/recbole_gnn/model/general_recommender/__init__.py +++ b/recbole_gnn/model/general_recommender/__init__.py @@ -7,3 +7,4 @@ from recbole_gnn.model.general_recommender.simgcl import SimGCL from recbole_gnn.model.general_recommender.xsimgcl import XSimGCL from recbole_gnn.model.general_recommender.directau import DirectAU +from recbole_gnn.model.general_recommender.ssl4rec import SSL4REC diff --git a/recbole_gnn/model/general_recommender/ssl4rec.py b/recbole_gnn/model/general_recommender/ssl4rec.py new file mode 100644 index 0000000..edd305e --- /dev/null +++ b/recbole_gnn/model/general_recommender/ssl4rec.py @@ -0,0 +1,163 @@ +r""" +SSL4REC +################################################ +Reference: + Tiansheng Yao et al. "Self-supervised Learning for Large-scale Item Recommendations." in CIKM 2021. + +Reference code: + https://github.com/Coder-Yu/SELFRec/model/graph/SSL4Rec.py +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from recbole.model.loss import EmbLoss +from recbole.utils import InputType + +from recbole.model.init import xavier_uniform_initialization +from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender + + +class SSL4REC(GeneralGraphRecommender): + input_type = InputType.PAIRWISE + + def __init__(self, config, dataset): + super(SSL4REC, self).__init__(config, dataset) + + # load parameters info + self.tau = config["tau"] + self.reg_weight = config["reg_weight"] + self.cl_rate = config["ssl_weight"] + self.require_pow = config["require_pow"] + + self.reg_loss = EmbLoss() + + self.encoder = DNN_Encoder(config, dataset) + + # storage variables for full sort evaluation acceleration + self.restore_user_e = None + self.restore_item_e = None + + # parameters initialization + self.apply(xavier_uniform_initialization) + self.other_parameter_name = ['restore_user_e', 'restore_item_e'] + + def forward(self, user, item): + user_e, item_e = self.encoder(user, item) + return user_e, item_e + + def calculate_batch_softmax_loss(self, user_emb, item_emb, temperature): + user_emb, item_emb = F.normalize(user_emb, dim=1), F.normalize(item_emb, dim=1) + pos_score = (user_emb * item_emb).sum(dim=-1) + pos_score = torch.exp(pos_score / temperature) + ttl_score = torch.matmul(user_emb, item_emb.transpose(0, 1)) + ttl_score = torch.exp(ttl_score / temperature).sum(dim=1) + loss = -torch.log(pos_score / ttl_score + 10e-6) + return torch.mean(loss) + + def calculate_loss(self, interaction): + # clear the storage variable when training + if self.restore_user_e is not None or self.restore_item_e is not None: + self.restore_user_e, self.restore_item_e = None, None + + user = interaction[self.USER_ID] + pos_item = interaction[self.ITEM_ID] + + user_embeddings, item_embeddings = self.forward(user, pos_item) + + rec_loss = self.calculate_batch_softmax_loss(user_embeddings, item_embeddings, self.tau) + cl_loss = self.encoder.calculate_cl_loss(pos_item) + reg_loss = self.reg_loss(user_embeddings, item_embeddings, require_pow=self.require_pow) + + loss = rec_loss + self.cl_rate * cl_loss + self.reg_weight * reg_loss + + return loss + + def predict(self, interaction): + user = interaction[self.USER_ID] + item = interaction[self.ITEM_ID] + + user_embeddings, item_embeddings = self.forward(user, item) + + u_embeddings = user_embeddings[user] + i_embeddings = item_embeddings[item] + scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1) + return scores + + def full_sort_predict(self, interaction): + user = interaction[self.USER_ID] + if self.restore_user_e is None or self.restore_item_e is None: + self.restore_user_e, self.restore_item_e = self.forward(torch.arange( + self.n_users, device=self.device), torch.arange(self.n_items, device=self.device)) + # get user embedding from storage variable + u_embeddings = self.restore_user_e[user] + + # dot with all item embedding to accelerate + scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1)) + + return scores.view(-1) + + +class DNN_Encoder(nn.Module): + def __init__(self, config, dataset): + super(DNN_Encoder, self).__init__() + + self.emb_size = config["embedding_size"] + self.drop_ratio = config["drop_ratio"] + self.tau = config["tau"] + + self.USER_ID = config["USER_ID_FIELD"] + self.ITEM_ID = config["ITEM_ID_FIELD"] + self.n_users = dataset.num(self.USER_ID) + self.n_items = dataset.num(self.ITEM_ID) + + self.user_tower = nn.Sequential( + nn.Linear(self.emb_size, 1024), + nn.ReLU(True), + nn.Linear(1024, 128), + nn.Tanh() + ) + self.item_tower = nn.Sequential( + nn.Linear(self.emb_size, 1024), + nn.ReLU(True), + nn.Linear(1024, 128), + nn.Tanh() + ) + self.dropout = nn.Dropout(self.drop_ratio) + + self.initial_user_emb = nn.Embedding(self.n_users, self.emb_size) + self.initial_item_emb = nn.Embedding(self.n_items, self.emb_size) + self.reset_parameters() + + def reset_parameters(self): + nn.init.xavier_uniform_(self.initial_user_emb.weight) + nn.init.xavier_uniform_(self.initial_item_emb.weight) + + def forward(self, q, x): + q_emb = self.initial_user_emb(q) + i_emb = self.initial_item_emb(x) + + q_emb = self.user_tower(q_emb) + i_emb = self.item_tower(i_emb) + + return q_emb, i_emb + + def item_encoding(self, x): + i_emb = self.initial_item_emb(x) + i1_emb = self.dropout(i_emb) + i2_emb = self.dropout(i_emb) + + i1_emb = self.item_tower(i1_emb) + i2_emb = self.item_tower(i2_emb) + + return i1_emb, i2_emb + + def calculate_cl_loss(self, idx): + x1, x2 = self.item_encoding(idx) + x1, x2 = F.normalize(x1, dim=-1), F.normalize(x2, dim=-1) + pos_score = (x1 * x2).sum(dim=-1) + pos_score = torch.exp(pos_score / self.tau) + ttl_score = torch.matmul(x1, x2.transpose(0, 1)) + ttl_score = torch.exp(ttl_score / self.tau).sum(dim=1) + return -torch.log(pos_score / ttl_score).mean() diff --git a/recbole_gnn/properties/model/SSL4REC.yaml b/recbole_gnn/properties/model/SSL4REC.yaml new file mode 100644 index 0000000..d249032 --- /dev/null +++ b/recbole_gnn/properties/model/SSL4REC.yaml @@ -0,0 +1,6 @@ +embedding_size: 64 +drop_ratio: 0.1 +tau: 0.1 +reg_weight: 1e-04 +ssl_weight: 1e-05 +require_pow: True \ No newline at end of file diff --git a/tests/test_model.py b/tests/test_model.py index cbdc2a3..ebf578b 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -79,6 +79,12 @@ def test_directau(self): } quick_test(config_dict) + def test_ssl4rec(self): + config_dict = { + 'model': 'SSL4REC' + } + quick_test(config_dict) + class TestSequentialRecommender(unittest.TestCase): def test_gru4rec(self):