# Part 4: Federated Learning with Remote Gradient Averaging

In Part 2 of this tutorial, we train a model using a very simple version of Federated Learning. 

In [1]:
import syft as sy
import copy
hook = sy.TorchHook()
from torch import nn, optim

In [15]:
# create a couple workers

bob = sy.VirtualWorker(id="bob")
alice = sy.VirtualWorker(id="alice")

# A Toy Dataset
data = sy.Var(sy.FloatTensor([[0,0],[0,1],[1,0],[1,1]]))
target = sy.Var(sy.FloatTensor([[0],[0],[1],[1]]))

# get pointers to training data on each worker by
# sending some training data to bob and alice
data_bob = data[0:2].send(bob)
target_bob = target[0:2].send(bob)

data_alice = data[2:].send(alice)
target_alice = target[2:].send(alice)

# organize pointers into a list
datasets = [(data_bob,target_bob),(data_alice,target_alice)]

# Iniitalize A Toy Model
model = nn.Linear(2,1)
opt = optim.SGD(params=model.parameters(),lr=0.1)

model2 = copy.deepcopy(model)
opt2 = copy.deepcopy(opt)



In [16]:
data, target = datasets[0]
model.send(data.location)
opt.zero_grad()
pred = model(data)
loss = ((pred - target)**2).sum()
loss.backward()

In [17]:
data, target = datasets[1]
model2.send(data.location)
opt2.zero_grad()
pred = model2(data)
loss = ((pred - target)**2).sum()
loss.backward()

In [18]:
params = list(model.parameters())

In [19]:
params2 = list(model2.parameters())

In [20]:
params

[Parameter containing:FloatTensor[_PointerTensor - id:1524190791 owner:me loc:bob id@loc:7579059905],
 Parameter containing:FloatTensor[_PointerTensor - id:40874395336 owner:me loc:bob id@loc:3314025808]]

In [21]:
params2

[Parameter containing:FloatTensor[_PointerTensor - id:69313090251 owner:me loc:alice id@loc:92587960154],
 Parameter containing:FloatTensor[_PointerTensor - id:33238012236 owner:me loc:alice id@loc:26578966117]]

In [22]:
params[0].grad

Variable containing:FloatTensor[_PointerTensor - id:2135025194 owner:me loc:bob id@loc:25427736680]

In [23]:
params[1].grad

Variable containing:FloatTensor[_PointerTensor - id:78975381333 owner:me loc:bob id@loc:52557237590]

In [3]:

def train():
    # Training Logic
    
    for iter in range(20):
        
        # NEW) iterate through each worker's dataset
        for data,target in datasets:
            
            # NEW) send model to correct worker
            model.send(data.location)

            # 1) erase previous gradients (if they exist)
            opt.zero_grad()

            # 2) make a prediction
            pred = model(data)

            # 3) calculate how much the missed
            loss = ((pred - target)**2).sum()

            # 4) figure out which weights caused us to miss
            loss.backward()

            # NEW) get model (with gradients)
            model.get()

            # 5) change those weights
            opt.step()

            # 6) print our progress
            print(loss.get().data[0]) # NEW) slight edit... need to call .get() on loss



In [4]:
train()

1.3705298900604248
0.3170836865901947
0.6483337879180908
0.2467782199382782
0.376615971326828
0.14604270458221436
0.22282710671424866
0.08502599596977234
0.13248728215694427
0.04956699162721634
0.07913585752248764
0.028994116932153702
0.0475175604224205
0.017038432881236076
0.028705758973956108
0.010072916746139526
0.017462393268942833
0.006000795401632786
0.010707026347517967
0.0036092933733016253
0.006623470224440098
0.0021964276675134897
0.004137811250984669
0.0013553204480558634
0.0026128306053578854
0.0008497933740727603
0.0016689568292349577
0.000542371766641736
0.001079025911167264
0.00035281339660286903
0.0007063847733661532
0.0002340254868613556
0.00046830164501443505
0.0001582468394190073
0.000314360047923401
0.00010896984167629853
0.0002135887771146372
7.628797175129876e-05
0.00014679577725473791
5.4188381909625605e-05
