In [1]:
# !pip install pymongo

In [2]:
# future
from __future__ import annotations

# stdlib
from typing import *
from uuid import UUID
from collections import defaultdict

# third party
import pydantic
from pydantic import BaseModel
import pymongo
from pymongo import MongoClient
from nacl.signing import SigningKey

# syft absolute
import syft as sy
from syft.core.common import UID
from syft.lib.python import Dict as SyDict

In [3]:
client = MongoClient(host='localhost', port=57100, username="root", password="example", uuidRepresentation="standard")

In [4]:
db = client["app"]

In [5]:
db.list_collection_names()

['users']

In [6]:
class SyftObjectRegistry:
    __object_version_registry__: Dict[str, Dict[int, Type[SyftObject]]] = defaultdict(lambda: {})
    def __init_subclass__(cls, **kwargs: Any) -> None:
        super().__init_subclass__(**kwargs)
        if hasattr(cls, "__canonical_name__"):
            cls.__object_version_registry__[cls.__canonical_name__][int(cls.__version__)] = cls

    @classmethod
    def versioned_class(cls, name: str, version: int) -> Optional[Type[SyftObject]]:
        if name not in cls.__object_version_registry__:
            return None
        classes = cls.__object_version_registry__[name]
        if version not in classes:
            return None
        return classes[version]

In [7]:
class SyftObject(BaseModel, SyftObjectRegistry):    
    class Config:
        arbitrary_types_allowed = True

    # all objects have a UID
    id: UID = None # consistent and persistent uuid across systems
    @pydantic.validator("id", pre=True, always=True)
    def make_id(cls, v):
        return v if isinstance(v, UID) else UID()
    
    __canonical_name__: str # the name which doesn't change even when there are multiple classes
    __version__: int # data is always versioned
    __attr_state__: List[str] # persistent recursive serde keys
    __attr_searchable__: List[str] # keys which can be searched in the ORM
    __attr_unique__: List[str] # the unique keys for the particular Collection the objects will be stored in

    def to_mongo(self) -> Dict[str, Any]:
        d = {}
        for k in self.__attr_searchable__:
            d[k] = getattr(self, k)
        blob = self.to_bytes()
        d["_id"] = self.id.value
        d["__canonical_name__"] = self.__canonical_name__
        d["__version__"] = self.__version__
        d["__blob__"] = blob

        return d

    @staticmethod
    def from_mongo(bson: Any) -> SyftObject:
        constructor = SyftObjectRegistry.versioned_class(
            name=bson["__canonical_name__"], version=bson["__version__"]
        )
        return constructor(**sy.deserialize(bson["__blob__"], from_bytes=True).upcast())

    def to_bytes(self) -> bytes:
        d = SyDict(**self)
        return sy.serialize(d, to_bytes=True)
    
    @staticmethod
    def from_bytes(blob: bytes) -> SyftObject:
        return sy.deserialize(blob, from_bytes=True)

    # allows splatting with **
    def keys(self) -> KeysView[str]:
        return self.__dict__.keys()

    # allows splatting with **
    def __getitem__(self, key: str) -> Any:
        return self.__dict__.__getitem__(key)
    
    def _upgrade_version(self, latest: bool = True) -> SyftObject:
        constructor = SyftObjectRegistry.versioned_class(
            name=self.__canonical_name__, version=self.__version__+1
        )
        if not constructor:
            return self
        else:
            # should we do some kind of recursive upgrades?
            upgraded = constructor._from_previous_version(self)
            if latest:
                upgraded = upgraded._upgrade_version(latest=latest)
            return upgraded

In [8]:
class SyftUser(SyftObject):
    # version
    __canonical_name__ = "SyftUser"
    __version__ = 1

    # fields
    email: str
    name: str
    bad_key: bool = False

    # serde / storage rules
    __attr_state__ = ["email", "name", "bad_key"]
    __attr_searchable__ = ["email", "name", "bad_key"]
    __attr_unique__ = ["email"]

In [9]:
class SyftUserV2(SyftObject):
    # version
    __canonical_name__ = "SyftUser"
    __version__ = 2

    # fields
    email: str
    name: str
    signing_key: bytes

    # serde / storage rules
    __attr_state__ = ["email", "name"]
    __attr_searchable__ = ["email", "name"]
    __attr_unique__ = ["email"]
    
    @classmethod
    def _from_previous_version(cls, userv1: SyftUser) -> SyftUserV2:
        kwargs = dict(**userv1)
        kwargs.update({"signing_key":bytes(SigningKey.generate())})
        return cls(**kwargs) # ignore bad_key

