Skip to content

Commit

Permalink
Merge pull request #75 from downeykking/main
Browse files Browse the repository at this point in the history
FEA: add sparse tensor support for ngcf_conv and lightgcn_conv
  • Loading branch information
hyp1231 committed Oct 23, 2023
2 parents a31626a + 8b4bf32 commit 979b219
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 16 deletions.
36 changes: 29 additions & 7 deletions recbole_gnn/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@
import pandas as pd

from tqdm import tqdm
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.utils import degree

try:
from torch_sparse import SparseTensor
is_sparse = True
except ImportError:
is_sparse = False

from recbole.data.dataset import SequentialDataset
from recbole.data.dataset import Dataset as RecBoleDataset
Expand Down Expand Up @@ -33,7 +38,16 @@ def save(self):
with open(file, "wb") as f:
pickle.dump(self, f)

def get_norm_adj_mat(self):
@staticmethod
def edge_index_to_adj_t(edge_index, edge_weight, m_num_nodes, n_num_nodes):
adj = SparseTensor(row=edge_index[0],
col=edge_index[1],
value=edge_weight,
sparse_sizes=(m_num_nodes, n_num_nodes))
return adj.t()

def get_norm_adj_mat(self, enable_sparse=False):
self.is_sparse = is_sparse
r"""Get the normalized interaction matrix of users and items.
Construct the square matrix from the training data and normalize it
using the laplace matrix.
Expand All @@ -48,11 +62,19 @@ def get_norm_adj_mat(self):
edge_index1 = torch.stack([row, col])
edge_index2 = torch.stack([col, row])
edge_index = torch.cat([edge_index1, edge_index2], dim=1)

deg = degree(edge_index[0], self.user_num + self.item_num)

norm_deg = 1. / torch.sqrt(torch.where(deg == 0, torch.ones([1]), deg))
edge_weight = norm_deg[edge_index[0]] * norm_deg[edge_index[1]]
edge_weight = torch.ones(edge_index.size(1))
num_nodes = self.user_num + self.item_num

if enable_sparse:
if not is_sparse:
self.logger.warning(
"Import `torch_sparse` error, please install corrsponding version of `torch_sparse`. Now we will use dense edge_index instead of SparseTensor in dataset.")
else:
adj_t = self.edge_index_to_adj_t(edge_index, edge_weight, num_nodes, num_nodes)
adj_t = gcn_norm(adj_t, None, num_nodes, add_self_loops=False)
return adj_t, None

edge_index, edge_weight = gcn_norm(edge_index, edge_weight, num_nodes, add_self_loops=False)

return edge_index, edge_weight

Expand Down
10 changes: 7 additions & 3 deletions recbole_gnn/model/abstract_recommender.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@ class GeneralGraphRecommender(GeneralRecommender):

def __init__(self, config, dataset):
super(GeneralGraphRecommender, self).__init__(config, dataset)
self.edge_index, self.edge_weight = dataset.get_norm_adj_mat()
self.edge_index, self.edge_weight = self.edge_index.to(self.device), self.edge_weight.to(self.device)
self.edge_index, self.edge_weight = dataset.get_norm_adj_mat(enable_sparse=config["enable_sparse"])
self.use_sparse = config["enable_sparse"] and dataset.is_sparse
if self.use_sparse:
self.edge_index, self.edge_weight = self.edge_index.to(self.device), None
else:
self.edge_index, self.edge_weight = self.edge_index.to(self.device), self.edge_weight.to(self.device)


class SocialRecommender(GeneralRecommender):
Expand All @@ -23,4 +27,4 @@ class SocialRecommender(GeneralRecommender):
type = ModelType.SOCIAL

def __init__(self, config, dataset):
super(SocialRecommender, self).__init__(config, dataset)
super(SocialRecommender, self).__init__(config, dataset)
2 changes: 1 addition & 1 deletion recbole_gnn/model/general_recommender/lightgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,4 @@ def full_sort_predict(self, interaction):
# dot with all item embedding to accelerate
scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1))

return scores.view(-1)
return scores.view(-1)
12 changes: 11 additions & 1 deletion recbole_gnn/model/general_recommender/ngcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,17 @@ def forward(self):
if self.node_dropout == 0:
edge_index, edge_weight = self.edge_index, self.edge_weight
else:
edge_index, edge_weight = dropout_adj(edge_index=self.edge_index, edge_attr=self.edge_weight, p=self.node_dropout)
edge_index, edge_weight = self.edge_index, self.edge_weight
if self.use_sparse:
row, col, edge_weight = edge_index.t().coo()
edge_index = torch.stack([row, col], 0)
edge_index, edge_weight = dropout_adj(edge_index=edge_index, edge_attr=edge_weight,
p=self.node_dropout, training=self.training)
from torch_sparse import SparseTensor
edge_index = SparseTensor(row=edge_index[0], col=edge_index[1], value=edge_weight,
sparse_sizes=(self.n_users + self.n_items, self.n_users + self.n_items))
edge_index = edge_index.t()
edge_weight = None

all_embeddings = self.get_ego_embeddings()
embeddings_list = [all_embeddings]
Expand Down
13 changes: 10 additions & 3 deletions recbole_gnn/model/general_recommender/sgl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch
import torch.nn.functional as F
from torch_geometric.utils import degree
from torch_geometric.nn.conv.gcn_conv import gcn_norm

from recbole.model.init import xavier_uniform_initialization
from recbole.model.loss import EmbLoss
Expand Down Expand Up @@ -53,6 +54,7 @@ def __init__(self, config, dataset):

self._user = dataset.inter_feat[dataset.uid_field]
self._item = dataset.inter_feat[dataset.iid_field]
self.dataset = dataset

# define layers and loss
self.user_embedding = torch.nn.Embedding(self.n_users, self.latent_dim)
Expand Down Expand Up @@ -111,10 +113,15 @@ def rand_sample(high, size=None, replace=True):
edge_index1 = torch.stack([row, col])
edge_index2 = torch.stack([col, row])
edge_index = torch.cat([edge_index1, edge_index2], dim=1)
edge_weight = torch.ones(edge_index.size(1))
num_nodes = self.n_users + self.n_items

deg = degree(edge_index[0], self.n_users + self.n_items)
norm_deg = 1. / torch.sqrt(torch.where(deg == 0, torch.ones([1]), deg))
edge_weight = norm_deg[edge_index[0]] * norm_deg[edge_index[1]]
if self.use_sparse:
adj_t = self.dataset.edge_index_to_adj_t(edge_index, edge_weight, num_nodes, num_nodes)
adj_t = gcn_norm(adj_t, None, num_nodes, add_self_loops=False)
return adj_t.to(self.device), None

edge_index, edge_weight = gcn_norm(edge_index, edge_weight, num_nodes, add_self_loops=False)

return edge_index.to(self.device), edge_weight.to(self.device)

Expand Down
9 changes: 8 additions & 1 deletion recbole_gnn/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops
from torch_sparse import matmul


class LightGCNConv(MessagePassing):
Expand All @@ -16,6 +16,9 @@ def forward(self, x, edge_index, edge_weight):
def message(self, x_j, edge_weight):
return edge_weight.view(-1, 1) * x_j

def message_and_aggregate(self, adj_t, x):
return matmul(adj_t, x, reduce=self.aggr)

def __repr__(self):
return '{}({})'.format(self.__class__.__name__, self.dim)

Expand All @@ -41,6 +44,7 @@ class BiGNNConv(MessagePassing):
.. math::
output = (L+I)EW_1 + LE \otimes EW_2
"""

