In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import syft as sy
import torch as th
from syft.core.node.vm.plan_vm import PlanVirtualMachine
from syft import Plan

In [3]:
import pytest

In [4]:
alice = PlanVirtualMachine(name="alice")
alice_client = alice.get_root_client()
remote_torch = alice_client.torch

In [5]:
from syft.core.plan.plan_builder import PLAN_BUILDER_VM, make_plan, build_plan_inputs, ROOT_CLIENT

# Option 1: make_plan

In [6]:
@make_plan
def add_plan(inp = th.zeros((3))):
    return inp + inp

In [7]:
input_tensor = th.tensor([1,2,3])

In [8]:
plan_pointer = add_plan.send(alice_client)

In [9]:
res = plan_pointer(inp=input_tensor)

In [10]:
assert th.equal(res.get()[0], th.tensor([2,4,6]))

In [11]:
@make_plan
def mul_plan(inp = th.zeros((3)), inp2 = th.zeros((3))):
    return inp * inp2

In [12]:
t1, t2 = th.tensor([1,2,3]), th.tensor([1,2,3])

In [13]:
plan_pointer = mul_plan.send(alice_client)

In [14]:
res = plan_pointer(inp=t1, inp2=t2)

[2021-03-02T10:20:48.048809+0100][CRITICAL][logger]][2202] <class 'syft.core.store.store_memory.MemoryStore'> __delitem__ error <UID: 6425801367d24b27b333e7f67bca7867>.


In [15]:
assert th.equal(res.get()[0], th.tensor([1,4,9]))

## assert error

In [16]:
def assertRaises(exc, obj, methodname, *args):
    with pytest.raises(exc) as e_info:
        getattr(obj, methodname)(*args)
    assert str(e_info) != ""

In [17]:
def test_define_plan():
    @make_plan
    def add_plan(inp):
        return inp + inp

In [18]:
assertRaises(ValueError, test_define_plan, "__call__")

# Extra use case for PlanBuilder, wrapping state

In [19]:
class PlanBuilder():
    def __init__(self, vm=None, wrapping_func=None):
        self.vm=PLAN_BUILDER_VM

        try:
            self.plan = self.build()
        except:
            raise
        
        self.plan_pointer=None

    def build(self, *args) -> 'PointerPlan':  
        inputs = build_plan_inputs(self.forward)

        self.vm.record_actions()
        res = self.forward(**inputs)
        self.vm.stop_recording()
        plan = Plan(actions=self.vm.recorded_actions, inputs=inputs, outputs=res)
        self.vm.recorded_actions=[]
        return plan

    def __call__(self, **kwargs):
        if self.plan_pointer is not None:
            return self.plan_pointer(**kwargs)
        else:
            client = PlanVirtualMachine(name="alice").get_root_client()
            print("Model is not remote yet, sending to a new VM")
            self.send_plan(client)
            return self.plan_pointer(**kwargs)
        
        
    def send_plan(self, client):
        self.plan_pointer = self.plan.send(client)


In [20]:
class TestModel(PlanBuilder):
    def __init__(self):
        self.model_pointer = th.tensor([1,2,3])
        super().__init__()
    
    def forward(self, x = th.tensor([0,0,0])):
        res = x * self.model_pointer
        return res

In [21]:
model = TestModel()

In [22]:
# model.send_plan(alice_client)

In [23]:
res = model(x=th.tensor([4,5,6]))
assert th.equal(*res.get(), th.tensor([4,10,18]))

[2021-03-02T10:20:53.074871+0100][CRITICAL][logger]][2202] <class 'syft.core.store.store_memory.MemoryStore'> __delitem__ error <UID: 45da735110b1414a9fa0a61d8a0fc534>.


Model is not remote yet, sending to a new VM
