Skip to content

Commit

Permalink
Lazy-creation of output tables for ibis to enable auto-inference of o…
Browse files Browse the repository at this point in the history
…utput schema
  • Loading branch information
jieguangzhou committed May 10, 2024
1 parent a9f9e34 commit b118c43
Show file tree
Hide file tree
Showing 17 changed files with 332 additions and 81 deletions.
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ lint-and-type-check: ## Perform code linting and type checking
ruff check $(DIRECTORIES)

@echo "===> Static Typing Check <==="

@if [ -d .mypy_cache ]; then rm -rf .mypy_cache; fi
mypy superduperdb
# Check for missing docstrings
# interrogate superduperdb
Expand All @@ -106,7 +108,8 @@ fix-and-check: ## Lint the code before testing
# Linter and code formatting
ruff check --fix $(DIRECTORIES)
# Linting
rm -rf .mypy_cache/

@if [ -d .mypy_cache ]; then rm -rf .mypy_cache; fi
mypy superduperdb


Expand Down
18 changes: 17 additions & 1 deletion superduperdb/backends/base/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@


class BaseDataBackend(ABC):
db_type = None

def __init__(self, conn: t.Any, name: str):
self.conn = conn
self.name = name
Expand Down Expand Up @@ -40,12 +42,16 @@ def build_artifact_store(self):
@abstractmethod
def create_output_dest(
self,
identifier: str,
predict_id: str,
datatype: t.Union[None, DataType, FieldType],
flatten: bool = False,
):
pass

@abstractmethod
def check_output_dest(self, predict_id) -> bool:
pass

@abstractmethod
def get_table_or_collection(self, identifier):
pass
Expand All @@ -70,3 +76,13 @@ def list_tables_or_collections(self):
"""
List all tables or collections in the database.
"""

@staticmethod
def infer_schema(data: t.Mapping[str, t.Any], identifier: t.Optional[str] = None):
"""
Infer a schema from a given data object
:param data: The data object
:param identifier: The identifier for the schema, if None, it will be generated
:return: The inferred schema
"""
24 changes: 24 additions & 0 deletions superduperdb/backends/ibis/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import pandas
from ibis.backends.base import BaseBackend
from pandas.core.frame import DataFrame
from sqlalchemy.exc import NoSuchTableError

from superduperdb.backends.base.data_backend import BaseDataBackend
from superduperdb.backends.ibis.db_helper import get_db_helper
from superduperdb.backends.ibis.field_types import FieldType, dtype
from superduperdb.backends.ibis.query import Table
from superduperdb.backends.local.artifacts import FileSystemArtifactStore
from superduperdb.backends.sqlalchemy.metadata import SQLAlchemyMetadata
from superduperdb.base.enums import DBType
from superduperdb.components.datatype import DataType
from superduperdb.components.schema import Schema

Expand All @@ -21,6 +23,8 @@


class IbisDataBackend(BaseDataBackend):
db_type = DBType.SQL

def __init__(self, conn: BaseBackend, name: str, in_memory: bool = False):
super().__init__(conn=conn, name=name)
self.in_memory = in_memory
Expand Down Expand Up @@ -98,6 +102,13 @@ def create_output_dest(
schema=Schema(identifier=f'_schema/{predict_id}', fields=fields),
)

def check_output_dest(self, predict_id) -> bool:
try:
self.conn.table(f'_outputs.{predict_id}')
return True
except NoSuchTableError:
return False

