In [153]:
#imports
import numpy as np
import torch as th
import syft as sy

In [154]:
#creation of a hook
hook = sy.TorchHook(th)



In [155]:
#imports the optimizer 
from torch import nn, optim


In [156]:
#create the workers
bob = sy.VirtualWorker(hook,id = "bob")
alice = sy.VirtualWorker(hook,id= "alice")
secure_worker = sy.VirtualWorker(hook, id = "secure_worker")

In [157]:
#make the secure workers 
bob.add_workers([alice, secure_worker])
alice.add_workers([bob,secure_worker])
secure_worker.add_workers([alice, bob])



<VirtualWorker id:secure_worker #objects:12>

In [158]:
#make the toy dataset
data = th.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)
target = th.tensor([[0],[0],[1],[1.]], requires_grad=True)

In [159]:
# get the data  from the training data on each of the worker 
# sending the trainng data to bob and alice
bobs_data = data[0:2].send(bob)
bobs_target = data[0:2].send(bob)

In [160]:
#send the data to alice
alices_data = data[2:].send(alice)
alices_target = target[2:].send(alice)


In [161]:
#initilaize a toy model 
model = nn.Linear(2,1)


In [165]:
iterations = 10
worker_iters = 5

for a_iter in range(iterations):

    bobs_model = model.copy().send(bob)
    alices_model = model.copy().send(alice)

    bobs_opt = optim.SGD(params=bobs_model.parameters(), lr=0.1)
    alices_opt = optim.SGD(params=alices_model.parameters(), lr=0.1)

    for wi in range(worker_iters):
        # Train Bob's Model
        bobs_opt.zero_grad()
        bobs_pred = bobs_model(bobs_data)
        bobs_loss = ((bobs_pred - bobs_target) ** 2).sum()
        bobs_loss.backward()

        bobs_opt.step()
        bobs_loss = bobs_loss.get().data

        # Train Alice's Model
        alices_opt.zero_grad()
        alices_pred = alices_model(alices_data)
        alices_loss = ((alices_pred - alices_target) ** 2).sum()
        alices_loss.backward()

        alices_opt.step()
        alices_loss = alices_loss.get().data

    alices_model.move(secure_worker)
    bobs_model.move(secure_worker)

    with th.no_grad():

        model.weight.set_(((alices_model.weight.data + bobs_model.weight.data) / 2).get())
        model.bias.set_(((alices_model.bias.data + bobs_model.bias.data) / 2).get())
    
    print("Bob:" + str(bobs_loss) + " Alice:" + str(alices_loss))

Bob:tensor(0.5073) Alice:tensor(0.0208)
Bob:tensor(0.5072) Alice:tensor(0.0207)
Bob:tensor(0.5072) Alice:tensor(0.0207)
Bob:tensor(0.5072) Alice:tensor(0.0207)
Bob:tensor(0.5072) Alice:tensor(0.0207)
Bob:tensor(0.5072) Alice:tensor(0.0206)
Bob:tensor(0.5072) Alice:tensor(0.0206)
Bob:tensor(0.5071) Alice:tensor(0.0206)
Bob:tensor(0.5071) Alice:tensor(0.0206)
Bob:tensor(0.5071) Alice:tensor(0.0206)
