# Tutorial: Asynchronous federated learning on MNIST

This notebook will go through the steps to run a federated learning via websocket workers in an asynchronous way. We will use federated averaging to join the remotely trained models. 

Authors:
- Silvia - GitHub [@midokura-silvia](https://github.com/midokura-silvia)

In [16]:
%load_ext autoreload
%autoreload 2

import inspect

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Federated Learning setup

For a Federated Learning we need different participants:

* _Workers_ that own datasets.

* An entity that knows the workers and the dataset name that lives in each worker. We'll call this a _scheduler_.

Each worker is represented by two parts, a proxy local to the scheduler (websocket client worker) and the remote instance that holds the data and performs the computations. The remote part is called a websocket server worker.

#### Preparation: Start the websocket workers
So first, we need to create the remote workers. For this, you need to run in a terminal (not possible from the notebook):

```bash
python start_websocket_servers.py
```

#### What's going on?

The script will instantiate three workers, Alice, Bob and Charlie and prepare their local data. 
Each worker is set up to have a subset of the MNIST training dataset. 
Alice holds all images corresponding to the digits 0-3, 
Bob holds all images corresponding to the digits 4-6 and 
Charlie holds all images corresponding to the digits 7-9.

| Worker      | Digits in local dataset | Number of samples |
| ----------- | ----------------------- | ----------------- |
| Alice       | 0-3                     | 24754             |
| Bob         | 4-6                     | 17181             |
| Charlie     | 7-9                     | 18065             |


In [17]:
# uncomment the following to see the code of the function that starts a worker
# import run_websocket_server

# print(inspect.getsource(run_websocket_server.start_websocket_server_worker))

We first need to perform the imports and setup some arguments and variables.

In [18]:
import sys
import asyncio
import syft as sy
from syft.workers import WebsocketClientWorker
import torch
from torchvision import datasets, transforms

from syft.frameworks.torch.federated import utils

In [19]:
import run_websocket_client as rwc
hook = sy.TorchHook(torch)



In [20]:
args = rwc.define_and_get_arguments(args=[])
use_cuda = args.cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
print(args)

Namespace(batch_size=32, cuda=False, epochs=100, federate_after_n_batches=10, lr=0.1, save_model=False, seed=1, test_batch_size=128, verbose=False)


Now let's instantiate the websocket client workers, our local proxies to the remote workers.
Note that **this step will fail, if the websocket server workers are not running**.

In [21]:
kwargs_websocket = {"host": "localhost", "hook": hook, "verbose": args.verbose}
alice = WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket)
bob = WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket)
charlie = WebsocketClientWorker(id="charlie", port=8779, **kwargs_websocket)

worker_instances = [alice, bob, charlie]
print(worker_instances)


[<WebsocketClientWorker id:alice #objects local:0 #objects remote: 0>, <WebsocketClientWorker id:bob #objects local:0 #objects remote: 0>, <WebsocketClientWorker id:charlie #objects local:0 #objects remote: 0>]


### Setting up a scheduler

We'll use this notebook as a scheduler, for this we'll need to:

* Have a model
* Have a loss function
* Define an optimizer
* Define hyper-parameters

#### Model
Let's instantiate the machine learning model. It is a small neural network with 2 convolutional and two fully connected layers. 
It uses ReLU activations and max pooling.

In [22]:
print(inspect.getsource(rwc.Net))

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)



In [23]:
model = rwc.Net().to(device)
print(model)

Net(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)


Before starting the training, let's load the MNIST test data and configure the logging.

In [24]:
test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "../data",
            train=True,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
            ),
        ),
        batch_size=args.test_batch_size,
        shuffle=False,
        drop_last=False,
    )
(data, target) = test_loader.__iter__().next()

In [25]:
import logging
#FORMAT = "%(asctime)s %(levelname)s %(filename)s(l:%(lineno)d) - %(message)s"
FORMAT = "%(asctime)s - %(message)s"
LOG_LEVEL = logging.DEBUG
logging.basicConfig(format=FORMAT, level=LOG_LEVEL)
logger = logging.getLogger("run_websocket_client")
logger.setLevel(LOG_LEVEL)

