From 31837e423ef3cf842ea285e96ff637e397201c90 Mon Sep 17 00:00:00 2001 From: xzjin Date: Fri, 20 Oct 2023 14:41:49 +0800 Subject: [PATCH 1/4] FEA: add DirectAU --- .../model/general_recommender/__init__.py | 1 + .../model/general_recommender/directau.py | 120 ++++++++++++++++++ recbole_gnn/properties/model/DirectAU.yaml | 7 + tests/test_model.py | 8 +- 4 files changed, 135 insertions(+), 1 deletion(-) create mode 100644 recbole_gnn/model/general_recommender/directau.py create mode 100644 recbole_gnn/properties/model/DirectAU.yaml diff --git a/recbole_gnn/model/general_recommender/__init__.py b/recbole_gnn/model/general_recommender/__init__.py index 2841ffb..d687d32 100644 --- a/recbole_gnn/model/general_recommender/__init__.py +++ b/recbole_gnn/model/general_recommender/__init__.py @@ -6,3 +6,4 @@ from recbole_gnn.model.general_recommender.lightgcl import LightGCL from recbole_gnn.model.general_recommender.simgcl import SimGCL from recbole_gnn.model.general_recommender.xsimgcl import XSimGCL +from recbole_gnn.model.general_recommender.directau import DirectAU diff --git a/recbole_gnn/model/general_recommender/directau.py b/recbole_gnn/model/general_recommender/directau.py new file mode 100644 index 0000000..3fc3be5 --- /dev/null +++ b/recbole_gnn/model/general_recommender/directau.py @@ -0,0 +1,120 @@ +# r""" +# DiretAU +# ################################################ +# Reference: +# Chenyang Wang et al. "Towards Representation Alignment and Uniformity in Collaborative Filtering." in KDD 2022. + +# Reference code: +# https://github.com/THUwangcy/DirectAU +# """ + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from recbole.model.init import xavier_normal_initialization +from recbole.utils import InputType +from recbole.model.general_recommender import BPR +from recbole_gnn.model.general_recommender import LightGCN + +from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender + + +class DirectAU(GeneralGraphRecommender): + input_type = InputType.POINTWISE + + def __init__(self, config, dataset): + super(DirectAU, self).__init__(config, dataset) + + # load parameters info + self.embedding_size = config['embedding_size'] + self.gamma = config['gamma'] + self.encoder_name = config['encoder'] + + # define encoder + if self.encoder_name == 'MF': + self.encoder = MFEncoder(config, dataset) + elif self.encoder_name == 'LightGCN': + self.encoder = LGCNEncoder(config, dataset) + else: + raise ValueError('Non-implemented Encoder.') + + # storage variables for full sort evaluation acceleration + self.restore_user_e = None + self.restore_item_e = None + + # parameters initialization + self.apply(xavier_normal_initialization) + + def forward(self, user, item): + user_e, item_e = self.encoder(user, item) + return F.normalize(user_e, dim=-1), F.normalize(item_e, dim=-1) + + @staticmethod + def alignment(x, y, alpha=2): + return (x - y).norm(p=2, dim=1).pow(alpha).mean() + + @staticmethod + def uniformity(x, t=2): + return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log() + + def calculate_loss(self, interaction): + if self.restore_user_e is not None or self.restore_item_e is not None: + self.restore_user_e, self.restore_item_e = None, None + + user = interaction[self.USER_ID] + item = interaction[self.ITEM_ID] + + user_e, item_e = self.forward(user, item) + align = self.alignment(user_e, item_e) + uniform = self.gamma * (self.uniformity(user_e) + self.uniformity(item_e)) / 2 + + return align, uniform + + def predict(self, interaction): + user = interaction[self.USER_ID] + item = interaction[self.ITEM_ID] + user_e = self.user_embedding(user) + item_e = self.item_embedding(item) + return torch.mul(user_e, item_e).sum(dim=1) + + def full_sort_predict(self, interaction): + user = interaction[self.USER_ID] + if self.encoder_name == 'LightGCN': + if self.restore_user_e is None or self.restore_item_e is None: + self.restore_user_e, self.restore_item_e = self.encoder.get_all_embeddings() + user_e = self.restore_user_e[user] + all_item_e = self.restore_item_e + else: + user_e = self.encoder.user_embedding(user) + all_item_e = self.encoder.item_embedding.weight + score = torch.matmul(user_e, all_item_e.transpose(0, 1)) + return score.view(-1) + + +class MFEncoder(BPR): + def __init__(self, config, dataset): + super(MFEncoder, self).__init__(config, dataset) + + def forward(self, user_id, item_id): + return super().forward(user_id, item_id) + + def get_all_embeddings(self): + user_embeddings = self.user_embedding.weight + item_embeddings = self.item_embedding.weight + return user_embeddings, item_embeddings + + +class LGCNEncoder(LightGCN): + def __init__(self, config, dataset): + super(LGCNEncoder, self).__init__(config, dataset) + + def forward(self, user_id, item_id): + user_all_embeddings, item_all_embeddings = self.get_all_embeddings() + u_embed = user_all_embeddings[user_id] + i_embed = item_all_embeddings[item_id] + return u_embed, i_embed + + def get_all_embeddings(self): + return super().forward() diff --git a/recbole_gnn/properties/model/DirectAU.yaml b/recbole_gnn/properties/model/DirectAU.yaml new file mode 100644 index 0000000..ba704f3 --- /dev/null +++ b/recbole_gnn/properties/model/DirectAU.yaml @@ -0,0 +1,7 @@ +embedding_size: 64 +encoder: "MF" # "MF" or "lightGCN" +gamma: 0.5 +weight_decay: 1e-6 +train_batch_size: 256 + +# n_layers: 3 # needed for LightGCN \ No newline at end of file diff --git a/tests/test_model.py b/tests/test_model.py index ebb04b7..cbdc2a3 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -60,7 +60,7 @@ def test_simgcl(self): 'model': 'SimGCL' } quick_test(config_dict) - + def test_xsimgcl(self): config_dict = { 'model': 'XSimGCL' @@ -73,6 +73,12 @@ def test_lightgcl(self): } quick_test(config_dict) + def test_directau(self): + config_dict = { + 'model': 'DirectAU' + } + quick_test(config_dict) + class TestSequentialRecommender(unittest.TestCase): def test_gru4rec(self): From 393927d87867b291eb52b611c14a5bfc8a4d1831 Mon Sep 17 00:00:00 2001 From: xzjin Date: Fri, 20 Oct 2023 14:44:30 +0800 Subject: [PATCH 2/4] FIX: fix hypertuning --- recbole_gnn/quick_start.py | 1 + 1 file changed, 1 insertion(+) diff --git a/recbole_gnn/quick_start.py b/recbole_gnn/quick_start.py index 0f61812..51897f6 100644 --- a/recbole_gnn/quick_start.py +++ b/recbole_gnn/quick_start.py @@ -80,6 +80,7 @@ def objective_function(config_dict=None, config_file_list=None, saved=True): test_result = trainer.evaluate(test_data, load_best_model=saved) return { + 'model': config['model'], 'best_valid_score': best_valid_score, 'valid_score_bigger': config['valid_metric_bigger'], 'best_valid_result': best_valid_result, From f7624600250927b8d2e4f030d99645f87a1ae562 Mon Sep 17 00:00:00 2001 From: xzjin Date: Fri, 20 Oct 2023 15:05:01 +0800 Subject: [PATCH 3/4] FIX: fix dataset saving dir error associated with recbole-1.1.1 --- recbole_gnn/data/dataset.py | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/recbole_gnn/data/dataset.py b/recbole_gnn/data/dataset.py index c9680c5..9dd3f7b 100644 --- a/recbole_gnn/data/dataset.py +++ b/recbole_gnn/data/dataset.py @@ -6,15 +6,33 @@ from tqdm import tqdm from torch_geometric.utils import degree + from recbole.data.dataset import SequentialDataset from recbole.data.dataset import Dataset as RecBoleDataset from recbole.utils import set_color, FeatureSource +import recbole +import pickle +from recbole.utils import ensure_dir + class GeneralGraphDataset(RecBoleDataset): def __init__(self, config): super().__init__(config) + if recbole.__version__ == "1.1.1": + + def save(self): + """Saving this :class:`Dataset` object to :attr:`config['checkpoint_dir']`.""" + save_dir = self.config["checkpoint_dir"] + ensure_dir(save_dir) + file = os.path.join(save_dir, f'{self.config["dataset"]}-{self.__class__.__name__}.pth') + self.logger.info( + set_color("Saving filtered dataset into ", "pink") + f"[{file}]" + ) + with open(file, "wb") as f: + pickle.dump(self, f) + def get_norm_adj_mat(self): r"""Get the normalized interaction matrix of users and items. Construct the square matrix from the training data and normalize it @@ -101,6 +119,7 @@ def build(self): dataset.session_graph_construction() return datasets + class MultiBehaviorDataset(SessionGraphDataset): def session_graph_construction(self): @@ -113,7 +132,7 @@ def session_graph_construction(self): # To be compatible with existing datasets item_behavior_seq = torch.tensor([0] * len(item_seq)) self.behavior_id_field = 'behavior_id' - self.field2id_token[self.behavior_id_field] = {0:'interaction'} + self.field2id_token[self.behavior_id_field] = {0: 'interaction'} else: item_behavior_seq = self.inter_feat[self.item_list_length_field] @@ -152,6 +171,7 @@ def session_graph_construction(self): 'alias_inputs': alias_inputs } + class LESSRDataset(SessionGraphDataset): def session_graph_construction(self): self.logger.info('Constructing LESSR session graphs.') @@ -199,14 +219,14 @@ def reverse_session(self): item_seq = self.inter_feat[self.item_id_list_field] item_seq_len = self.inter_feat[self.item_list_length_field] for i in tqdm(range(item_seq.shape[0])): - item_seq[i,:item_seq_len[i]] = item_seq[i,:item_seq_len[i]].flip(dims=[0]) + item_seq[i, :item_seq_len[i]] = item_seq[i, :item_seq_len[i]].flip(dims=[0]) def bidirectional_edge(self, edge_index): seq_len = edge_index.shape[1] ed = edge_index.T ed2 = edge_index.T.flip(dims=[1]) idc = ed.unsqueeze(1).expand(-1, seq_len, 2) == ed2.unsqueeze(0).expand(seq_len, -1, 2) - return torch.logical_and(idc[:,:,0], idc[:,:,1]).any(dim=-1) + return torch.logical_and(idc[:, :, 0], idc[:, :, 1]).any(dim=-1) def session_graph_construction(self): self.logger.info('Constructing session graphs.') @@ -276,9 +296,10 @@ class SocialDataset(GeneralGraphDataset): net_feat (pandas.DataFrame): Internal data structure stores the users' social network relations. It's loaded from file ``.net``. """ + def __init__(self, config): super().__init__(config) - + def _get_field_from_config(self): super()._get_field_from_config() @@ -410,4 +431,4 @@ def net_matrix(self, form='coo', value_field=None): Returns: scipy.sparse: Sparse matrix in form ``coo`` or ``csr``. """ - return self._create_sparse_matrix(self.net_feat, self.net_src_field, self.net_tgt_field, form, value_field) \ No newline at end of file + return self._create_sparse_matrix(self.net_feat, self.net_src_field, self.net_tgt_field, form, value_field) From c6c28cee32a1d4867ab66326ada4c8c538f90f2e Mon Sep 17 00:00:00 2001 From: xzjin Date: Mon, 23 Oct 2023 11:02:08 +0800 Subject: [PATCH 4/4] Fix: fix bug in DirectAU --- recbole_gnn/model/general_recommender/directau.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recbole_gnn/model/general_recommender/directau.py b/recbole_gnn/model/general_recommender/directau.py index 3fc3be5..99ff5da 100644 --- a/recbole_gnn/model/general_recommender/directau.py +++ b/recbole_gnn/model/general_recommender/directau.py @@ -22,7 +22,7 @@ class DirectAU(GeneralGraphRecommender): - input_type = InputType.POINTWISE + input_type = InputType.PAIRWISE def __init__(self, config, dataset): super(DirectAU, self).__init__(config, dataset)