### Imports and model specifications

In [15]:
# Support fetch plan + AST tensor
import torch as th
import syft as sy
import torch.nn as nn
import torch.nn.functional as F

plan_func = True

hook = sy.TorchHook(th)
me = hook.local_worker
me.is_client_worker = False

alice = sy.VirtualWorker(hook, id="alice")
bob = sy.VirtualWorker(hook, id="bob")
charlie = sy.VirtualWorker(hook, id="charlie")
dan = sy.VirtualWorker(hook, id="dan")

if plan_func:
    @sy.func2plan(args_shape=[(1,)], state={"bias": th.tensor([3.0])})
    def plan_mult_3(x, state):
        bias = state.read("bias")
        x = x + bias
        return x
else:
    class Net(sy.Plan):
        def __init__(self):
            super(Net, self).__init__(id="net")
            self.fc1 = nn.Linear(1, 1)
            self.add_to_state(["fc1"])

        def forward(self, x):
            return self.fc1(x)
    
    plan_mult_3 = Net()
    plan_mult_3.build(th.tensor(1))
    
print([p for p in plan_mult_3.parameters()])
sent_plan = plan_mult_3.send(alice).fix_prec().share(bob, charlie, crypto_provider=dan)

# Fetch plan
fetched_plan = alice.fetch_plan(sent_plan.id)
x = th.tensor([1.])
x_ptr = x.fix_prec().share(bob, charlie, crypto_provider=dan)

# TODO: this should be stored automatically
me._objects[x_ptr.id] = x_ptr



[]


In [19]:
# TODO: this should be done internally
new_state_ids = []
for state_id in fetched_plan.state_ids:
    # TODO: we should not have direct access to the weights
    a_sh = me._objects[state_id].fix_prec().share(bob, charlie, crypto_provider=dan).get()
    # TODO: this should be stored automatically
    me._objects[a_sh.id] = a_sh
    new_state_ids.append(a_sh.id)

In [20]:
fetched_plan.replace_ids(fetched_plan.state_ids, new_state_ids)
fetched_plan.state_ids = new_state_ids

In [21]:
fetched_plan.forward

In [22]:
print(fetched_plan(x_ptr).get().float_prec())

tensor([4.])


In [24]:
# Support fetching a plan

import torch as th
import syft as sy
import torch.nn as nn
import torch.nn.functional as F

plan_func = True

hook = sy.TorchHook(th)
me = hook.local_worker
me.is_client_worker = False

alice = sy.VirtualWorker(hook, id="alice")

if plan_func:
    @sy.func2plan(args_shape=[(1,)], state={"bias": th.tensor([3.0])})
    def plan_mult_3(x, state):
        bias = state.read("bias")
        x = x * bias
        return x
else:
    class Net(sy.Plan):
        def __init__(self):
            super(Net, self).__init__()
            self.fc1 = nn.Linear(1, 1)
            self.add_to_state(["fc1"])

        def forward(self, x):
            return self.fc1(x)
    
    plan_mult_3 = Net()
    plan_mult_3.build(th.tensor(1))

sent_plan = plan_mult_3.send(alice)

# Fetch plan
fetched_plan = alice.fetch_plan(sent_plan.id)

x = th.tensor([1.])
print(fetched_plan(x))



tensor([3.])