## Let's start the training

Now we are ready to start the federated training. We will perform training over a given number of batches separately on each worker and then calculate the federated average of the resulting model.

Every 10th training round we will evaluate the performance of the models returned by the workers and of the model obtained by federated averaging. 

The performance will be given both as the accuracy (ratio of correct predictions) and as the histograms of predicted digits. This is of interest, as each worker only owns a subset of the digits. Therefore, in the beginning each worker will only predict 'his' numbers and only know about the other numbers via the federated averaging process.

The training is done in an asynchronous manner. This means that the scheduler just tell the workers to train and does not block to wait for the result of the training before talking to the next worker.

The parameters of the training are given in the arguments. 
Each worker will train on a given number of batches, given by the value of federate_after_n_batches.
The training batch size and learning rate are also configured. 

In [29]:
print("Federate_after_n_batches: " + str(args.federate_after_n_batches))
print("Batch size: " + str(args.batch_size))
print("Initial learning rate: " + str(args.lr))

Federate_after_n_batches: 10
Batch size: 32
Initial learning rate: 0.1


In [26]:
traced_model = torch.jit.trace(model, data)
learning_rate = args.lr
for epoch in range(1, args.epochs + 1):
    logger.info("Starting epoch %s/%s", epoch, args.epochs)

    results = await asyncio.gather(
        *[
            rwc.fit_model_on_worker(
                worker=worker,
                traced_model=traced_model,
                batch_size=args.batch_size,
                curr_epoch=epoch,
                max_nr_batches=args.federate_after_n_batches,
                lr=learning_rate,
            )
            for worker in worker_instances
        ]
    )
    models = {}
    loss_values = {}

    test_models = epoch % 10 == 1 or epoch == args.epochs
    if test_models:
        rwc.evaluate_models_on_test_data(test_loader, results)

    for worker_id, worker_model, worker_loss in results:
        if worker_model is not None:
            models[worker_id] = worker_model
            loss_values[worker_id] = worker_loss

    traced_model = utils.federated_avg(models)
    if test_models:
        rwc.evaluate_model("Federated model", traced_model, "cpu", test_loader)

    # decay learning rate
    learning_rate = max(0.98 * learning_rate, args.lr * 0.01)

if args.save_model:
    torch.save(model.state_dict(), "mnist_cnn.pt")

