Skip to content

Commit

Permalink
Add file datatype type to support saving and reading files/folders in…
Browse files Browse the repository at this point in the history
… the artifact store.
  • Loading branch information
jieguangzhou committed Feb 21, 2024
1 parent baec125 commit 77ec8c8
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 18 deletions.
29 changes: 24 additions & 5 deletions superduperdb/backends/base/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def exists(
def _save_bytes(self, serialized: bytes, file_id: str):
pass

@abstractmethod
def _save_file(self, file_path: str, file_id: str):
pass

def save_artifact(self, r: t.Dict):
"""
Save serialized object in the artifact store.
Expand All @@ -99,14 +103,17 @@ def save_artifact(self, r: t.Dict):
assert 'datatype' in r, 'no datatype specified!'
datatype = self.serializers[r['datatype']]
uri = r.get('uri')
file_id = None
if uri is not None:
file_id = _construct_file_id_from_uri(uri)
else:
file_id = hashlib.sha1(r['bytes']).hexdigest()
file_id = r.get('sha1') or hashlib.sha1(r['bytes']).hexdigest()
if r.get('directory'):
file_id = os.path.join(datatype.directory, file_id)
self._save_bytes(r['bytes'], file_id=file_id)

if r['datatype'] == 'file':
self._save_file(r['bytes'], file_id)
else:
self._save_bytes(r['bytes'], file_id=file_id)
r['file_id'] = file_id
return r

Expand All @@ -119,6 +126,15 @@ def _load_bytes(self, file_id: str) -> bytes:
"""
pass

@abstractmethod
def _load_file(self, file_id: str) -> str:
"""
Load file from artifact store and return path
:param file_id: Identifier of artifact in the store
"""
pass

def load_artifact(self, r):
"""
Load artifact from artifact store, and deserialize.
Expand All @@ -133,8 +149,11 @@ def load_artifact(self, r):
if file_id is None:
assert uri is not None, '"uri" and "file_id" can\'t both be None'
file_id = _construct_file_id_from_uri(uri)
bytes = self._load_bytes(file_id)
return datatype.decoder(bytes)
if r['datatype'] == 'file':
x = self._load_file(file_id)
else:
x = self._load_bytes(file_id)
return datatype.decoder(x)

def save(self, r: t.Dict) -> t.Dict:
"""
Expand Down
11 changes: 11 additions & 0 deletions superduperdb/backends/local/artifacts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import shutil
import typing as t
from pathlib import Path

import click

Expand Down Expand Up @@ -69,6 +70,16 @@ def _load_bytes(self, file_id: str) -> bytes:
with open(os.path.join(self.conn, file_id), 'rb') as f:
return f.read()

def _save_file(self, file_path: str, file_id: str):
path = Path(file_path)
if path.is_dir():
shutil.copytree(file_path, os.path.join(self.conn, file_id))
else:
shutil.copy(file_path, os.path.join(self.conn, file_id))

def _load_file(self, file_id: str) -> str:
return os.path.join(self.conn, file_id)

def disconnect(self):
"""
Disconnect the client
Expand Down
66 changes: 66 additions & 0 deletions superduperdb/backends/mongodb/artifacts.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import os
import tempfile
from pathlib import Path

import click
import gridfs

Expand Down Expand Up @@ -48,6 +52,28 @@ def _load_bytes(self, file_id: str):
raise FileNotFoundError(f'File not found in {file_id}')
return next(cur)

def _save_file(self, file_path: str, file_id: str):
path = Path(file_path)
if path.is_dir():
upload_folder(file_path, file_id, self.filesystem)
else:
self.filesystem.put(
open(file_path, 'rb'),
filename=file_path,
metadata={"file_id": file_id, "type": "file"},
)

def _load_file(self, file_id: str) -> str:
file = self.filesystem.find_one(
{"metadata.file_id": file_id, "metadata.type": "file"}
)
if file is not None:
with open(file.filename, 'wb') as f:
f.write(file.read())
return file.filename

return download_folder(file_id, self.filesystem)

def _save_bytes(self, serialized: bytes, file_id: str):
return self.filesystem.put(serialized, filename=file_id)

Expand All @@ -57,3 +83,43 @@ def disconnect(self):
"""

# TODO: implement me


def upload_folder(path, file_id, fs, parent_path=""):
if not os.listdir(path):
fs.put(
b'',
filename=os.path.join(parent_path, os.path.basename(path) + '/'),
metadata={"file_id": file_id, "is_empty_dir": True, 'type': 'dir'},
)
else:
for item in os.listdir(path):
item_path = os.path.join(path, item)
if os.path.isdir(item_path):
upload_folder(item_path, file_id, fs, os.path.join(parent_path, item))
else:
with open(item_path, 'rb') as file_to_upload:
fs.put(
file_to_upload,
filename=os.path.join(parent_path, item),
metadata={"file_id": file_id, "type": "dir"},
)


def download_folder(file_id, fs):
temp_dir = tempfile.mkdtemp()
logging.info(f"Downloading files to temporary directory: {temp_dir}")

for grid_out in fs.find({"metadata.file_id": file_id, "metadata.type": "dir"}):
file_path = os.path.join(temp_dir, grid_out.filename)
if grid_out.metadata.get("is_empty_dir", False):
if not os.path.exists(file_path):
os.makedirs(file_path)
else:
directory = os.path.dirname(file_path)
if not os.path.exists(directory):
os.makedirs(directory)
with open(file_path, 'wb') as file_to_write:
file_to_write.write(grid_out.read())

return temp_dir
1 change: 1 addition & 0 deletions superduperdb/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def _encode(
return out
# ruff: noqa: E501
if isinstance(r, Leaf) and not isinstance(r, leaf_types_to_keep): # type: ignore[arg-type]
# TODO: (not leaf_types_to_keep or isinstance(r, leaf_types_to_keep)) ?
return r.encode(
bytes_encoding=bytes_encoding, leaf_types_to_keep=leaf_types_to_keep
)
Expand Down
2 changes: 0 additions & 2 deletions superduperdb/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def artifact_schema(self):
from superduperdb import Schema
from superduperdb.components.datatype import dill_serializer

e = []
schema = {}
lookup = dict(self._artifacts)
if self.artifacts is not None:
Expand All @@ -60,7 +59,6 @@ def artifact_schema(self):
if a is None:
continue
if f.name in lookup:
e.append(f.name)
schema[f.name] = lookup[f.name]
elif callable(getattr(self, f.name)) and not isinstance(
getattr(self, f.name), Serializable
Expand Down
45 changes: 34 additions & 11 deletions superduperdb/components/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import dataclasses as dc
import hashlib
import io
import os
import pickle
import typing as t

Expand Down Expand Up @@ -33,6 +34,12 @@ def dill_decode(b: bytes, info: t.Optional[t.Dict] = None) -> t.Any:
return dill.loads(b)


def file_check(path: t.Any, info: t.Optional[t.Dict] = None) -> str:
if not (isinstance(path, str) and os.path.exists(path)):
raise ValueError(f"Path '{path}' does not exist")
return path


def torch_encode(object: t.Any, info: t.Optional[t.Dict] = None) -> bytes:
import torch

Expand Down Expand Up @@ -105,10 +112,14 @@ def __call__(
torch_serializer = DataType(
'torch', encoder=torch_encode, decoder=torch_decode, artifact=True
)
file_serializer = DataType(
'file', encoder=file_check, decoder=file_check, artifact=True
)
serializers = {
'pickle': pickle_serializer,
'dill': dill_serializer,
'torch': torch_serializer,
'file': file_serializer,
}


Expand Down Expand Up @@ -185,24 +196,36 @@ def encode(

def _encode(x):
try:
bytes_ = self.datatype.encoder(x)
x = self.datatype.encoder(x)
except Exception as e:
raise ArtifactSavingError from e
sha1 = str(hashlib.sha1(bytes_).hexdigest())
if (
CFG.bytes_encoding == BytesEncoding.BASE64
or bytes_encoding == BytesEncoding.BASE64
):
bytes_ = to_base64(bytes_)
return bytes_, sha1
raise ArtifactSavingError(e) from e

if isinstance(x, str):
sha1 = str(hashlib.sha1(x.encode()).hexdigest())
elif isinstance(x, bytes):
sha1 = str(hashlib.sha1(x).hexdigest())
if (
CFG.bytes_encoding == BytesEncoding.BASE64
or bytes_encoding == BytesEncoding.BASE64
):
# TODO: artifact stores is not compatible with non-bytes types
x = to_base64(x)
else:
raise ValueError(
'The datatype can only encode data as [bytes, str], ',
f'but now it is encoded as {type(x)}',
)

return x, sha1

if self.datatype.encoder is None:
return self.x

bytes_, sha1 = _encode(self.x)
x, sha1 = _encode(self.x)
# TODO: Use a new class to handle this
return {
'_content': {
'bytes': bytes_,
'bytes': x,
'datatype': self.datatype.identifier,
'leaf_type': 'encodable',
'sha1': sha1,
Expand Down
Empty file.
44 changes: 44 additions & 0 deletions test/unittest/backends/local/test_artifacts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import dataclasses as dc
import os
import typing as t
from test.db_config import DBConfig

import pytest

from superduperdb.backends.local.artifacts import FileSystemArtifactStore
from superduperdb.components.component import Component
from superduperdb.components.datatype import (
DataType,
file_serializer,
serializers,
)


@dc.dataclass(kw_only=True)
class TestComponent(Component):
path: str
type_id: t.ClassVar[str] = "TestComponent"

_artifacts: t.ClassVar[t.Sequence[t.Tuple[str, "DataType"]]] = (
("path", file_serializer),
)


@pytest.fixture
def artifact_strore(tmpdir) -> FileSystemArtifactStore:
tmpdir = "uttest"
artifact_strore = FileSystemArtifactStore(f"{tmpdir}")
artifact_strore._serializers = serializers
return artifact_strore


@pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True)
def test_save_and_load_file(db, artifact_strore: FileSystemArtifactStore):
db.artifact_store = artifact_strore
test_component = TestComponent(path="superduperdb", identifier="test")
db.add(test_component)
test_component_loaded = db.load("TestComponent", "test")
assert test_component.path != test_component_loaded.path
assert os.path.getsize(test_component.path) == os.path.getsize(
test_component_loaded.path
)

0 comments on commit 77ec8c8

Please sign in to comment.