From 77ec8c84b8d4711f34099e0a421de910ccb42809 Mon Sep 17 00:00:00 2001 From: JieguangZhou Date: Wed, 21 Feb 2024 23:15:09 +0800 Subject: [PATCH] Add file datatype type to support saving and reading files/folders in the artifact store. --- superduperdb/backends/base/artifact.py | 29 ++++++-- superduperdb/backends/local/artifacts.py | 11 ++++ superduperdb/backends/mongodb/artifacts.py | 66 +++++++++++++++++++ superduperdb/base/document.py | 1 + superduperdb/components/component.py | 2 - superduperdb/components/datatype.py | 45 +++++++++---- test/unittest/backends/local/__init__.py | 0 .../unittest/backends/local/test_artifacts.py | 44 +++++++++++++ 8 files changed, 180 insertions(+), 18 deletions(-) create mode 100644 test/unittest/backends/local/__init__.py create mode 100644 test/unittest/backends/local/test_artifacts.py diff --git a/superduperdb/backends/base/artifact.py b/superduperdb/backends/base/artifact.py index 67384c1c4c..166525a8fb 100644 --- a/superduperdb/backends/base/artifact.py +++ b/superduperdb/backends/base/artifact.py @@ -86,6 +86,10 @@ def exists( def _save_bytes(self, serialized: bytes, file_id: str): pass + @abstractmethod + def _save_file(self, file_path: str, file_id: str): + pass + def save_artifact(self, r: t.Dict): """ Save serialized object in the artifact store. @@ -99,14 +103,17 @@ def save_artifact(self, r: t.Dict): assert 'datatype' in r, 'no datatype specified!' datatype = self.serializers[r['datatype']] uri = r.get('uri') - file_id = None if uri is not None: file_id = _construct_file_id_from_uri(uri) else: - file_id = hashlib.sha1(r['bytes']).hexdigest() + file_id = r.get('sha1') or hashlib.sha1(r['bytes']).hexdigest() if r.get('directory'): file_id = os.path.join(datatype.directory, file_id) - self._save_bytes(r['bytes'], file_id=file_id) + + if r['datatype'] == 'file': + self._save_file(r['bytes'], file_id) + else: + self._save_bytes(r['bytes'], file_id=file_id) r['file_id'] = file_id return r @@ -119,6 +126,15 @@ def _load_bytes(self, file_id: str) -> bytes: """ pass + @abstractmethod + def _load_file(self, file_id: str) -> str: + """ + Load file from artifact store and return path + + :param file_id: Identifier of artifact in the store + """ + pass + def load_artifact(self, r): """ Load artifact from artifact store, and deserialize. @@ -133,8 +149,11 @@ def load_artifact(self, r): if file_id is None: assert uri is not None, '"uri" and "file_id" can\'t both be None' file_id = _construct_file_id_from_uri(uri) - bytes = self._load_bytes(file_id) - return datatype.decoder(bytes) + if r['datatype'] == 'file': + x = self._load_file(file_id) + else: + x = self._load_bytes(file_id) + return datatype.decoder(x) def save(self, r: t.Dict) -> t.Dict: """ diff --git a/superduperdb/backends/local/artifacts.py b/superduperdb/backends/local/artifacts.py index 2d1779fe9f..dd9f6bf983 100644 --- a/superduperdb/backends/local/artifacts.py +++ b/superduperdb/backends/local/artifacts.py @@ -1,6 +1,7 @@ import os import shutil import typing as t +from pathlib import Path import click @@ -69,6 +70,16 @@ def _load_bytes(self, file_id: str) -> bytes: with open(os.path.join(self.conn, file_id), 'rb') as f: return f.read() + def _save_file(self, file_path: str, file_id: str): + path = Path(file_path) + if path.is_dir(): + shutil.copytree(file_path, os.path.join(self.conn, file_id)) + else: + shutil.copy(file_path, os.path.join(self.conn, file_id)) + + def _load_file(self, file_id: str) -> str: + return os.path.join(self.conn, file_id) + def disconnect(self): """ Disconnect the client diff --git a/superduperdb/backends/mongodb/artifacts.py b/superduperdb/backends/mongodb/artifacts.py index 37b66325d1..10ebe0ece4 100644 --- a/superduperdb/backends/mongodb/artifacts.py +++ b/superduperdb/backends/mongodb/artifacts.py @@ -1,3 +1,7 @@ +import os +import tempfile +from pathlib import Path + import click import gridfs @@ -48,6 +52,28 @@ def _load_bytes(self, file_id: str): raise FileNotFoundError(f'File not found in {file_id}') return next(cur) + def _save_file(self, file_path: str, file_id: str): + path = Path(file_path) + if path.is_dir(): + upload_folder(file_path, file_id, self.filesystem) + else: + self.filesystem.put( + open(file_path, 'rb'), + filename=file_path, + metadata={"file_id": file_id, "type": "file"}, + ) + + def _load_file(self, file_id: str) -> str: + file = self.filesystem.find_one( + {"metadata.file_id": file_id, "metadata.type": "file"} + ) + if file is not None: + with open(file.filename, 'wb') as f: + f.write(file.read()) + return file.filename + + return download_folder(file_id, self.filesystem) + def _save_bytes(self, serialized: bytes, file_id: str): return self.filesystem.put(serialized, filename=file_id) @@ -57,3 +83,43 @@ def disconnect(self): """ # TODO: implement me + + +def upload_folder(path, file_id, fs, parent_path=""): + if not os.listdir(path): + fs.put( + b'', + filename=os.path.join(parent_path, os.path.basename(path) + '/'), + metadata={"file_id": file_id, "is_empty_dir": True, 'type': 'dir'}, + ) + else: + for item in os.listdir(path): + item_path = os.path.join(path, item) + if os.path.isdir(item_path): + upload_folder(item_path, file_id, fs, os.path.join(parent_path, item)) + else: + with open(item_path, 'rb') as file_to_upload: + fs.put( + file_to_upload, + filename=os.path.join(parent_path, item), + metadata={"file_id": file_id, "type": "dir"}, + ) + + +def download_folder(file_id, fs): + temp_dir = tempfile.mkdtemp() + logging.info(f"Downloading files to temporary directory: {temp_dir}") + + for grid_out in fs.find({"metadata.file_id": file_id, "metadata.type": "dir"}): + file_path = os.path.join(temp_dir, grid_out.filename) + if grid_out.metadata.get("is_empty_dir", False): + if not os.path.exists(file_path): + os.makedirs(file_path) + else: + directory = os.path.dirname(file_path) + if not os.path.exists(directory): + os.makedirs(directory) + with open(file_path, 'wb') as file_to_write: + file_to_write.write(grid_out.read()) + + return temp_dir diff --git a/superduperdb/base/document.py b/superduperdb/base/document.py index c6b010fb05..051e1fa3d0 100644 --- a/superduperdb/base/document.py +++ b/superduperdb/base/document.py @@ -204,6 +204,7 @@ def _encode( return out # ruff: noqa: E501 if isinstance(r, Leaf) and not isinstance(r, leaf_types_to_keep): # type: ignore[arg-type] + # TODO: (not leaf_types_to_keep or isinstance(r, leaf_types_to_keep)) ? return r.encode( bytes_encoding=bytes_encoding, leaf_types_to_keep=leaf_types_to_keep ) diff --git a/superduperdb/components/component.py b/superduperdb/components/component.py index 5af8ffffd2..3be2416dea 100644 --- a/superduperdb/components/component.py +++ b/superduperdb/components/component.py @@ -50,7 +50,6 @@ def artifact_schema(self): from superduperdb import Schema from superduperdb.components.datatype import dill_serializer - e = [] schema = {} lookup = dict(self._artifacts) if self.artifacts is not None: @@ -60,7 +59,6 @@ def artifact_schema(self): if a is None: continue if f.name in lookup: - e.append(f.name) schema[f.name] = lookup[f.name] elif callable(getattr(self, f.name)) and not isinstance( getattr(self, f.name), Serializable diff --git a/superduperdb/components/datatype.py b/superduperdb/components/datatype.py index b00d1cb154..6baf1abd50 100644 --- a/superduperdb/components/datatype.py +++ b/superduperdb/components/datatype.py @@ -2,6 +2,7 @@ import dataclasses as dc import hashlib import io +import os import pickle import typing as t @@ -33,6 +34,12 @@ def dill_decode(b: bytes, info: t.Optional[t.Dict] = None) -> t.Any: return dill.loads(b) +def file_check(path: t.Any, info: t.Optional[t.Dict] = None) -> str: + if not (isinstance(path, str) and os.path.exists(path)): + raise ValueError(f"Path '{path}' does not exist") + return path + + def torch_encode(object: t.Any, info: t.Optional[t.Dict] = None) -> bytes: import torch @@ -105,10 +112,14 @@ def __call__( torch_serializer = DataType( 'torch', encoder=torch_encode, decoder=torch_decode, artifact=True ) +file_serializer = DataType( + 'file', encoder=file_check, decoder=file_check, artifact=True +) serializers = { 'pickle': pickle_serializer, 'dill': dill_serializer, 'torch': torch_serializer, + 'file': file_serializer, } @@ -185,24 +196,36 @@ def encode( def _encode(x): try: - bytes_ = self.datatype.encoder(x) + x = self.datatype.encoder(x) except Exception as e: - raise ArtifactSavingError from e - sha1 = str(hashlib.sha1(bytes_).hexdigest()) - if ( - CFG.bytes_encoding == BytesEncoding.BASE64 - or bytes_encoding == BytesEncoding.BASE64 - ): - bytes_ = to_base64(bytes_) - return bytes_, sha1 + raise ArtifactSavingError(e) from e + + if isinstance(x, str): + sha1 = str(hashlib.sha1(x.encode()).hexdigest()) + elif isinstance(x, bytes): + sha1 = str(hashlib.sha1(x).hexdigest()) + if ( + CFG.bytes_encoding == BytesEncoding.BASE64 + or bytes_encoding == BytesEncoding.BASE64 + ): + # TODO: artifact stores is not compatible with non-bytes types + x = to_base64(x) + else: + raise ValueError( + 'The datatype can only encode data as [bytes, str], ', + f'but now it is encoded as {type(x)}', + ) + + return x, sha1 if self.datatype.encoder is None: return self.x - bytes_, sha1 = _encode(self.x) + x, sha1 = _encode(self.x) + # TODO: Use a new class to handle this return { '_content': { - 'bytes': bytes_, + 'bytes': x, 'datatype': self.datatype.identifier, 'leaf_type': 'encodable', 'sha1': sha1, diff --git a/test/unittest/backends/local/__init__.py b/test/unittest/backends/local/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/unittest/backends/local/test_artifacts.py b/test/unittest/backends/local/test_artifacts.py new file mode 100644 index 0000000000..0d4e6b56b2 --- /dev/null +++ b/test/unittest/backends/local/test_artifacts.py @@ -0,0 +1,44 @@ +import dataclasses as dc +import os +import typing as t +from test.db_config import DBConfig + +import pytest + +from superduperdb.backends.local.artifacts import FileSystemArtifactStore +from superduperdb.components.component import Component +from superduperdb.components.datatype import ( + DataType, + file_serializer, + serializers, +) + + +@dc.dataclass(kw_only=True) +class TestComponent(Component): + path: str + type_id: t.ClassVar[str] = "TestComponent" + + _artifacts: t.ClassVar[t.Sequence[t.Tuple[str, "DataType"]]] = ( + ("path", file_serializer), + ) + + +@pytest.fixture +def artifact_strore(tmpdir) -> FileSystemArtifactStore: + tmpdir = "uttest" + artifact_strore = FileSystemArtifactStore(f"{tmpdir}") + artifact_strore._serializers = serializers + return artifact_strore + + +@pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True) +def test_save_and_load_file(db, artifact_strore: FileSystemArtifactStore): + db.artifact_store = artifact_strore + test_component = TestComponent(path="superduperdb", identifier="test") + db.add(test_component) + test_component_loaded = db.load("TestComponent", "test") + assert test_component.path != test_component_loaded.path + assert os.path.getsize(test_component.path) == os.path.getsize( + test_component_loaded.path + )