# Tutorial: Federated learning with TrainConfig

This notebook will go through the steps to run a federated learning via websocket workers. We will use federated averaging to join the remotely trained models. 

Authors:
- [mari-linhares](https://github.com/mari-linhares)

## Preparation: start the websocket server workers

Each worker is represented by two parts, a local handle (websocket client worker) and the remote instance that holds the data and performs the computations. The remote part is called a websocket server worker.

So first, we need to create the remote workers. For this, you need to run in a terminal (not possible from the notebook):

```bash
python examples/experimental/Federated\ Learning\ with\ TrainConfig/run_websocket_server.py --port 8777 --id alice
```

## Setting up the websocket client workers (all these cells are from the FL websocket tutorial)

We first need to perform the imports and setup some arguments and variables.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

import torch as th
import torch.nn.functional as F
from torch import nn
from torchvision import datasets, transforms

import syft as sy
from syft.workers import WebsocketClientWorker
from syft.frameworks.torch.federated import utils

In [3]:
use_cuda = th.cuda.is_available()
th.manual_seed(1)
device = th.device("cuda" if use_cuda else "cpu")

Now let's instantiate the websocket client workers, our local access point to the remote workers.
Note that **this step will fail, if the websocket server workers are not running**.

In [4]:
hook = sy.TorchHook(th)

kwargs_websocket = {"host": "localhost", "hook": hook, "verbose": False}
alice = WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket)

## A Star is (almost) born: TrainConfig

In [5]:
# TODO: maybe by default is_client_worker could be False?
# for plans the client_worker needs to be able to register tensors
# since it needs to build the plan.
hook.local_worker.is_client_worker = False
me = hook.local_worker

# Loss function
@sy.func2plan
def loss_fn(real, pred):
    return ((real - pred) ** 2).mean()

# Model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(2, 3)
        self.fc2 = nn.Linear(3, 2)
        self.fc3 = nn.Linear(2, 1)

    @sy.method2plan
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Force build
# TODO: this should be done automatically
loss_fn(th.tensor([1.0]), th.tensor([1.0]))

model = Net()

# TODO: this line should not be needed at all.
model.send(me)

# Force build
# TODO: this should be done automatically
model(th.tensor([1.0, 2]))


# TODO: this line should not be needed at all.
model.get()

# Create and send train config
train_config = sy.TrainConfig(model=model, loss_plan=loss_fn, batch_size=1)
train_config.send(alice)

# TODO: Returns a tensor when it should actually return a Pointer to the result Tensor.
for epoch in range(5):
    loss = alice.fit(dataset_key="vectors")
    print("-" * 50)
    print("Iteration %s: alice's loss: %s" % (epoch, loss))

--------------------------------------------------
Iteration 0: alice's loss: tensor(0.0483, requires_grad=True)
--------------------------------------------------
Iteration 1: alice's loss: tensor(0.0483, requires_grad=True)
--------------------------------------------------
Iteration 2: alice's loss: tensor(0.0483, requires_grad=True)
--------------------------------------------------
Iteration 3: alice's loss: tensor(0.0483, requires_grad=True)
--------------------------------------------------
Iteration 4: alice's loss: tensor(0.0483, requires_grad=True)
