In [1]:
from functools import partial
from importlib import reload
import sys
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union as TypeUnion

In [2]:
import syft as sy
# syft absolute
import syft
from syft.ast import add_dynamic_objects
from syft.ast.globals import Globals
from syft.core.node.abstract.node import AbstractNodeClient
from syft.core.node.common.client import Client
from syft.lib import lib_ast
from syft.core.common.serde.recursive import RecursiveSerde

In [3]:
import module_test
sys.modules["module_test"] = module_test

In [4]:
wrapped_type = module_test.A

In [5]:
class Wrapper(RecursiveSerde,module_test.A):
    __attr_allowlist__ = ["n"]
    
    def __init__(self,value):
        self.obj = value
        super().__init__(value.n)
        
    def upcast(self):
        return self.obj
        
    @staticmethod
    def wrapped_type()-> type:
        return wrapped_type

In [6]:
from syft.util import aggressive_set_attr

module_type = type(syft)
wrapped_type = module_test.A
import_path = "module_test.A"

# relevant part of GenerateWrapper
module_parts = import_path.split(".")
klass = module_parts.pop()
Wrapper.__name__ = f"{klass}Wrapper"
Wrapper.__module__ = f"syft.wrappers.{'.'.join(module_parts)}"
# create a fake module `wrappers` under `syft`
if "wrappers" not in syft.__dict__:
    syft.__dict__["wrappers"] = module_type(name="wrappers")
# for each part of the path, create a fake module and add it to it's parent
parent = syft.__dict__["wrappers"]
for n in module_parts:
    if n not in parent.__dict__:
        parent.__dict__[n] = module_type(name=n)
    parent = parent.__dict__[n]
# finally add our wrapper class to the end of the path
parent.__dict__[Wrapper.__name__] = Wrapper

aggressive_set_attr(
    obj=wrapped_type, name="_sy_serializable_wrapper_type", attr=Wrapper
)

In [7]:
proto = module_test.A._sy_serializable_wrapper_type(module_test.A(1))._object2proto()
proto, module_test.A._sy_serializable_wrapper_type._proto2object(proto).__dict__

(data: "\n\024syft.lib.python.Dict\022v\n1\n\026syft.lib.python.String\022\027\n\001n\022\022\n\020\337R\274M\364\310LI\201\n\315\312\230\257\233O\022-\n\023syft.lib.python.Int\022\026\010\001\022\022\n\020=\362\036\344\215)A\000\2313\\\r\205\022\021w\032\022\n\020C\3313\344\237\246A\274\247\317\324\272\246\007\340\256"
 fully_qualified_name: "syft.wrappers.module_test.AWrapper",
 {'n': 1})

In [8]:
# ast 

def update_ast_test(
    ast_or_client: TypeUnion[Globals, AbstractNodeClient],
    methods: List[Tuple[str, str]],
    dynamic_objects: Optional[List[Tuple[str, str]]] = None,
) -> None:
    """Checks functionality of update_ast, uses create_ast"""
    if isinstance(ast_or_client, Globals):
        ast = ast_or_client
        test_ast = create_ast_test(
            client=None, methods=methods, dynamic_objects=dynamic_objects
        )
        ast.add_attr(attr_name="module_test", attr=test_ast.attrs["module_test"])
    elif isinstance(ast_or_client, AbstractNodeClient):
        client = ast_or_client
        test_ast = create_ast_test(
            client=client, methods=methods, dynamic_objects=dynamic_objects
        )
        client.lib_ast.attrs["module_test"] = test_ast.attrs["module_test"]
        setattr(client, "module_test", test_ast.attrs["module_test"])
    else:
        raise ValueError(
            f"Expected param of type (Globals, AbstractNodeClient), but got {type(ast_or_client)}"
        )


def create_ast_test(
    client: Optional[AbstractNodeClient],
    methods: List[Tuple[str, str]],
    dynamic_objects: Optional[List[Tuple[str, str]]],
) -> Globals:
    """Unit test for create_ast functionality"""
    ast = Globals(client)

    for method, return_type in methods:
        ast.add_path(
            path=method, framework_reference=module_test, return_type_name=return_type
        )

    if dynamic_objects:
        add_dynamic_objects(ast, dynamic_objects)

    for klass in ast.classes:
        klass.create_pointer_class()
        klass.create_send_method()
        klass.create_storable_object_attr_convenience_methods()

    return ast


In [9]:
module_test_methods = [
    ("module_test.A","module_test.A"),
    ("module_test.A.get_n","syft.lib.python.Int")
]

In [10]:
update_ast_test(
        ast_or_client=syft.lib_ast,
        methods=module_test_methods
    )

    # Make sure that when we register a new client it would update the specific AST
lib_ast.loaded_lib_constructors["module_test"] = partial(
        update_ast_test, methods=module_test_methods
    )

client = sy.VirtualMachine().get_root_client()

In [11]:
lib_ast.attrs.get("module_test"),lib_ast.attrs.get("module_test").attrs.get("A").attrs

(Module:
 	.A -> <syft.ast.klass.Class object at 0x7f70dd488c40>,
 {'get_n': <syft.ast.callable.Callable at 0x7f70dd4883a0>})

In [12]:
a_ptr = client.module_test.A(1)

SaveObjectAction <Storable: 1>
<syft.core.node.common.node_service.object_action.obj_action_service.ImmediateObjectActionServiceWithoutReply object at 0x7f70dd4b75b0>
RunClassMethodAction A(IntPointer, )
<syft.core.node.common.node_service.object_action.obj_action_service.ImmediateObjectActionServiceWithoutReply object at 0x7f70dd4b75b0>
<GarbageCollectObjectAction: fdfcce4821524bc7b4a30e64d9c89b10>
<syft.core.node.common.node_service.object_action.obj_action_service.EventualObjectActionServiceWithoutReply object at 0x7f70dd4b71c0>


In [13]:
a_ptr.get_n().get()

RunClassMethodAction APointer.get_n(, )
<syft.core.node.common.node_service.object_action.obj_action_service.ImmediateObjectActionServiceWithoutReply object at 0x7f70dd4b75b0>
<GetObjectAction: a49a02236e144dc7b2790da3b2bd9e48>
<syft.core.node.common.node_service.object_action.obj_action_service.ImmediateObjectActionServiceWithReply object at 0x7f70dd48e610>


1