# Tutorial: Federated learning with TrainConfig

This notebook will go through the steps to run a federated learning via websocket workers. We will use federated averaging to join the remotely trained models. 

Authors:
- [mari-linhares](https://github.com/mari-linhares)
- [midokura-silvia](https://github.com/midokura-silvia)

## Preparation: Start the websocket server workers

Each worker is represented by two parts, a local handle (websocket client worker) and the remote instance that holds the data and performs the computations. The remote part is called a websocket server worker.

So first, we need to create the remote workers. For this, you need to run in a terminal (not possible from the notebook):

```bash
python "examples/experimental/Federated\ Learning\ with\ TrainConfig/run_websocket_server.py" --port 8777 --id alice
```

## Setting up the websocket client workers

We first need to perform the imports and setup some arguments and variables.

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch as th
import torch.nn.functional as F
from torch import nn

import syft as sy
from syft import workers

In [4]:
use_cuda = th.cuda.is_available()
th.manual_seed(1)
device = th.device("cuda" if use_cuda else "cpu")

Now let's instantiate the websocket client workers, our local access point (proxy) to the remote workers.
Note that **this step will fail, if the websocket server workers are not running**.

In [5]:
hook = sy.TorchHook(th)
me = hook.local_worker

kwargs_websocket = {"host": "localhost", "hook": hook, "verbose": False}
alice = workers.WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket)

## Define your model and loss function

In [6]:
# Loss function
@th.jit.script
def loss_fn(real, pred):
    return ((real.float() - pred.float()) ** 2).mean()


# Model
class Net(th.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(2, 3)
        self.fc2 = nn.Linear(3, 2)
        self.fc3 = nn.Linear(2, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

   ## Configure your training: TrainConfig

In [7]:
model = Net()
data = th.tensor(th.tensor([[-10, -2.0], [1, 1.1], [11, 22.1], [-10, 1.2]]))

traced_model = th.jit.trace(model, data)

# Create and send train config
train_config = sy.TrainConfig(batch_size=4)
train_config.send(alice, traced_loss_fn=loss_fn, traced_model=traced_model)

  current_tensor = hook_self.torch.native_tensor(*args, **kwargs)


<weakref at 0x7fc995f17048; to 'TrainConfig' at 0x7fc9961793c8>

In [8]:
data = th.tensor([[-1, 2.0], [0, 1.1], [-1, 2.1], [0, 1.2]], requires_grad=True)
target = th.tensor([[1], [0], [1], [0]], requires_grad=False)

print("\nEvaluation before training")
pred = model(data)
loss = loss_fn(real=target, pred=pred)
print("Loss: {}".format(loss))


Evaluation before training
Loss: 0.328726589679718


Now we can run the training on the data that is available at the worker. The data was provided to it in the python script that started the remote instance (run_websocket_server.py).

In [9]:
for epoch in range(5):
    loss = alice.fit(dataset_key="vectors")
    print("-" * 50)
    print("Iteration %s: alice's loss: %s" % (epoch, loss))

TypeError: fit() missing 1 required positional argument: 'return_id'

In [None]:
new_model = train_config.model_ptr.get()

print("\nEvaluation after training:")
pred = new_model(data)
loss = loss_fn(real=target, pred=pred)
print("Loss: {}".format(loss))