In [1]:
import torch
import syft as sy
from torch import nn, optim
import torch.nn.functional as f



In [2]:
class network(nn.Module):
    def __init__(self):
        super(network, self).__init__()
        self.fc1 = nn.Linear(2,2)
        self.fc2 = nn.Linear(2,1)
    
    def forward(self, x):
        x = self.fc1(x)
        x = f.relu(x)
        x = self.fc2(x)
        return x

final_model = network()

In [3]:
hook = sy.TorchHook(torch)

bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")
trusted_aggregrator = sy.VirtualWorker(hook, id="trusted_aggregrator")

In [4]:
x = torch.tensor([[1.,2],[2,3],[4,5],[2,4]])
y = torch.tensor([[1.],[0],[0],[1]])

In [5]:
bob_x = x[0:2].send(bob)
bob_y = y[0:2].send(bob)

alice_x = x[2:].send(alice)
alice_y = y[2:].send(alice)

model_bob = final_model.copy().send(bob)
model_alice = final_model.copy().send(alice)

bob_optim = optim.SGD(model_bob.parameters(), lr=0.003)
alice_optim = optim.SGD(model_alice.parameters(), lr=0.003)

data_x = [bob_x, alice_x]
data_y = [bob_y, alice_y]
models = [model_bob, model_alice]
optims = [bob_optim, alice_optim]

In [6]:
for x,y,m,o in zip(data_x, data_y, models, optims):
    for epoch in range(50):
        for i in range(1000):
            pred = m(x)
            o.zero_grad()
            loss = f.binary_cross_entropy(torch.sigmoid(pred), y)
            loss.backward()
            o.step()
        #print loss after each epoch    
        print("Model: " + str(x.location) + " Loss: " + str(loss.get().data))
        

Model: <VirtualWorker id:bob #tensors:8> Loss: tensor(0.6910)
Model: <VirtualWorker id:bob #tensors:8> Loss: tensor(0.6816)
Model: <VirtualWorker id:bob #tensors:8> Loss: tensor(0.6687)
Model: <VirtualWorker id:bob #tensors:8> Loss: tensor(0.6460)
Model: <VirtualWorker id:bob #tensors:8> Loss: tensor(0.6069)
Model: <VirtualWorker id:bob #tensors:8> Loss: tensor(0.5471)
Model: <VirtualWorker id:bob #tensors:8> Loss: tensor(0.4697)
Model: <VirtualWorker id:bob #tensors:8> Loss: tensor(0.3858)
Model: <VirtualWorker id:bob #tensors:8> Loss: tensor(0.3097)
Model: <VirtualWorker id:bob #tensors:8> Loss: tensor(0.2479)
Model: <VirtualWorker id:bob #tensors:8> Loss: tensor(0.2004)
Model: <VirtualWorker id:bob #tensors:8> Loss: tensor(0.1648)
Model: <VirtualWorker id:bob #tensors:8> Loss: tensor(0.1375)
Model: <VirtualWorker id:bob #tensors:8> Loss: tensor(0.1166)
Model: <VirtualWorker id:bob #tensors:8> Loss: tensor(0.1009)
Model: <VirtualWorker id:bob #tensors:8> Loss: tensor(0.0880)
Model: <

In [7]:
print(final_model.fc1.weight.data)
print(final_model.fc1.bias.data)
print(final_model.fc2.weight.data)
print(final_model.fc2.bias.data)

tensor([[-0.0929, -0.0340],
        [ 0.4715, -0.4788]])
tensor([0.1327, 0.5191])
tensor([[0.3127, 0.3289]])
tensor([0.3383])


In [8]:
#aggregrate the weights on trusted aggregrator worker and update the final_model's parameters
model_bob.move(trusted_aggregrator)
model_alice.move(trusted_aggregrator)

final_model.fc1.weight.data = ((model_bob.fc1.weight.data + model_alice.fc1.weight.data)/2).get()
final_model.fc1.bias.data = ((model_bob.fc1.bias.data + model_alice.fc1.bias.data)/2).get()
final_model.fc2.weight.data = ((model_bob.fc2.weight.data + model_alice.fc2.weight.data)/2).get()
final_model.fc2.bias.data = ((model_bob.fc2.bias.data + model_alice.fc2.bias.data)/2).get()

In [9]:
print(final_model.fc1.weight.data)
print(final_model.fc1.bias.data)
print(final_model.fc2.weight.data)
print(final_model.fc2.bias.data)

tensor([[-0.0929, -0.0340],
        [-0.8923, -0.1419]])
tensor([0.1327, 2.2199])
tensor([[0.3127, 2.3920]])
tensor([-1.8908])
