In [1]:
import logging
logging.basicConfig(level=logging.WARN)

# stdlib
from functools import partial
from importlib import reload
import sys
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union as TypeUnion
import syft as sy
# third party
import pytest

# syft absolute
import syft
from syft.ast import add_dynamic_objects
from syft.ast.globals import Globals
from syft.core.node.abstract.node import AbstractNodeClient
from syft.core.node.common.client import Client
from syft.lib import lib_ast

from syft.core.test import module_test

sys.modules["module_test"] = module_test

module_test_methods = [
    ("module_test.A", "module_test.A"),
    ("module_test.A.__len__", "syft.lib.python.Int"),
    ("module_test.A.__iter__", "syft.lib.python.Iterator"),
    ("module_test.A.__next__", "syft.lib.python.Int"),
    ("module_test.A.test_method", "syft.lib.python.Int"),
    ("module_test.A.test_property", "syft.lib.python.Float"),
    ("module_test.A._private_attr", "syft.lib.python.Float"),
    ("module_test.A.static_method", "syft.lib.python.Float"),
    ("module_test.A.static_attr", "syft.lib.python.Int"),
    ("module_test.B.Car", "module_test.B"),
    ("module_test.C", "module_test.C"),
    ("module_test.C.type_reload_func", "syft.lib.python._SyNone"),
    ("module_test.C.obj_reload_func", "syft.lib.python._SyNone"),
    ("module_test.C.dummy_reloadable_func", "syft.lib.python.Int"),
    ("module_test.global_value", "syft.lib.python.Int"),
    ("module_test.global_function", "syft.lib.python.Int"),
]

dynamic_objects = [("module_test.C.dynamic_object", "syft.lib.python.Int")]


def update_ast_test(
    ast_or_client: TypeUnion[Globals, AbstractNodeClient],
    methods: List[Tuple[str, str]],
    dynamic_objects: Optional[List[Tuple[str, str]]] = None,
) -> None:
    """Checks functionality of update_ast, uses create_ast"""
    if isinstance(ast_or_client, Globals):
        ast = ast_or_client
        test_ast = create_ast_test(
            client=None, methods=methods, dynamic_objects=dynamic_objects
        )
        ast.add_attr(attr_name="module_test", attr=test_ast.attrs["module_test"])
    elif isinstance(ast_or_client, AbstractNodeClient):
        client = ast_or_client
        test_ast = create_ast_test(
            client=client, methods=methods, dynamic_objects=dynamic_objects
        )
        client.lib_ast.attrs["module_test"] = test_ast.attrs["module_test"]
        setattr(client, "module_test", test_ast.attrs["module_test"])
    else:
        raise ValueError(
            f"Expected param of type (Globals, AbstractNodeClient), but got {type(ast_or_client)}"
        )


def create_ast_test(
    client: Optional[AbstractNodeClient],
    methods: List[Tuple[str, str]],
    dynamic_objects: Optional[List[Tuple[str, str]]],
) -> Globals:
    """Unit test for create_ast functionality"""
    ast = Globals(client)

    for method, return_type in methods:
        ast.add_path(
            path=method, framework_reference=module_test, return_type_name=return_type
        )

    if dynamic_objects:
        add_dynamic_objects(ast, dynamic_objects)

    for klass in ast.classes:
        klass.create_pointer_class()
        klass.create_send_method()
        klass.create_storable_object_attr_convenience_methods()

    return ast


def register_module_test() -> None:
    """Test which is required for every other tests (runs first even in random execution)"""
    # Make lib_ast contain the specific methods/attributes
    update_ast_test(
        ast_or_client=syft.lib_ast,
        methods=module_test_methods,
        dynamic_objects=dynamic_objects,
    )

    # Make sure that when we register a new client it would update the specific AST
    lib_ast.loaded_lib_constructors["module_test"] = partial(
        update_ast_test, methods=module_test_methods, dynamic_objects=dynamic_objects
    )



def custom_client() -> Client:
    """Return VM for unit tests"""
    alice = syft.VirtualMachine(name="alice")
    alice_client = alice.get_root_client()

    return alice_client

register_module_test()

domain = syft.Domain("me")
root_client = domain.get_root_client()

domain.store

INITIALIZING IN MEMORY STORE!!!
INITIALIZING IN MEMORY STORE!!!
Getting all of those values


[]

In [2]:
value_to_set = 7.5

a_ptr = root_client.module_test.A()

Storing object:<UID: dbd9f0b44c194a2487c5d2e707add8b2> -> <Storable: <syft.core.test.module_test.A object at 0x7fb60023a400>>


In [3]:
a_ptr.test_property = value_to_set

Storing object:<UID: df2d20f4e78744f1a2e71464d01dfa70> -> <Storable: 7.5>
Begin executing action...
Storing object:<UID: dbd9f0b44c194a2487c5d2e707add8b2> -> <Storable: <syft.core.test.module_test.A object at 0x7fb5f1967100>>
Deleting object:<UID: dbd9f0b44c194a2487c5d2e707add8b2>
Storing object:<UID: 8bc90c2e0a8e435e903b5111a6d9f757> -> <Storable: <syft.lib.python._SyNone object at 0x7fb60091f430>>
Deleting object:<UID: df2d20f4e78744f1a2e71464d01dfa70>
Deleting object:<UID: 8bc90c2e0a8e435e903b5111a6d9f757>


In [5]:
result_ptr = a_ptr.test_property

Begin executing action...
Storing object:<UID: c330ec611b964874879343cd3511972f> -> <Storable: 7.5>


In [7]:
a = module_test.A()
a.test_property = value_to_set
result = a.test_property


result

7.5

In [8]:
result_ptr.get()

Deleting object:<UID: c330ec611b964874879343cd3511972f>


7.5