diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index 00a5cc56..3ef1ff69 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -33,5 +33,5 @@ from .searcher.web.bing_search import BingSearch from .searcher.web.google_search import GoogleSearch from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter -from .storage import JsonKVStorage, JsonListStorage, NetworkXStorage +from .storage import JsonKVStorage, JsonListStorage, NetworkXStorage, RocksDBCache from .tokenizer import Tokenizer diff --git a/graphgen/models/storage/__init__.py b/graphgen/models/storage/__init__.py index 56338984..1e8f8341 100644 --- a/graphgen/models/storage/__init__.py +++ b/graphgen/models/storage/__init__.py @@ -1,2 +1,3 @@ from .json_storage import JsonKVStorage, JsonListStorage from .networkx_storage import NetworkXStorage +from .rocksdb_cache import RocksDBCache diff --git a/graphgen/models/storage/rocksdb_cache.py b/graphgen/models/storage/rocksdb_cache.py new file mode 100644 index 00000000..2345b5b5 --- /dev/null +++ b/graphgen/models/storage/rocksdb_cache.py @@ -0,0 +1,43 @@ +from pathlib import Path +from typing import Any, Iterator, Optional + +# rocksdict is a lightweight C wrapper around RocksDB for Python, pylint may not recognize it +# pylint: disable=no-name-in-module +from rocksdict import Rdict + + +class RocksDBCache: + def __init__(self, cache_dir: str): + self.db_path = Path(cache_dir) + self.db = Rdict(str(self.db_path)) + + def get(self, key: str) -> Optional[Any]: + return self.db.get(key) + + def set(self, key: str, value: Any): + self.db[key] = value + + def delete(self, key: str): + try: + del self.db[key] + except KeyError: + # If the key does not exist, do nothing (deletion is idempotent for caches) + pass + + def close(self): + if hasattr(self, "db") and self.db is not None: + self.db.close() + self.db = None + + def __del__(self): + # Ensure the database is closed when the object is destroyed + self.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def __iter__(self) -> Iterator[str]: + return iter(self.db.keys()) diff --git a/graphgen/operators/read/parallel_file_scanner.py b/graphgen/operators/read/parallel_file_scanner.py index 890a50a9..73b477c3 100644 --- a/graphgen/operators/read/parallel_file_scanner.py +++ b/graphgen/operators/read/parallel_file_scanner.py @@ -4,8 +4,7 @@ from pathlib import Path from typing import Any, Dict, List, Set, Union -from diskcache import Cache - +from graphgen.models import RocksDBCache from graphgen.utils import logger @@ -13,7 +12,7 @@ class ParallelFileScanner: def __init__( self, cache_dir: str, allowed_suffix, rescan: bool = False, max_workers: int = 4 ): - self.cache = Cache(cache_dir) + self.cache = RocksDBCache(os.path.join(cache_dir, "file_paths_cache")) self.allowed_suffix = set(allowed_suffix) if allowed_suffix else None self.rescan = rescan self.max_workers = max_workers diff --git a/requirements.txt b/requirements.txt index fa2b1efc..85fc43e3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,13 +20,15 @@ requests fastapi trafilatura aiohttp -diskcache socksio leidenalg igraph python-louvain +# storage +rocksdict + # KG rdflib