# Basic Federated Learning Example

In this notebook, we show how PySyft can be used to train a model using Federated Learning. We train a simple linear model using stochastic gradient descent across a toy dataset with two differenet owners, Alice and Bob.

In [1]:
from syft.core.hooks import TorchHook
from syft.core.workers import VirtualWorker
import torch
import torch.nn as nn
from torch.autograd import Variable as Var
import torch.optim as optim
# this is our hook
hook = TorchHook()
me = hook.local_worker

bob = VirtualWorker(id=1,hook=hook)
alice = VirtualWorker(id=2,hook=hook)

me.add_worker(bob)
me.add_worker(alice)

Hooking into Torch...
Overloading complete.


In [2]:
# create our dataset
data = Var(torch.FloatTensor([[0,0],[0,1],[1,0],[1,1]]))
target = Var(torch.FloatTensor([[0],[0],[1],[1]]))

data_bob = data[0:2].send(bob)
target_bob = target[0:2].send(bob)

data_alice = data[2:].send(alice)
target_alice = target[2:].send(alice)

In [4]:
# create our model
model = nn.Linear(2,1)

In [6]:
opt = optim.SGD(params=model.parameters(),lr=0.1)

In [11]:
datasets = [(data_bob,target_bob),(data_alice,target_alice)]

In [None]:
for iter in range(100):

    for data,target in datasets:
        model.send(data.owners[0])

        # update the model
        model.zero_grad()
        pred = model(data)
        loss = ((pred - target)**2).sum()
        loss.backward()
        opt.step()

        model.get_()
        print(loss.get().data[0])

0.2021646648645401
0.13839808106422424
0.11899920552968979
0.08499595522880554
0.07074058055877686
0.05278114974498749
0.042621225118637085
0.033168911933898926
0.02611139416694641
0.021115796640515327
0.016316626220941544
0.013629190623760223
0.010426655411720276
0.008923706598579884
0.006824550684541464
0.0059279147535562515
0.004576891660690308
0.003994321916252375
0.003141997382044792
0.0027284545358270407
0.002203211886808276
0.0018877169350162148
0.0015735100023448467
0.0013213929487392306
0.001140911364927888
0.0009346912847831845
0.0008372107404284179
0.0006672536255791783
0.0006199877243489027
0.00048013869673013687
0.0004622222622856498
0.0003478446160443127
0.0003462526365183294
0.0002534514351282269
0.000260226777754724
0.00018556122086010873
0.00019598835206124932
0.00013639700773637742
0.0001477954792790115
0.00010059244959848002
0.00011152803926961496
7.438792090397328e-05
8.418006473220885e-05
5.5134201829787344e-05
6.353509525069967e-05
4.093891038792208e-05
4.79417831