diff --git a/recbole_gnn/model/abstract_recommender.py b/recbole_gnn/model/abstract_recommender.py index 088f0c6..236bf53 100644 --- a/recbole_gnn/model/abstract_recommender.py +++ b/recbole_gnn/model/abstract_recommender.py @@ -12,8 +12,12 @@ class GeneralGraphRecommender(GeneralRecommender): def __init__(self, config, dataset): super(GeneralGraphRecommender, self).__init__(config, dataset) - self.edge_index, self.edge_weight = dataset.get_norm_adj_mat() - self.edge_index, self.edge_weight = self.edge_index.to(self.device), self.edge_weight.to(self.device) + self.edge_index, self.edge_weight = dataset.get_norm_adj_mat(enable_sparse=config["enable_sparse"]) + self.use_sparse = config["enable_sparse"] and dataset.is_sparse + if self.use_sparse: + self.edge_index, self.edge_weight = self.edge_index.to(self.device), None + else: + self.edge_index, self.edge_weight = self.edge_index.to(self.device), self.edge_weight.to(self.device) class SocialRecommender(GeneralRecommender): @@ -23,4 +27,4 @@ class SocialRecommender(GeneralRecommender): type = ModelType.SOCIAL def __init__(self, config, dataset): - super(SocialRecommender, self).__init__(config, dataset) \ No newline at end of file + super(SocialRecommender, self).__init__(config, dataset)