In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import syft as sy
import torch as th
import sys

from syft.core.node.common.action.get_object_action import GetObjectAction
from syft.core.node.common.action.function_or_constructor_action import (
    RunFunctionOrConstructorAction,
)
from syft.core.node.common.action.run_class_method_action import RunClassMethodAction
from syft.core.node.common.action.garbage_collect_object_action import (
    GarbageCollectObjectAction,
)
from syft.core.node.common.action.get_enum_attribute_action import EnumAttributeAction
from syft.core.node.common.action.get_or_set_property_action import (
    GetOrSetPropertyAction,
    PropertyActions,
)
from syft.core.node.common.action.get_or_set_static_attribute_action import (
    GetSetStaticAttributeAction,
    StaticAttributeAction,
)
from syft.core.node.common.action.save_object_action import SaveObjectAction
from syft.core.store.storeable_object import StorableObject


from syft.core.node.common.action.common import Action
from syft.proto.core.plan.plan_pb2 import Plan as Plan_PB

from typing import List
from syft.core.common.uid import UID
from syft.core.io.address import Address
from syft.core.common.serde.deserialize import _deserialize
from syft.core.common.object import Serializable
from syft.proto.core.node.common.action.action_pb2 import Action as Action_PB
from syft import Plan
from syft.core.pointer.pointer import Pointer

# Plans

## Serialization

In [3]:
# cumbersome way to get a pointer as input for our actions,
# there is probably a better/shorter way
alice = sy.VirtualMachine(name="alice")
alice_client = alice.get_client()
t = th.tensor([1, 2, 3])
tensor_pointer = t.send(alice_client)

In [4]:
a1 = GetObjectAction(
    id_at_location=UID(), address=Address(), reply_to=Address(), msg_id=UID()
)
a2 = RunFunctionOrConstructorAction(
    path="torch.Tensor.add",
    args=tuple(),
    kwargs={},
    id_at_location=UID(),
    address=Address(),
    msg_id=UID(),
)

a3 = RunClassMethodAction(
    path="torch.Tensor.add",
    _self=tensor_pointer,
    args=[],
    kwargs={},
    id_at_location=UID(),
    address=Address(),
    msg_id=UID(),
)

a4 = GarbageCollectObjectAction(id_at_location=UID(), address=Address())
a5 = EnumAttributeAction(path="", id_at_location=UID(), address=Address())

a6 = GetOrSetPropertyAction(
    path="",
    _self=tensor_pointer,
    id_at_location=UID(),
    address=Address(),
    args=[],
    kwargs={},
    action=PropertyActions.GET,
)
a7 = GetSetStaticAttributeAction(
    path="", id_at_location=UID(), address=Address(), action=StaticAttributeAction.GET
)
a8 = SaveObjectAction(obj=StorableObject(id=UID(), data=t), address=Address())

In [5]:
plan = Plan([a1, a2, a3, a4, a5, a6, a7, a8])

In [6]:
blob = sy.serialize(plan)
plan_reconstructed = sy.deserialize(blob=blob)

In [7]:
assert isinstance(plan_reconstructed, Plan)
assert all(isinstance(a, Action) for a in plan_reconstructed.actions)

## Batched Execution

In [8]:
alice = sy.VirtualMachine(name="alice")
alice_client = alice.get_client()

In [9]:
# placeholders for our input
input_tensor_pointer1 = th.tensor([0, 0]).send(alice_client)
input_tensor_pointer2 = th.tensor([0, 0]).send(alice_client)

# tensors in our model
model_tensor_pointer1 = th.tensor([1, 2]).send(alice_client)
model_tensor_pointer2 = th.tensor([3, 4]).send(alice_client)

# placeholders for intermediate results
result_tensor_pointer1 = th.tensor([0, 0]).send(alice_client)
result_tensor_pointer2 = th.tensor([0, 0]).send(alice_client)
result_tensor_pointer3 = th.tensor([0, 0]).send(alice_client)

result1_uid = result_tensor_pointer1.id_at_location
result2_uid = result_tensor_pointer2.id_at_location
result3_uid = result_tensor_pointer3.id_at_location

In [10]:
a1 = RunClassMethodAction(
    path="torch.Tensor.mul",
    _self=input_tensor_pointer1,
    args=[model_tensor_pointer1],
    kwargs={},
    id_at_location=result1_uid,
    address=Address(),
    msg_id=UID(),
)

a2 = RunClassMethodAction(
    path="torch.Tensor.add",
    _self=result_tensor_pointer1,
    args=[model_tensor_pointer2],
    kwargs={},
    id_at_location=result2_uid,
    address=Address(),
    msg_id=UID(),
)

a3 = RunFunctionOrConstructorAction(
    path="torch.eq",
    args=[result_tensor_pointer2, input_tensor_pointer2],
    kwargs={},
    id_at_location=result3_uid,
    address=Address(),
    msg_id=UID(),
)

plan = Plan(
    [a1, a2, a3],
    inputs={"x": input_tensor_pointer1, "y": input_tensor_pointer2},
    outputs=result_tensor_pointer3,
)

In [11]:
plan_pointer = plan.send(alice_client)

In [12]:
# x is random input, y is the expected model(x)
x_batches = [(th.tensor([1, 1]) + i).send(alice_client) for i in range(2)]
y_batches = [
    (((th.tensor([1, 1]) + i) * th.tensor([1, 2])) + th.tensor([3, 4])).send(
        alice_client
    )
    for i in range(2)
]

for x, y in zip(x_batches, y_batches):

    res = plan_pointer(x=x, y=y)
    # checks if (model(x) == y) == [True, True]
    assert all(*res.get())