In [1]:
# !pip install pyarango

In [2]:
# future
from __future__ import annotations

# stdlib
from typing import *
from uuid import UUID
from collections import defaultdict

# third party
import pydantic
from pydantic import BaseModel
import pyArango
from pyArango.connection import *

from nacl.signing import SigningKey

# syft absolute
import syft as sy
from syft.core.common import UID
from syft.lib.python import Dict as SyDict



In [3]:
client = Connection(arangoURL='http://127.0.0.1:51930', username="root", password="somepassword")

In [4]:
if client.hasDatabase("app"):
    db = client["app"]
else:
    db = client.createDatabase(name="app")

In [5]:
if not db.hasCollection("users"):
    db.createCollection(name="users")

In [9]:
class SyftObjectRegistry:
    __object_version_registry__: Dict[str, Dict[int, Type[SyftObject]]] = defaultdict(lambda: {})
    def __init_subclass__(cls, **kwargs: Any) -> None:
        super().__init_subclass__(**kwargs)
        if hasattr(cls, "__canonical_name__"):
            cls.__object_version_registry__[cls.__canonical_name__][int(cls.__version__)] = cls

    @classmethod
    def versioned_class(cls, name: str, version: int) -> Optional[Type[SyftObject]]:
        if name not in cls.__object_version_registry__:
            return None
        classes = cls.__object_version_registry__[name]
        if version not in classes:
            return None
        return classes[version]

In [45]:
class SyftObject(BaseModel, SyftObjectRegistry):    
    class Config:
        arbitrary_types_allowed = True

    # all objects have a UID
    id: UID = None # consistent and persistent uuid across systems
    @pydantic.validator("id", pre=True, always=True)
    def make_id(cls, v):
        return v if isinstance(v, UID) else UID()
    
    __canonical_name__: str # the name which doesn't change even when there are multiple classes
    __version__: int # data is always versioned
    __attr_state__: List[str] # persistent recursive serde keys
    __attr_searchable__: List[str] # keys which can be searched in the ORM
    __attr_unique__: List[str] # the unique keys for the particular Collection the objects will be stored in

    def to_arango(self,doc) -> Dict[str, Any]:
        for k in self.__attr_searchable__:
            doc[k] = getattr(self, k)
        blob = self.to_bytes()
        doc["_key"] = self.id.value
        doc["__canonical_name__"] = self.__canonical_name__
        doc["__version__"] = self.__version__
        doc["__blob__"] = blob
        doc.save()

    @staticmethod
    def from_arango(bson: Any) -> SyftObject:
        constructor = SyftObjectRegistry.versioned_class(
            name=bson["__canonical_name__"], version=bson["__version__"]
        )
        return constructor(**sy.deserialize(bson["__blob__"], from_bytes=True).upcast())

    def to_bytes(self) -> bytes:
        d = SyDict(**self)
        return sy.serialize(d, to_bytes=True)
    
    @staticmethod
    def from_bytes(blob: bytes) -> SyftObject:
        return sy.deserialize(blob, from_bytes=True)

    # allows splatting with **
    def keys(self) -> KeysView[str]:
        return self.__dict__.keys()

    # allows splatting with **
    def __getitem__(self, key: str) -> Any:
        return self.__dict__.__getitem__(key)
    
    def _upgrade_version(self, latest: bool = True) -> SyftObject:
        constructor = SyftObjectRegistry.versioned_class(
            name=self.__canonical_name__, version=self.__version__+1
        )
        if not constructor:
            return self
        else:
            # should we do some kind of recursive upgrades?
            upgraded = constructor._from_previous_version(self)
            if latest:
                upgraded = upgraded._upgrade_version(latest=latest)
            return upgraded

In [46]:
class SyftUser(SyftObject):
    # version
    __canonical_name__ = "SyftUser"
    __version__ = 1

    # fields
    email: str
    name: str
    bad_key: bool = False

    # serde / storage rules
    __attr_state__ = ["email", "name", "bad_key"]
    __attr_searchable__ = ["email", "name", "bad_key"]
    __attr_unique__ = ["email"]

In [47]:
class SyftUserV2(SyftObject):
    # version
    __canonical_name__ = "SyftUser"
    __version__ = 2

    # fields
    email: str
    name: str
    signing_key: bytes

    # serde / storage rules
    __attr_state__ = ["email", "name"]
    __attr_searchable__ = ["email", "name"]
    __attr_unique__ = ["email"]
    
    @classmethod
    def _from_previous_version(cls, userv1: SyftUser) -> SyftUserV2:
        kwargs = dict(**userv1)
        kwargs.update({"signing_key":bytes(SigningKey.generate())})
        return cls(**kwargs) # ignore bad_key

