### Imports and model specifications

In [1]:
import torch as th
import syft as sy
import torch.nn as nn
import torch.nn.functional as F
import grid as gr

hook = sy.TorchHook(th)
me = hook.local_worker
me.is_client_worker = False

In [2]:
alice = gr.WebsocketGridClient(hook, "http://localhost:3001", id="Alice")
alice.connect()

In [3]:
bob = gr.WebsocketGridClient(hook, "http://localhost:3000", id="Bob")
charlie = gr.WebsocketGridClient(hook, "http://localhost:3002", id="James")
dan = gr.WebsocketGridClient(hook, "http://localhost:3003", id="Dan")
bob.connect()
charlie.connect()
dan.connect()

In [4]:
gr.connect_all_nodes([bob, alice, charlie, dan])

In [5]:
# Support fetch plan + AST tensor

plan_func = False


if plan_func:
    @sy.func2plan(args_shape=[(1,)], state={"bias": th.tensor([3.0])})
    def plan_mult_3(x, state):
        bias = state.read("bias")
        x = x + bias
        return x
else:
    class Net(sy.Plan):
        def __init__(self):
            super(Net, self).__init__(id="net")
            self.fc1 = nn.Linear(1, 1)
            self.add_to_state(["fc1"])

        def forward(self, x):
            return self.fc1(x)
    
    plan_mult_3 = Net()
    plan_mult_3.build(th.tensor(1))
    
print([p for p in plan_mult_3.parameters()])
sent_plan = plan_mult_3.send(alice).fix_prec().share(bob, charlie, crypto_provider=dan)

# Fetch plan
fetched_plan = alice.fetch_plan(sent_plan.id)
x = th.tensor([1.])
x_ptr = x.fix_prec().share(bob, charlie, crypto_provider=dan)

# TODO: this should be stored automatically
me._objects[x_ptr.id] = x_ptr

[Parameter containing:
tensor([[-0.2847]], requires_grad=True), Parameter containing:
tensor([-0.3792], requires_grad=True)]


In [6]:
# TODO: this should be done internally
id0, id1 = fetched_plan.state_ids

# TODO: we should not have direct access to the weights
a_sh = me._objects[id0].fix_prec().share(bob, charlie, crypto_provider=dan)
b_sh = me._objects[id1].fix_prec().share(bob, charlie, crypto_provider=dan)

# TODO: this should be stored automatically
me._objects[a_sh.id] = a_sh
me._objects[b_sh.id] = b_sh

In [7]:
new_state_ids = [a_sh.id, b_sh.id]

In [8]:
fetched_plan.replace_ids(fetched_plan.state_ids, new_state_ids)
fetched_plan.state_ids = new_state_ids

In [9]:
print(fetched_plan(x_ptr).get().float_prec())

tensor([-0.6630])


In [10]:
# Support fetching a plan
plan_func = False

if plan_func:
    @sy.func2plan(args_shape=[(1,)], state={"bias": th.tensor([3.0])})
    def plan_mult_3(x, state):
        bias = state.read("bias")
        x = x * bias
        return x
else:
    class Net(sy.Plan):
        def __init__(self):
            super(Net, self).__init__()
            self.fc1 = nn.Linear(1, 1)
            self.add_to_state(["fc1"])

        def forward(self, x):
            return self.fc1(x)
    
    plan_mult_3 = Net()
    plan_mult_3.build(th.tensor(1))

sent_plan = plan_mult_3.send(alice)

print(sent_plan.id)

# Fetch plan
fetched_plan = alice.fetch_plan(sent_plan.id)

x = th.tensor([1.])
print(fetched_plan(x))

73459554738
tensor([-0.2022], requires_grad=True)
