diff --git a/recbole_gnn/data/dataset.py b/recbole_gnn/data/dataset.py index 9dd3f7b..11273a8 100644 --- a/recbole_gnn/data/dataset.py +++ b/recbole_gnn/data/dataset.py @@ -4,8 +4,13 @@ import pandas as pd from tqdm import tqdm +from torch_geometric.nn.conv.gcn_conv import gcn_norm from torch_geometric.utils import degree - +try: + from torch_sparse import SparseTensor + is_sparse = True +except ImportError: + is_sparse = False from recbole.data.dataset import SequentialDataset from recbole.data.dataset import Dataset as RecBoleDataset @@ -33,7 +38,16 @@ def save(self): with open(file, "wb") as f: pickle.dump(self, f) - def get_norm_adj_mat(self): + @staticmethod + def edge_index_to_adj_t(edge_index, edge_weight, m_num_nodes, n_num_nodes): + adj = SparseTensor(row=edge_index[0], + col=edge_index[1], + value=edge_weight, + sparse_sizes=(m_num_nodes, n_num_nodes)) + return adj.t() + + def get_norm_adj_mat(self, enable_sparse=False): + self.is_sparse = is_sparse r"""Get the normalized interaction matrix of users and items. Construct the square matrix from the training data and normalize it using the laplace matrix. @@ -48,11 +62,19 @@ def get_norm_adj_mat(self): edge_index1 = torch.stack([row, col]) edge_index2 = torch.stack([col, row]) edge_index = torch.cat([edge_index1, edge_index2], dim=1) - - deg = degree(edge_index[0], self.user_num + self.item_num) - - norm_deg = 1. / torch.sqrt(torch.where(deg == 0, torch.ones([1]), deg)) - edge_weight = norm_deg[edge_index[0]] * norm_deg[edge_index[1]] + edge_weight = torch.ones(edge_index.size(1)) + num_nodes = self.user_num + self.item_num + + if enable_sparse: + if not is_sparse: + self.logger.warning( + "Import `torch_sparse` error, please install corrsponding version of `torch_sparse`. Now we will use dense edge_index instead of SparseTensor in dataset.") + else: + adj_t = self.edge_index_to_adj_t(edge_index, edge_weight, num_nodes, num_nodes) + adj_t = gcn_norm(adj_t, None, num_nodes, add_self_loops=False) + return adj_t, None + + edge_index, edge_weight = gcn_norm(edge_index, edge_weight, num_nodes, add_self_loops=False) return edge_index, edge_weight diff --git a/recbole_gnn/model/abstract_recommender.py b/recbole_gnn/model/abstract_recommender.py index 088f0c6..236bf53 100644 --- a/recbole_gnn/model/abstract_recommender.py +++ b/recbole_gnn/model/abstract_recommender.py @@ -12,8 +12,12 @@ class GeneralGraphRecommender(GeneralRecommender): def __init__(self, config, dataset): super(GeneralGraphRecommender, self).__init__(config, dataset) - self.edge_index, self.edge_weight = dataset.get_norm_adj_mat() - self.edge_index, self.edge_weight = self.edge_index.to(self.device), self.edge_weight.to(self.device) + self.edge_index, self.edge_weight = dataset.get_norm_adj_mat(enable_sparse=config["enable_sparse"]) + self.use_sparse = config["enable_sparse"] and dataset.is_sparse + if self.use_sparse: + self.edge_index, self.edge_weight = self.edge_index.to(self.device), None + else: + self.edge_index, self.edge_weight = self.edge_index.to(self.device), self.edge_weight.to(self.device) class SocialRecommender(GeneralRecommender): @@ -23,4 +27,4 @@ class SocialRecommender(GeneralRecommender): type = ModelType.SOCIAL def __init__(self, config, dataset): - super(SocialRecommender, self).__init__(config, dataset) \ No newline at end of file + super(SocialRecommender, self).__init__(config, dataset) diff --git a/recbole_gnn/model/general_recommender/lightgcn.py b/recbole_gnn/model/general_recommender/lightgcn.py index 4d3ad7c..525a8da 100644 --- a/recbole_gnn/model/general_recommender/lightgcn.py +++ b/recbole_gnn/model/general_recommender/lightgcn.py @@ -130,4 +130,4 @@ def full_sort_predict(self, interaction): # dot with all item embedding to accelerate scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1)) - return scores.view(-1) \ No newline at end of file + return scores.view(-1) diff --git a/recbole_gnn/model/general_recommender/ngcf.py b/recbole_gnn/model/general_recommender/ngcf.py index f243fbe..7ede56e 100644 --- a/recbole_gnn/model/general_recommender/ngcf.py +++ b/recbole_gnn/model/general_recommender/ngcf.py @@ -74,7 +74,17 @@ def forward(self): if self.node_dropout == 0: edge_index, edge_weight = self.edge_index, self.edge_weight else: - edge_index, edge_weight = dropout_adj(edge_index=self.edge_index, edge_attr=self.edge_weight, p=self.node_dropout) + edge_index, edge_weight = self.edge_index, self.edge_weight + if self.use_sparse: + row, col, edge_weight = edge_index.t().coo() + edge_index = torch.stack([row, col], 0) + edge_index, edge_weight = dropout_adj(edge_index=edge_index, edge_attr=edge_weight, + p=self.node_dropout, training=self.training) + from torch_sparse import SparseTensor + edge_index = SparseTensor(row=edge_index[0], col=edge_index[1], value=edge_weight, + sparse_sizes=(self.n_users + self.n_items, self.n_users + self.n_items)) + edge_index = edge_index.t() + edge_weight = None all_embeddings = self.get_ego_embeddings() embeddings_list = [all_embeddings] diff --git a/recbole_gnn/model/general_recommender/sgl.py b/recbole_gnn/model/general_recommender/sgl.py index a5377ba..985feea 100644 --- a/recbole_gnn/model/general_recommender/sgl.py +++ b/recbole_gnn/model/general_recommender/sgl.py @@ -16,6 +16,7 @@ import torch import torch.nn.functional as F from torch_geometric.utils import degree +from torch_geometric.nn.conv.gcn_conv import gcn_norm from recbole.model.init import xavier_uniform_initialization from recbole.model.loss import EmbLoss @@ -53,6 +54,7 @@ def __init__(self, config, dataset): self._user = dataset.inter_feat[dataset.uid_field] self._item = dataset.inter_feat[dataset.iid_field] + self.dataset = dataset # define layers and loss self.user_embedding = torch.nn.Embedding(self.n_users, self.latent_dim) @@ -111,10 +113,15 @@ def rand_sample(high, size=None, replace=True): edge_index1 = torch.stack([row, col]) edge_index2 = torch.stack([col, row]) edge_index = torch.cat([edge_index1, edge_index2], dim=1) + edge_weight = torch.ones(edge_index.size(1)) + num_nodes = self.n_users + self.n_items - deg = degree(edge_index[0], self.n_users + self.n_items) - norm_deg = 1. / torch.sqrt(torch.where(deg == 0, torch.ones([1]), deg)) - edge_weight = norm_deg[edge_index[0]] * norm_deg[edge_index[1]] + if self.use_sparse: + adj_t = self.dataset.edge_index_to_adj_t(edge_index, edge_weight, num_nodes, num_nodes) + adj_t = gcn_norm(adj_t, None, num_nodes, add_self_loops=False) + return adj_t.to(self.device), None + + edge_index, edge_weight = gcn_norm(edge_index, edge_weight, num_nodes, add_self_loops=False) return edge_index.to(self.device), edge_weight.to(self.device) diff --git a/recbole_gnn/model/layers.py b/recbole_gnn/model/layers.py index f8697df..5584a69 100644 --- a/recbole_gnn/model/layers.py +++ b/recbole_gnn/model/layers.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn from torch_geometric.nn import MessagePassing -from torch_geometric.utils import add_self_loops +from torch_sparse import matmul class LightGCNConv(MessagePassing): @@ -16,6 +16,9 @@ def forward(self, x, edge_index, edge_weight): def message(self, x_j, edge_weight): return edge_weight.view(-1, 1) * x_j + def message_and_aggregate(self, adj_t, x): + return matmul(adj_t, x, reduce=self.aggr) + def __repr__(self): return '{}({})'.format(self.__class__.__name__, self.dim) @@ -41,6 +44,7 @@ class BiGNNConv(MessagePassing): .. math:: output = (L+I)EW_1 + LE \otimes EW_2 """ + def __init__(self, in_channels, out_channels): super().__init__(aggr='add') self.in_channels, self.out_channels = in_channels, out_channels @@ -56,6 +60,9 @@ def forward(self, x, edge_index, edge_weight): def message(self, x_j, edge_weight): return edge_weight.view(-1, 1) * x_j + def message_and_aggregate(self, adj_t, x): + return matmul(adj_t, x, reduce=self.aggr) + def __repr__(self): return '{}({},{})'.format(self.__class__.__name__, self.in_channels, self.out_channels) diff --git a/recbole_gnn/quick_start.py b/recbole_gnn/quick_start.py index 51897f6..712c070 100644 --- a/recbole_gnn/quick_start.py +++ b/recbole_gnn/quick_start.py @@ -18,6 +18,10 @@ def run_recbole_gnn(model=None, dataset=None, config_file_list=None, config_dict """ # configurations initialization config = Config(model=model, dataset=dataset, config_file_list=config_file_list, config_dict=config_dict) + try: + assert config["enable_sparse"] in [True, False, None] + except AssertionError: + raise ValueError("Your config `enable_sparse` must be `True` or `False` or `None`") init_seed(config['seed'], config['reproducibility']) # logger initialization init_logger(config) @@ -69,6 +73,10 @@ def objective_function(config_dict=None, config_file_list=None, saved=True): """ config = Config(config_dict=config_dict, config_file_list=config_file_list) + try: + assert config["enable_sparse"] in [True, False, None] + except AssertionError: + raise ValueError("Your config `enable_sparse` must be `True` or `False` or `None`") init_seed(config['seed'], config['reproducibility']) logging.basicConfig(level=logging.ERROR) dataset = create_dataset(config)