def create_table_and_schema(self, identifier: str, mapping: dict):
"""
Create a schema in the data-backend.
Expand Down Expand Up @@ -131,3 +142,16 @@ def disconnect(self):

def list_tables_or_collections(self):
return self.conn.list_tables()

@staticmethod
def infer_schema(data: t.Mapping[str, t.Any], identifier: t.Optional[str] = None):
"""
Infer a schema from a given data object
:param data: The data object
:param identifier: The identifier for the schema, if None, it will be generated
:return: The inferred schema
"""
from superduperdb.misc.auto_schema import infer_schema

return infer_schema(data, identifier=identifier, ibis=True)
13 changes: 0 additions & 13 deletions superduperdb/backends/ibis/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,19 +668,6 @@ def pre_create(self, db: 'Datalayer'):
else:
raise e

@staticmethod
def infer_schema(data: t.Mapping[str, t.Any], identifier: t.Optional[str] = None):
"""
Infer a schema from a given data object
:param data: The data object
:param identifier: The identifier for the schema, if None, it will be generated
:return: The inferred schema
"""
from superduperdb.misc.auto_schema import infer_schema

return infer_schema(data, identifier=identifier, ibis=True)

@property
def table_or_collection(self):
return IbisQueryTable(self.identifier, primary_id=self.primary_id)
Expand Down
21 changes: 20 additions & 1 deletion superduperdb/backends/mongodb/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from superduperdb.backends.ibis.field_types import FieldType
from superduperdb.backends.mongodb.artifacts import MongoArtifactStore
from superduperdb.backends.mongodb.metadata import MongoMetaDataStore
from superduperdb.base.enums import DBType
from superduperdb.base.serializable import Serializable
from superduperdb.components.datatype import DataType
from superduperdb.misc.colors import Colors
Expand All @@ -23,6 +24,8 @@ class MongoDataBackend(BaseDataBackend):
:param name: Name of database to host filesystem
"""

db_type = DBType.MONGODB

id_field = '_id'

def __init__(self, conn: pymongo.MongoClient, name: str):
Expand Down Expand Up @@ -129,8 +132,24 @@ def disconnect(self):

def create_output_dest(
self,
identifier: str,
predict_id: str,
datatype: t.Union[None, DataType, FieldType],
flatten: bool = False,
):
pass

def check_output_dest(self, predict_id) -> bool:
return True

@staticmethod
def infer_schema(data: t.Mapping[str, t.Any], identifier: t.Optional[str] = None):
"""
Infer a schema from a given data object
:param data: The data object
:param identifier: The identifier for the schema, if None, it will be generated
:return: The inferred schema
"""
from superduperdb.misc.auto_schema import infer_schema

return infer_schema(data, identifier)
13 changes: 0 additions & 13 deletions superduperdb/backends/mongodb/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,19 +956,6 @@ def model_update(
collection.insert_many([Document(**doc) for doc in bulk_writes])
)

@staticmethod
def infer_schema(data: t.Mapping[str, t.Any], identifier: t.Optional[str] = None):
"""
Infer a schema from a given data object
:param data: The data object
:param identifier: The identifier for the schema, if None, it will be generated
:return: The inferred schema
"""
from superduperdb.misc.auto_schema import infer_schema

return infer_schema(data, identifier)


def _get_decode_function(db) -> t.Callable[[t.Any], t.Any]:
def decode(output):
Expand Down
13 changes: 13 additions & 0 deletions superduperdb/base/datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from superduperdb.cdc.cdc import DatabaseChangeDataCapture
from superduperdb.components.component import Component
from superduperdb.components.datatype import DataType, _BaseEncodable, serializers
from superduperdb.components.schema import Schema
from superduperdb.jobs.job import ComponentJob, FunctionJob, Job
from superduperdb.jobs.task_workflow import TaskWorkflow
from superduperdb.misc.annotations import deprecated
Expand Down Expand Up @@ -1103,6 +1104,18 @@ def _add_component_to_cache(self, component: Component):
getattr(self, cm)[component.identifier] = component
component.on_load(self)

def infer_schema(
self, data: t.Mapping[str, t.Any], identifier: t.Optional[str] = None
) -> Schema:
"""
Infer a schema from a given data object
:param data: The data object
:param identifier: The identifier for the schema, if None, it will be generated
:return: The inferred schema
"""
return self.databackend.infer_schema(data, identifier)


@dc.dataclass
class LoadDict(dict):
Expand Down
10 changes: 10 additions & 0 deletions superduperdb/base/enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from enum import Enum


class DBType(str, Enum):
"""
DBType is an enumeration of the supported database types.
"""

SQL = "SQL"
MONGODB = "MONGODB"
2 changes: 1 addition & 1 deletion superduperdb/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def artifact_schema(self):
return Schema(f'serializer/{self.identifier}', fields=schema)

