In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import syft as sy
import torch as th
from syft.core.node.vm.plan_vm import PlanVirtualMachine
from syft import Plan

ModuleNotFoundError: No module named 'syft.core.node.common.action.common.Action'; 'syft.core.node.common.action.common' is not a package

In [None]:
alice = PlanVirtualMachine(name="alice")
alice_client = alice.get_root_client()
remote_torch = alice_client.torch

# Option 1: make_plan

In [None]:
def make_plan(func):
    def func_wrapper(*args):
        vm = PlanVirtualMachine(name="alice")
        client = vm.get_root_client()
        inputs = [arg.send(client) for arg in args]    
        vm.record_actions()
        res = func(*inputs)
        vm.stop_recording()
        plan = Plan(actions=vm.recorded_actions, inputs=inputs, outputs=res)
        return plan.send(client)
    return func_wrapper

In [5]:
@make_plan
def test_plan_builder(inp):
    return inp + inp

In [6]:
input_tensor = th.tensor([1,2,3])

In [7]:
test_plan_builder

<function __main__.make_plan.<locals>.func_wrapper(*args)>

In [8]:
test_plan_pointer = test_plan_builder(input_tensor)

In [9]:
res = test_plan_pointer(input_tensor)

In [10]:
res.get()[0]

tensor([2, 4, 6])

# Option 2: @make_plan2 with PlanBuilder

In [12]:
class PlanBuilder():
    def __init__(self, vm):
        self.vm=vm
    
    def build(self, *args) -> 'PointerPlan':
        client = self.vm.get_root_client()
        inputs = [arg.send(client) for arg in args]    
        self.vm.record_actions()
        res = self.forward(*inputs)
        self.vm.stop_recording()
        plan = Plan(actions=self.vm.recorded_actions, inputs=inputs, outputs=res)
        return plan.send(client)

In [13]:
def make_plan2(func):
    res = PlanBuilder(vm=PlanVirtualMachine(name="alice"))
    res.forward=func
    return res

In [14]:
@make_plan2
def test_plan_builder2(inp):
    return inp + inp

In [15]:
test_plan_builder2

<__main__.PlanBuilder at 0x7fa5c47f1990>

In [16]:
test_plan_pointer = test_plan_builder2.build(input_tensor)

In [17]:
res = test_plan_pointer(input_tensor)

[2021-03-01T10:42:27.850586+0100][CRITICAL][logger]][3697] <class 'syft.core.store.store_memory.MemoryStore'> __delitem__ error <UID: daadeff0d2bd4079a4b5e4b7820f5794>.


In [18]:
res.get()[0]

tensor([2, 4, 6])

# Extra use case for PlanBuilder, wrapping state

In [19]:
model_pointer1 = th.tensor([1,2,3]).send(alice_client)
input_tensor = th.tensor([1,2,3])

In [20]:
class TestPlanBuilder(PlanBuilder):
    def __init__(self, vm, model_pointer):
        super().__init__(vm)
        self.model_pointer = model_pointer
    
    def forward(self, x):
        res = x * self.model_pointer
        return res

In [21]:
plan_pointer = TestPlanBuilder(alice, model_pointer1).build(input_tensor)

In [22]:
res = plan_pointer(input_tensor)

[2021-03-01T10:42:29.179597+0100][CRITICAL][logger]][3697] <class 'syft.core.store.store_memory.MemoryStore'> __delitem__ error <UID: 04c53d4795e44c18ad3e949fb5e38e41>.


In [23]:
res.get()[0]

tensor([1, 4, 9])