def __init__(self, in_channels, out_channels):
super().__init__(aggr='add')
self.in_channels, self.out_channels = in_channels, out_channels
Expand All @@ -56,6 +60,9 @@ def forward(self, x, edge_index, edge_weight):
def message(self, x_j, edge_weight):
return edge_weight.view(-1, 1) * x_j

def message_and_aggregate(self, adj_t, x):
return matmul(adj_t, x, reduce=self.aggr)

def __repr__(self):
return '{}({},{})'.format(self.__class__.__name__, self.in_channels, self.out_channels)

Expand Down
8 changes: 8 additions & 0 deletions recbole_gnn/quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ def run_recbole_gnn(model=None, dataset=None, config_file_list=None, config_dict
"""
# configurations initialization
config = Config(model=model, dataset=dataset, config_file_list=config_file_list, config_dict=config_dict)
try:
assert config["enable_sparse"] in [True, False, None]
except AssertionError:
raise ValueError("Your config `enable_sparse` must be `True` or `False` or `None`")
init_seed(config['seed'], config['reproducibility'])
# logger initialization
init_logger(config)
Expand Down Expand Up @@ -69,6 +73,10 @@ def objective_function(config_dict=None, config_file_list=None, saved=True):
"""

config = Config(config_dict=config_dict, config_file_list=config_file_list)
try:
assert config["enable_sparse"] in [True, False, None]
except AssertionError:
raise ValueError("Your config `enable_sparse` must be `True` or `False` or `None`")
init_seed(config['seed'], config['reproducibility'])
logging.basicConfig(level=logging.ERROR)
dataset = create_dataset(config)
Expand Down

0 comments on commit 979b219

Please sign in to comment.