In [1]:
import syft as sy
import torch as t
hook = sy.TorchHook(t)
from torch import nn, optim

In [2]:
bob = sy.VirtualWorker(hook,id="bob")
alice = sy.VirtualWorker(hook,id="alice")
secure_worker = sy.VirtualWorker(hook,id="secure_worker")

In [3]:
bob.add_workers([alice, secure_worker])
alice.add_workers([bob, secure_worker])
secure_worker.add_workers([bob, alice])



<VirtualWorker id:secure_worker #objects:0>

In [4]:
#Toy Dataset
data = t.tensor([[0,0],[0,1],[1,0],[1,1.]],requires_grad=True)
target = t.tensor([[0],[0],[1],[1.]], requires_grad=True)

In [5]:
#Sending data to workers
bobData = data[:2].send(bob)
bobTarget = target[2:].send(bob)

aliceData = data[2:].send(alice)
aliceTarget = target[2:].send(alice)

In [6]:
#Initialise a toy model
model = nn.Linear(2,1)

In [7]:
bobModel = model.copy().send(bob)
aliceModel = model.copy().send(alice)

In [8]:
#Diff workers diff models diff optimizers
bobOpt = optim.SGD(params=bobModel.parameters() , lr=0.1)
aliceOpt = optim.SGD(params=aliceModel.parameters() , lr=0.1)

In [9]:
bobOpt.zero_grad()
bobPred = bobModel(bobData)
bobLoss = ((bobPred - bobTarget) ** 2).sum()
bobLoss.backward()

bobOpt.step()
bobLoss = bobLoss.get().data
bobLoss
#run this cell again & again u'll observe loss's going down

tensor(3.4647)

In [10]:
aliceOpt.zero_grad()
alicePred = aliceModel(aliceData)
aliceLoss = ((alicePred - aliceTarget) ** 2).sum()
aliceLoss.backward()

aliceOpt.step()
aliceLoss = aliceLoss.get().data
aliceLoss
#run this cell again & again u'll observe loss's going down

tensor(5.1756)

# Let's Package

In [11]:
for round_iter in range(10):
    
    bobModel = model.copy().send(bob)
    aliceModel = model.copy().send(alice)
    
    bobOpt = optim.SGD(params=bobModel.parameters() , lr=0.1)
    aliceOpt = optim.SGD(params=aliceModel.parameters() , lr=0.1)
    
    #Epochs
    for i in range(10):
        bobOpt.zero_grad()
        bobPred = bobModel(bobData)
        bobLoss = ((bobPred - bobTarget) ** 2).sum()
        bobLoss.backward()
        
        bobOpt.step()
        bobLoss = bobLoss.get().data
        bobLoss
        
        aliceOpt.zero_grad()
        alicePred = aliceModel(aliceData)
        aliceLoss = ((alicePred - aliceTarget) ** 2).sum()
        aliceLoss.backward()
        
        aliceOpt.step()
        aliceLoss = aliceLoss.get().data
        aliceLoss
        
    #Here comes the role of our secure_worker
    #Trusted Aggregator
    
    aliceModel.move(secure_worker)
    bobModel.move(secure_worker)
    
    with t.no_grad():
        #Weight Aggregation while they are with Trusted aggregator
        model.weight.set_(((aliceModel.weight.data + bobModel.weight.data) / 2).get())

        #Bias Aggregation while they are with Trusted aggregator
        model.bias.set_(((aliceModel.bias.data + bobModel.bias.data) / 2).get())
    
    print('Bob Loss: {} \t Alice Loss: {}'.format(bobLoss,aliceLoss))

Bob Loss: 0.002106368774548173 	 Alice Loss: 0.0006552091799676418
Bob Loss: 0.001269206521101296 	 Alice Loss: 0.0002336803008802235
Bob Loss: 0.0004387546214275062 	 Alice Loss: 7.13843764970079e-05
Bob Loss: 0.00016431546828243881 	 Alice Loss: 7.575537892989814e-06
Bob Loss: 7.452676072716713e-05 	 Alice Loss: 5.521442858480441e-09
Bob Loss: 4.05733662773855e-05 	 Alice Loss: 9.931389968187432e-07
Bob Loss: 2.525991112634074e-05 	 Alice Loss: 1.7968086467590183e-06
Bob Loss: 1.7116466551669873e-05 	 Alice Loss: 1.8905408296632231e-06
Bob Loss: 1.2177690223325044e-05 	 Alice Loss: 1.6573189896007534e-06
Bob Loss: 8.897986845113337e-06 	 Alice Loss: 1.3427068097371375e-06