INFO:run_websocket_client:Starting epoch 1/100
INFO:run_websocket_client:Training round 1, calling fit on worker: bob, lr = 0.100
INFO:run_websocket_client:Training round 1, calling fit on worker: charlie, lr = 0.100
INFO:run_websocket_client:Training round 1, calling fit on worker: alice, lr = 0.100
INFO:run_websocket_client:Training round: 1, worker: alice, avg_loss: tensor(0.6313, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Training round: 1, worker: bob, avg_loss: tensor(0.9031, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Training round: 1, worker: charlie, avg_loss: tensor(2.0880, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Prediction hist: [ 8860  25556  0  25584  0  0  0  0  0  0]
INFO:run_websocket_client:alice: Test set: Average loss: 0.0251, Accuracy: 16598/60000 (28)
INFO:run_websocket_client:Prediction hist: [ 0  0  0  0  31  59969  0  0  0  0]
INFO:run_websocket_client:bob: Test set: Average loss: 0.0369, Accuracy: 5421/60000 (9)
INFO:run_websocket_cli

INFO:run_websocket_client:Prediction hist: [ 6382  7127  5535  5841  3675  4677  5903  6846  5875  8139]
INFO:run_websocket_client:Federated model: Test set: Average loss: 0.0033, Accuracy: 52377/60000 (87)
INFO:run_websocket_client:Starting epoch 12/100
INFO:run_websocket_client:Training round 12, calling fit on worker: charlie, lr = 0.080
INFO:run_websocket_client:Training round 12, calling fit on worker: alice, lr = 0.080
INFO:run_websocket_client:Training round 12, calling fit on worker: bob, lr = 0.080
INFO:run_websocket_client:Training round: 12, worker: charlie, avg_loss: tensor(0.2108, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Training round: 12, worker: bob, avg_loss: tensor(0.2715, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Training round: 12, worker: alice, avg_loss: tensor(0.0978, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Starting epoch 13/100
INFO:run_websocket_client:Training round 13, calling fit on worker: charlie, lr = 0.078
INFO:run_websocket

INFO:run_websocket_client:Training round 23, calling fit on worker: alice, lr = 0.064
INFO:run_websocket_client:Training round: 23, worker: charlie, avg_loss: tensor(0.0431, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Training round: 23, worker: alice, avg_loss: tensor(0.1731, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Training round: 23, worker: bob, avg_loss: tensor(0.0113, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Starting epoch 24/100
INFO:run_websocket_client:Training round 24, calling fit on worker: bob, lr = 0.063
INFO:run_websocket_client:Training round 24, calling fit on worker: charlie, lr = 0.063
INFO:run_websocket_client:Training round 24, calling fit on worker: alice, lr = 0.063
INFO:run_websocket_client:Training round: 24, worker: alice, avg_loss: tensor(0.0104, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Training round: 24, worker: charlie, avg_loss: tensor(0.0765, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Training round: 24, worke

INFO:run_websocket_client:Starting epoch 35/100
INFO:run_websocket_client:Training round 35, calling fit on worker: bob, lr = 0.050
INFO:run_websocket_client:Training round 35, calling fit on worker: charlie, lr = 0.050
INFO:run_websocket_client:Training round 35, calling fit on worker: alice, lr = 0.050
INFO:run_websocket_client:Training round: 35, worker: bob, avg_loss: tensor(0.0292, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Training round: 35, worker: alice, avg_loss: tensor(0.0105, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Training round: 35, worker: charlie, avg_loss: tensor(0.0637, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Starting epoch 36/100
INFO:run_websocket_client:Training round 36, calling fit on worker: alice, lr = 0.049
INFO:run_websocket_client:Training round 36, calling fit on worker: charlie, lr = 0.049
INFO:run_websocket_client:Training round 36, calling fit on worker: bob, lr = 0.049
INFO:run_websocket_client:Training round: 36, worker: c

INFO:run_websocket_client:Training round: 46, worker: charlie, avg_loss: tensor(0.0745, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Training round: 46, worker: bob, avg_loss: tensor(0.0057, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Starting epoch 47/100
INFO:run_websocket_client:Training round 47, calling fit on worker: alice, lr = 0.039
INFO:run_websocket_client:Training round 47, calling fit on worker: charlie, lr = 0.039
INFO:run_websocket_client:Training round 47, calling fit on worker: bob, lr = 0.039
INFO:run_websocket_client:Training round: 47, worker: bob, avg_loss: tensor(0.0242, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Training round: 47, worker: alice, avg_loss: tensor(0.0284, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Training round: 47, worker: charlie, avg_loss: tensor(0.0183, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Starting epoch 48/100
INFO:run_websocket_client:Training round 48, calling fit on worker: charlie, lr = 0.039
INF

INFO:run_websocket_client:Training round 58, calling fit on worker: alice, lr = 0.032
INFO:run_websocket_client:Training round: 58, worker: charlie, avg_loss: tensor(0.0554, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Training round: 58, worker: alice, avg_loss: tensor(0.0014, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Training round: 58, worker: bob, avg_loss: tensor(0.0020, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Starting epoch 59/100
INFO:run_websocket_client:Training round 59, calling fit on worker: charlie, lr = 0.031
INFO:run_websocket_client:Training round 59, calling fit on worker: alice, lr = 0.031
INFO:run_websocket_client:Training round 59, calling fit on worker: bob, lr = 0.031
INFO:run_websocket_client:Training round: 59, worker: charlie, avg_loss: tensor(0.1098, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Training round: 59, worker: alice, avg_loss: tensor(0.0215, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Training round: 59, worke

INFO:run_websocket_client:Starting epoch 70/100
INFO:run_websocket_client:Training round 70, calling fit on worker: bob, lr = 0.025
INFO:run_websocket_client:Training round 70, calling fit on worker: alice, lr = 0.025
INFO:run_websocket_client:Training round 70, calling fit on worker: charlie, lr = 0.025
INFO:run_websocket_client:Training round: 70, worker: alice, avg_loss: tensor(0.0126, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Training round: 70, worker: bob, avg_loss: tensor(0.0031, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Training round: 70, worker: charlie, avg_loss: tensor(0.0857, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Starting epoch 71/100
INFO:run_websocket_client:Training round 71, calling fit on worker: alice, lr = 0.024
INFO:run_websocket_client:Training round 71, calling fit on worker: bob, lr = 0.024
INFO:run_websocket_client:Training round 71, calling fit on worker: charlie, lr = 0.024
INFO:run_websocket_client:Training round: 71, worker: b

INFO:run_websocket_client:Training round: 81, worker: alice, avg_loss: tensor(0.0207, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Training round: 81, worker: charlie, avg_loss: tensor(0.0513, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Prediction hist: [ 6500  7929  7071  7468  6327  4534  5677  5781  3662  5051]
INFO:run_websocket_client:alice: Test set: Average loss: 0.0021, Accuracy: 54475/60000 (91)
INFO:run_websocket_client:Prediction hist: [ 5697  6799  5901  5220  7911  6988  6587  6287  4839  3771]
INFO:run_websocket_client:bob: Test set: Average loss: 0.0019, Accuracy: 54706/60000 (91)
INFO:run_websocket_client:Prediction hist: [ 5035  6256  4536  4217  3111  3601  5646  6761  9297  11540]
INFO:run_websocket_client:charlie: Test set: Average loss: 0.0036, Accuracy: 49989/60000 (83)
INFO:run_websocket_client:Prediction hist: [ 5909  6916  5934  5934  5919  5365  6005  6373  5688  5957]
INFO:run_websocket_client:Federated model: Test set: Average loss: 0.0008, Accu

INFO:run_websocket_client:Starting epoch 92/100
INFO:run_websocket_client:Training round 92, calling fit on worker: charlie, lr = 0.016
INFO:run_websocket_client:Training round 92, calling fit on worker: alice, lr = 0.016
INFO:run_websocket_client:Training round 92, calling fit on worker: bob, lr = 0.016
INFO:run_websocket_client:Training round: 92, worker: charlie, avg_loss: tensor(0.0276, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Training round: 92, worker: alice, avg_loss: tensor(0.0348, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Training round: 92, worker: bob, avg_loss: tensor(0.0867, grad_fn=<MeanBackward1>)
INFO:run_websocket_client:Starting epoch 93/100
INFO:run_websocket_client:Training round 93, calling fit on worker: bob, lr = 0.016
INFO:run_websocket_client:Training round 93, calling fit on worker: charlie, lr = 0.016
INFO:run_websocket_client:Training round 93, calling fit on worker: alice, lr = 0.016
INFO:run_websocket_client:Training round: 93, worker: b

After 100 rounds of training we acheive an accuracy > 95% on the entire testing dataset. 
This is impressing, given that no worker has access to more than 4 digits.

# Congratulations!!! - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!

### Star PySyft on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.

- [Star PySyft](https://github.com/OpenMined/PySyft)

### Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! You can do so by filling out the form at [http://slack.openmined.org](http://slack.openmined.org)

### Join a Code Project!

The best way to contribute to our community is to become a code contributor! At any time you can go to PySyft GitHub Issues page and filter for "Projects". This will show you all the top level Tickets giving an overview of what projects you can join! If you don't want to join a project, but you would like to do a bit of coding, you can also look for more "one off" mini-projects by searching for GitHub issues marked "good first issue".

- [PySyft Projects](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3AProject)
- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)

### Donate

If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!

[OpenMined's Open Collective Page](https://opencollective.com/openmined)