In [1]:
import syft as sy
import torch as th
hook = sy.TorchHook(th)
from torch import nn, optim

W0710 18:55:05.340134  4052 secure_random.py:22] Falling back to insecure randomness since the required custom op could not be found for the installed version of TensorFlow (1.14.0). Fix this by compiling custom ops.
W0710 18:55:05.399016  4052 deprecation_wrapper.py:119] From e:\anaconda3\envs\pysyft\lib\site-packages\tf_encrypted\session.py:28: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.



In [2]:
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")
secure_worker = sy.VirtualWorker(hook, id="secure_worker")

In [3]:
data = th.tensor([[0, 0], [0, 1,], [1, 0], [1, 1.]], requires_grad=True)
target = th.tensor([[0], [0], [1], [1.]], requires_grad=True)

In [4]:
bobs_data = data[0:2].send(bob)
bobs_target = target[0:2].send(bob)

In [5]:
alices_data = data[2:].send(alice)
alices_target = target[2:].send(alice)

In [6]:
model = nn.Linear(2, 1)

In [9]:
for iter_round in range(10):
    bob_model = model.copy().send(bob)
    alice_model = model.copy().send(alice)

    bob_opt = optim.SGD(params=bob_model.parameters(), lr=0.1)
    alice_opt = optim.SGD(params=alice_model.parameters(), lr=0.1)

    for i in range(10):
        bob_opt.zero_grad()
        bob_pred = bob_model(bobs_data)
        bob_loss = ((bob_pred - bobs_target)**2).sum()
        bob_loss.backward()

        bob_opt.step()
        bob_loss = bob_loss.get().data

        alice_opt.zero_grad()
        alice_pred = alice_model(alices_data)
        alice_loss = ((alice_pred - alices_target)**2).sum()
        alice_loss.backward()

        alice_opt.step()
        alice_loss = alice_loss.get().data

    alice_model.move(secure_worker)
    bob_model.move(secure_worker)

    with th.no_grad():
        # this happen in the secure worker
        avg_weight = (alice_model.weight.data + bob_model.weight.data) / 2
        avg_bias = (alice_model.bias.data + bob_model.bias.data) / 2
        # we pull the resulting avg to our local machine
        avg_weight = avg_weight.get()
        avg_bias = avg_bias.get()
        # update model weight based on above
        model.weight.set_(avg_weight)
        model.bias.set_(avg_bias)
    
    # clear out our 3rd party
    secure_worker.clear_objects()
        
    print("Bob: ", str(bob_loss), "Alice: ", str(alice_loss))

Bob:  tensor(0.0001) Alice:  tensor(2.0943e-05)
Bob:  tensor(9.1351e-05) Alice:  tensor(1.5589e-05)
Bob:  tensor(6.8902e-05) Alice:  tensor(1.1665e-05)
Bob:  tensor(5.1909e-05) Alice:  tensor(8.7527e-06)
Bob:  tensor(3.9083e-05) Alice:  tensor(6.5765e-06)
Bob:  tensor(2.9418e-05) Alice:  tensor(4.9444e-06)
Bob:  tensor(2.2139e-05) Alice:  tensor(3.7191e-06)
Bob:  tensor(1.6660e-05) Alice:  tensor(2.7978e-06)
Bob:  tensor(1.2536e-05) Alice:  tensor(2.1050e-06)
Bob:  tensor(9.4330e-06) Alice:  tensor(1.5838e-06)
