Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add file datatype type to support saving and reading files/folders in artifact_store #1805

Merged
merged 6 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Update Menu structure and renamed use-cases
- Change and simplify the contract for writing new `_Predictor` descendants (`.predict_one`, `.predict`)
- Create models directly by importing package from auto
- Add file datatype type to support saving and reading files/folders in artifact_store

#### Bug Fixes
- LLM CI random errors
Expand Down
64 changes: 41 additions & 23 deletions superduperdb/backends/base/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ def exists(

@abstractmethod
def _save_bytes(self, serialized: bytes, file_id: str):
"""Save bytes in artifact store""" ""
pass

@abstractmethod
def _save_file(self, file_path: str, file_id: str) -> str:
"""Save file in artifact store and return file_id"""
pass

def save_artifact(self, r: t.Dict):
Expand All @@ -95,18 +101,23 @@ def save_artifact(self, r: t.Dict):
and optional fields
{'file_id', 'uri'}
"""
assert 'bytes' in r, 'serialized bytes are missing!'
assert 'datatype' in r, 'no datatype specified!'
datatype = self.serializers[r['datatype']]
uri = r.get('uri')
file_id = None
if uri is not None:
file_id = _construct_file_id_from_uri(uri)
if r.get('type') == 'file':
assert 'file_id' in r, 'file_id is missing!'
file_id = self._save_file(r['x'], r['file_id'])
else:
file_id = hashlib.sha1(r['bytes']).hexdigest()
if r.get('directory'):
file_id = os.path.join(datatype.directory, file_id)
self._save_bytes(r['bytes'], file_id=file_id)
assert 'bytes' in r, 'serialized bytes are missing!'
assert 'datatype' in r, 'no datatype specified!'
datatype = self.serializers[r['datatype']]
uri = r.get('uri')
if uri is not None:
file_id = _construct_file_id_from_uri(uri)
else:
file_id = hashlib.sha1(r['bytes']).hexdigest()
if r.get('directory'):
file_id = os.path.join(datatype.directory, file_id)

self._save_bytes(r['bytes'], file_id=file_id)
del r['bytes']
r['file_id'] = file_id
return r

Expand All @@ -119,6 +130,15 @@ def _load_bytes(self, file_id: str) -> bytes:
"""
pass

@abstractmethod
def _load_file(self, file_id: str) -> str:
"""
Load file from artifact store and return path

:param file_id: Identifier of artifact in the store
"""
pass

def load_artifact(self, r):
"""
Load artifact from artifact store, and deserialize.
Expand All @@ -129,26 +149,24 @@ def load_artifact(self, r):

datatype = self.serializers[r['datatype']]
file_id = r.get('file_id')
uri = r.get('uri')
if file_id is None:
assert uri is not None, '"uri" and "file_id" can\'t both be None'
file_id = _construct_file_id_from_uri(uri)
bytes = self._load_bytes(file_id)
return datatype.decoder(bytes)
if r.get('type') == 'file':
x = self._load_file(file_id)
else:
uri = r.get('uri')
if file_id is None:
assert uri is not None, '"uri" and "file_id" can\'t both be None'
file_id = _construct_file_id_from_uri(uri)
x = self._load_bytes(file_id)
return datatype.decoder(x)

def save(self, r: t.Dict) -> t.Dict:
"""
Save list of artifacts and replace the artifacts with file reference
:param artifacts: List of ``Artifact`` instances
"""
if isinstance(r, dict):
if (
'_content' in r
and r['_content']['leaf_type'] == 'encodable'
and 'bytes' in r['_content']
):
if '_content' in r and r['_content'].get('artifact'):
self.save_artifact(r['_content'])
del r['_content']['bytes']
else:
for k in r:
self.save(r[k])
Expand Down
24 changes: 24 additions & 0 deletions superduperdb/backends/local/artifacts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import shutil
import typing as t
from pathlib import Path

import click

Expand Down Expand Up @@ -69,6 +70,29 @@ def _load_bytes(self, file_id: str) -> bytes:
with open(os.path.join(self.conn, file_id), 'rb') as f:
return f.read()

def _save_file(self, file_path: str, file_id: str):
"""
Save file in artifact store and return the relative path
return the relative path {file_id}/{name}
"""
path = Path(file_path)
name = path.name
file_id_folder = os.path.join(self.conn, file_id)
os.makedirs(file_id_folder, exist_ok=True)
save_path = os.path.join(file_id_folder, name)
logging.info(f"Copying file {file_path} to {save_path}")
if path.is_dir():
shutil.copytree(file_path, save_path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jieguangzhou if there is a partially copy and program crashes may be a rollback?
possibly in future if not in this pr.

Ignore if already happens in shut.copy*

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we need to do this in future, not only filesystem, all the artifact store need to support this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be handled in db.add. We need to be able to roll back the whole add process.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Deleting already added artifacts
  • Cancelling any running jobs
  • Removing computed outputs

Not an easy task.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jieguangzhou @blythed
Since we are copying it in artifact as a snapshot of the directory, we need to caution users that if they change the source directory after db.load , it will not be reflected in the component.
although it is intended to be like this

else:
shutil.copy(file_path, save_path)
# return the relative path {file_id}/{name}
return os.path.join(file_id, name)

def _load_file(self, file_id: str) -> str:
"""Return the path to the file in the artifact store"""
logging.info(f"Loading file {file_id} from {self.conn}")
return os.path.join(self.conn, file_id)

def disconnect(self):
"""
Disconnect the client
Expand Down
115 changes: 114 additions & 1 deletion superduperdb/backends/mongodb/artifacts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import os
import tempfile
from pathlib import Path

import click
import gridfs
from tqdm import tqdm

from superduperdb import logging
from superduperdb import CFG, logging
from superduperdb.backends.base.artifact import ArtifactStore
from superduperdb.misc.colors import Colors

Expand Down Expand Up @@ -48,6 +53,22 @@ def _load_bytes(self, file_id: str):
raise FileNotFoundError(f'File not found in {file_id}')
return cur.read()

def _save_file(self, file_path: str, file_id: str):
"""Save file to GridFS"""
path = Path(file_path)
if path.is_dir():
upload_folder(file_path, file_id, self.filesystem)
else:
upload_file(file_path, file_id, self.filesystem)
return file_id

def _load_file(self, file_id: str) -> str:
"""
Download file from GridFS and return the path
The path is a temporary directory, {tmp_prefix}/{file_id}/{filename or folder}
"""
return download(file_id, self.filesystem)

def _save_bytes(self, serialized: bytes, file_id: str):
return self.filesystem.put(serialized, filename=file_id)

Expand All @@ -57,3 +78,95 @@ def disconnect(self):
"""

# TODO: implement me


def upload_file(path, file_id, fs):
"""Upload file to GridFS"""
logging.info(f"Uploading file {path} to GridFS with file_id {file_id}")
path = Path(path)
with open(path, 'rb') as file_to_upload:
fs.put(
file_to_upload,
filename=path.name,
metadata={"file_id": file_id, "type": "file"},
)


def upload_folder(path, file_id, fs, parent_path=""):
"""Upload folder to GridFS"""
path = Path(path)
if not parent_path:
logging.info(f"Uploading folder {path} to GridFS with file_id {file_id}")
parent_path = os.path.basename(path)

# if the folder is empty, create an empty file
if not os.listdir(path):
fs.put(
b'',
filename=os.path.join(parent_path, os.path.basename(path)),
metadata={"file_id": file_id, "is_empty_dir": True, 'type': 'dir'},
)
else:
for item in os.listdir(path):
item_path = os.path.join(path, item)
if os.path.isdir(item_path):
upload_folder(item_path, file_id, fs, os.path.join(parent_path, item))
else:
with open(item_path, 'rb') as file_to_upload:
fs.put(
file_to_upload,
filename=os.path.join(parent_path, item),
metadata={"file_id": file_id, "type": "dir"},
)


def download(file_id, fs):
"""Download file or folder from GridFS and return the path"""

download_folder = CFG.downloads.folder

if not download_folder:
download_folder = os.path.join(
tempfile.gettempdir(), "superduperdb", "ArtifactStore"
)

save_folder = os.path.join(download_folder, file_id)
os.makedirs(save_folder, exist_ok=True)

file = fs.find_one({"metadata.file_id": file_id})
if file is None:
raise FileNotFoundError(f"File not found in {file_id}")

type_ = file.metadata.get("type")
if type_ not in {"file", "dir"}:
raise ValueError(
f"Unknown type '{type_}' for file_id {file_id}, expected file or dir"
)

if type_ == 'file':
save_path = os.path.join(save_folder, os.path.split(file.filename)[-1])
logging.info(f"Downloading file_id {file_id} to {save_path}")
with open(save_path, 'wb') as f:
f.write(file.read())
return save_path

logging.info(f"Downloading folder with file_id {file_id} to {save_folder}")
for grid_out in tqdm(
fs.find({"metadata.file_id": file_id, "metadata.type": "dir"})
):
file_path = os.path.join(save_folder, grid_out.filename)
if grid_out.metadata.get("is_empty_dir", False):
if not os.path.exists(file_path):
os.makedirs(file_path)
else:
directory = os.path.dirname(file_path)
if not os.path.exists(directory):
os.makedirs(directory)
with open(file_path, 'wb') as file_to_write:
file_to_write.write(grid_out.read())

folders = os.listdir(save_folder)
assert len(folders) == 1, f"Expected only one folder, got {folders}"
save_folder = os.path.join(save_folder, folders[0])
logging.info(f"Downloaded folder with file_id {file_id} to {save_folder}")
return save_folder
12 changes: 8 additions & 4 deletions superduperdb/base/datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from superduperdb.base.superduper import superduper
from superduperdb.cdc.cdc import DatabaseChangeDataCapture
from superduperdb.components.component import Component
from superduperdb.components.datatype import DataType, Encodable, serializers
from superduperdb.components.datatype import DataType, _BaseEncodable, serializers
from superduperdb.components.model import ObjectModel
from superduperdb.components.schema import Schema
from superduperdb.jobs.job import ComponentJob, FunctionJob, Job
Expand Down Expand Up @@ -188,7 +188,7 @@ def backfill_vector_search(self, vi, searcher):
vi.indexing_listener.model.identifier,
version=vi.indexing_listener.model.version,
)
if isinstance(h, Encodable):
if isinstance(h, _BaseEncodable):
h = h.x

items.append(VectorItem.create(id=str(id), vector=h))
Expand Down Expand Up @@ -574,7 +574,7 @@ def load(
if info_only:
return info

m = serializable.Serializable.decode(info)
m = serializable.Serializable.decode(info, db=self)
m.db = self
m.on_load(self)

Expand Down Expand Up @@ -803,7 +803,11 @@ def _add(
if serialized is None:
leaves = object.dict().get_leaves()
leaves = leaves.values()
artifacts = [leaf for leaf in leaves if isinstance(leaf, Encodable)]
artifacts = [
leaf
for leaf in leaves
if isinstance(leaf, _BaseEncodable) and leaf.artifact
kartik4949 marked this conversation as resolved.
Show resolved Hide resolved
]
children = [leaf for leaf in leaves if isinstance(leaf, Component)]

for child in children:
Expand Down
15 changes: 8 additions & 7 deletions superduperdb/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from superduperdb import CFG
from superduperdb.backends.base.artifact import _construct_file_id_from_uri
from superduperdb.base.config import BytesEncoding
from superduperdb.base.leaf import Leaf
from superduperdb.base.leaf import Leaf, find_leaf_cls
from superduperdb.base.serializable import Serializable
from superduperdb.components.component import Component
from superduperdb.components.datatype import DataType, Encodable
from superduperdb.components.datatype import DataType, Encodable, _BaseEncodable
from superduperdb.misc.special_dicts import MongoStyleDict

if t.TYPE_CHECKING:
Expand Down Expand Up @@ -129,7 +129,8 @@ def _find_leaves(r: t.Any, leaf_type: t.Optional[str] = None, pop: bool = False)
keys.extend([(f'{k}.{i}' if k else f'{i}') for k in sub_keys])
return keys, leaves
if leaf_type:
if isinstance(r, _LEAF_TYPES[leaf_type]):
leaf_cls = _LEAF_TYPES.get(leaf_type) or find_leaf_cls(leaf_type)
if isinstance(r, leaf_cls):
return [''], [r]
else:
return [], []
Expand All @@ -148,9 +149,9 @@ def _decode(
) -> t.Any:
bytes_encoding = bytes_encoding or CFG.bytes_encoding
if isinstance(r, dict) and '_content' in r:
return _LEAF_TYPES[r['_content']['leaf_type']].decode(
r, db=db, reference=reference
)
leaf_type = r['_content']['leaf_type']
leaf_cls = _LEAF_TYPES.get(leaf_type) or find_leaf_cls(leaf_type)
return leaf_cls.decode(r, db=db, reference=reference)
elif isinstance(r, list):
return [
_decode(x, db=db, bytes_encoding=bytes_encoding, reference=reference)
Expand Down Expand Up @@ -237,7 +238,7 @@ def _encode_with_schema(


def _unpack(item: t.Any, db=None) -> t.Any:
if isinstance(item, Encodable):
if isinstance(item, _BaseEncodable):
# TODO move logic into Encodable
if item.reference:
file_id = _construct_file_id_from_uri(item.uri)
Expand Down
Loading
Loading