# Basic Federated Learning Example

In this notebook, we show how PySyft can be used to train a model using Federated Learning. We train a simple linear model using stochastic gradient descent across a toy dataset with two different owners, Alice and Bob.

This notebook is also available on Colab: https://colab.research.google.com/drive/1F3ALlA3ogfeeVXuwQwVoX4PimzTDJhPy

This notebook was created in the following video (at the end): https://www.twitch.tv/videos/275910219


# Step 1: Hook PyTorch and Create Workers

In [1]:
from syft.core.hooks import TorchHook
from syft.core.workers import VirtualWorker
import torch
import torch.nn as nn
from torch.autograd import Variable as Var
import torch.optim as optim
# this is our hook
hook = TorchHook()
me = hook.local_worker

bob = VirtualWorker(id=1,hook=hook)
alice = VirtualWorker(id=2,hook=hook)

me.add_worker(bob)
me.add_worker(alice)

Hooking into Torch...
Overloading complete.


# Step 2: Initialize Dataset and Move to Alice and Bob

In [2]:
# create our dataset
data = Var(torch.FloatTensor([[0,0],[0,1],[1,0],[1,1]]))
target = Var(torch.FloatTensor([[0],[0],[1],[1]]))

data_bob = data[0:2].send(bob)
target_bob = target[0:2].send(bob)

data_alice = data[2:].send(alice)
target_alice = target[2:].send(alice)

# Step 3: Create our Model

In [3]:
# create our model
model = nn.Linear(2,1)

In [4]:
opt = optim.SGD(params=model.parameters(),lr=0.1)

# Step 4: Train over Distributed Dataset

(this is the Federated Learning part)

In [5]:
datasets = [(data_bob,target_bob),(data_alice,target_alice)]

In [6]:
for iter in range(100):

    for data,target in datasets:
        model.send(data.owners[0])

        # update the model
        model.zero_grad()
        pred = model(data)
        loss = ((pred - target)**2).sum()
        loss.backward()
        opt.step()

        model.get_()
        print(loss.get().data[0])

0.20391982793807983
5.428705215454102
1.0745136737823486
0.5397716164588928
0.7883853912353516
0.26717209815979004
0.5002551078796387
0.1635095477104187
0.3177748918533325
0.10324167460203171
0.20411884784698486
0.06646480411291122
0.1327308565378189
0.043627142906188965
0.08740168064832687
0.029200267046689987
0.058281492441892624
0.019915681332349777
0.03934536501765251
0.013822791166603565
0.026877785101532936
0.009744998067617416
0.018566446378827095
0.006963507272303104
0.012957894243299961
0.0050324732437729836
0.009128624573349953
0.0036705918610095978
0.006485156714916229
0.002696977462619543
0.004641538951545954
0.001993006793782115
0.003343722550198436
0.0014792672591283917
0.0024224475491791964
0.0011016016360372305
0.0017635883996263146
0.0008223712211474776
0.001289311214350164
0.0006150172557681799
0.0009459602297283709
0.00046053665573708713
0.0006961692706681788
0.000345170235959813
0.0005136716645210981
0.0002588584029581398
0.00037985501694492996
0.0001942042144946754