# Tutorial: Federated learning with TrainConfig

This notebook will go through the steps to run a federated learning via websocket workers. We will use federated averaging to join the remotely trained models. 

Authors:
- [mari-linhares](https://github.com/mari-linhares)
- [midokura-silvia](https://github.com/midokura-silvia)

## Preparation: start the websocket server workers

Each worker is represented by two parts, a local handle (websocket client worker) and the remote instance that holds the data and performs the computations. The remote part is called a websocket server worker.

So first, we need to create the remote workers. For this, you need to run in a terminal (not possible from the notebook):

```bash
python examples/experimental/Federated\ Learning\ with\ TrainConfig/run_websocket_server.py --port 8777 --id alice
```

## Setting up the websocket client workers (all these cells are from the FL websocket tutorial)

We first need to perform the imports and setup some arguments and variables.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch as th
import torch.nn.functional as F
from torch import nn

import syft as sy
from syft import workers
from syft.frameworks.torch import pointers

In [3]:
use_cuda = th.cuda.is_available()
th.manual_seed(1)
device = th.device("cuda" if use_cuda else "cpu")

Now let's instantiate the websocket client workers, our local access point to the remote workers.
Note that **this step will fail, if the websocket server workers are not running**.

In [4]:
hook = sy.TorchHook(th)
me = hook.local_worker

kwargs_websocket = {"host": "localhost", "hook": hook, "verbose": False}
alice = workers.WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket)

## A Star is (almost) born: TrainConfig

In [5]:
# Loss function
@th.jit.script
def loss_fn(real, pred):
    return ((real - pred) ** 2).mean()

# Model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(2, 3)
        self.fc2 = nn.Linear(3, 2)
        self.fc3 = nn.Linear(2, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = Net()
data = th.tensor(th.tensor([[-10, -2.0], [1, 1.1], [11, 22.1], [-10, 1.2]]))

traced_model = th.jit.trace(model, data)

model_with_id = pointers.ObjectWrapper(id=sy.ID_PROVIDER.pop(), obj=traced_model)
loss_fn_with_id = pointers.ObjectWrapper(id=sy.ID_PROVIDER.pop(), obj=loss_fn)

model_ptr = me.send(model_with_id, alice)
loss_fn_ptr = me.send(loss_fn_with_id, alice)

# Create and send train config
train_config = sy.TrainConfig(
    model_id=model_ptr.id_at_location, loss_plan_id=loss_fn_ptr.id_at_location, batch_size=2
)
train_config.send(alice)

  current_tensor = hook_self.torch.native_tensor(*args, **kwargs)


<weakref at 0x7fe8dc212e08; to 'TrainConfig' at 0x7fe8dc2186a0>

In [6]:
data = th.tensor([[-1, 2.0], [0, 1.1], [-1, 2.1], [0, 1.2]], requires_grad=True)
target = th.tensor([[1.0], [0.0], [1.0], [0.0]], requires_grad=True)

print("\nEvaluation before training")
pred = model(data)
loss = loss_fn(real=target, pred=pred)
print("Loss: {}".format(loss))


Evaluation before training
Loss: 0.328726589679718


Run training on data available at the worker and evaluate again. The last loss notified by the worker will be equal to the loss evaluated with the new model, as the content of the data available on the proxy is equal to the remote data.

In [7]:
for epoch in range(5):
    loss = alice.fit(dataset_key="vectors", return_id=88 + epoch)
    print("-" * 50)
    print("Iteration %s: alice's loss: %s" % (epoch, loss))

new_model = model_ptr.get()

print("\nEvaluation after training:")
pred = new_model(data)
loss = loss_fn(real=target, pred=pred)
print("Loss: {}".format(loss))

--------------------------------------------------
Iteration 0: alice's loss: tensor(0.2991, requires_grad=True)
--------------------------------------------------
Iteration 1: alice's loss: tensor(0.2680, requires_grad=True)
--------------------------------------------------
Iteration 2: alice's loss: tensor(0.2549, requires_grad=True)
--------------------------------------------------
Iteration 3: alice's loss: tensor(0.2489, requires_grad=True)
--------------------------------------------------
Iteration 4: alice's loss: tensor(0.2458, requires_grad=True)

Evaluation after training:
Loss: 0.24460731446743011
