From d4179b671dfaa6f21438f30fcb06c2a246ea8caa Mon Sep 17 00:00:00 2001 From: JieguangZhou Date: Fri, 23 Feb 2024 17:43:40 +0800 Subject: [PATCH] Generating random file_id for saving file to artifact store --- superduperdb/backends/base/artifact.py | 4 +- superduperdb/backends/mongodb/artifacts.py | 46 +++++++++++++------ superduperdb/base/document.py | 1 - superduperdb/components/datatype.py | 37 +++++++++++++-- test/integration/artifacts/test_mongodb.py | 29 ++++++++---- .../unittest/backends/local/test_artifacts.py | 39 ++++++++++++---- 6 files changed, 120 insertions(+), 36 deletions(-) diff --git a/superduperdb/backends/base/artifact.py b/superduperdb/backends/base/artifact.py index d6c5d9e2d5..a17086bef9 100644 --- a/superduperdb/backends/base/artifact.py +++ b/superduperdb/backends/base/artifact.py @@ -102,8 +102,8 @@ def save_artifact(self, r: t.Dict): {'file_id', 'uri'} """ if r.get('type') == 'file': - sha1 = hashlib.sha1(str(id(r)).encode()).hexdigest() - file_id = self._save_file(r['x'], sha1) + assert 'file_id' in r, 'file_id is missing!' + file_id = self._save_file(r['x'], r['file_id']) else: assert 'bytes' in r, 'serialized bytes are missing!' assert 'datatype' in r, 'no datatype specified!' diff --git a/superduperdb/backends/mongodb/artifacts.py b/superduperdb/backends/mongodb/artifacts.py index 36a0d2d3df..0bb66200af 100644 --- a/superduperdb/backends/mongodb/artifacts.py +++ b/superduperdb/backends/mongodb/artifacts.py @@ -4,8 +4,9 @@ import click import gridfs +from tqdm import tqdm -from superduperdb import logging +from superduperdb import CFG, logging from superduperdb.backends.base.artifact import ArtifactStore from superduperdb.misc.colors import Colors @@ -121,20 +122,39 @@ def upload_folder(path, file_id, fs, parent_path=""): def download(file_id, fs): """Download file or folder from GridFS and return the path""" - temp_dir = tempfile.mkdtemp(prefix=file_id) - # try to download a file first, if it fails, assume it's a folder - file = fs.find_one({"metadata.file_id": file_id, "metadata.type": "file"}) - if file is not None: - save_path = os.path.join(temp_dir, os.path.split(file.filename)[-1]) + download_folder = CFG.downloads.folder + + if not download_folder: + download_folder = os.path.join( + tempfile.gettempdir(), "superduperdb", "ArtifactStore" + ) + + save_folder = os.path.join(download_folder, file_id) + os.makedirs(save_folder, exist_ok=True) + + file = fs.find_one({"metadata.file_id": file_id}) + if file is None: + raise FileNotFoundError(f"File not found in {file_id}") + + type_ = file.metadata.get("type") + if type_ not in {"file", "dir"}: + raise ValueError( + f"Unknown type '{type_}' for file_id {file_id}, expected file or dir" + ) + + if type_ == 'file': + save_path = os.path.join(save_folder, os.path.split(file.filename)[-1]) logging.info(f"Downloading file_id {file_id} to {save_path}") with open(save_path, 'wb') as f: f.write(file.read()) return save_path - logging.info(f"Downloading folder with file_id {file_id} to {temp_dir}") - for grid_out in fs.find({"metadata.file_id": file_id, "metadata.type": "dir"}): - file_path = os.path.join(temp_dir, grid_out.filename) + logging.info(f"Downloading folder with file_id {file_id} to {save_folder}") + for grid_out in tqdm( + fs.find({"metadata.file_id": file_id, "metadata.type": "dir"}) + ): + file_path = os.path.join(save_folder, grid_out.filename) if grid_out.metadata.get("is_empty_dir", False): if not os.path.exists(file_path): os.makedirs(file_path) @@ -145,8 +165,8 @@ def download(file_id, fs): with open(file_path, 'wb') as file_to_write: file_to_write.write(grid_out.read()) - folders = os.listdir(temp_dir) + folders = os.listdir(save_folder) assert len(folders) == 1, f"Expected only one folder, got {folders}" - temp_dir = os.path.join(temp_dir, folders[0]) - logging.info(f"Downloaded folder with file_id {file_id} to {temp_dir}") - return temp_dir + save_folder = os.path.join(save_folder, folders[0]) + logging.info(f"Downloaded folder with file_id {file_id} to {save_folder}") + return save_folder diff --git a/superduperdb/base/document.py b/superduperdb/base/document.py index c2e732b298..78347ea34f 100644 --- a/superduperdb/base/document.py +++ b/superduperdb/base/document.py @@ -22,7 +22,6 @@ ItemType = t.Union[t.Dict[str, t.Any], Encodable, ObjectId] _OUTPUTS_KEY: str = '_outputs' -# TODO: Remove this dict to map leaf types to classes _LEAF_TYPES = { 'component': Component, 'encodable': Encodable, diff --git a/superduperdb/components/datatype.py b/superduperdb/components/datatype.py index 8628613828..31528a4534 100644 --- a/superduperdb/components/datatype.py +++ b/superduperdb/components/datatype.py @@ -18,6 +18,17 @@ Encode = t.Callable[[t.Any], bytes] +def random_sha1(): + """ + Generate random sha1 values + Can be used to generate file_id and other values + """ + random_data = os.urandom(256) + sha1 = hashlib.sha1() + sha1.update(random_data) + return sha1.hexdigest() + + def pickle_encode(object: t.Any, info: t.Optional[t.Dict] = None) -> bytes: return pickle.dumps(object) @@ -96,12 +107,13 @@ class DataType(Component): artifact: bool = False reference: bool = False directory: t.Optional[str] = None - encodable_cls: t.Optional[t.Type['_BaseEncodable']] = None + encodable: t.Optional[str] = None def __call__( self, x: t.Optional[t.Any] = None, uri: t.Optional[str] = None ) -> '_BaseEncodable': - encodable_cls = self.encodable_cls or Encodable + # get the sub class with name artifact_stroreof _BaseEncodable + encodable_cls = _BaseEncodable.get_encodable_cls(self.encodable, Encodable) return encodable_cls(self, x=x, uri=uri) @@ -186,6 +198,24 @@ def unpack(self, db): self.init(db) return self.x + @classmethod + def get_encodable_cls(cls, name, default=None): + """ + Get the subclass of the _BaseEncodable with the given name. + All the registered subclasses must be subclasses of the _BaseEncodable. + """ + for sub_cls in cls.__subclasses__(): + if sub_cls.__name__ == name: + return sub_cls + if default is None: + raise ValueError(f'No subclass with name "{name}" found.') + elif not issubclass(default, cls): + raise ValueError( + "The default class must be a subclass of the _BaseEncodable." + ) + + return default + @dc.dataclass class Encodable(_BaseEncodable): @@ -286,6 +316,7 @@ def encode( 'datatype': self.datatype.identifier, 'uri': None, 'artifact': self.artifact, + 'file_id': random_sha1(), } } @@ -314,7 +345,7 @@ def decode(cls, r, db=None, reference: bool = False) -> '_BaseEncodable': encoder=file_check, decoder=file_check, artifact=True, - encodable_cls=ReferenceEncodable, + encodable="ReferenceEncodable", ) serializers = { 'pickle': pickle_serializer, diff --git a/test/integration/artifacts/test_mongodb.py b/test/integration/artifacts/test_mongodb.py index 563e2a1273..31bf5872a0 100644 --- a/test/integration/artifacts/test_mongodb.py +++ b/test/integration/artifacts/test_mongodb.py @@ -5,12 +5,10 @@ import pytest -from superduperdb.backends.local.artifacts import FileSystemArtifactStore from superduperdb.components.component import Component from superduperdb.components.datatype import ( DataType, file_serializer, - serializers, ) @@ -25,16 +23,29 @@ class TestComponent(Component): @pytest.fixture -def artifact_strore(tmpdir) -> FileSystemArtifactStore: - artifact_strore = FileSystemArtifactStore(f"{tmpdir}") - artifact_strore._serializers = serializers - return artifact_strore +def random_directory(tmpdir): + tmpdir_path = os.path.join(tmpdir, "test_data") + os.makedirs(tmpdir_path, exist_ok=True) + for i in range(10): + file_name = f'{i}.txt' + file_path = os.path.join(tmpdir_path, file_name) + with open(file_path, 'w') as file: + file.write(str(i)) -def test_save_and_load_directory(test_db): + for j in range(10): + sub_dir = os.path.join(tmpdir_path, f'subdir_{j}') + os.makedirs(sub_dir, exist_ok=True) + sub_file_path = os.path.join(sub_dir, file_name) + with open(sub_file_path, 'w') as file: + file.write(f"{i} {j}") + + return tmpdir_path + + +def test_save_and_load_directory(test_db, random_directory): # test save and load directory - directory = os.path.join(os.getcwd(), "superduperdb") - test_component = TestComponent(path=directory, identifier="test") + test_component = TestComponent(path=random_directory, identifier="test") test_db.add(test_component) test_component_loaded = test_db.load("TestComponent", "test") test_component_loaded.init() diff --git a/test/unittest/backends/local/test_artifacts.py b/test/unittest/backends/local/test_artifacts.py index 617dcb8d2c..17090e34a9 100644 --- a/test/unittest/backends/local/test_artifacts.py +++ b/test/unittest/backends/local/test_artifacts.py @@ -26,19 +26,42 @@ class TestComponent(Component): @pytest.fixture -def artifact_strore(tmpdir) -> FileSystemArtifactStore: - artifact_strore = FileSystemArtifactStore(f"{tmpdir}") +def random_directory(tmpdir): + tmpdir_path = os.path.join(tmpdir, "test_data") + os.makedirs(tmpdir_path, exist_ok=True) + for i in range(10): + file_name = f'{i}.txt' + file_path = os.path.join(tmpdir_path, file_name) + + with open(file_path, 'w') as file: + file.write(str(i)) + + for j in range(10): + sub_dir = os.path.join(tmpdir_path, f'subdir_{j}') + os.makedirs(sub_dir, exist_ok=True) + sub_file_path = os.path.join(sub_dir, file_name) + with open(sub_file_path, 'w') as file: + file.write(f"{i} {j}") + + return tmpdir_path + + +@pytest.fixture +def artifact_store(tmpdir) -> FileSystemArtifactStore: + tmpdir_path = os.path.join(tmpdir, "artifact_store") + artifact_strore = FileSystemArtifactStore(f"{tmpdir_path}") artifact_strore._serializers = serializers return artifact_strore @pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True) -def test_save_and_load_directory(db, artifact_strore: FileSystemArtifactStore): - db.artifact_store = artifact_strore +def test_save_and_load_directory( + db, artifact_store: FileSystemArtifactStore, random_directory +): + db.artifact_store = artifact_store # test save and load directory - directory = os.path.join(os.getcwd(), "superduperdb") - test_component = TestComponent(path=directory, identifier="test") + test_component = TestComponent(path=random_directory, identifier="test") db.add(test_component) test_component_loaded = db.load("TestComponent", "test") test_component_loaded.init() @@ -56,8 +79,8 @@ def test_save_and_load_directory(db, artifact_strore: FileSystemArtifactStore): @pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True) -def test_save_and_load_file(db, artifact_strore: FileSystemArtifactStore): - db.artifact_store = artifact_strore +def test_save_and_load_file(db, artifact_store: FileSystemArtifactStore): + db.artifact_store = artifact_store # test save and load file file = os.path.abspath(__file__) test_component = TestComponent(path=file, identifier="test")