Skip to content

Commit

Permalink
Generating random file_id for saving file to artifact store
Browse files Browse the repository at this point in the history
  • Loading branch information
jieguangzhou authored and blythed committed Feb 23, 2024
1 parent 1c0ff42 commit d4179b6
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 36 deletions.
4 changes: 2 additions & 2 deletions superduperdb/backends/base/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def save_artifact(self, r: t.Dict):
{'file_id', 'uri'}
"""
if r.get('type') == 'file':
sha1 = hashlib.sha1(str(id(r)).encode()).hexdigest()
file_id = self._save_file(r['x'], sha1)
assert 'file_id' in r, 'file_id is missing!'
file_id = self._save_file(r['x'], r['file_id'])
else:
assert 'bytes' in r, 'serialized bytes are missing!'
assert 'datatype' in r, 'no datatype specified!'
Expand Down
46 changes: 33 additions & 13 deletions superduperdb/backends/mongodb/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

import click
import gridfs
from tqdm import tqdm

from superduperdb import logging
from superduperdb import CFG, logging
from superduperdb.backends.base.artifact import ArtifactStore
from superduperdb.misc.colors import Colors

Expand Down Expand Up @@ -121,20 +122,39 @@ def upload_folder(path, file_id, fs, parent_path=""):

def download(file_id, fs):
"""Download file or folder from GridFS and return the path"""
temp_dir = tempfile.mkdtemp(prefix=file_id)

# try to download a file first, if it fails, assume it's a folder
file = fs.find_one({"metadata.file_id": file_id, "metadata.type": "file"})
if file is not None:
save_path = os.path.join(temp_dir, os.path.split(file.filename)[-1])
download_folder = CFG.downloads.folder

if not download_folder:
download_folder = os.path.join(
tempfile.gettempdir(), "superduperdb", "ArtifactStore"
)

save_folder = os.path.join(download_folder, file_id)
os.makedirs(save_folder, exist_ok=True)

file = fs.find_one({"metadata.file_id": file_id})
if file is None:
raise FileNotFoundError(f"File not found in {file_id}")

type_ = file.metadata.get("type")
if type_ not in {"file", "dir"}:
raise ValueError(
f"Unknown type '{type_}' for file_id {file_id}, expected file or dir"
)

if type_ == 'file':
save_path = os.path.join(save_folder, os.path.split(file.filename)[-1])
logging.info(f"Downloading file_id {file_id} to {save_path}")
with open(save_path, 'wb') as f:
f.write(file.read())
return save_path

logging.info(f"Downloading folder with file_id {file_id} to {temp_dir}")
for grid_out in fs.find({"metadata.file_id": file_id, "metadata.type": "dir"}):
file_path = os.path.join(temp_dir, grid_out.filename)
logging.info(f"Downloading folder with file_id {file_id} to {save_folder}")
for grid_out in tqdm(
fs.find({"metadata.file_id": file_id, "metadata.type": "dir"})
):
file_path = os.path.join(save_folder, grid_out.filename)
if grid_out.metadata.get("is_empty_dir", False):
if not os.path.exists(file_path):
os.makedirs(file_path)
Expand All @@ -145,8 +165,8 @@ def download(file_id, fs):
with open(file_path, 'wb') as file_to_write:
file_to_write.write(grid_out.read())

folders = os.listdir(temp_dir)
folders = os.listdir(save_folder)
assert len(folders) == 1, f"Expected only one folder, got {folders}"
temp_dir = os.path.join(temp_dir, folders[0])
logging.info(f"Downloaded folder with file_id {file_id} to {temp_dir}")
return temp_dir
save_folder = os.path.join(save_folder, folders[0])
logging.info(f"Downloaded folder with file_id {file_id} to {save_folder}")
return save_folder
1 change: 0 additions & 1 deletion superduperdb/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
ItemType = t.Union[t.Dict[str, t.Any], Encodable, ObjectId]

_OUTPUTS_KEY: str = '_outputs'
# TODO: Remove this dict to map leaf types to classes
_LEAF_TYPES = {
'component': Component,
'encodable': Encodable,
Expand Down
37 changes: 34 additions & 3 deletions superduperdb/components/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@
Encode = t.Callable[[t.Any], bytes]


def random_sha1():
"""
Generate random sha1 values
Can be used to generate file_id and other values
"""
random_data = os.urandom(256)
sha1 = hashlib.sha1()
sha1.update(random_data)
return sha1.hexdigest()


def pickle_encode(object: t.Any, info: t.Optional[t.Dict] = None) -> bytes:
return pickle.dumps(object)

Expand Down Expand Up @@ -96,12 +107,13 @@ class DataType(Component):
artifact: bool = False
reference: bool = False
directory: t.Optional[str] = None
encodable_cls: t.Optional[t.Type['_BaseEncodable']] = None
encodable: t.Optional[str] = None

def __call__(
self, x: t.Optional[t.Any] = None, uri: t.Optional[str] = None
) -> '_BaseEncodable':
encodable_cls = self.encodable_cls or Encodable
# get the sub class with name artifact_stroreof _BaseEncodable
encodable_cls = _BaseEncodable.get_encodable_cls(self.encodable, Encodable)
return encodable_cls(self, x=x, uri=uri)


Expand Down Expand Up @@ -186,6 +198,24 @@ def unpack(self, db):
self.init(db)
return self.x

@classmethod
def get_encodable_cls(cls, name, default=None):
"""
Get the subclass of the _BaseEncodable with the given name.
All the registered subclasses must be subclasses of the _BaseEncodable.
"""
for sub_cls in cls.__subclasses__():
if sub_cls.__name__ == name:
return sub_cls
if default is None:
raise ValueError(f'No subclass with name "{name}" found.')
elif not issubclass(default, cls):
raise ValueError(
"The default class must be a subclass of the _BaseEncodable."
)

return default


@dc.dataclass
class Encodable(_BaseEncodable):
Expand Down Expand Up @@ -286,6 +316,7 @@ def encode(
'datatype': self.datatype.identifier,
'uri': None,
'artifact': self.artifact,
'file_id': random_sha1(),
}
}

Expand Down Expand Up @@ -314,7 +345,7 @@ def decode(cls, r, db=None, reference: bool = False) -> '_BaseEncodable':
encoder=file_check,
decoder=file_check,
artifact=True,
encodable_cls=ReferenceEncodable,
encodable="ReferenceEncodable",
)
serializers = {
'pickle': pickle_serializer,
Expand Down
29 changes: 20 additions & 9 deletions test/integration/artifacts/test_mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@

import pytest

from superduperdb.backends.local.artifacts import FileSystemArtifactStore
from superduperdb.components.component import Component
from superduperdb.components.datatype import (
DataType,
file_serializer,
serializers,
)


Expand All @@ -25,16 +23,29 @@ class TestComponent(Component):


@pytest.fixture
def artifact_strore(tmpdir) -> FileSystemArtifactStore:
artifact_strore = FileSystemArtifactStore(f"{tmpdir}")
artifact_strore._serializers = serializers
return artifact_strore
def random_directory(tmpdir):
tmpdir_path = os.path.join(tmpdir, "test_data")
os.makedirs(tmpdir_path, exist_ok=True)
for i in range(10):
file_name = f'{i}.txt'
file_path = os.path.join(tmpdir_path, file_name)

with open(file_path, 'w') as file:
file.write(str(i))

def test_save_and_load_directory(test_db):
for j in range(10):
sub_dir = os.path.join(tmpdir_path, f'subdir_{j}')
os.makedirs(sub_dir, exist_ok=True)
sub_file_path = os.path.join(sub_dir, file_name)
with open(sub_file_path, 'w') as file:
file.write(f"{i} {j}")

return tmpdir_path


def test_save_and_load_directory(test_db, random_directory):
# test save and load directory
directory = os.path.join(os.getcwd(), "superduperdb")
test_component = TestComponent(path=directory, identifier="test")
test_component = TestComponent(path=random_directory, identifier="test")
test_db.add(test_component)
test_component_loaded = test_db.load("TestComponent", "test")
test_component_loaded.init()
Expand Down
39 changes: 31 additions & 8 deletions test/unittest/backends/local/test_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,42 @@ class TestComponent(Component):


@pytest.fixture
def artifact_strore(tmpdir) -> FileSystemArtifactStore:
artifact_strore = FileSystemArtifactStore(f"{tmpdir}")
def random_directory(tmpdir):
tmpdir_path = os.path.join(tmpdir, "test_data")
os.makedirs(tmpdir_path, exist_ok=True)
for i in range(10):
file_name = f'{i}.txt'
file_path = os.path.join(tmpdir_path, file_name)

with open(file_path, 'w') as file:
file.write(str(i))

for j in range(10):
sub_dir = os.path.join(tmpdir_path, f'subdir_{j}')
os.makedirs(sub_dir, exist_ok=True)
sub_file_path = os.path.join(sub_dir, file_name)
with open(sub_file_path, 'w') as file:
file.write(f"{i} {j}")

return tmpdir_path


@pytest.fixture
def artifact_store(tmpdir) -> FileSystemArtifactStore:
tmpdir_path = os.path.join(tmpdir, "artifact_store")
artifact_strore = FileSystemArtifactStore(f"{tmpdir_path}")
artifact_strore._serializers = serializers
return artifact_strore


@pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True)
def test_save_and_load_directory(db, artifact_strore: FileSystemArtifactStore):
db.artifact_store = artifact_strore
def test_save_and_load_directory(
db, artifact_store: FileSystemArtifactStore, random_directory
):
db.artifact_store = artifact_store

# test save and load directory
directory = os.path.join(os.getcwd(), "superduperdb")
test_component = TestComponent(path=directory, identifier="test")
test_component = TestComponent(path=random_directory, identifier="test")
db.add(test_component)
test_component_loaded = db.load("TestComponent", "test")
test_component_loaded.init()
Expand All @@ -56,8 +79,8 @@ def test_save_and_load_directory(db, artifact_strore: FileSystemArtifactStore):


@pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True)
def test_save_and_load_file(db, artifact_strore: FileSystemArtifactStore):
db.artifact_store = artifact_strore
def test_save_and_load_file(db, artifact_store: FileSystemArtifactStore):
db.artifact_store = artifact_store
# test save and load file
file = os.path.abspath(__file__)
test_component = TestComponent(path=file, identifier="test")
Expand Down

0 comments on commit d4179b6

Please sign in to comment.