# Lesson: Toy Federated Learning

Let's start by training a toy model the centralized way. This is about a simple as models get. We first need:

- a toy dataset
- a model
- some basic training logic for training a model to fit the data.

In [1]:
import torch as th
import syft as sy
from torch import nn, optim



In [2]:
# A Toy Dataset
data = th.tensor([[1.,1],[0,1],[1,0],[0,0]], requires_grad=True)
target = th.tensor([[1.],[1], [0], [0]], requires_grad=True)

In [3]:
# A Toy Model
model = nn.Linear(2,1)

In [4]:
opt = optim.SGD(params=model.parameters(), lr=0.1)

In [5]:
def train(iterations=20):
    for iter in range(iterations):
        opt.zero_grad()

        pred = model(data)

        loss = ((pred - target)**2).sum()
        loss.backward()
        opt.step()
        print(loss.data)

# train the model
train()

tensor(3.8518)
tensor(0.2964)
tensor(0.0271)
tensor(0.0057)
tensor(0.0032)
tensor(0.0023)
tensor(0.0018)
tensor(0.0014)
tensor(0.0010)
tensor(0.0008)
tensor(0.0006)
tensor(0.0005)
tensor(0.0004)
tensor(0.0003)
tensor(0.0002)
tensor(0.0002)
tensor(0.0001)
tensor(9.2395e-05)
tensor(7.0634e-05)
tensor(5.3999e-05)


In [6]:
hook = sy.TorchHook(th)
bob = sy.VirtualWorker(hook, id = 'bob')
alice = sy.VirtualWorker(hook, id = 'alice')

In [7]:
data_bob = data[0:2].send(bob)
target_bob = target[0:2].send(bob)

In [8]:
data_alice = data[2:4].send(alice)
target_alice = target[2:4].send(alice)

In [9]:
datasets = [(data_bob, target_bob), (data_alice, target_alice)]

In [10]:
def train(iterations=20):

    model = nn.Linear(2,1)
    opt = optim.SGD(params=model.parameters(), lr=0.1)
    
    for iter in range(iterations):
        for _data, _target in datasets:
            # send model to the data
            model = model.send(_data.location)

            # do normal training
            opt.zero_grad()
            pred = model(_data)
            loss = ((pred - _target)**2).sum()
            loss.backward()
            opt.step()

            # get smarter model back
            model = model.get()

            print(loss.get())

In [11]:
train()

tensor(5.3171, requires_grad=True)
tensor(1.6561, requires_grad=True)
tensor(0.8883, requires_grad=True)
tensor(1.0821, requires_grad=True)
tensor(0.4562, requires_grad=True)
tensor(0.6387, requires_grad=True)
tensor(0.2616, requires_grad=True)
tensor(0.3751, requires_grad=True)
tensor(0.1514, requires_grad=True)
tensor(0.2207, requires_grad=True)
tensor(0.0877, requires_grad=True)
tensor(0.1302, requires_grad=True)
tensor(0.0509, requires_grad=True)
tensor(0.0770, requires_grad=True)
tensor(0.0296, requires_grad=True)
tensor(0.0458, requires_grad=True)
tensor(0.0172, requires_grad=True)
tensor(0.0273, requires_grad=True)
tensor(0.0101, requires_grad=True)
tensor(0.0164, requires_grad=True)
tensor(0.0059, requires_grad=True)
tensor(0.0099, requires_grad=True)
tensor(0.0035, requires_grad=True)
tensor(0.0060, requires_grad=True)
tensor(0.0021, requires_grad=True)
tensor(0.0037, requires_grad=True)
tensor(0.0012, requires_grad=True)
tensor(0.0023, requires_grad=True)
tensor(0.0008, requi