In [10]:
# a collection is like a table of documents but with what ever shape you like
class SyftCollection:
    _db: str
    _collection_name: str
    _collection: pymongo.collection.Collection
    _syft_object_type: Dict[int, Type[SyftObject]]
    
    def __init__(self, client: pymongo.mongo_client.MongoClient) -> None:
        self._db = client[self._db]
        self._collection = self._db[self._collection_name]

    def add(self, obj: SyftObject) -> SyftObject:
        self._collection.insert_one(obj.to_mongo())

    def drop(self) -> None:
        self._collection.drop()

    def delete() -> None: pass
    def update() -> None: pass
    def find(self, search_params: Dict[str, Any]) -> List[SyftObject]:
        results = []
        res = self._collection.find(search_params)
        for d in res:
            results.append(SyftObject.from_mongo(d))
        return results
    def find_one(self, search_params: Dict[str, Any]) -> Optional[SyftObject]:
        d = self._collection.find_one(search_params)
        if d is None:
            return d
        return SyftObject.from_mongo(d)

In [11]:
# a collection of SyftUsers
class SyftUserCollection(SyftCollection):
    _db = "app"
    _collection_name = "users"
    __canonical_object_name__ = "SyftUser"

In [12]:
# do some object creation and serde

In [13]:
uid = UUID('3873fc45-f513-48ab-8a47-7306bc7382b0')

In [14]:
madhava = SyftUser(email="madhava@openmined.org", name="Madhava", id=uid)

In [15]:
ser = madhava.to_bytes()

In [16]:
de = SyftUser.from_bytes(ser)

In [17]:
assert madhava == de

In [18]:
key = SigningKey.generate()

In [19]:
madhava_v2 = SyftUserV2(email="madhava@openmined.org", name="Madhava", signing_key=bytes(key))

In [20]:
madhava_v2

SyftUserV2(id=<UID: 381f96a1451347d3b49b021ee5e9b770>, email='madhava@openmined.org', name='Madhava', signing_key=b'\x96h \xad\xcd\x88\xa9y\x82f\xdc\xfa\xce\xd6\xf5IV.\xfe\x00\xdb%\xff\xd3\xd2_\x19\xbe%e\x17\xb9')

In [21]:
assert madhava_v2.__canonical_name__ == madhava.__canonical_name__

In [22]:
# do some collection stuff

In [23]:
user_collection = SyftUserCollection(client=client)
user_collection.drop()

In [24]:
user_collection.add(madhava)

In [25]:
try:
    user_collection.add(madhava)
except pymongo.errors.DuplicateKeyError as e:
    print("Duplicate key")

Duplicate key


In [26]:
madhava = user_collection.find_one({"name": "Madhava"})

In [27]:
type(madhava)

__main__.SyftUser

In [28]:
user_collection.add(madhava_v2)

In [29]:
madhavas = user_collection.find({"name": "Madhava"})

In [30]:
# a collection of different versioned types
upgraded = []
for m in madhavas:
    print(m.__version__, m)
    upgraded.append(m._upgrade_version())
    
for m in upgraded:
    print(m.__version__, m)

1 id=<UID: 9cc591d7c101493c870672f946d14482> email='madhava@openmined.org' name='Madhava' bad_key=False
2 id=<UID: 381f96a1451347d3b49b021ee5e9b770> email='madhava@openmined.org' name='Madhava' signing_key=b'\x96h \xad\xcd\x88\xa9y\x82f\xdc\xfa\xce\xd6\xf5IV.\xfe\x00\xdb%\xff\xd3\xd2_\x19\xbe%e\x17\xb9'
2 id=<UID: 9cc591d7c101493c870672f946d14482> email='madhava@openmined.org' name='Madhava' signing_key=b'=\x86\xcd\xa2Y{\xfc+\xd5Vs:\xb0q\xe0\xb7\x04\xb0\x11H\x8e\x13+\x1b\x87;\x8e\xa1T\xba\xeca'
2 id=<UID: 381f96a1451347d3b49b021ee5e9b770> email='madhava@openmined.org' name='Madhava' signing_key=b'\x96h \xad\xcd\x88\xa9y\x82f\xdc\xfa\xce\xd6\xf5IV.\xfe\x00\xdb%\xff\xd3\xd2_\x19\xbe%e\x17\xb9'
