In [4]:
from functools import partial
from importlib import reload
import sys
import pandas as pd
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union as TypeUnion

In [7]:
import syft as sy
# syft absolute
import syft
from syft.ast import add_dynamic_objects
from syft.ast.globals import Globals
from syft.core.node.abstract.node import AbstractNodeClient
from syft.core.node.common.client import Client
from syft.lib import lib_ast
from syft.core.common.serde.recursive import RecursiveSerde

In [8]:
import module_test
sys.modules["module_test"] = module_test

In [9]:
data = pd.DataFrame({"A":[1,2,3]})
module_test.A(1,data).__dict__

{'n': 1,
 'data':    A
 0  1
 1  2
 2  3}

In [10]:
wrapped_type = module_test.A

In [11]:
# since we get the 
class Wrapper(RecursiveSerde,module_test.A):
    __attr_allowlist__ = ["n","data"]
    
    def __init__(self,value):
        self.obj = value
        super().__init__(value.n,value.data)
        
    def upcast(self):
        return self.obj
        
    @staticmethod
    def wrapped_type()-> type:
        return wrapped_type

In [14]:
from syft.util import aggressive_set_attr

# relevant part of GenerateWrapper
def add_wrapper_class(Wrapper:RecursiveSerde,wrapped_type:type, import_path:str):
    module_type = type(syft)
    module_parts = import_path.split(".")
    klass = module_parts.pop()
    Wrapper.__name__ = f"{klass}Wrapper"
    Wrapper.__module__ = f"syft.wrappers.{'.'.join(module_parts)}"
    # create a fake module `wrappers` under `syft`
    if "wrappers" not in syft.__dict__:
        syft.__dict__["wrappers"] = module_type(name="wrappers")
    # for each part of the path, create a fake module and add it to it's parent
    parent = syft.__dict__["wrappers"]
    for n in module_parts:
        if n not in parent.__dict__:
            parent.__dict__[n] = module_type(name=n)
        parent = parent.__dict__[n]
    # finally add our wrapper class to the end of the path
    parent.__dict__[Wrapper.__name__] = Wrapper

    aggressive_set_attr(
        obj=wrapped_type, name="_sy_serializable_wrapper_type", attr=Wrapper
    )

add_wrapper_class(Wrapper,wrapped_type=module_test.A,import_path="module_test.A")

In [15]:
proto = module_test.A._sy_serializable_wrapper_type(module_test.A(1,data))._object2proto()
module_test.A._sy_serializable_wrapper_type._proto2object(proto).__dict__

{'n': 1,
 'data':    A
 0  1
 1  2
 2  3}

In [16]:
# ast 

def update_ast_test(
    ast_or_client: TypeUnion[Globals, AbstractNodeClient],
    methods: List[Tuple[str, str]],
    dynamic_objects: Optional[List[Tuple[str, str]]] = None,
) -> None:
    """Checks functionality of update_ast, uses create_ast"""
    if isinstance(ast_or_client, Globals):
        ast = ast_or_client
        test_ast = create_ast_test(
            client=None, methods=methods, dynamic_objects=dynamic_objects
        )
        ast.add_attr(attr_name="module_test", attr=test_ast.attrs["module_test"])
    elif isinstance(ast_or_client, AbstractNodeClient):
        client = ast_or_client
        test_ast = create_ast_test(
            client=client, methods=methods, dynamic_objects=dynamic_objects
        )
        client.lib_ast.attrs["module_test"] = test_ast.attrs["module_test"]
        setattr(client, "module_test", test_ast.attrs["module_test"])
    else:
        raise ValueError(
            f"Expected param of type (Globals, AbstractNodeClient), but got {type(ast_or_client)}"
        )


def create_ast_test(
    client: Optional[AbstractNodeClient],
    methods: List[Tuple[str, str]],
    dynamic_objects: Optional[List[Tuple[str, str]]],
) -> Globals:
    """Unit test for create_ast functionality"""
    ast = Globals(client)

    for method, return_type in methods:
        ast.add_path(
            path=method, framework_reference=module_test, return_type_name=return_type
        )

    if dynamic_objects:
        add_dynamic_objects(ast, dynamic_objects)

    for klass in ast.classes:
        klass.create_pointer_class()
        klass.create_send_method()
        klass.create_storable_object_attr_convenience_methods()

    return ast


In [17]:
module_test_methods = [
    ("module_test.A","module_test.A"),
    ("module_test.A.get_n","syft.lib.python.Int"),
    ("module_test.A.get_data","pandas.DataFrame")
]

In [18]:
update_ast_test(
        ast_or_client=syft.lib_ast,
        methods=module_test_methods
    )

    # Make sure that when we register a new client it would update the specific AST
lib_ast.loaded_lib_constructors["module_test"] = partial(
        update_ast_test, methods=module_test_methods
    )

client = sy.VirtualMachine().get_root_client()

In [19]:
lib_ast.attrs.get("module_test"),lib_ast.attrs.get("module_test").attrs.get("A").attrs

(Module:
 	.A -> <syft.ast.klass.Class object at 0x7f8557043ac0>,
 {'get_n': <syft.ast.callable.Callable at 0x7f8557043b80>,
  'get_data': <syft.ast.callable.Callable at 0x7f8557043be0>})

In [20]:
a_ptr = client.module_test.A(1,data)

SaveObjectAction <Storable: 1>
<syft.core.node.common.node_service.object_action.obj_action_service.ImmediateObjectActionServiceWithoutReply object at 0x7f855704dc10>
SaveObjectAction <Storable:   A0 11 22 3>
<syft.core.node.common.node_service.object_action.obj_action_service.ImmediateObjectActionServiceWithoutReply object at 0x7f855704dc10>
RunClassMethodAction A(IntPointer,DataFramePointer, )
<syft.core.node.common.node_service.object_action.obj_action_service.ImmediateObjectActionServiceWithoutReply object at 0x7f855704dc10>
<GarbageCollectObjectAction: d873b7c143f847c4b1bfacfdd1c7de83>
<syft.core.node.common.node_service.object_action.obj_action_service.EventualObjectActionServiceWithoutReply object at 0x7f855704d640>
<GarbageCollectObjectAction: 8d68b28718524e6588059ad8fff1650e>
<syft.core.node.common.node_service.object_action.obj_action_service.EventualObjectActionServiceWithoutReply object at 0x7f855704d640>


In [21]:
a_ptr.get_n().get()

RunClassMethodAction APointer.get_n(, )
<syft.core.node.common.node_service.object_action.obj_action_service.ImmediateObjectActionServiceWithoutReply object at 0x7f855704dc10>
<GetObjectAction: 83563230ea6e433c9b2897c312158bdf>
<syft.core.node.common.node_service.object_action.obj_action_service.ImmediateObjectActionServiceWithReply object at 0x7f855704dee0>


1

In [22]:
a_ptr.get_data().get()

RunClassMethodAction APointer.get_data(, )
<syft.core.node.common.node_service.object_action.obj_action_service.ImmediateObjectActionServiceWithoutReply object at 0x7f855704dc10>
<GetObjectAction: 931a19542e304635820d8b9ba1fd67c4>
<syft.core.node.common.node_service.object_action.obj_action_service.ImmediateObjectActionServiceWithReply object at 0x7f855704dee0>


Unnamed: 0,A
0,1
1,2
2,3
