Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds modular SERDE approach #175

Merged
merged 2 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions burr/core/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,17 @@ def save(
class SQLLitePersister(BaseStatePersister):
"""Class for SQLLite persistence of state. This is a simple implementation."""

def __init__(self, db_path: str, table_name: str = "burr_state"):
def __init__(self, db_path: str, table_name: str = "burr_state", serde_kwargs: dict = None):
skrawcz marked this conversation as resolved.
Show resolved Hide resolved
"""Constructor

:param db_path: the path the DB will be stored.
:param table_name: the table name to store things under.
:param table_name: the table name to store things under.
:param serde_kwargs: kwargs for state serialization/deserialization.
"""
self.db_path = db_path
self.table_name = table_name
self.connection = sqlite3.connect(db_path)
self.serde_kwargs = serde_kwargs or {}

def create_table_if_not_exists(self, table_name: str):
"""Helper function to create the table where things are stored if it doesn't exist."""
Expand Down Expand Up @@ -229,7 +231,7 @@ def load(
row = cursor.fetchone()
if row is None:
return None
_state = State(json.loads(row[1]))
_state = State.deserialize(json.loads(row[1]), **self.serde_kwargs)
return {
"partition_key": partition_key,
"app_id": row[3],
Expand Down Expand Up @@ -277,7 +279,7 @@ def save(
status,
)
cursor = self.connection.cursor()
json_state = json.dumps(state.get_all())
json_state = json.dumps(state.serialize(**self.serde_kwargs))
cursor.execute(
f"INSERT INTO {self.table_name} (partition_key, app_id, sequence_id, position, state, status) "
f"VALUES (?, ?, ?, ?, ?, ?)",
Expand Down
102 changes: 102 additions & 0 deletions burr/core/serde.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from functools import singledispatch
from typing import Any, Union

KEY = "__burr_serde__"


class StringDispatch:
"""Class to capture how to deserialize something.

We register a key with a deserializer function. It's like single dispatch
but based on a string key value.

Example usage:

.. code-block:: python

from burr.core import serde

@serde.deserializer.register("pickle")
def deserialize_pickle(value: dict, pickle_kwargs: dict = None, **kwargs) -> cls:
if pickle_kwargs is None:
pickle_kwargs = {}
return pickle.loads(value["value"], **pickle_kwargs)

What this does is register the function `deserialize_pickle` with the key "pickle".
This should mirror the appropriate serialization function - which is what sets the key value
to match the deserializer function against.

Notice that this namespaces its kwargs. This is important because we don't want to have
a collision with other kwargs that might be passed in.
"""

def __init__(self):
self.func_map = {}

def register(self, key):
def decorator(func):
self.func_map[key] = func
return func

return decorator

def call(self, key, *args, **kwargs):
if key in self.func_map:
return self.func_map[key](*args, **kwargs)
else:
raise ValueError(f"No function registered for key: {key}")


deserializer = StringDispatch()


def deserialize(value: Any, **kwargs) -> Any:
"""Main function to deserialize a value.

Looks for a key in the value if it's a dictionary and calls the appropriate deserializer function.
"""
if isinstance(value, dict):
class_to_instantiate = value.get(KEY, None)
if class_to_instantiate is not None:
return deserializer.call(class_to_instantiate, value, **kwargs)
else:
return {k: deserialize(v, **kwargs) for k, v in value.items()}
elif isinstance(value, list):
return [deserialize(v, **kwargs) for v in value]
else:
return value


@singledispatch
def serialize(value, **kwargs) -> Any:
"""This is the default implementation for serializing a value.

All other implementations should be registered with the `@serialize.register` decorator.

Each function should output a dictionary, and include the `KEY` & value to use for deserialization.

:param value: The value to serialize
:param kwargs: Any additional keyword arguments. Each implementation should namespace their kwargs.
:return: A dictionary representation of the value
"""
if value is None:
return None
return str(value)


@serialize.register(str)
@serialize.register(int)
@serialize.register(float)
@serialize.register(bool)
def serialize_primitive(value, **kwargs) -> Union[str, int, float, bool]:
return value


@serialize.register(dict)
def serialize_dict(value: dict, **kwargs) -> dict[str, Any]:
return {k: serialize(v, **kwargs) for k, v in value.items()}


@serialize.register(list)
def serialize_list(value: list, **kwargs) -> list[Any]:
return [serialize(v, **kwargs) for v in value]
34 changes: 34 additions & 0 deletions burr/core/state.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import abc
import copy
import dataclasses
import importlib
import logging
from typing import Any, Dict, Iterator, Mapping

from burr.core import serde

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -190,6 +193,26 @@ def get_all(self) -> Dict[str, Any]:
"""Returns the entire state, realize as a dictionary. This is a copy."""
return dict(self)

def serialize(self, **kwargs) -> dict:
"""Converts the state to a JSON serializable object"""
_dict = self.get_all()
return {
# TODO: handle field specific custom serialization
k: serde.serialize(v, **kwargs)
for k, v in _dict.items()
}

@classmethod
def deserialize(cls, json_dict: dict, **kwargs) -> "State":
"""Converts a dictionary representing a JSON object back into a state"""
return State(
{
# TODO: handle field specific custom deserialization
k: serde.deserialize(v, **kwargs)
for k, v in json_dict.items()
}
)

def update(self, **updates: Any) -> "State":
"""Updates the state with a set of key-value pairs
Does an upsert operation (if the keys exist their value will be overwritten,
Expand Down Expand Up @@ -272,3 +295,14 @@ def __iter__(self) -> Iterator[Any]:

def __repr__(self):
return self.get_all().__repr__() # quick hack


# We register the serde plugins here that we'll automatically try to load.
# In the future if we need to reorder/replace, we'll just have some
# check here that can skip loading plugins/override which ones to load.
# Note for pickle, we require people to manually register the type for that.
for serde_plugin in ["langchain", "pydantic", "pandas"]:
skrawcz marked this conversation as resolved.
Show resolved Hide resolved
try:
importlib.import_module(f"burr.integrations.serde.{serde_plugin}")
except ImportError:
logger.debug(f"Skipped registering {serde_plugin} serde plugin.")
2 changes: 1 addition & 1 deletion burr/examples
38 changes: 23 additions & 15 deletions burr/integrations/persisters/b_mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,36 @@

class MongoDBPersister(persistence.BaseStatePersister):
"""A class used to represent a MongoDB Persister.

Example usage:
persister = MongoDBPersister(uri='mongodb://user:pass@localhost:27017', db_name='mydatabase', collection_name='mystates')
persister.save(
partition_key='example_partition',
app_id='example_app',
sequence_id=1,
position='example_position',
state=state.State({'key': 'value'}),
status='completed'
)
loaded_state = persister.load(partition_key='example_partition', app_id='example_app', sequence_id=1)
print(loaded_state)

.. code-block:: python

persister = MongoDBPersister(uri='mongodb://user:pass@localhost:27017', db_name='mydatabase', collection_name='mystates')
persister.save(
partition_key='example_partition',
app_id='example_app',
sequence_id=1,
position='example_position',
state=state.State({'key': 'value'}),
status='completed'
)
loaded_state = persister.load(partition_key='example_partition', app_id='example_app', sequence_id=1)
print(loaded_state)
"""

def __init__(
self, uri="mongodb://localhost:27017", db_name="mydatabase", collection_name="mystates"
self,
uri="mongodb://localhost:27017",
db_name="mydatabase",
collection_name="mystates",
serde_kwargs: dict = None,
):
"""Initializes the MongoDBPersister class."""
self.client = MongoClient(uri)
self.db = self.client[db_name]
self.collection = self.db[collection_name]
self.serde_kwargs = serde_kwargs or {}

def list_app_ids(self, partition_key: str, **kwargs) -> list[str]:
"""List the app ids for a given partition key."""
Expand All @@ -49,7 +58,7 @@ def load(
document = self.collection.find_one(query, sort=[("sequence_id", -1)])
if not document:
return None
_state = state.State(json.loads(document["state"]))
_state = state.State.deserialize(json.loads(document["state"]), **self.serde_kwargs)
return {
"partition_key": partition_key,
"app_id": app_id,
Expand All @@ -74,7 +83,7 @@ def save(
key = {"partition_key": partition_key, "app_id": app_id, "sequence_id": sequence_id}
if self.collection.find_one(key):
raise ValueError(f"partition_key:app_id:sequence_id[{key}] already exists.")
json_state = json.dumps(state.get_all())
json_state = json.dumps(state.serialize(**self.serde_kwargs))
self.collection.insert_one(
{
"partition_key": partition_key,
Expand All @@ -84,7 +93,6 @@ def save(
"state": json_state,
"status": status,
"created_at": datetime.now(timezone.utc).isoformat(),
# "created_at": datetime.datetime.utcnow().isoformat(),
}
)

Expand Down
13 changes: 8 additions & 5 deletions burr/integrations/persisters/b_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
except ImportError as e:
base.require_plugin(e, ["redis"], "redis")

import datetime
import json
import logging
from datetime import datetime, timezone
from typing import Literal, Optional

from burr.core import persistence, state
Expand All @@ -23,7 +23,9 @@ class RedisPersister(persistence.BaseStatePersister):
It inherits from the BaseStatePersister class.
"""

def __init__(self, host: str, port: int, db: int, password: str = None):
def __init__(
self, host: str, port: int, db: int, password: str = None, serde_kwargs: dict = None
):
"""Initializes the RedisPersister class.

:param host:
Expand All @@ -32,6 +34,7 @@ def __init__(self, host: str, port: int, db: int, password: str = None):
:param password:
"""
self.connection = redis.Redis(host=host, port=port, db=db, password=password)
self.serde_kwargs = serde_kwargs or {}

def list_app_ids(self, partition_key: str, **kwargs) -> list[str]:
"""List the app ids for a given partition key."""
Expand Down Expand Up @@ -60,7 +63,7 @@ def load(
data = self.connection.hgetall(key)
if not data:
return None
_state = state.State(json.loads(data[b"state"].decode()))
_state = state.State.deserialize(json.loads(data[b"state"].decode()), **self.serde_kwargs)
return {
"partition_key": partition_key,
"app_id": app_id,
Expand Down Expand Up @@ -100,7 +103,7 @@ def save(
key = self.create_key(app_id, partition_key, sequence_id)
if self.connection.exists(key):
raise ValueError(f"partition_key:app_id:sequence_id[{key}] already exists.")
json_state = json.dumps(state.get_all())
json_state = json.dumps(state.serialize(**self.serde_kwargs))
self.connection.hset(
key,
mapping={
Expand All @@ -110,7 +113,7 @@ def save(
"position": position,
"state": json_state,
"status": status,
"created_at": datetime.datetime.utcnow().isoformat(),
"created_at": datetime.now(timezone.utc).isoformat(),
},
)
self.connection.zadd(partition_key, {app_id: sequence_id})
Expand Down
11 changes: 8 additions & 3 deletions burr/integrations/persisters/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,19 @@ def from_values(
)
return cls(connection, table_name)

def __init__(self, connection, table_name: str = "burr_state"):
def __init__(self, connection, table_name: str = "burr_state", serde_kwargs: dict = None):
"""Constructor

:param connection: the connection to the PostgreSQL database.
:param table_name: the table name to store things under.
"""
self.table_name = table_name
self.connection = connection
self.serde_kwargs = serde_kwargs or {}

def set_serde_kwargs(self, serde_kwargs: dict):
skrawcz marked this conversation as resolved.
Show resolved Hide resolved
"""Sets the serde_kwargs for the persister."""
self.serde_kwargs = serde_kwargs

def create_table(self, table_name: str):
"""Helper function to create the table where things are stored."""
Expand Down Expand Up @@ -160,7 +165,7 @@ def load(
row = cursor.fetchone()
if row is None:
return None
_state = state.State(row[1])
_state = state.State.deserialize(row[1], **self.serde_kwargs)
return {
"partition_key": partition_key,
"app_id": row[3],
Expand Down Expand Up @@ -208,7 +213,7 @@ def save(
status,
)
cursor = self.connection.cursor()
json_state = json.dumps(state.get_all())
json_state = json.dumps(state.serialize(**self.serde_kwargs))
cursor.execute(
f"INSERT INTO {self.table_name} (partition_key, app_id, sequence_id, position, state, status) "
"VALUES (%s, %s, %s, %s, %s, %s)",
Expand Down
Empty file.
Loading
Loading