From faa303cfe4403bc3898bf3f749808ad2863ed7e1 Mon Sep 17 00:00:00 2001 From: Duncan Blythe Date: Tue, 21 May 2024 20:30:58 +0200 Subject: [PATCH] Refactor rest server implementation --- .github/workflows/ci_code.yml | 10 +- CHANGELOG.md | 2 + Makefile | 10 +- deploy/rest/config.yaml | 231 ++++++++++++++++++ deploy/testenv/env/rest/rest_mock.yaml | 1 + superduperdb/backends/base/artifacts.py | 19 +- superduperdb/backends/base/query.py | 23 +- superduperdb/backends/ibis/query.py | 6 +- superduperdb/backends/local/artifacts.py | 24 +- superduperdb/backends/mongodb/artifacts.py | 26 +- superduperdb/backends/mongodb/metadata.py | 2 +- superduperdb/backends/mongodb/query.py | 6 +- superduperdb/backends/sqlalchemy/metadata.py | 5 +- superduperdb/base/code.py | 8 +- superduperdb/base/config.py | 3 + superduperdb/base/datalayer.py | 58 ++++- superduperdb/base/document.py | 38 +-- superduperdb/base/leaf.py | 8 - superduperdb/cdc/cdc.py | 2 +- superduperdb/components/component.py | 19 -- superduperdb/components/datatype.py | 49 +--- superduperdb/components/graph.py | 6 - superduperdb/components/listener.py | 23 -- superduperdb/components/metric.py | 1 - superduperdb/components/model.py | 73 ------ superduperdb/components/vector_index.py | 6 - superduperdb/ext/llm/prompter.py | 5 - superduperdb/ext/openai/model.py | 6 - .../ext/sentence_transformers/model.py | 16 -- superduperdb/ext/sklearn/model.py | 5 - superduperdb/misc/annotations.py | 6 +- superduperdb/rest/app.py | 155 ++++-------- superduperdb/rest/utils.py | 97 -------- superduperdb/server/app.py | 13 +- test/rest/mock_client.py | 78 +++--- test/rest/test_rest.py | 156 +++++++----- test/unittest/base/test_datalayer.py | 27 +- test/unittest/base/test_document.py | 29 ++- 38 files changed, 661 insertions(+), 591 deletions(-) create mode 100644 deploy/rest/config.yaml delete mode 100644 superduperdb/rest/utils.py diff --git a/.github/workflows/ci_code.yml b/.github/workflows/ci_code.yml index 4610872fd0..7660bf246c 100644 --- a/.github/workflows/ci_code.yml +++ b/.github/workflows/ci_code.yml @@ -80,13 +80,9 @@ jobs: run: | make ext_testing - - name: Upload code coverage to Codecov - uses: codecov/codecov-action@v3.1.4 - with: - env_vars: RUNNER_OS,PYTHON_VERSION - file: ./coverage.xml - fail_ci_if_error: false - name: codecov-umbrella + - name: Rest Testing + run: | + make rest_testing # --------------------------------- # Integration Testing diff --git a/CHANGELOG.md b/CHANGELOG.md index 4a9c7e11d0..cc987f04df 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,8 +20,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Force load vector indices during backfill - Fix pandas database (in-memory) - Add and update docstrings in component classes and methods. +- Changed the rest implementation to use new serialization #### New Features & Functionality + - Add nightly image for pre-release testing in the cloud environment - Fix torch model fit and make schedule_jobs at db add - Add requires functionality for all extension modules diff --git a/Makefile b/Makefile index 30ffd59b1c..c989a72ebb 100644 --- a/Makefile +++ b/Makefile @@ -53,7 +53,7 @@ install_devkit: ## Add essential development tools python -m pip install pre-commit @echo "Download Code Quality dependencies" - python -m pip install --user black==23.3 ruff mypy types-PyYAML types-requests interrogate + python -m pip install --user black==23.3 ruff==0.4.4 mypy types-PyYAML types-requests interrogate @echo "Download Code Testing dependencies" python -m pip install --user pytest pytest-cov "nbval>=0.10.0" @@ -222,5 +222,13 @@ ext_testing: ## Execute integration testing find ./test -type f -name "*.pyc" -delete pytest $(PYTEST_ARGUMENTS) ./test/integration/ext +rest_testing: ## Execute smoke testing + echo "starting rest server" + SUPERDUPERDB_CONFIG=deploy/testenv/env/rest/rest_mock.yaml python -m superduperdb rest & + sleep 10 + SUPERDUPERDB_CONFIG=deploy/testenv/env/rest/rest_mock.yaml pytest test/rest/test_rest.py + echo "stopping rest server" + lsof -ti:8002 | xargs kill -9 + smoke_testing: ## Execute smoke testing SUPERDUPERDB_CONFIG=deploy/testenv/env/smoke/config.yaml pytest $(PYTEST_ARGUMENTS) ./test/smoke diff --git a/deploy/rest/config.yaml b/deploy/rest/config.yaml new file mode 100644 index 0000000000..7e4eb6f780 --- /dev/null +++ b/deploy/rest/config.yaml @@ -0,0 +1,231 @@ +leaves: + query: + MongoQuery: + _path: superduperdb/backends/mongodb/query/parse_query + query: + type: code + documents: + type: json + default: [] + code: + Code: + _path: superduperdb/Code + identifier: + type: str + code: + type: str + default: | + from superduperdb import code + + @code + def my_code(x): + return x + lazy_artifact: + LazyArtifact: + _path: superduperdb/components/datatype/LazyArtifact + identifier: + type: str + file_id: + type: blob + vector_index: + VectorIndex: + _path: superduperdb/VectorIndex + identifier: + type: str + measure: + type: str + choices: + - cosine + - dot + - l2 + indexing_listener: + type: listener + compatible_listener: + type: listener + optional: True + datatype: + image: + _path: superduperdb/ext/pillow/image_type + identifier: + type: str + media_type: + type: str + default: image/png + vector: + _path: superduperdb/vector + identifier: + type: str + shape: + type: int + stack: + Stack: + _path: superduperdb/Stack + identifier: + type: str + components: + type: [model, listener, vector_index] + listener: + Listener: + _path: superduperdb/Listener + identifier: + type: str + key: + type: str + select: + type: query + optional: True + model: + ObjectModel: + _path: superduperdb/ObjectModel + identifier: + type: str + object: + type: lazy_artifact + datatype: + type: datatype + optional: True + predict_kwargs: + type: json + optional: True + default: {} + signature: + type: str + optional: True + default: "*args,**kwargs" + SequentialModel: + _path: superduperdb/SequentialModel + identifier: + type: str + models: + type: model + sequence: True + QueryModel: + _path: superduperdb/QueryModel + identifier: + type: str + select: + type: query + optional: True + default: + documents: + - {"": "$my_value"} + - {"_outputs": 0, "_id": 0} + query: | + .like(documents[0], vector_index='').find({}, documents[1]).limit(10) + CodeModel: + _path: superduperdb/CodeModel + identifier: + type: str + object: + type: code + default: | + from superduperdb import code + + @code + def my_code(x): + return x + datatype: + type: component/datatype + optional: True + predict_kwargs: + type: json + optional: True + default: {} + signature: + type: str + optional: True + default: "*args,**kwargs" + RetrievalPrompt: + _path: superduperdb/ext/llm/prompt/RetrievalPrompt + select: + type: query + prompt_explanation: + type: str + default: | + HERE ARE SOME FACTS SEPARATED BY '---' IN OUR DATA + REPOSITORY WHICH WILL HELP YOU ANSWER THE QUESTION. + prompt_introduction: + type: str + default: | + HERE IS THE QUESTION WHICH YOU SHOULD ANSWER BASED + ONLY ON THE PREVIOUS FACTS + join: + type: str + default: "\n---\n" + SklearnEstimator: + _path: superduperdb/ext/sklearn/Estimator + identifier: + type: str + object: + type: lazy_artifact + preprocess: + type: code + optional: True + postprocess: + type: code + optional: True + OpenAIEmbedding: + _path: superduperdb/ext/openai/OpenAIEmbedding + identifier: + type: str + model: + type: str + openai_api_key: + type: str + optional: True + openai_api_base: + type: str + optional: True + OpenAIChatCompletion: + _path: superduperdb/ext/openai/OpenAIChatCompletion + identifier: + type: str + model: + type: str + openai_api_key: + type: str + optional: True + openai_api_base: + type: str + optional: True + SentenceTransformer: + _path: superduperdb/ext/sentence_transformers/SentenceTransformer + identifier: + type: str + model: + type: str + device: + type: str + default: cpu + predict_kwargs: + type: json + default: + show_progress_bar: true + postprocess: + type: code + default: | + from superduperdb import code + + @code + def my_code(x): + return x.tolist() + signature: + type: str + default: singleton + +presets: + datatype: + pickle: + _path: superduperdb/components/datatype/get_serializer + identifier: pickle_lazy + method: pickle + encodable: lazy_artifact + dill: + _path: superduperdb/components/datatype/get_serializer + identifier: dill_lazy + method: dill + encodable: lazy_artifact + image: + _path: superduperdb/ext/pillow/encoder/image_type + identifier: image + media_type: image/png diff --git a/deploy/testenv/env/rest/rest_mock.yaml b/deploy/testenv/env/rest/rest_mock.yaml index 52a65b7c85..0a937f0a19 100644 --- a/deploy/testenv/env/rest/rest_mock.yaml +++ b/deploy/testenv/env/rest/rest_mock.yaml @@ -4,6 +4,7 @@ bytes_encoding: Bytes cluster: rest: uri: http://localhost:8002 + config: deploy/rest/config.yaml data_backend: mongomock://test downloads: folder: null diff --git a/superduperdb/backends/base/artifacts.py b/superduperdb/backends/base/artifacts.py index 349b5681e5..8ef10995a6 100644 --- a/superduperdb/backends/base/artifacts.py +++ b/superduperdb/backends/base/artifacts.py @@ -97,16 +97,15 @@ def exists( return self._exists(file_id) @abstractmethod - def _save_bytes(self, serialized: bytes, file_id: str): + def put_bytes(self, serialized: bytes, file_id: str): """Save bytes in artifact store""" "" pass @abstractmethod - def _save_file(self, file_path: str, file_id: str) -> str: + def put_file(self, file_path: str, file_id: str) -> str: """Save file in artifact store and return file_id.""" pass - # def save_artifact(self, r: t.Dict): """Save serialized object in the artifact store. @@ -117,13 +116,13 @@ def save_artifact(self, r: t.Dict): for file_id, blob in blobs.items(): try: - self._save_bytes(blob, file_id=file_id) + self.put_bytes(blob, file_id=file_id) except FileExistsError: continue for file_id, file_path in files.items(): try: - self._save_file(file_path, file_id=file_id) + self.put_file(file_path, file_id=file_id) except FileExistsError: continue @@ -154,7 +153,7 @@ def update_artifact(self, old_r: t.Dict, new_r: t.Dict): return self.save_artifact(new_r) @abstractmethod - def _load_bytes(self, file_id: str) -> bytes: + def get_bytes(self, file_id: str) -> bytes: """ Load bytes from artifact store. @@ -163,7 +162,7 @@ def _load_bytes(self, file_id: str) -> bytes: pass @abstractmethod - def _load_file(self, file_id: str) -> str: + def get_file(self, file_id: str) -> str: """ Load file from artifact store and return path. @@ -180,14 +179,14 @@ def load_artifact(self, r): datatype = self.serializers[r['datatype']] file_id = r.get('file_id') if r.get('encodable') == 'file': - x = self._load_file(file_id) + x = self.get_file(file_id) else: - # We should always have file_id available at load time (because saved) + # TODO We should always have file_id available at load time (because saved) uri = r.get('uri') if file_id is None: assert uri is not None, '"uri" and "file_id" can\'t both be None' file_id = _construct_file_id_from_uri(uri) - x = self._load_bytes(file_id) + x = self.get_bytes(file_id) return datatype.decode_data(x) def save(self, r: t.Dict) -> t.Dict: diff --git a/superduperdb/backends/base/query.py b/superduperdb/backends/base/query.py index 0580d00fa5..0e8121db82 100644 --- a/superduperdb/backends/base/query.py +++ b/superduperdb/backends/base/query.py @@ -219,7 +219,6 @@ def _update_item(a, documents, queries): def _to_str(self): documents = {} queries = {} - # out = self.identifier[:] out = str(self.identifier) for part in self.parts: if isinstance(part, str): @@ -256,11 +255,11 @@ def _dump_query(self): def __repr__(self): output, docs = self._dump_query() - for i, doc in enumerate(docs): - doc_string = str(doc) - if isinstance(doc, Document): - doc_string = str(doc.unpack()) - output = output.replace(f'documents[{i}]', doc_string) + # for i, doc in enumerate(docs): + # doc_string = str(doc) + # if isinstance(doc, Document): + # doc_string = str(doc.unpack()) + # output = output.replace(f'documents[{i}]', doc_string) return output def __eq__(self, other): @@ -502,7 +501,7 @@ def _prepare_documents(self): return documents -def _parse_query_part(part, documents, query, builder_cls): +def _parse_query_part(part, documents, query, builder_cls, db=None): key = part.split('.') if key[0] == '_outputs': identifier = f'{key[0]}.{key[1]}' @@ -511,7 +510,7 @@ def _parse_query_part(part, documents, query, builder_cls): identifier = key[0] part = part.split('.')[1:] - current = builder_cls(identifier=identifier, parts=()) + current = builder_cls(identifier=identifier, parts=(), db=db) for comp in part: match = re.match('^([a-zA-Z0-9_]+)\((.*)\)$', comp) if match is None: @@ -537,22 +536,22 @@ def _parse_query_part(part, documents, query, builder_cls): def parse_query( query: t.Union[str, list], - documents, builder_cls, + documents: t.Sequence[t.Any] = (), db: t.Optional['Datalayer'] = None, ): """Parse a string query into a query object. :param query: The query to parse. - :param documents: The documents to query. :param builder_cls: The class to use to build the query. + :param documents: The documents to query. :param db: The datalayer to use to execute the query. """ - documents = [Document(r) for r in documents] + documents = [Document(r, db=db) for r in documents] if isinstance(query, str): query = [x.strip() for x in query.split('\n') if x.strip()] for i, q in enumerate(query): - query[i] = _parse_query_part(q, documents, query[:i], builder_cls) + query[i] = _parse_query_part(q, documents, query[:i], builder_cls, db=db) return query[-1] diff --git a/superduperdb/backends/ibis/query.py b/superduperdb/backends/ibis/query.py index 8650286a0b..962d6185f5 100644 --- a/superduperdb/backends/ibis/query.py +++ b/superduperdb/backends/ibis/query.py @@ -22,7 +22,9 @@ from superduperdb.base.datalayer import Datalayer -def parse_query(query, documents, db: t.Optional['Datalayer'] = None): +def parse_query( + query, documents: t.Sequence[t.Dict] = (), db: t.Optional['Datalayer'] = None +): """Parse a string query into a query object. :param query: The query to parse. @@ -31,7 +33,7 @@ def parse_query(query, documents, db: t.Optional['Datalayer'] = None): """ return _parse_query( query=query, - documents=documents, + documents=list(documents), builder_cls=IbisQuery, db=db, ) diff --git a/superduperdb/backends/local/artifacts.py b/superduperdb/backends/local/artifacts.py index fd8744ae12..66f6ac1fab 100644 --- a/superduperdb/backends/local/artifacts.py +++ b/superduperdb/backends/local/artifacts.py @@ -66,22 +66,33 @@ def drop(self, force: bool = False): shutil.rmtree(self.conn, ignore_errors=force) os.makedirs(self.conn) - def _save_bytes( + def put_bytes( self, serialized: bytes, file_id: str, ) -> t.Any: + """ + Save bytes in artifact store. + + :param serialized: The bytes to be saved. + :param file_id: The id of the file. + """ path = os.path.join(self.conn, file_id) if os.path.exists(path): logging.warn(f"File {path} already exists") with open(path, 'wb') as f: f.write(serialized) - def _load_bytes(self, file_id: str) -> bytes: + def get_bytes(self, file_id: str) -> bytes: + """ + Return the bytes from the artifact store. + + :param file_id: The id of the file. + """ with open(os.path.join(self.conn, file_id), 'rb') as f: return f.read() - def _save_file(self, file_path: str, file_id: str): + def put_file(self, file_path: str, file_id: str): """Save file in artifact store and return the relative path. return the relative path {file_id}/{name} @@ -102,8 +113,11 @@ def _save_file(self, file_path: str, file_id: str): # return the relative path {file_id}/{name} return os.path.join(file_id, name) - def _load_file(self, file_id: str) -> str: - """Return the path to the file in the artifact store.""" + def get_file(self, file_id: str) -> str: + """Return the path to the file in the artifact store. + + :param file_id: The id of the file. + """ logging.info(f"Loading file {file_id} from {self.conn}") return os.path.join(self.conn, file_id) diff --git a/superduperdb/backends/mongodb/artifacts.py b/superduperdb/backends/mongodb/artifacts.py index 530a5a8ed3..1fdf765671 100644 --- a/superduperdb/backends/mongodb/artifacts.py +++ b/superduperdb/backends/mongodb/artifacts.py @@ -56,14 +56,23 @@ def _delete_bytes(self, file_id: str): for _id in ids: self.filesystem.delete(_id) - def _load_bytes(self, file_id: str): + def get_bytes(self, file_id: str): + """ + Get the bytes of the file from GridFS. + + :param file_id: The file_id of the file to get + """ cur = self.filesystem.find_one({'filename': file_id}) if cur is None: raise FileNotFoundError(f'File not found in {file_id}') return cur.read() - def _save_file(self, file_path: str, file_id: str): - """Save file to GridFS.""" + def put_file(self, file_path: str, file_id: str): + """Save file to GridFS. + + :param file_path: The path to the file to save + :param file_id: The file_id of the file + """ path = Path(file_path) if path.is_dir(): upload_folder(file_path, file_id, self.filesystem) @@ -71,14 +80,21 @@ def _save_file(self, file_path: str, file_id: str): _upload_file(file_path, file_id, self.filesystem) return file_id - def _load_file(self, file_id: str) -> str: + def get_file(self, file_id: str) -> str: """Download file from GridFS and return the path. The path is a temporary directory, `{tmp_prefix}/{file_id}/{filename or folder}` + :param file_id: The file_id of the file to download """ return _download(file_id, self.filesystem) - def _save_bytes(self, serialized: bytes, file_id: str): + def put_bytes(self, serialized: bytes, file_id: str): + """ + Save bytes in GridFS. + + :param serialized: The bytes to save + :param file_id: The file_id of the file + """ cur = self.filesystem.find_one({'filename': file_id}) if cur is not None: raise FileExistsError diff --git a/superduperdb/backends/mongodb/metadata.py b/superduperdb/backends/mongodb/metadata.py index f9a0fc0fd1..81432a2b25 100644 --- a/superduperdb/backends/mongodb/metadata.py +++ b/superduperdb/backends/mongodb/metadata.py @@ -208,7 +208,7 @@ def show_component_versions( def list_components_in_scope(self, scope: str): """List components in a scope. - :param scope: scope of components + :param scope: scope of , include_prcomponents """ out = [] for r in self.component_collection.find({'parent': scope}): diff --git a/superduperdb/backends/mongodb/query.py b/superduperdb/backends/mongodb/query.py index 0c33d9797d..ead6f70768 100644 --- a/superduperdb/backends/mongodb/query.py +++ b/superduperdb/backends/mongodb/query.py @@ -22,7 +22,9 @@ from superduperdb.base.datalayer import Datalayer -def parse_query(query, documents, db: t.Optional['Datalayer'] = None): +def parse_query( + query, documents: t.Sequence[t.Dict] = (), db: t.Optional['Datalayer'] = None +): """Parse a string query into a query object. :param query: The query to parse. @@ -31,8 +33,8 @@ def parse_query(query, documents, db: t.Optional['Datalayer'] = None): """ return _parse_query( query=query, - documents=documents, builder_cls=MongoQuery, + documents=list(documents), db=db, ) diff --git a/superduperdb/backends/sqlalchemy/metadata.py b/superduperdb/backends/sqlalchemy/metadata.py index 6ddfa42c19..9444ba6672 100644 --- a/superduperdb/backends/sqlalchemy/metadata.py +++ b/superduperdb/backends/sqlalchemy/metadata.py @@ -11,7 +11,6 @@ from superduperdb.backends.base.metadata import MetaDataStore, NonExistentMetadataError from superduperdb.backends.sqlalchemy.db_helper import get_db_config from superduperdb.base.document import Document -from superduperdb.components.component import Component as _Component from superduperdb.misc.colors import Colors if t.TYPE_CHECKING: @@ -180,12 +179,12 @@ def component_version_has_parents( :param identifier: the identifier of the component :param version: the version of the component """ - unique_id = _Component.make_unique_id(type_id, identifier, version) + uuid = self._get_component_uuid(type_id, identifier, version) with self.session_context() as session: stmt = ( select(self.parent_child_association_table) .where( - self.parent_child_association_table.c.child_id == unique_id, + self.parent_child_association_table.c.child_id == uuid, ) .limit(1) ) diff --git a/superduperdb/base/code.py b/superduperdb/base/code.py index 615d7c13cf..dd54759b3a 100644 --- a/superduperdb/base/code.py +++ b/superduperdb/base/code.py @@ -1,6 +1,5 @@ import dataclasses as dc import inspect -import typing as t from superduperdb.base.leaf import Leaf from superduperdb.misc.annotations import merge_docstrings @@ -10,8 +9,6 @@ @code {definition}""" -default = template.format(definition='def my_code(x):\n return x\n') - @merge_docstrings @dc.dataclass(kw_only=True) @@ -24,7 +21,7 @@ class Code(Leaf): """ code: str - default: t.ClassVar[str] = default + identifier: str = '' @staticmethod def from_object(obj): @@ -40,7 +37,8 @@ def from_object(obj): print(mini_module) return Code(mini_module) - def __post_init__(self): + def __post_init__(self, db): + super().__post_init__(db) namespace = {} exec(self.code, namespace) remote_code = next( diff --git a/superduperdb/base/config.py b/superduperdb/base/config.py index c9d43c99f0..bbee3102b6 100644 --- a/superduperdb/base/config.py +++ b/superduperdb/base/config.py @@ -136,9 +136,12 @@ class Rest(BaseConfig): """Describes the configuration for the REST service. :param uri: The URI for the REST service + :param config: The path to the config yaml file + for the REST service """ uri: t.Optional[str] = None + config: t.Optional[str] = None @dc.dataclass diff --git a/superduperdb/base/datalayer.py b/superduperdb/base/datalayer.py index 229258031d..a1ea359cfd 100644 --- a/superduperdb/base/datalayer.py +++ b/superduperdb/base/datalayer.py @@ -62,7 +62,9 @@ class Datalayer: 'datatype': 'datatypes', 'vector_index': 'vector_indices', 'schema': 'schemas', + 'listener': 'listeners', } + cache_to_type_id_mapping = {v: k for k, v in type_id_to_cache_mapping.items()} def __init__( self, @@ -244,6 +246,7 @@ def show( type_id: t.Optional[str] = None, identifier: t.Optional[str] = None, version: t.Optional[int] = None, + include_presets: bool = False, ): """ Show available functionality which has been added using ``self.add``. @@ -262,11 +265,36 @@ def show( if type_id is None: nt = namedtuple('nt', ('type_id', 'identifier')) out = self.metadata.show_components() - out = list(set(nt(**x) for x in out)) + if not include_presets: + return out + subcaches = [ + self.models, + self.datatypes, + self.tables, + self.schemas, + self.vector_indices, + self.listeners, + ] + cached = sum( + [ + [nt(comp.type_id, comp.identifier) for comp in subcache.values()] + for subcache in subcaches + ], + [], + ) + out = sorted(list(set([nt(**x) for x in out] + cached))) return [x._asdict() for x in out] if identifier is None: - return self.metadata.show_components(type_id=type_id) + mapping = self.type_id_to_cache_mapping + out = self.metadata.show_components(type_id=type_id) + if not include_presets: + return sorted(out) + try: + t = mapping[type_id] + return sorted(list(set(out + list(getattr(self, t).keys())))) + except KeyError: + return sorted(out) if version is None: return sorted( @@ -579,6 +607,7 @@ def load( version: t.Optional[int] = None, allow_hidden: bool = False, uuid: t.Optional[str] = None, + include_presets: bool = False, ) -> t.Union[Component, t.Dict[str, t.Any]]: """ Load a component using uniquely identifying information. @@ -594,7 +623,22 @@ def load( :param allow_hidden: Toggle to ``True`` to allow loading of deprecated components. :param uuid: [Optional] UUID of the component to load. + :param include_presets: Include items cached in `db.models` etc. """ + if ( + include_presets + and isinstance(type_id, str) + and isinstance(identifier, str) + and version is None + and type_id in self.type_id_to_cache_mapping + ): + cache_name = self.type_id_to_cache_mapping[type_id] + cache = getattr(self, cache_name) + try: + return cache[identifier] + except KeyError: + pass + if type_id == 'encoder': logging.warn( '"encoder" has moved to "datatype" this functionality will not work' @@ -885,14 +929,16 @@ def _add( serialized = object.dict().encode(leaves_to_keep=(Component,)) children = [ - v for k, v in serialized['_leaves'].items() if isinstance(v, Component) + v for v in serialized['_leaves'].values() if isinstance(v, Component) ] jobs.extend(self._add_child_components(children, parent=object)) for k, v in serialized['_leaves'].items(): if isinstance(v, Component): - serialized['_leaves'][k] = f'%{v.id}' + serialized['_leaves'][ + k + ] = f'?db.load({v.type_id}, {v.identifier}, {v.version})' serialized = self.artifact_store.save_artifact(serialized) @@ -1093,7 +1139,9 @@ class LoadDict(dict): def __missing__(self, key: str): if self.field is not None: - value = self[key] = self.database.load(self.field, key) + value = self[key] = self.database.load( + self.field, key, include_presets=False + ) else: msg = f'callable is ``None`` for {key}' assert self.callable is not None, msg diff --git a/superduperdb/base/document.py b/superduperdb/base/document.py index 23e9285d9e..165a5a0bb5 100644 --- a/superduperdb/base/document.py +++ b/superduperdb/base/document.py @@ -1,3 +1,4 @@ +import re import typing as t from bson.objectid import ObjectId @@ -125,22 +126,25 @@ def decode( if '_leaves' in r: cache = r['_leaves'] - del r['_leaves'] if '_blobs' in r: blobs = r['_blobs'] - del r['_blobs'] if '_files' in r: files = r['_files'] - del r['_files'] schema = schema or r.get(SCHEMA_KEY) schema = get_schema(db, schema) if schema is not None: schema.init() r = schema.decode_data(r) - r = _deep_flat_decode(r, cache, blobs, files=files, db=db) + r = _deep_flat_decode( + {k: v for k, v in r.items() if k not in ('_leaves', '_blobs', '_files')}, + cache, + blobs, + files=files, + db=db, + ) if isinstance(r, dict): return Document(r, schema=schema) @@ -306,20 +310,24 @@ def _deep_flat_decode(r, cache, blobs, files={}, db: t.Optional['Datalayer'] = N module = '.'.join(parts[:-1]) dict_ = {k: v for k, v in r.items() if k != '_path'} dict_ = _deep_flat_decode(dict_, cache, blobs, files, db=db) - instence = _import_item(cls=cls, module=module, dict=dict_, db=db) - # TODO: Auto unpack the instence - # instence.unpack() - return instence - + instance = _import_item(cls=cls, module=module, dict=dict_, db=db) + return instance if isinstance(r, dict): return { k: _deep_flat_decode(v, cache, blobs, files, db=db) for k, v in r.items() } - if isinstance(r, str) and r.startswith('?'): + if isinstance(r, str) and r.startswith('?') and not r.startswith('?db'): return _get_leaf_from_cache(r[1:], cache, blobs, files, db=db) - if isinstance(r, str) and r.startswith('%'): - uuid = r.split('/')[-1] - if db is None: - raise ValueError(f'No database provided to decode {r}') - return db.load(uuid=uuid) + if isinstance(r, str) and re.match("^\?db\.load\((.*)\)$", r): + match = re.match("^\?db\.load\((.*)\)$", r) + assert match is not None + assert db is not None, 'db is required for ?db.load()' + args = [x.strip() for x in match.groups()[0].split(',')] + if len(args) == 1: + return db.load(uuid=args[0]) + if len(args) == 2: + return db.load(type_id=args[0], identifier=args[1], include_presets=True) + if len(args) == 3: + return db.load(type_id=args[0], identifier=args[1], version=int(args[2])) + raise ValueError(f'Invalid number of arguments for {r}') return r diff --git a/superduperdb/base/leaf.py b/superduperdb/base/leaf.py index 04d9b5b066..8b7a3b91df 100644 --- a/superduperdb/base/leaf.py +++ b/superduperdb/base/leaf.py @@ -135,14 +135,6 @@ def dict(self): ) return Document({'_path': path, **r}) - @classmethod - def handle_integration(cls, r): - """Method to handle integration. - - :param r: Encoded data. - """ - return r - @classmethod def _register_class(cls): """Register class in the class registry and set the full import path.""" diff --git a/superduperdb/cdc/cdc.py b/superduperdb/cdc/cdc.py index b6dfe5b0f2..ba18621d86 100644 --- a/superduperdb/cdc/cdc.py +++ b/superduperdb/cdc/cdc.py @@ -410,7 +410,7 @@ def __init__(self, db: 'Datalayer'): t.Union['TableOrCollection', 'IbisQuery'] ] = [] - listeners = self.db.show('listeners') + listeners = self.db.show('listener') if listeners: from superduperdb.components.listener import Listener diff --git a/superduperdb/components/component.py b/superduperdb/components/component.py index 08eb84c3e8..144284b710 100644 --- a/superduperdb/components/component.py +++ b/superduperdb/components/component.py @@ -79,7 +79,6 @@ class Component(Leaf): leaf_type: t.ClassVar[str] = 'component' _artifacts: t.ClassVar[t.Sequence[t.Tuple[str, 'DataType']]] = () set_post_init: t.ClassVar[t.Sequence] = ('version',) - ui_schema: t.ClassVar[t.List[t.Dict]] = [{'name': 'identifier', 'type': 'str'}] changed: t.ClassVar[set] = set([]) artifacts: dc.InitVar[t.Optional[t.Dict]] = None @@ -97,14 +96,6 @@ def __post_init__(self, db, artifacts): if not self.identifier: raise ValueError('identifier cannot be empty or None') - @classmethod - def handle_integration(cls, kwargs): - """Abstract method for handling integration. - - :param kwargs: Integration kwargs. - """ - return kwargs - @property def id(self): """Returns the component identifier.""" @@ -116,16 +107,6 @@ def id_tuple(self): """Returns an object as `ComponentTuple`.""" return ComponentTuple(self.type_id, self.identifier, self.version) - @classmethod - def get_ui_schema(cls): - """Helper method to get the UI schema.""" - out = {} - ancestors = cls.mro()[::-1] - for a in ancestors: - if hasattr(a, 'ui_schema'): - out.update({x['name']: x for x in a.ui_schema}) - return list(out.values()) - def set_variables(self, db, **kwargs): """Set free variables of self. diff --git a/superduperdb/components/datatype.py b/superduperdb/components/datatype.py index e2500f5e3a..c1d30d6c2c 100644 --- a/superduperdb/components/datatype.py +++ b/superduperdb/components/datatype.py @@ -192,36 +192,11 @@ class DataType(Component): :param encodable: The type of encodable object ('encodable', 'lazy_artifact', or 'file'). :param bytes_encoding: The encoding type for bytes ('base64' or 'bytes'). - :param intermidia_type: Type of the intermidia data - [IntermidiaType.BYTES, IntermidiaType.STRING] + :param intermediate_type: Type of the intermediate data + [IntermediateType.BYTES, IntermediateType.STRING] :param media_type: The media type. """ - ui_schema: t.ClassVar[t.List[t.Dict]] = [ - { - 'name': 'serializer', - 'type': 'string', - 'choices': ['pickle', 'dill', 'torch'], - 'default': 'dill', - }, - {'name': 'info', 'type': 'json', 'optional': True}, - {'name': 'shape', 'type': 'json', 'optional': True}, - {'name': 'directory', 'type': 'str', 'optional': True}, - { - 'name': 'encodable', - 'type': 'str', - 'choices': ['encodable', 'lazy_artifact', 'file'], - 'default': 'lazy_artifact', - }, - { - 'name': 'bytes_encoding', - 'type': 'str', - 'choices': ['base64', 'bytes'], - 'default': 'bytes', - }, - {'name': 'media_type', 'type': 'str', 'optional': True}, - ] - type_id: t.ClassVar[str] = 'datatype' encoder: t.Optional[t.Callable] = None # not necessary if encodable is file decoder: t.Optional[t.Callable] = None @@ -230,7 +205,7 @@ class DataType(Component): directory: t.Optional[str] = None encodable: str = 'encodable' bytes_encoding: t.Optional[str] = CFG.bytes_encoding - intermidia_type: t.Optional[str] = IntermediateType.BYTES + intermediate_type: t.Optional[str] = IntermediateType.BYTES media_type: t.Optional[str] = None registered_types: t.ClassVar[t.Dict[str, "DataType"]] = {} @@ -298,7 +273,7 @@ def bytes_encoding_after_encode(self, data): """ if ( self.bytes_encoding == BytesEncoding.BASE64 - and self.intermidia_type == IntermediateType.BYTES + and self.intermediate_type == IntermediateType.BYTES ): return bytes_to_base64(data) return data @@ -312,7 +287,7 @@ def bytes_encoding_before_decode(self, data): """ if ( self.bytes_encoding == BytesEncoding.BASE64 - and self.intermidia_type == IntermediateType.BYTES + and self.intermediate_type == IntermediateType.BYTES ): return base64_to_bytes(data) return data @@ -661,7 +636,7 @@ def init(self): """Initialize to load `x` with the actual file from the artifact store.""" assert self.file_id is not None if isinstance(self.x, Empty): - blob = self.db.artifact_store._load_bytes(self.file_id) + blob = self.db.artifact_store.get_bytes(self.file_id) self.datatype.init() self.x = self.datatype.decoder(blob) @@ -669,13 +644,15 @@ def _deep_flat_encode(self, cache, blobs, files, leaves_to_keep=(), schema=None) if isinstance(self, leaves_to_keep): cache[self.id] = self return f'?{self.id}' - maybe_bytes, file_id = self._encode() - self.file_id = file_id + maybe_bytes = None + if self.file_id is None: + maybe_bytes, self.file_id = self._encode() r = super()._deep_flat_encode( cache, blobs, files, leaves_to_keep=leaves_to_keep, schema=schema ) del r['x'] - blobs[self.file_id] = maybe_bytes + if isinstance(maybe_bytes, bytes): + blobs[self.file_id] = maybe_bytes cache[self.id] = r return f'?{self.id}' @@ -743,7 +720,7 @@ def _deep_flat_encode(self, cache, blobs, files, leaves_to_keep=(), schema=None) def init(self): """Initialize to load `x` with the actual file from the artifact store.""" if isinstance(self.x, Empty): - file = self.db.artifact_store._load_file(self.file_id) + file = self.db.artifact_store.get_file(self.file_id) if self.file_name is not None: file = os.path.join(file, self.file_name) self.x = file @@ -784,7 +761,7 @@ class LazyFile(File): decoder=json_decode, encodable='encodable', bytes_encoding=BytesEncoding.BASE64, - intermidia_type=IntermediateType.STRING, + intermediate_type=IntermediateType.STRING, ) methods: t.Dict[str, t.Dict] = { diff --git a/superduperdb/components/graph.py b/superduperdb/components/graph.py index 8561d88089..d3c765c5a0 100644 --- a/superduperdb/components/graph.py +++ b/superduperdb/components/graph.py @@ -255,12 +255,6 @@ class Graph(Model): _DEFAULT_ARG_WEIGHT: t.ClassVar[t.Tuple] = (None, 'singleton') type_id: t.ClassVar[str] = 'model' - ui_schema: t.ClassVar[t.List[t.Dict]] = [ - {'name': 'models', 'type': 'component/model', 'sequence': True}, - {'name': 'edges', 'type': 'json'}, - {'name': 'signature', 'type': 'str', 'default': '*args,**kwargs'}, - ] - models: t.List[Model] = dc.field(default_factory=list) edges: t.List[t.Tuple[str, str, t.Tuple[t.Union[int, str], str]]] = dc.field( default_factory=list diff --git a/superduperdb/components/listener.py b/superduperdb/components/listener.py index f3b9d734c4..14fe8d2962 100644 --- a/superduperdb/components/listener.py +++ b/superduperdb/components/listener.py @@ -9,7 +9,6 @@ from superduperdb.components.model import Mapping from superduperdb.misc.annotations import merge_docstrings from superduperdb.misc.server import request_server -from superduperdb.rest.utils import parse_query from ..jobs.job import Job from .component import Component, ComponentTuple @@ -38,15 +37,6 @@ class Listener(Component): :param identifier: A string used to identify the model. """ - ui_schema: t.ClassVar[t.List[t.Dict]] = [ - {'name': 'identifier', 'type': 'str', 'default': ''}, - {'name': 'key', 'type': 'json'}, - {'name': 'model', 'type': 'component/model'}, - {'name': 'select', 'type': 'json', 'default': SELECT_TEMPLATE}, - {'name': 'active', 'type': 'bool', 'default': True}, - {'name': 'predict_kwargs', 'type': 'json', 'default': {}}, - ] - key: ModelInputType model: Model select: Query @@ -56,19 +46,6 @@ class Listener(Component): type_id: t.ClassVar[str] = 'listener' - @classmethod - def handle_integration(cls, kwargs): - """Method to handle integration. - - :param kwargs: Integration keyword arguments. - """ - if 'select' in kwargs and isinstance(kwargs['select'], dict): - kwargs['select'] = parse_query( - query=kwargs['select']['query'], - documents=kwargs['select']['documents'], - ) - return kwargs - def __post_init__(self, db, artifacts): if self.identifier == '': self.identifier = self.id diff --git a/superduperdb/components/metric.py b/superduperdb/components/metric.py index 98e848cc3a..5b15de8bac 100644 --- a/superduperdb/components/metric.py +++ b/superduperdb/components/metric.py @@ -16,7 +16,6 @@ class Metric(Component): """ type_id: t.ClassVar[str] = 'metric' - ui_schema: t.ClassVar[t.List[t.Dict]] = [{'name': 'object', 'type': 'artifact'}] object: t.Callable diff --git a/superduperdb/components/model.py b/superduperdb/components/model.py index 506fa1838a..50ce3974f1 100644 --- a/superduperdb/components/model.py +++ b/superduperdb/components/model.py @@ -27,7 +27,6 @@ from superduperdb.components.schema import Schema from superduperdb.jobs.job import ComponentJob, Job from superduperdb.misc.annotations import merge_docstrings -from superduperdb.rest.utils import parse_query if t.TYPE_CHECKING: from superduperdb.base.datalayer import Datalayer @@ -495,12 +494,6 @@ class Model(Component): """ type_id: t.ClassVar[str] = 'model' - ui_schema: t.ClassVar[t.Dict] = [ - {'name': 'datatype', 'type': 'component/datatype', 'optional': True}, - {'name': 'predict_kwargs', 'type': 'json', 'default': {}}, - {'name': 'signature', 'type': 'str', 'default': '*args,**kwargs'}, - ] - signature: Signature = '*args,**kwargs' datatype: EncoderArg = None output_schema: t.Optional[Schema] = None @@ -1064,11 +1057,6 @@ class _ObjectModel(Model, ABC): num_workers: int = 0 object: t.Any - ui_schema: t.ClassVar[t.List[t.Dict]] = [ - {'name': 'num_workers', 'type': 'int', 'default': '0'}, - {'name': 'signature', 'type': 'str', 'default': '*args,**kwargs'}, - ] - @property def outputs(self): """Get an instance of ``IndexableNode`` to index outputs.""" @@ -1142,8 +1130,6 @@ class ObjectModel(_ObjectModel): """ - ui_schema: t.ClassVar[t.List[t.Dict]] = [{'name': 'object', 'type': 'artifact'}] - _artifacts: t.ClassVar[t.Sequence[t.Tuple[str, 'DataType']]] = ( ('object', dill_lazy), ) @@ -1157,23 +1143,8 @@ class CodeModel(_ObjectModel): :param object: Code object """ - ui_schema: t.ClassVar[t.List[t.Dict]] = [ - {'name': 'object', 'type': 'code', 'default': Code.default} - ] object: Code - @classmethod - def handle_integration(cls, kwargs): - """Handler integration from ui. - - :param kwargs: integration kwargs - """ - if isinstance(kwargs['object'], str): - kwargs['object'] = Code(kwargs['object']) - else: - assert isinstance(kwargs['object'], Code) - return kwargs - @merge_docstrings @dc.dataclass(kw_only=True) @@ -1182,7 +1153,6 @@ class APIBaseModel(Model): :param model: The Model to use, e.g. ``'text-embedding-ada-002'`` :param max_batch_size: Maximum batch size. - """ model: t.Optional[str] = None @@ -1263,20 +1233,6 @@ def predict_one(self, *args, **kwargs): return out -LIKE_TEMPLATE = { - 'documents': [ - {"": "$my_value"}, - {"_outputs": 0, "_id": 0}, - ], - 'query': ( - "" - ".like(_documents[0], vector_index='')" - ".find({}, _documents[1])" - ".limit(10)" - ), -} - - @merge_docstrings @dc.dataclass(kw_only=True) class QueryModel(Model): @@ -1290,11 +1246,6 @@ class QueryModel(Model): :param select: query used to find data (can include `like`) """ - ui_schema: t.ClassVar[t.List[t.Dict]] = [ - {'name': 'postprocess', 'type': 'code', 'default': Code.default}, - {'name': 'select', 'type': 'json', 'default': LIKE_TEMPLATE}, - ] - preprocess: t.Optional[t.Callable] = None postprocess: t.Optional[t.Union[t.Callable, Code]] = None select: Query @@ -1311,25 +1262,6 @@ def _replace_variables(r): r[k] = Variable(v[1:]) return r - @classmethod - def handle_integration(cls, kwargs): - """Handle integration from UI. - - :param kwargs: Integration kwargs. - """ - if 'select' in kwargs and isinstance(kwargs['select'], dict): - for i, r in enumerate(kwargs['select']['documents']): - kwargs['select']['documents'][i] = cls._replace_variables(r) - kwargs['select'] = parse_query( - query=kwargs['select']['query'], - documents=kwargs['select']['documents'], - ) - if isinstance(kwargs.get('preprocess'), str): - kwargs['preprocess'] = Code(kwargs['preprocess']) - if isinstance(kwargs.get('postprocess'), str): - kwargs['postprocess'] = Code(kwargs['postprocess']) - return kwargs - @property def inputs(self) -> Inputs: """Instance of `Inputs` to represent model params.""" @@ -1380,11 +1312,6 @@ class SequentialModel(Model): :param models: A list of models to use """ - ui_schema: t.ClassVar[t.List[t.Dict]] = [ - {'name': 'models', 'type': 'component/model', 'sequence': True}, - {'name': 'signature', 'type': 'str', 'optional': True, 'default': None}, - ] - models: t.List[Model] def __post_init__(self, db, artifacts): diff --git a/superduperdb/components/vector_index.py b/superduperdb/components/vector_index.py index 9e471d4561..58796b8ddc 100644 --- a/superduperdb/components/vector_index.py +++ b/superduperdb/components/vector_index.py @@ -34,12 +34,6 @@ class VectorIndex(Component): :param metric_values: Metric values for this index """ - ui_schema: t.ClassVar[t.List[t.Dict]] = [ - {'name': 'indexing_listener', 'type': 'component/listener'}, - {'name': 'compatible_listener', 'type': 'component/listener', 'optional': True}, - {'name': 'measure', 'type': 'str', 'choices': ['cosine', 'dot', 'l2']}, - ] - type_id: t.ClassVar[str] = 'vector_index' indexing_listener: Listener diff --git a/superduperdb/ext/llm/prompter.py b/superduperdb/ext/llm/prompter.py index 04a024466b..c81aff5962 100644 --- a/superduperdb/ext/llm/prompter.py +++ b/superduperdb/ext/llm/prompter.py @@ -69,11 +69,6 @@ class RetrievalPrompt(QueryModel): prompt_explanation: str = PROMPT_EXPLANATION prompt_introduction: str = PROMPT_INTRODUCTION join: str = "\n---\n" - ui_schema: t.ClassVar[t.List[t.Dict]] = [ - {'name': 'prompt_explanation', 'type': 'str', 'default': PROMPT_EXPLANATION}, - {'name': 'prompt_introduction', 'type': 'str', 'default': PROMPT_INTRODUCTION}, - {'name': 'join', 'type': 'str', 'default': "\n---\n"}, - ] def __post_init__(self, artifacts): assert len(self.select.variables) == 1 diff --git a/superduperdb/ext/openai/model.py b/superduperdb/ext/openai/model.py index 520246a7d7..6fac2e07a2 100644 --- a/superduperdb/ext/openai/model.py +++ b/superduperdb/ext/openai/model.py @@ -56,12 +56,6 @@ class _OpenAI(APIBaseModel): openai_api_base: t.Optional[str] = None client_kwargs: t.Optional[dict] = dc.field(default_factory=dict) - @classmethod - def handle_integration(cls, kwargs): - if 'signature' in kwargs: - del kwargs['signature'] - return kwargs - def __post_init__(self, db, artifacts): super().__post_init__(db, artifacts) diff --git a/superduperdb/ext/sentence_transformers/model.py b/superduperdb/ext/sentence_transformers/model.py index f1ac5c4af2..fcbfa111e1 100644 --- a/superduperdb/ext/sentence_transformers/model.py +++ b/superduperdb/ext/sentence_transformers/model.py @@ -40,22 +40,6 @@ class SentenceTransformer(Model, _DeviceManaged): postprocess: t.Union[None, t.Callable, Code] = None signature: Signature = 'singleton' - ui_schema: t.ClassVar[t.List[t.Dict]] = [ - {'name': 'model', 'type': 'str', 'default': 'all-MiniLM-L6-v2'}, - {'name': 'device', 'type': 'str', 'default': 'cpu', 'choices': ['cpu', 'cuda']}, - {'name': 'predict_kwargs', 'type': 'json', 'default': DEFAULT_PREDICT_KWARGS}, - {'name': 'postprocess', 'type': 'code', 'default': Code.default}, - ] - - @classmethod - def handle_integration(cls, kwargs): - """Handle integration of the model.""" - if isinstance(kwargs.get('preprocess'), str): - kwargs['preprocess'] = Code(kwargs['preprocess']) - if isinstance(kwargs.get('postprocess'), str): - kwargs['postprocess'] = Code(kwargs['postprocess']) - return kwargs - def __post_init__(self, db, artifacts): super().__post_init__(db, artifacts) diff --git a/superduperdb/ext/sklearn/model.py b/superduperdb/ext/sklearn/model.py index 9e7ef16230..5974804132 100644 --- a/superduperdb/ext/sklearn/model.py +++ b/superduperdb/ext/sklearn/model.py @@ -106,11 +106,6 @@ class Estimator(Model, _Fittable): _artifacts: t.ClassVar[t.Sequence[t.Tuple[str, DataType]]] = ( ('object', pickle_serializer), ) - ui_schema: t.ClassVar[t.List[t.Dict]] = [ - {'name': 'object', 'type': 'artifact'}, - {'name': 'preprocess', 'type': 'artifact', 'optional': True}, - {'name': 'postprocess', 'type': 'artifact', 'optional': True}, - ] object: BaseEstimator trainer: t.Optional[SklearnTrainer] = None diff --git a/superduperdb/misc/annotations.py b/superduperdb/misc/annotations.py index 0b86a6df32..d9c9e3eb15 100644 --- a/superduperdb/misc/annotations.py +++ b/superduperdb/misc/annotations.py @@ -166,11 +166,10 @@ def _get_indent(docstring: str) -> int: return len(non_empty_lines[1]) - len(non_empty_lines[1].lstrip()) -def component(*schema: t.Dict, handle_integration: t.Callable = lambda x: x): +def component(*schema: t.Dict): """Decorator for creating a component. :param schema: schema for the component - :param handle_integration: function to handle integration """ def decorator(f): @@ -197,9 +196,6 @@ def _deep_flat_encode(cache, blobs, files, leaves_to_keep=(), schema=None): out._deep_flat_encode = _deep_flat_encode return out - decorated.get_ui_schema = lambda: schema - decorated.build = lambda r: f(**r) - decorated.handle_integration = handle_integration return decorated return decorator diff --git a/superduperdb/rest/app.py b/superduperdb/rest/app.py index cc60cb20d2..c521ba5a67 100644 --- a/superduperdb/rest/app.py +++ b/superduperdb/rest/app.py @@ -1,26 +1,13 @@ -import json +import hashlib import typing as t import magic +import yaml from fastapi import File, Response from superduperdb import CFG, logging +from superduperdb.base.datalayer import Datalayer from superduperdb.base.document import Document -from superduperdb.components.datatype import DataType -from superduperdb.components.listener import Listener -from superduperdb.components.model import ( - CodeModel, - ObjectModel, - QueryModel, - SequentialModel, -) -from superduperdb.components.vector_index import VectorIndex, vector -from superduperdb.ext import openai, sentence_transformers -from superduperdb.ext.llm.prompter import RetrievalPrompt -from superduperdb.ext.pillow.encoder import image_type -from superduperdb.ext.sklearn.model import Estimator -from superduperdb.ext.torch.model import TorchModel -from superduperdb.rest.utils import parse_query, strip_artifacts from superduperdb.server import app as superduperapp assert isinstance( @@ -28,52 +15,23 @@ ), "cluster.rest.uri should be set with a valid uri" port = int(CFG.cluster.rest.uri.split(':')[-1]) +assert CFG.cluster.rest.config, "cluster.rest.config should be set with a valid path" +with open(CFG.cluster.rest.config) as f: + CONFIG = yaml.safe_load(f) + app = superduperapp.SuperDuperApp('rest', port=port) -# TODO - should be a configuration -CLASSES: t.Dict[str, t.Dict[str, t.Any]] = { - 'model': { - 'ObjectModel': ObjectModel, - 'SequentialModel': SequentialModel, - 'QueryModel': QueryModel, - 'CodeModel': CodeModel, - 'RetrievalPrompt': RetrievalPrompt, - 'TorchModel': TorchModel, - 'SklearnEstimator': Estimator, - 'OpenAIEmbedding': openai.OpenAIEmbedding, - 'OpenAIChatCompletion': openai.OpenAIChatCompletion, - 'SentenceTransformer': sentence_transformers.SentenceTransformer, - }, - 'listener': { - 'Listener': Listener, - }, - 'datatype': { - 'image': image_type, - 'vector': vector, - 'DataType': DataType, - }, - 'vector-index': {'VectorIndex': VectorIndex}, -} - -FLAT_CLASSES = {} -for k in CLASSES: - for sub in CLASSES[k]: - FLAT_CLASSES[sub] = CLASSES[k][sub] - - -MODULE_LOOKUP: t.Dict[str, t.Dict[str, t.Any]] = {} -API_SCHEMAS: t.Dict[str, t.Dict[str, t.Any]] = {} -for type_id in CLASSES: - API_SCHEMAS[type_id] = {} - MODULE_LOOKUP[type_id] = {} - for cls_name in CLASSES[type_id]: - cls = CLASSES[type_id][cls_name] - API_SCHEMAS[type_id][cls_name] = cls.get_ui_schema() - MODULE_LOOKUP[type_id][cls_name] = cls.__module__ - - -logging.info(json.dumps(API_SCHEMAS, indent=2)) +def _init_hook(db: Datalayer): + for type_id in CONFIG['presets']: + for leaf in CONFIG['presets'][type_id]: + leaf = CONFIG['presets'][type_id][leaf] + leaf = Document.decode(leaf).unpack() + t = db.type_id_to_cache_mapping[type_id] + getattr(db, t)[leaf.identifier] = leaf + + +app.init_hook = _init_hook def build_app(app: superduperapp.SuperDuperApp): @@ -85,32 +43,23 @@ def build_app(app: superduperapp.SuperDuperApp): @app.add('/spec/show', method='get') def spec_show(): - return API_SCHEMAS - - @app.add('/spec/lookup', method='get') - def spec_lookup(): - return MODULE_LOOKUP - - @app.add('/db/artifact_store/save_artifact', method='put') - def db_artifact_store_save_artifact(datatype: str, raw: bytes = File(...)): - r = app.db.artifact_store.save_artifact({'bytes': raw, 'datatype': datatype}) - return {'file_id': r['file_id']} - - @app.add('/db/artifact_store/get_artifact', method='get') - def db_artifact_store_get_artifact(file_id: str, datatype: t.Optional[str] = None): - bytes = app.db.artifact_store._load_bytes(file_id=file_id) - - if datatype is not None: - datatype = app.db.datatypes[datatype] - if datatype is None or datatype.media_type is None: - media_type = magic.from_buffer(bytes, mime=True) - else: - media_type = datatype.media_type + return CONFIG['leaves'] + + @app.add('/db/artifact_store/put', method='put') + def db_artifact_store_put_bytes(raw: bytes = File(...)): + file_id = str(hashlib.sha1(raw).hexdigest()) + app.db.artifact_store.put_bytes(serialized=raw, file_id=file_id) + return {'file_id': file_id} + + @app.add('/db/artifact_store/get', method='get') + def db_artifact_store_get_bytes(file_id: str): + bytes = app.db.artifact_store.get_bytes(file_id=file_id) + media_type = magic.from_buffer(bytes, mime=True) return Response(content=bytes, media_type=media_type) @app.add('/db/apply', method='post') def db_apply(info: t.Dict): - component = Document.decode(info) + component = Document.decode(info).unpack() app.db.apply(component) return {'status': 'ok'} @@ -125,12 +74,12 @@ def db_show( identifier: t.Optional[str] = None, version: t.Optional[int] = None, ): - out = app.db.show(type_id=type_id, identifier=identifier, version=version) - if isinstance(out, dict) and '_id' in out: - del out['_id'] - if type_id == 'datatype' and identifier is None: - out.extend(list(app.db.datatypes.keys())) - return out + return app.db.show( + type_id=type_id, + identifier=identifier, + version=version, + include_presets=True, + ) @app.add('/db/metadata/show_jobs', method='get') def db_metadata_show_jobs(type_id: str, identifier: t.Optional[str] = None): @@ -144,33 +93,31 @@ def db_metadata_show_jobs(type_id: str, identifier: t.Optional[str] = None): @app.add('/db/execute', method='post') def db_execute( - query: str = ".(*args, **kwargs)", - documents: t.List[t.Dict] = [], + query: t.Dict, ): - query = [x for x in query.split('\n') if x.strip()] - query = parse_query(query, documents, db=app.db) + if '_path' not in query: + databackend = app.db.databackend.__module__.split('.')[-2] + query['_path'] = f'superduperdb/backends/{databackend}/query/parse_query' + + q = Document.decode(query, db=app.db).unpack() logging.info('processing this query:') - logging.info(query) + logging.info(q) - result = app.db.execute(query) + result = q.execute() - if query.type in {'insert', 'delete'}: + if q.type in {'insert', 'delete', 'update'}: return {'_base': [str(x) for x in result[0]]}, [] - logging.warn(str(query)) + logging.warn(str(q)) + if isinstance(result, Document): result = [result] - elif result is None: - result = [] - else: - result = list(result) + + result = [r.encode() for r in result] for r in result: - if '_id' in r: - del r['_id'] - result = [strip_artifacts(r.encode()) for r in result] - logging.warn(str(result)) - return result + r.pop_blobs() + return list(result) build_app(app) diff --git a/superduperdb/rest/utils.py b/superduperdb/rest/utils.py deleted file mode 100644 index dc2d4c6e93..0000000000 --- a/superduperdb/rest/utils.py +++ /dev/null @@ -1,97 +0,0 @@ -import re -import typing as t - -from superduperdb.backends.base.query import Model -from superduperdb.base.document import Document - - -def _parse_query_part(part, documents, query, db: t.Optional[t.Any] = None): - documents = [Document.decode(r, db=db) for r in documents] - from superduperdb.backends.mongodb.query import MongoQuery - - part = part.replace(' ', '').replace('\n', '') - part = part.replace('_documents', 'documents') - model_match = re.match('^model\([\'"]([A-Za-z0-9\.\-_])+[\'"]\)\.(.*)$', part) - model_match = re.match('^model\([\'"]([^)]+)+[\'"]\)\.(.*)$', part) - - if model_match: - current = Model(model_match.groups()[0]) - part = model_match.groups()[1].split('.') - else: - current = MongoQuery(part.split('.')[0]) - part = part.split('.')[1:] - for comp in part: - match = re.match('^([a-zA-Z0-9_]+)\((.*)\)$', comp) - if match is None: - current = getattr(current, comp) - continue - if not match.groups()[1].strip(): - current = getattr(current, match.groups()[0])() - continue - - comp = getattr(current, match.groups()[0]) - args_kwargs = [x.strip() for x in match.groups()[1].split(',')] - args = [] - kwargs = {} - for x in args_kwargs: - if '=' in x: - k, v = x.split('=') - kwargs[k] = eval(v, {'documents': documents, 'query': query}) - else: - args.append(eval(x, {'documents': documents, 'query': query})) - current = comp(*args, **kwargs) - return current - - -def parse_query(query, documents, db): - """Parse a query string into a query object. - - :param query: query string to parse - :param documents: documents to use in the query - :param db: datalayer instance - """ - if isinstance(query, str): - query = [x.strip() for x in query.split('\n') if x.strip()] - for i, q in enumerate(query): - query[i] = _parse_query_part(q, documents, query[:i], db=db) - return query[-1] - - -def strip_artifacts(r: t.Any): - """Strip artifacts for the data. - - :param r: the data to strip artifacts from - """ - if isinstance(r, dict): - if '_content' in r: - return f'_artifact/{r["_content"]["file_id"]}', [r["_content"]["file_id"]] - else: - out = {} - a_out = [] - for k, v in r.items(): - vv, tmp = strip_artifacts(v) - a_out.extend(tmp) - out[k] = vv - return out, a_out - elif isinstance(r, list): - out = [] - a_out = [] - for x in r: - xx, tmp = strip_artifacts(x) - out.append(xx) - a_out.extend(tmp) - return out, a_out - else: - return r, [] - - -if __name__ == '__main__': - q = parse_query( - [ - 'documents.find($documents[0], a={"b": 1}).sort(c=1).limit(1)', - ], - [], - [], - ) - - print(q) diff --git a/superduperdb/server/app.py b/superduperdb/server/app.py index bd916f421a..fe87d1b027 100644 --- a/superduperdb/server/app.py +++ b/superduperdb/server/app.py @@ -69,7 +69,13 @@ class SuperDuperApp: :param db: datalayer instance """ - def __init__(self, service='vector_search', port=8000, db: Datalayer = None): + def __init__( + self, + service='vector_search', + port=8000, + db: Datalayer = None, + init_hook: t.Optional[t.Callable] = None, + ): self.service = service self.port = port @@ -79,6 +85,7 @@ def __init__(self, service='vector_search', port=8000, db: Datalayer = None): self.router = APIRouter() self._user_startup = False self._user_shutdown = False + self.init_hook = init_hook self._app.add_middleware(ExceptionHandlerMiddleware) self._app.add_middleware( @@ -110,7 +117,7 @@ def raise_error(self, msg: str, code: int): raise HTTPException(code, detail=msg) @cached_property - def db(self): + def db(self) -> Datalayer: """Return the database instance from the app state.""" return self._app.state.pool @@ -217,6 +224,8 @@ def startup_db_client(): if function: function(db=db) self._app.state.pool = db + if self.init_hook: + self.init_hook(db=db) return diff --git a/test/rest/mock_client.py b/test/rest/mock_client.py index cb32d8b004..353f663861 100644 --- a/test/rest/mock_client.py +++ b/test/rest/mock_client.py @@ -1,32 +1,64 @@ import json import os +from urllib.parse import urlencode +from superduperdb import CFG -def curl_get(endpoint, data): - raise NotImplementedError +HOST = CFG.cluster.rest.uri +VERBOSE = os.environ.get('SUPERDUPERDB_VERBOSE', '1') -def curl_post(endpoint, data): +def make_params(params): + return '?' + urlencode(params) + + +def curl_get(endpoint, params=None): + if params is not None: + params = make_params(params) + else: + params = '' + request = f"curl '{HOST}{endpoint}{params}'" + if VERBOSE == '1': + print('CURL REQUEST:') + print(request) + result = os.popen(request).read() + assert result, f'GET request to {request} returned empty response' + result = json.loads(result) + if 'msg' in result: + raise Exception('Error: ' + result['msg']) + return result + + +def curl_post(endpoint, data, params=None): + if params is not None: + params = make_params(params) + else: + params = '' data = json.dumps(data) request = f"""curl -X 'POST' \ - 'http://localhost:8002{endpoint}' \ + '{HOST}{endpoint}{params}' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -s \ -d '{data}'""" - print('CURL REQUEST:') - print(request) + if VERBOSE == '1': + print('CURL REQUEST:') + print(request) result = os.popen(request).read() assert result, f'POST request to {endpoint} with {data} returned empty response' - result = json.loads(result), f'Response is not valid JSON: {result}' + result = json.loads(result) + if 'msg' in result: + raise Exception('Error: ' + result['msg']) return result def curl_put(endpoint, file, media_type, params=None): if params is not None: - params = '?' + '&'.join([f'{k}={v}' for k, v in params.items()]) + params = make_params(params) + else: + params = '' request = f"""curl -X 'PUT' \ - 'http://localhost:8002{endpoint}{params}' \ + '{HOST}{endpoint}{params}' \ -H 'accept: application/json' \ -H 'Content-Type: multipart/form-data' \ -s \ @@ -37,34 +69,25 @@ def curl_put(endpoint, file, media_type, params=None): assert ( result ), f'PUT request to {endpoint} with {params} and {file} returned empty response' + result = json.loads(result) return result def insert(data): - data = { - "documents": data, - "query": ["documents.insert_many($documents)"], - "artifacts": [], - } - return curl_post('/db/execute', data) + query = {'query': 'coll.insert_many(documents)', 'documents': data} + return curl_post('/db/execute', data=query) def apply(component): - data = {'component': {component['dict']['identifier']: component}} - return curl_post('/db/apply', data) + return curl_post('/db/apply', data=component) def delete(): - data = { - "documents": [], - "query": ["documents.delete_many({})"], - "artifacts": [], - } - return curl_post('/db/execute', data) + return curl_post('/db/execute', data={'query': 'coll.delete_many({})'}) def remove(type_id, identifier): - return curl_post('/db/remove?type_id={type_id}&identifier={identifier}', {}) + return curl_post(f'/db/remove?type_id={type_id}&identifier={identifier}', {}) def setup(): @@ -73,13 +96,6 @@ def setup(): {"x": [6, 7, 8, 9, 10], "y": 'test'}, ] insert(data) - apply( - { - 'cls': 'image_type', - 'module': 'superduperdb.ext.pillow.encoder', - 'dict': {'identifier': 'image', 'media_type': 'image/png'}, - } - ) def teardown(): diff --git a/test/rest/test_rest.py b/test/rest/test_rest.py index 44256b4f27..70ad0f2fd4 100644 --- a/test/rest/test_rest.py +++ b/test/rest/test_rest.py @@ -1,11 +1,8 @@ -import json -import os - import pytest -from superduperdb import CFG +from superduperdb.base.document import Document -from .mock_client import curl_post, setup as _setup, teardown +from .mock_client import curl_get, curl_post, curl_put, setup as _setup, teardown @pytest.fixture @@ -15,76 +12,103 @@ def setup(): def test_select_data(setup): - form = { - "documents": [], - "query": "documents.find()", - "artifacts": [], - } - result = curl_post('/db/execute', form) + result = curl_post('/db/execute', data={'query': 'coll.find({}, {"_id": 0})'}) print(result) assert len(result) == 2 +def test_presets(setup): + result = curl_get( + '/db/show', + params={'type_id': 'datatype'}, + ) + print(result) + assert 'image' in result + + +CODE = """ +from superduperdb import code + +@code +def my_function(x): + return x + 1 +""" + + +def test_apply(setup): + m = { + '_leaves': { + 'function_body': { + '_path': 'superduperdb/base/code/Code', + 'code': CODE, + }, + 'my_function': { + '_path': 'superduperdb/components/model/CodeModel', + 'object': '?function_body', + 'identifier': 'my_function', + }, + }, + '_base': '?my_function', + } + + _ = curl_post( + endpoint='/db/apply', + data=m, + ) + + models = curl_get('/db/show', params={'type_id': 'model'}) + + assert models == ['my_function'] + + def test_insert_image(setup): - request = f"""curl -X 'PUT' \ - '{CFG.cluster.rest.uri}/db/artifact_store/save_artifact?datatype=image' \ - -H 'accept: application/json' \ - -H 'Content-Type: multipart/form-data' \ - -s \ - -F 'raw=@test/material/data/test.png;type=image/png'""" - - result = os.popen(request).read() - result = json.loads(result) - assert 'file_id' in result + result = curl_put( + endpoint='/db/artifact_store/put', + file='test/material/data/test.png', + media_type='image/png', + ) + file_id = result['file_id'] - form = { - "documents": [ - { - "img": { - "_content": { - "file_id": result["file_id"], - "datatype": "image", - "leaf_type": "lazy_artifact", - "uri": None, - } - } - }, - ], - "query": ["documents.insert_one(_documents[0])"], + query = { + '_path': 'superduperdb/backends/mongodb/query/parse_query', + '_leaves': { + 'my_artifact': { + '_path': 'superduperdb/components/datatype/LazyArtifact', + 'file_id': file_id, + 'datatype': "?db.load(datatype, image)", + } + }, + 'query': 'coll.insert_one(documents[0])', + 'documents': [{'img': '?my_artifact'}], } - form = json.dumps(form) - - request = f"""curl -X 'POST' \ - '{CFG.cluster.rest.uri}/db/execute' \ - -H 'accept: application/json' \ - -H 'Content-Type: application/json' \ - -s \ - -d '{form}'""" - - print('making request') - result = json.loads(os.popen(request).read()) - if 'error' in result: - raise Exception(result['messages']) - print(result) - form = json.dumps( - { - "documents": [], - "query": "documents.find()", - } + result = curl_post( + endpoint='/db/execute', + data=query, ) - request = f"""curl -X 'POST' \ - '{CFG.cluster.rest.uri}/db/execute' \ - -H 'accept: application/json' \ - -H 'Content-Type: application/json' \ - -s \ - -d '{form}'""" + query = { + '_path': 'superduperdb/backends/mongodb/query/parse_query', + 'query': 'coll.find(documents[0], documents[1])', + 'documents': [{}, {'_id': 0}], + } - print('making request') - result = os.popen(request).read() - print(result) - result = json.loads(result) - result = next(r for r in result if 'img' in r) - assert result['img']['_content']['file_id'] == file_id + result = curl_post( + endpoint='/db/execute', + data=query, + ) + + from superduperdb import superduper + + db = superduper() + + result = [Document.decode(r, db=db).unpack() for r in result] + + assert len(result) == 3 + + image_record = next(r for r in result if 'img' in r) + + from PIL.PngImagePlugin import PngImageFile + + assert isinstance(image_record['img'], PngImageFile) diff --git a/test/unittest/base/test_datalayer.py b/test/unittest/base/test_datalayer.py index 48012f28f2..5b487ecc71 100644 --- a/test/unittest/base/test_datalayer.py +++ b/test/unittest/base/test_datalayer.py @@ -105,8 +105,7 @@ def add_fake_model(db: Datalayer): return listener -# EMPTY_CASES = [DBConfig.mongodb_empty, DBConfig.sqldb_empty] -EMPTY_CASES = [DBConfig.mongodb_empty] +EMPTY_CASES = [DBConfig.mongodb_empty, DBConfig.sqldb_empty] @pytest.mark.parametrize("db", EMPTY_CASES, indirect=True) @@ -222,7 +221,7 @@ def test_add_with_artifact(db): db.apply(m) - m = db.load('model', m.identifier) + m = db.load('model', m.identifier, include_presets=False) assert m.object is not None @@ -385,7 +384,7 @@ def test_show(db): assert 'None' in str(e) and '1' in str(e) assert sorted(db.show('test-component')) == ['a1', 'a2', 'a3', 'b'] - assert sorted(db.show('datatype')) == ['c1', 'c2'] + assert sorted(db.show('datatype', include_presets=False)) == ['c1', 'c2'] assert sorted(db.show('test-component', 'a1')) == [0] assert sorted(db.show('test-component', 'b')) == [0, 1, 2] @@ -590,10 +589,24 @@ def test_reload_dataset(db): from superduperdb.components.dataset import Dataset if isinstance(db.databackend, MongoDataBackend): - select = MongoQuery(db=db, identifier='documents').find({'_fold': 'valid'}) + select = db['documents'].find({'_fold': 'valid'}) else: - table = db.load('table', 'documents') - select = table.select('id', 'x', 'y', 'z').filter(table._fold == 'valid') + db.apply( + Table( + 'documents', + schema=Schema( + 'documents', + fields={ + 'id': dtype('str'), + 'x': dtype('int'), + 'y': dtype('int'), + 'z': dtype('int'), + }, + ), + ) + ) + condition = db['documents']._fold == 'valid' + select = db['documents'].select('id', 'x', 'y', 'z').filter(condition) d = Dataset( identifier='my_valid', diff --git a/test/unittest/base/test_document.py b/test/unittest/base/test_document.py index 9c51f92701..5e6eb054ff 100644 --- a/test/unittest/base/test_document.py +++ b/test/unittest/base/test_document.py @@ -10,7 +10,7 @@ from superduperdb.components.model import ObjectModel from superduperdb.components.schema import Schema from superduperdb.components.table import Table -from superduperdb.ext.pillow.encoder import pil_image +from superduperdb.ext.pillow.encoder import image_type, pil_image try: import torch @@ -198,3 +198,30 @@ def test_column_encoding(db): ).execute() db['test'].select("x", "y", "img").execute() + + +@pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True) +def test_refer_to_system(db): + db.datatypes['image'] = image_type(identifier='image', encodable='artifact') + + import PIL.Image + import PIL.PngImagePlugin + + img = PIL.Image.open('test/material/data/test.png') + + db.artifact_store.put_bytes(db.datatypes['image'].encoder(img), file_id='12345') + + r = { + '_leaves': { + 'my_artifact': { + '_path': 'superduperdb/components/datatype/LazyArtifact', + 'file_id': '12345', + 'datatype': "?db.load(datatype, image)", + } + }, + 'img': '?my_artifact', + } + + r = Document.decode(r, db=db).unpack() + + assert isinstance(r['img'], PIL.PngImagePlugin.PngImageFile)