From 8b4bf326cd9ef5e3c61e62f6eaa8e35e90583783 Mon Sep 17 00:00:00 2001 From: xzjin Date: Mon, 23 Oct 2023 21:49:19 +0800 Subject: [PATCH] FEA: add sparse tensor support for ngcf_conv and lightgcn_conv --- recbole_gnn/data/dataset.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/recbole_gnn/data/dataset.py b/recbole_gnn/data/dataset.py index 61324a1..11273a8 100644 --- a/recbole_gnn/data/dataset.py +++ b/recbole_gnn/data/dataset.py @@ -16,11 +16,28 @@ from recbole.data.dataset import Dataset as RecBoleDataset from recbole.utils import set_color, FeatureSource +import recbole +import pickle +from recbole.utils import ensure_dir + class GeneralGraphDataset(RecBoleDataset): def __init__(self, config): super().__init__(config) + if recbole.__version__ == "1.1.1": + + def save(self): + """Saving this :class:`Dataset` object to :attr:`config['checkpoint_dir']`.""" + save_dir = self.config["checkpoint_dir"] + ensure_dir(save_dir) + file = os.path.join(save_dir, f'{self.config["dataset"]}-{self.__class__.__name__}.pth') + self.logger.info( + set_color("Saving filtered dataset into ", "pink") + f"[{file}]" + ) + with open(file, "wb") as f: + pickle.dump(self, f) + @staticmethod def edge_index_to_adj_t(edge_index, edge_weight, m_num_nodes, n_num_nodes): adj = SparseTensor(row=edge_index[0],