In [100]:
# a collection is like a table of documents but with what ever shape you like
class SyftCollection:
    _db: str
    _collection_name: str
    _collection: pyArango.collection.Collection
    _syft_object_type: Dict[int, Type[SyftObject]]
    
    def __init__(self, client: pyArango.connection.Connection) -> None:
        self._db = client[self._db]
        self._collection = self._db[self._collection_name]

    def add(self, obj: SyftObject) -> SyftObject:
        doc = self._collection.createDocument()
        obj.to_arango(doc)

    def drop(self) -> None:
        self._collection.truncate()

    def delete() -> None: pass
    def update() -> None: pass
    def find(self, search_params: Dict[str, Any]) -> List[SyftObject]:
        results = []
        res = self._collection.find(search_params)
        for d in res:
            results.append(SyftObject.from_mongo(d))
        return results
    def find_one(self, search_params: Dict[str, Any]) -> Optional[SyftObject]:
        d = self._collection.find_one(search_params)
        if d is None:
            return d
        return SyftObject.from_arango(d)

In [101]:
# a collection of SyftUsers
class SyftUserCollection(SyftCollection):
    _db = "app"
    _collection_name = "users"
    __canonical_object_name__ = "SyftUser"

In [102]:
# do some object creation and serde

In [103]:
uid = UUID('3873fc45-f513-48ab-8a47-7306bc7382b0')

In [104]:
madhava = SyftUser(email="madhava@openmined.org", name="Madhava", id=uid)

In [105]:
ser = madhava.to_bytes()

In [106]:
de = SyftUser.from_bytes(ser)

In [107]:
assert madhava == de

In [108]:
key = SigningKey.generate()

In [109]:
madhava_v2 = SyftUserV2(email="madhava@openmined.org", name="Madhava", signing_key=bytes(key))

In [110]:
madhava_v2

SyftUserV2(id=<UID: 36adec3939944288ad2f23f857a28e1d>, email='madhava@openmined.org', name='Madhava', signing_key=b'i\xeb}3;\xd0M\x88\xbbN\x85\x876\xe5\x86\xe3\x93\t?j>\xa0\x1b\xfa\xa5\xddq+\x85\xc9\xbbk')

In [111]:
assert madhava_v2.__canonical_name__ == madhava.__canonical_name__

In [112]:
# do some collection stuff

In [116]:
user_collection = SyftUserCollection(client=client)
user_collection.drop()

In [117]:
user_collection.add(madhava)

In [95]:
madhava_db = user_collection._collection[madhava.id.value]

In [121]:
SyftObject.from_arango(madhava_db)

TypeError: memoryview: a bytes-like object is required, not 'str'

In [124]:
madhava_db["__blob__"]

"b'\\n\\tprotobuf:\\x12\\x14syft.lib.python.Dict\\x1a\\xa4\\x04\\n=\\n\\tprotobuf:\\x12\\x16syft.lib.python.String\\x1a\\x18\\n\\x02id\\x12\\x12\\n\\x10\\xf2\\x82\\x9d#\\xd67C\\x8e\\x8bMQ\\x1c\\x97\\x81!W\\n@\\n\\tprotobuf:\\x12\\x16syft.lib.python.String\\x1a\\x1b\\n\\x05email\\x12\\x12\\n\\x10U\\xa8\\x88\\xaa\\xfe\\xf2H\\xd4\\xa6.,a6\\x0c\\xa3Z\\n?\\n\\tprotobuf:\\x12\\x16syft.lib.python.String\\x1a\\x1a\\n\\x04name\\x12\\x12\\n\\x10\\xaf\\xedE\\xe2\\x17\\x9bG\\xff\\xbc\\x99\\xd9\\x9dr\\xaezg\\nB\\n\\tprotobuf:\\x12\\x16syft.lib.python.String\\x1a\\x1d\\n\\x07bad_key\\x12\\x12\\n\\x10\\x99]c\\x9flpDt\\xa8\\xa9\\xb20\\xc5Ap\\x1d\\x129\\n\\tprotobuf:\\x12\\x18syft.core.common.uid.UID\\x1a\\x12\\n\\x10\\xd2\\x03%\\xc1\\xd2\\xd6J\\x95\\xa1\\xd0\\x87\\xe4\\xc4\\xfc\\x1c\\x0e\\x12P\\n\\tprotobuf:\\x12\\x16syft.lib.python.String\\x1a+\\n\\x15madhava@openmined.org\\x12\\x12\\n\\x10\\xe2\\xb3\\xb4\\xe5\\xa4\\xd2L\\x93\\xba\\xd7%;\\xdd@\\x94!\\x12B\\n\\tprotobuf:\\x12\\x16syft.lib.python.Strin

In [97]:
type(madhava_db)

pyArango.document.Document

In [98]:
user_collection.add(madhava_v2)

In [None]:
madhavas = user_collection.find({"name": "Madhava"})

In [None]:
# a collection of different versioned types
upgraded = []
for m in madhavas:
    print(m.__version__, m)
    upgraded.append(m._upgrade_version())
    
for m in upgraded:
    print(m.__version__, m)