### Imports and model specifications

In [1]:
# Import dependencies
import torch as th
import syft as sy
import torch.nn as nn
import torch.nn.functional as F
import grid as gr

# Hook
hook = sy.TorchHook(th)
me = hook.local_worker
me.is_client_worker = False
    
# Connect to nodes
alice = gr.WebsocketGridClient(hook, "http://localhost:3001", id="Alice")
alice.connect()
bob = gr.WebsocketGridClient(hook, "http://localhost:3000", id="Bob")
charlie = gr.WebsocketGridClient(hook, "http://localhost:3002", id="James")
dan = gr.WebsocketGridClient(hook, "http://localhost:3003", id="Dan")
bob.connect()
charlie.connect()
dan.connect()

gr.connect_all_nodes([bob, alice, charlie, dan])

In [None]:
# Model Owner
# Support fetch plan + AST tensor
class Net(sy.Plan):
    def __init__(self):
        super(Net, self).__init__(id="convnet")
        self.conv1 = nn.Conv2d(3, 4, 5, 1)
        self.fc1 = nn.Linear(3136, 40)
        self.fc2 = nn.Linear(40, 1)

        self.add_to_state(["conv1", "fc1", "fc2"])

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = x.view(-1, 3136)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

plan = Net()
    
data_shape = (1, 3, 32, 32)
data = th.zeros(data_shape)
plan.build(data)
print(plan(data))
    
sent_plan = plan.send(alice)

tensor([[-0.0209]], grad_fn=<AddmmBackward>)


In [None]:
from IPython.display import display_html

def restart_kernel() :
    display_html("<script>Jupyter.notebook.kernel.restart()</script>",raw=True)
    
restart_kernel()

In [1]:
# Import dependencies
import torch as th
import syft as sy
import torch.nn as nn
import torch.nn.functional as F
import grid as gr

# Hook
hook = sy.TorchHook(th)
me = hook.local_worker
me.is_client_worker = False
    
# Connect to nodes
alice = gr.WebsocketGridClient(hook, "http://localhost:3001", id="Alice")
alice.connect()
bob = gr.WebsocketGridClient(hook, "http://localhost:3000", id="Bob")
charlie = gr.WebsocketGridClient(hook, "http://localhost:3002", id="James")
dan = gr.WebsocketGridClient(hook, "http://localhost:3003", id="Dan")
bob.connect()
charlie.connect()
dan.connect()

gr.connect_all_nodes([bob, alice, charlie, dan])

In [2]:
# Fetch plan
fetched_plan = alice.fetch_plan("convnet")
data_shape = (1, 3, 32, 32)
data = th.zeros(data_shape)
x_ptr = data.fix_prec().share(bob, charlie, crypto_provider=dan)

# TODO: this should be stored automatically
me._objects[x_ptr.id] = x_ptr

In [3]:
me._objects

{26521377524: (Wrapper)>[PointerTensor | me:56353713837 -> Alice:64965785327],
 56353713837: (Wrapper)>[PointerTensor | me:56353713837 -> Alice:64965785327],
 12032377909: (Wrapper)>[PointerTensor | me:69960128432 -> Alice:92969008696],
 69960128432: (Wrapper)>[PointerTensor | me:69960128432 -> Alice:92969008696],
 96250629554: (Wrapper)>[PointerTensor | me:715510378 -> Alice:84862814872],
 715510378: (Wrapper)>[PointerTensor | me:715510378 -> Alice:84862814872],
 30122723808: (Wrapper)>[PointerTensor | me:35644423231 -> Alice:1698717535],
 35644423231: (Wrapper)>[PointerTensor | me:35644423231 -> Alice:1698717535],
 18164706485: (Wrapper)>[PointerTensor | me:51752319607 -> Alice:20180829760],
 51752319607: (Wrapper)>[PointerTensor | me:51752319607 -> Alice:20180829760],
 67972504464: (Wrapper)>[PointerTensor | me:16002842495 -> Alice:94412804842],
 16002842495: (Wrapper)>[PointerTensor | me:16002842495 -> Alice:94412804842],
 97739968624: tensor(4611686018427387904),
 48332848142: (Wr

In [4]:
# TODO: this should be done internally
new_state_ids = []
for state_id in fetched_plan.state_ids:
    # TODO: we should not have direct access to the weights
    a_sh = me._objects[state_id].fix_prec().share(bob, charlie, crypto_provider=dan).get()
    # TODO: this should be stored automatically
    me._objects[a_sh.id] = a_sh
    new_state_ids.append(a_sh.id)

In [5]:
fetched_plan.state_ids

[56353713837, 69960128432, 715510378, 35644423231, 51752319607, 16002842495]

In [6]:
fetched_plan.replace_ids(fetched_plan.state_ids, new_state_ids)
fetched_plan.state_ids = new_state_ids

In [7]:
%%time
print(fetched_plan(x_ptr).get().float_prec())

tensor([[-0.0200]])
CPU times: user 16.7 s, sys: 4.87 s, total: 21.6 s
Wall time: 20.6 s


In [8]:
# Support fetching a plan
plan_func = False

if plan_func:
    @sy.func2plan(args_shape=[(1,)], state={"bias": th.tensor([3.0])})
    def plan_mult_3(x, state):
        bias = state.read("bias")
        x = x * bias
        return x
else:
    class Net(sy.Plan):
        def __init__(self):
            super(Net, self).__init__(id="net2")
            self.fc1 = nn.Linear(1, 1)
            self.add_to_state(["fc1"])

        def forward(self, x):
            return self.fc1(x)
    
    plan_mult_3 = Net()
    plan_mult_3.build(th.tensor(1))

sent_plan = plan_mult_3.send(alice)
print(sent_plan.id)

net2


In [None]:
from IPython.display import display_html

def restart_kernel() :
    display_html("<script>Jupyter.notebook.kernel.restart()</script>",raw=True)
    
restart_kernel()

In [1]:
# Import dependencies
import torch as th
import syft as sy
import torch.nn as nn
import torch.nn.functional as F
import grid as gr

# Hook
hook = sy.TorchHook(th)
me = hook.local_worker
me.is_client_worker = False
    
# Connect to nodes
alice = gr.WebsocketGridClient(hook, "http://localhost:3001", id="Alice")
alice.connect()
bob = gr.WebsocketGridClient(hook, "http://localhost:3000", id="Bob")
charlie = gr.WebsocketGridClient(hook, "http://localhost:3002", id="James")
dan = gr.WebsocketGridClient(hook, "http://localhost:3003", id="Dan")
bob.connect()
charlie.connect()
dan.connect()

gr.connect_all_nodes([bob, alice, charlie, dan])

In [9]:
# Fetch plan
fetched_plan = alice.fetch_plan("net2")

In [10]:
# TODO: this should be done internally
new_state_ids = []
for state_id in fetched_plan.state_ids:
    # TODO: we should not have direct access to the weights
    a_sh = me._objects[state_id].get()
    # TODO: this should be stored automatically
    me._objects[a_sh.id] = a_sh
    new_state_ids.append(a_sh.id)

In [11]:
fetched_plan.replace_ids(fetched_plan.state_ids, new_state_ids)
fetched_plan.state_ids = new_state_ids

In [12]:
x = th.tensor([1.])
print(fetched_plan(x))

tensor([0.2805], requires_grad=True)
