# Part 4: Federated Learning with Remote Gradient Averaging

In Part 2 of this tutorial, we trained a model using a very simple version of Federated Learning. This required each 

In [1]:
import syft as sy
import copy
hook = sy.TorchHook()
from torch import nn, optim

In [2]:
# create a couple workers

bob = sy.VirtualWorker(id="bob")
alice = sy.VirtualWorker(id="alice")

bob.add_worker(alice)
alice.add_worker(bob)

# A Toy Dataset
data = sy.Var(sy.FloatTensor([[0,0],[0,1],[1,0],[1,1]]))
target = sy.Var(sy.FloatTensor([[0],[0],[1],[1]]))

# get pointers to training data on each worker by
# sending some training data to bob and alice
data_bob = data[0:2].send(bob)
target_bob = target[0:2].send(bob)

data_alice = data[2:].send(alice)
target_alice = target[2:].send(alice)

# organize pointers into a list
datasets = [(data_bob,target_bob),(data_alice,target_alice)]

# Iniitalize A Toy Model
model = nn.Linear(2,1)
opt = optim.SGD(params=model.parameters(),lr=0.1)

In [3]:
for iter in range(100):

    model2 = copy.deepcopy(model)
    opt2 = copy.deepcopy(opt)

    data, target = datasets[0]
    model.send(data.location)
    opt.zero_grad()
    pred = model(data)
    loss = ((pred - target)**2).sum()
    loss.backward()

    local_loss = loss.get()

    data, target = datasets[1]
    model2.send(data.location)
    opt2.zero_grad()
    pred = model2(data)
    loss = ((pred - target)**2).sum()
    loss.backward()

    local_loss += loss.get()

    params = list(model.parameters())
    params2 = list(model2.parameters())

    params2[0].grad.move(bob)
    params2[1].grad.move(bob)

    weight_update = (params2[0].grad.data + params[0].grad.data) / 2
    model.weight.data -= weight_update * 0.01

    bias_update = (params2[1].grad.data + params[1].grad.data) / 2
    model.bias.data -= bias_update * 0.01

    model.get()

    print(local_loss.data[0])

4.003336429595947
3.572225570678711
3.1930911540985107
2.859544515609741
2.565984010696411
2.3074986934661865
2.0797829627990723
1.8790628910064697
1.7020283937454224
1.545778512954712
1.4077693223953247
1.2857705354690552
1.1778264045715332
1.08222234249115
0.9974549412727356
0.9222058057785034
0.8553187847137451
0.7957802414894104
0.7427010536193848
0.6953015327453613
0.652897834777832
0.6148902177810669
0.5807523727416992
0.550023078918457
0.5222972631454468
0.4972197711467743
0.47447872161865234
0.453800767660141
0.43494582176208496
0.41770297288894653
0.401887446641922
0.38733649253845215
0.3739076852798462
0.36147552728652954
0.34992992877960205
0.3391740322113037
0.3291229009628296
0.31970152258872986
0.3108441233634949
0.30249279737472534
0.29459643363952637
0.28711003065109253
0.2799941301345825
0.2732137143611908
0.26673775911331177
0.2605392336845398
0.2545939087867737
0.248880535364151
0.24338021874427795
0.23807621002197266
0.23295360803604126
0.22799929976463318
0.2232014

In [53]:
params2[0].get_shape()

torch.Size([1, 2])

In [27]:
model2.get()

In [28]:
model2

Linear(in_features=2, out_features=1, bias=True)

Variable containing:FloatTensor[_PointerTensor - id:86598874295 owner:me loc:alice id@loc:49122764949]

Variable containing:FloatTensor[_PointerTensor - id:30386184361 owner:me loc:alice id@loc:7485891957]

In [3]:

def train():
    # Training Logic
    
    for iter in range(20):
        
        # NEW) iterate through each worker's dataset
        for data,target in datasets:
            
            # NEW) send model to correct worker
            model.send(data.location)

            # 1) erase previous gradients (if they exist)
            opt.zero_grad()

            # 2) make a prediction
            pred = model(data)

            # 3) calculate how much the missed
            loss = ((pred - target)**2).sum()

            # 4) figure out which weights caused us to miss
            loss.backward()

            # NEW) get model (with gradients)
            model.get()

            # 5) change those weights
            opt.step()

            # 6) print our progress
            print(loss.get().data[0]) # NEW) slight edit... need to call .get() on loss



In [4]:
train()

1.3705298900604248
0.3170836865901947
0.6483337879180908
0.2467782199382782
0.376615971326828
0.14604270458221436
0.22282710671424866
0.08502599596977234
0.13248728215694427
0.04956699162721634
0.07913585752248764
0.028994116932153702
0.0475175604224205
0.017038432881236076
0.028705758973956108
0.010072916746139526
0.017462393268942833
0.006000795401632786
0.010707026347517967
0.0036092933733016253
0.006623470224440098
0.0021964276675134897
0.004137811250984669
0.0013553204480558634
0.0026128306053578854
0.0008497933740727603
0.0016689568292349577
0.000542371766641736
0.001079025911167264
0.00035281339660286903
0.0007063847733661532
0.0002340254868613556
0.00046830164501443505
0.0001582468394190073
0.000314360047923401
0.00010896984167629853
0.0002135887771146372
7.628797175129876e-05
0.00014679577725473791
5.4188381909625605e-05


In [None]:
batch = datasets
agg_worker = bob

workers = set()
for w in batch:
    workers.add(w[0].location)

remote_models = {}
for worker in workers:
    remote_models[worker.id] = copy.deepcopy(model).send(worker)

for input,target in datasets:
    remote_models[input.location.id](data)