Skip to content

Commit

Permalink
Merge pull request #74 from downeykking/main
Browse files Browse the repository at this point in the history
FEA: add DirectAU and fix some bugs
  • Loading branch information
hyp1231 committed Oct 23, 2023
2 parents 5145866 + c6c28ce commit a31626a
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 6 deletions.
31 changes: 26 additions & 5 deletions recbole_gnn/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,33 @@
from tqdm import tqdm
from torch_geometric.utils import degree


from recbole.data.dataset import SequentialDataset
from recbole.data.dataset import Dataset as RecBoleDataset
from recbole.utils import set_color, FeatureSource

import recbole
import pickle
from recbole.utils import ensure_dir


class GeneralGraphDataset(RecBoleDataset):
def __init__(self, config):
super().__init__(config)

if recbole.__version__ == "1.1.1":

def save(self):
"""Saving this :class:`Dataset` object to :attr:`config['checkpoint_dir']`."""
save_dir = self.config["checkpoint_dir"]
ensure_dir(save_dir)
file = os.path.join(save_dir, f'{self.config["dataset"]}-{self.__class__.__name__}.pth')
self.logger.info(
set_color("Saving filtered dataset into ", "pink") + f"[{file}]"
)
with open(file, "wb") as f:
pickle.dump(self, f)

def get_norm_adj_mat(self):
r"""Get the normalized interaction matrix of users and items.
Construct the square matrix from the training data and normalize it
Expand Down Expand Up @@ -101,6 +119,7 @@ def build(self):
dataset.session_graph_construction()
return datasets


class MultiBehaviorDataset(SessionGraphDataset):

def session_graph_construction(self):
Expand All @@ -113,7 +132,7 @@ def session_graph_construction(self):
# To be compatible with existing datasets
item_behavior_seq = torch.tensor([0] * len(item_seq))
self.behavior_id_field = 'behavior_id'
self.field2id_token[self.behavior_id_field] = {0:'interaction'}
self.field2id_token[self.behavior_id_field] = {0: 'interaction'}
else:
item_behavior_seq = self.inter_feat[self.item_list_length_field]

Expand Down Expand Up @@ -152,6 +171,7 @@ def session_graph_construction(self):
'alias_inputs': alias_inputs
}


class LESSRDataset(SessionGraphDataset):
def session_graph_construction(self):
self.logger.info('Constructing LESSR session graphs.')
Expand Down Expand Up @@ -199,14 +219,14 @@ def reverse_session(self):
item_seq = self.inter_feat[self.item_id_list_field]
item_seq_len = self.inter_feat[self.item_list_length_field]
for i in tqdm(range(item_seq.shape[0])):
item_seq[i,:item_seq_len[i]] = item_seq[i,:item_seq_len[i]].flip(dims=[0])
item_seq[i, :item_seq_len[i]] = item_seq[i, :item_seq_len[i]].flip(dims=[0])

def bidirectional_edge(self, edge_index):
seq_len = edge_index.shape[1]
ed = edge_index.T
ed2 = edge_index.T.flip(dims=[1])
idc = ed.unsqueeze(1).expand(-1, seq_len, 2) == ed2.unsqueeze(0).expand(seq_len, -1, 2)
return torch.logical_and(idc[:,:,0], idc[:,:,1]).any(dim=-1)
return torch.logical_and(idc[:, :, 0], idc[:, :, 1]).any(dim=-1)

def session_graph_construction(self):
self.logger.info('Constructing session graphs.')
Expand Down Expand Up @@ -276,9 +296,10 @@ class SocialDataset(GeneralGraphDataset):
net_feat (pandas.DataFrame): Internal data structure stores the users' social network relations.
It's loaded from file ``.net``.
"""

def __init__(self, config):
super().__init__(config)

def _get_field_from_config(self):
super()._get_field_from_config()

Expand Down Expand Up @@ -410,4 +431,4 @@ def net_matrix(self, form='coo', value_field=None):
Returns:
scipy.sparse: Sparse matrix in form ``coo`` or ``csr``.
"""
return self._create_sparse_matrix(self.net_feat, self.net_src_field, self.net_tgt_field, form, value_field)
return self._create_sparse_matrix(self.net_feat, self.net_src_field, self.net_tgt_field, form, value_field)
1 change: 1 addition & 0 deletions recbole_gnn/model/general_recommender/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from recbole_gnn.model.general_recommender.lightgcl import LightGCL
from recbole_gnn.model.general_recommender.simgcl import SimGCL
from recbole_gnn.model.general_recommender.xsimgcl import XSimGCL
from recbole_gnn.model.general_recommender.directau import DirectAU
120 changes: 120 additions & 0 deletions recbole_gnn/model/general_recommender/directau.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# r"""
# DiretAU
# ################################################
# Reference:
# Chenyang Wang et al. "Towards Representation Alignment and Uniformity in Collaborative Filtering." in KDD 2022.