@property
def db(self):
def db(self) -> Datalayer:
"""
Datalayer instance.
"""
Expand Down
71 changes: 57 additions & 14 deletions superduperdb/components/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from superduperdb import CFG
from superduperdb.backends.base.query import CompoundSelect
from superduperdb.base.datalayer import Datalayer
from superduperdb.base.document import _OUTPUTS_KEY
from superduperdb.base.serializable import Variable
from superduperdb.components.model import Mapping
Expand All @@ -17,6 +16,10 @@
from .component import Component, ComponentTuple
from .model import Model, ModelInputType

if t.TYPE_CHECKING:
from superduperdb.base.datalayer import Datalayer


SELECT_TEMPLATE = {'documents': [], 'query': '<collection_name>.find()'}


Expand Down Expand Up @@ -113,7 +116,19 @@ def outputs_select(self):
"""
Get query reference to model outputs.
"""
return self.select.table_or_collection.outputs(self.predict_id)
if self.select.DB_TYPE == "SQL":
return self.select.table_or_collection.outputs(self.predict_id)

else:
from superduperdb.backends.mongodb.query import Collection

model_update_kwargs = self.model.model_update_kwargs or {}
if model_update_kwargs.get('document_embedded', True):
collection_name = self.select.table_or_collection.identifier
else:
collection_name = self.outputs

return Collection(collection_name).find()

@property
def outputs_key(self):
Expand All @@ -126,7 +141,7 @@ def outputs_key(self):
return self.outputs

@override
def pre_create(self, db: Datalayer) -> None:
def pre_create(self, db: "Datalayer") -> None:
"""
Pre-create hook.
Expand All @@ -152,19 +167,13 @@ def _set_key(db, key, **kwargs):
return key

@override
def post_create(self, db: Datalayer) -> None:
def post_create(self, db: "Datalayer") -> None:
"""
Post-create hook.
:param db: Data layer instance.
"""
output_table = db.databackend.create_output_dest(
f'{self.identifier}::{self.version}',
self.model.datatype,
flatten=self.model.flatten,
)
if output_table is not None:
db.add(output_table)
self.create_output_dest(db, self.predict_id, self.model)
if self.select is not None and self.active and not db.server_mode:
if CFG.cluster.cdc.uri:
request_server(
Expand All @@ -176,6 +185,25 @@ def post_create(self, db: Datalayer) -> None:
else:
db.cdc.add(self)

@classmethod
def create_output_dest(cls, db: "Datalayer", predict_id, model: Model):
"""
Create output destination.
:param db: Data layer instance.
:param predict_id: Predict ID.
:param model: Model instance.
"""
if model.datatype is None:
return
output_table = db.databackend.create_output_dest(
predict_id,
model.datatype,
flatten=model.flatten,
)
if output_table is not None:
db.add(output_table)

@property
def dependencies(self) -> t.List[ComponentTuple]:
"""
Expand All @@ -198,6 +226,18 @@ def predict_id(self):
"""
return f'{self.identifier}::{self.version}'

@classmethod
def from_predict_id(cls, db: "Datalayer", predict_id) -> 'Listener':
"""
Split predict ID.
:param db: Data layer instance.
:param predict_id: Predict ID.
"""

identifier, version = predict_id.rsplit('::', 1)
return t.cast(Listener, db.load('listener', identifier, version=int(version)))

@property
def id_key(self) -> str:
"""
Expand Down Expand Up @@ -235,7 +275,10 @@ def depends_on(self, other: Component):

@override
def schedule_jobs(
self, db: Datalayer, dependencies: t.Sequence[Job] = (), overwrite: bool = False
self,
db: "Datalayer",
dependencies: t.Sequence[Job] = (),
overwrite: bool = False,
) -> t.Sequence[t.Any]:
"""
Schedule jobs for the listener.
Expand All @@ -251,7 +294,7 @@ def schedule_jobs(
self.model.predict_in_db_job(
X=self.key,
db=db,
predict_id=f'{self.identifier}::{self.version}',
predict_id=self.predict_id,
select=self.select.copy(),
dependencies=dependencies,
overwrite=overwrite,
Expand All @@ -260,7 +303,7 @@ def schedule_jobs(
]
return out

def cleanup(self, database: Datalayer) -> None:
def cleanup(self, database: "Datalayer") -> None:
"""Clean up when the listener is deleted.
:param database: Data layer instance to process.
Expand Down
Loading

0 comments on commit b118c43

Please sign in to comment.