# Reference code:
# https://github.com/THUwangcy/DirectAU
# """

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from recbole.model.init import xavier_normal_initialization
from recbole.utils import InputType
from recbole.model.general_recommender import BPR
from recbole_gnn.model.general_recommender import LightGCN

from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender


class DirectAU(GeneralGraphRecommender):
input_type = InputType.PAIRWISE

def __init__(self, config, dataset):
super(DirectAU, self).__init__(config, dataset)

# load parameters info
self.embedding_size = config['embedding_size']
self.gamma = config['gamma']
self.encoder_name = config['encoder']

# define encoder
if self.encoder_name == 'MF':
self.encoder = MFEncoder(config, dataset)
elif self.encoder_name == 'LightGCN':
self.encoder = LGCNEncoder(config, dataset)
else:
raise ValueError('Non-implemented Encoder.')

# storage variables for full sort evaluation acceleration
self.restore_user_e = None
self.restore_item_e = None

# parameters initialization
self.apply(xavier_normal_initialization)

def forward(self, user, item):
user_e, item_e = self.encoder(user, item)
return F.normalize(user_e, dim=-1), F.normalize(item_e, dim=-1)

@staticmethod
def alignment(x, y, alpha=2):
return (x - y).norm(p=2, dim=1).pow(alpha).mean()

@staticmethod
def uniformity(x, t=2):
return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log()

def calculate_loss(self, interaction):
if self.restore_user_e is not None or self.restore_item_e is not None:
self.restore_user_e, self.restore_item_e = None, None

user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]

user_e, item_e = self.forward(user, item)
align = self.alignment(user_e, item_e)
uniform = self.gamma * (self.uniformity(user_e) + self.uniformity(item_e)) / 2

return align, uniform

def predict(self, interaction):
user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]
user_e = self.user_embedding(user)
item_e = self.item_embedding(item)
return torch.mul(user_e, item_e).sum(dim=1)

def full_sort_predict(self, interaction):
user = interaction[self.USER_ID]
if self.encoder_name == 'LightGCN':
if self.restore_user_e is None or self.restore_item_e is None:
self.restore_user_e, self.restore_item_e = self.encoder.get_all_embeddings()
user_e = self.restore_user_e[user]
all_item_e = self.restore_item_e
else:
user_e = self.encoder.user_embedding(user)
all_item_e = self.encoder.item_embedding.weight
score = torch.matmul(user_e, all_item_e.transpose(0, 1))
return score.view(-1)


class MFEncoder(BPR):
def __init__(self, config, dataset):
super(MFEncoder, self).__init__(config, dataset)

def forward(self, user_id, item_id):
return super().forward(user_id, item_id)

def get_all_embeddings(self):
user_embeddings = self.user_embedding.weight
item_embeddings = self.item_embedding.weight
return user_embeddings, item_embeddings


class LGCNEncoder(LightGCN):
def __init__(self, config, dataset):
super(LGCNEncoder, self).__init__(config, dataset)

def forward(self, user_id, item_id):
user_all_embeddings, item_all_embeddings = self.get_all_embeddings()
u_embed = user_all_embeddings[user_id]
i_embed = item_all_embeddings[item_id]
return u_embed, i_embed

def get_all_embeddings(self):
return super().forward()
7 changes: 7 additions & 0 deletions recbole_gnn/properties/model/DirectAU.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
embedding_size: 64
encoder: "MF" # "MF" or "lightGCN"
gamma: 0.5
weight_decay: 1e-6
train_batch_size: 256

# n_layers: 3 # needed for LightGCN
1 change: 1 addition & 0 deletions recbole_gnn/quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def objective_function(config_dict=None, config_file_list=None, saved=True):
test_result = trainer.evaluate(test_data, load_best_model=saved)

return {
'model': config['model'],
'best_valid_score': best_valid_score,
'valid_score_bigger': config['valid_metric_bigger'],
'best_valid_result': best_valid_result,
Expand Down
8 changes: 7 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_simgcl(self):
'model': 'SimGCL'
}
quick_test(config_dict)

def test_xsimgcl(self):
config_dict = {
'model': 'XSimGCL'
Expand All @@ -73,6 +73,12 @@ def test_lightgcl(self):
}
quick_test(config_dict)

def test_directau(self):
config_dict = {
'model': 'DirectAU'
}
quick_test(config_dict)


class TestSequentialRecommender(unittest.TestCase):
def test_gru4rec(self):
Expand Down

0 comments on commit a31626a

Please sign in to comment.