# Tutorial: Asynchronous federated learning on MNIST

This notebook will go through the steps to run a federated learning via websocket workers in an asynchronous way. We will use federated averaging to join the remotely trained models. 

Authors:
- Silvia - GitHub [@midokura-silvia](https://github.com/midokura-silvia)

In [1]:
%load_ext autoreload
%autoreload 2

import inspect

## Federated Learning setup

For a Federated Learning we need different participants:

* _Workers_ that own datasets.

* An entity that knows the workers and the dataset name that lives in each worker. We'll call this a _scheduler_.

Each worker is represented by two parts, a proxy local to the scheduler (websocket client worker) and the remote instance that holds the data and performs the computations. The remote part is called a websocket server worker.

#### Preparation: Start the websocket workers
So first, we need to create the remote workers. For this, you need to run in a terminal (not possible from the notebook):

```bash
python start_websocket_servers.py
```

#### What's going on?

The script will instantiate three workers, Alice, Bob and Charlie and prepare their local data. 
Each worker is set up to have a subset of the MNIST training dataset. 
Alice holds all images corresponding to the digits 0-3, 
Bob holds all images corresponding to the digits 4-6 and 
Charlie holds all images corresponding to the digits 7-9.

| Worker      | Digits in local dataset | Number of samples |
| ----------- | ----------------------- | ----------------- |
| Alice       | 0-3                     | 24754             |
| Bob         | 4-6                     | 17181             |
| Charlie     | 7-9                     | 18065             |


In [2]:
# uncomment the following to see the code of the function that starts a worker
# import run_websocket_server

# print(inspect.getsource(run_websocket_server.start_websocket_server_worker))

We first need to perform the imports and setup some arguments and variables.

In [3]:
import sys
import asyncio
import syft as sy
from syft.workers import WebsocketClientWorker
import torch
from torchvision import datasets, transforms

from syft.frameworks.torch.federated import utils

In [4]:
import run_websocket_client as rwc
hook = sy.TorchHook(torch)

In [5]:
args = rwc.define_and_get_arguments(args=[])
use_cuda = args.cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
print(args)

Namespace(batch_size=32, cuda=False, federate_after_n_batches=10, lr=0.1, save_model=False, seed=1, test_batch_size=128, training_rounds=100, verbose=False)


Now let's instantiate the websocket client workers, our local proxies to the remote workers.
Note that **this step will fail, if the websocket server workers are not running**.

In [6]:
kwargs_websocket = {"host": "localhost", "hook": hook, "verbose": args.verbose}
alice = WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket)
bob = WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket)
charlie = WebsocketClientWorker(id="charlie", port=8779, **kwargs_websocket)

worker_instances = [alice, bob, charlie]
print(worker_instances)


[<WebsocketClientWorker id:alice #objects local:0 #objects remote: 350>, <WebsocketClientWorker id:bob #objects local:0 #objects remote: 346>, <WebsocketClientWorker id:charlie #objects local:0 #objects remote: 349>]


### Setting up a scheduler

We'll use this notebook as a scheduler, for this we'll need to:

* Have a model
* Have a loss function
* Define an optimizer
* Define hyper-parameters

#### Model
Let's instantiate the machine learning model. It is a small neural network with 2 convolutional and two fully connected layers. 
It uses ReLU activations and max pooling.

In [7]:
print(inspect.getsource(rwc.Net))

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)



In [8]:
model = rwc.Net().to(device)
print(model)

Net(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)


Before starting the training, let's load the MNIST test data and configure the logging.

In [9]:
test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "../data",
            train=False,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
            ),
        ),
        batch_size=args.test_batch_size,
        shuffle=False,
        drop_last=False,
    )
(data, target) = test_loader.__iter__().next()

In [10]:
# configure the logging of the training process.
import logging
logger = logging.getLogger("run_websocket_client")
if not len(logger.handlers):
    FORMAT = "%(asctime)s - %(message)s"
    DATE_FMT = "%H:%M:%S"
    formatter = logging.Formatter(FORMAT, DATE_FMT)
    handler = logging.StreamHandler()
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    logger.propagate = False
LOG_LEVEL = logging.DEBUG
logger.setLevel(LOG_LEVEL)

## Let's start the training

Now we are ready to start the federated training. We will perform training over a given number of batches separately on each worker and then calculate the federated average of the resulting model.

Every 10th training round we will evaluate the performance of the models returned by the workers and of the model obtained by federated averaging. 

The performance will be given both as the accuracy (ratio of correct predictions) and as the histograms of predicted digits. This is of interest, as each worker only owns a subset of the digits. Therefore, in the beginning each worker will only predict 'his' numbers and only know about the other numbers via the federated averaging process.

The training is done in an asynchronous manner. This means that the scheduler just tell the workers to train and does not block to wait for the result of the training before talking to the next worker.

The parameters of the training are given in the arguments. 
Each worker will train on a given number of batches, given by the value of federate_after_n_batches.
The training batch size and learning rate are also configured. 

In [11]:
print("Federate_after_n_batches: " + str(args.federate_after_n_batches))
print("Batch size: " + str(args.batch_size))
print("Initial learning rate: " + str(args.lr))

Federate_after_n_batches: 10
Batch size: 32
Initial learning rate: 0.1


In [12]:
traced_model = torch.jit.trace(model, data)
learning_rate = args.lr
for curr_round in range(1, args.training_rounds + 1):
    logger.info("Starting training round %s/%s", curr_round, args.training_rounds)

    results = await asyncio.gather(
        *[
            rwc.fit_model_on_worker(
                worker=worker,
                traced_model=traced_model,                batch_size=args.batch_size,
                curr_round=curr_round,
                max_nr_batches=args.federate_after_n_batches,
                lr=learning_rate,
            )
            for worker in worker_instances
        ]
    )
    models = {}
    loss_values = {}

    test_models = curr_round % 10 == 1 or curr_round == args.training_rounds
    if test_models:
        rwc.evaluate_models_on_test_data(test_loader, results)

    for worker_id, worker_model, worker_loss in results:
        if worker_model is not None:
            models[worker_id] = worker_model
            loss_values[worker_id] = worker_loss

    traced_model = utils.federated_avg(models)
    if test_models:
        rwc.evaluate_model("Federated model", traced_model, "cpu", test_loader)

    # decay learning rate
    learning_rate = max(0.98 * learning_rate, args.lr * 0.01)

if args.save_model:
    torch.save(model.state_dict(), "mnist_cnn.pt")

09:42:48 - Starting training round 1/100
09:42:48 - Training round 1, calling fit on worker: charlie, lr = 0.100
09:42:48 - Training round 1, calling fit on worker: alice, lr = 0.100
09:42:49 - Training round 1, calling fit on worker: bob, lr = 0.100
09:42:51 - Training round: 1, worker: bob, avg_loss: tensor(2.7445, grad_fn=<MeanBackward1>)
09:42:52 - Training round: 1, worker: charlie, avg_loss: tensor(1.5985, grad_fn=<MeanBackward1>)
09:42:54 - Training round: 1, worker: alice, avg_loss: tensor(1.1139, grad_fn=<MeanBackward1>)
09:42:58 - Prediction hist: [ 8327  1331  240  102  0  0  0  0  0  0]
09:42:58 - alice: Test set: Average loss: 0.0269, Accuracy: 2192/10000 (22)
09:43:00 - Prediction hist: [ 0  0  0  0  5097  0  4903  0  0  0]
09:43:00 - bob: Test set: Average loss: 0.0182, Accuracy: 1779/10000 (18)
09:43:02 - Prediction hist: [ 0  0  0  0  0  0  0  0  0  10000]
09:43:02 - charlie: Test set: Average loss: 0.0200, Accuracy: 1009/10000 (10)
09:43:05 - Prediction hist: [ 4556  

09:44:18 - Starting training round 14/100
09:44:18 - Training round 14, calling fit on worker: alice, lr = 0.077
09:44:18 - Training round 14, calling fit on worker: bob, lr = 0.077
09:44:19 - Training round 14, calling fit on worker: charlie, lr = 0.077
09:44:20 - Training round: 14, worker: charlie, avg_loss: tensor(0.1591, grad_fn=<MeanBackward1>)
09:44:21 - Training round: 14, worker: bob, avg_loss: tensor(0.0101, grad_fn=<MeanBackward1>)
09:44:23 - Training round: 14, worker: alice, avg_loss: tensor(0.0155, grad_fn=<MeanBackward1>)
09:44:24 - Starting training round 15/100
09:44:24 - Training round 15, calling fit on worker: bob, lr = 0.075
09:44:24 - Training round 15, calling fit on worker: charlie, lr = 0.075
09:44:24 - Training round 15, calling fit on worker: alice, lr = 0.075
09:44:25 - Training round: 15, worker: charlie, avg_loss: tensor(0.2015, grad_fn=<MeanBackward1>)
09:44:27 - Training round: 15, worker: alice, avg_loss: tensor(0.1064, grad_fn=<MeanBackward1>)
09:44:28

09:45:38 - Starting training round 28/100
09:45:38 - Training round 28, calling fit on worker: alice, lr = 0.058
09:45:38 - Training round 28, calling fit on worker: bob, lr = 0.058
09:45:38 - Training round 28, calling fit on worker: charlie, lr = 0.058
09:45:40 - Training round: 28, worker: bob, avg_loss: tensor(0.0012, grad_fn=<MeanBackward1>)
09:45:41 - Training round: 28, worker: alice, avg_loss: tensor(0.1312, grad_fn=<MeanBackward1>)
09:45:42 - Training round: 28, worker: charlie, avg_loss: tensor(0.0675, grad_fn=<MeanBackward1>)
09:45:43 - Starting training round 29/100
09:45:43 - Training round 29, calling fit on worker: charlie, lr = 0.057
09:45:44 - Training round 29, calling fit on worker: bob, lr = 0.057
09:45:44 - Training round 29, calling fit on worker: alice, lr = 0.057
09:45:45 - Training round: 29, worker: bob, avg_loss: tensor(0.0881, grad_fn=<MeanBackward1>)
09:45:46 - Training round: 29, worker: alice, avg_loss: tensor(0.0110, grad_fn=<MeanBackward1>)
09:45:48 - T

09:47:01 - Prediction hist: [ 1156  1240  1041  2305  902  193  871  853  605  834]
09:47:01 - alice: Test set: Average loss: 0.0041, Accuracy: 8259/10000 (83)
09:47:03 - Prediction hist: [ 783  1084  1003  300  1506  2377  1114  957  567  309]
09:47:03 - bob: Test set: Average loss: 0.0052, Accuracy: 7711/10000 (77)
09:47:05 - Prediction hist: [ 423  859  639  585  342  164  812  1255  3183  1738]
09:47:05 - charlie: Test set: Average loss: 0.0070, Accuracy: 6709/10000 (67)
09:47:07 - Prediction hist: [ 978  1123  1007  1014  967  906  971  1034  1006  994]
09:47:07 - Federated model: Test set: Average loss: 0.0010, Accuracy: 9619/10000 (96)
09:47:07 - Starting training round 42/100
09:47:08 - Training round 42, calling fit on worker: alice, lr = 0.044
09:47:08 - Training round 42, calling fit on worker: bob, lr = 0.044
09:47:08 - Training round 42, calling fit on worker: charlie, lr = 0.044
09:47:10 - Training round: 42, worker: charlie, avg_loss: tensor(0.0219, grad_fn=<MeanBackward

09:48:23 - Training round: 54, worker: alice, avg_loss: tensor(0.2276, grad_fn=<MeanBackward1>)
09:48:24 - Starting training round 55/100
09:48:24 - Training round 55, calling fit on worker: bob, lr = 0.034
09:48:24 - Training round 55, calling fit on worker: alice, lr = 0.034
09:48:25 - Training round 55, calling fit on worker: charlie, lr = 0.034
09:48:26 - Training round: 55, worker: alice, avg_loss: tensor(0.0086, grad_fn=<MeanBackward1>)
09:48:27 - Training round: 55, worker: charlie, avg_loss: tensor(0.0086, grad_fn=<MeanBackward1>)
09:48:28 - Training round: 55, worker: bob, avg_loss: tensor(0.0099, grad_fn=<MeanBackward1>)
09:48:29 - Starting training round 56/100
09:48:29 - Training round 56, calling fit on worker: alice, lr = 0.033
09:48:29 - Training round 56, calling fit on worker: bob, lr = 0.033
09:48:29 - Training round 56, calling fit on worker: charlie, lr = 0.033
09:48:31 - Training round: 56, worker: bob, avg_loss: tensor(0.2595, grad_fn=<MeanBackward1>)
09:48:32 - T

09:49:44 - Training round: 68, worker: charlie, avg_loss: tensor(0.0440, grad_fn=<MeanBackward1>)
09:49:45 - Starting training round 69/100
09:49:45 - Training round 69, calling fit on worker: alice, lr = 0.025
09:49:45 - Training round 69, calling fit on worker: charlie, lr = 0.025
09:49:45 - Training round 69, calling fit on worker: bob, lr = 0.025
09:49:47 - Training round: 69, worker: charlie, avg_loss: tensor(0.1417, grad_fn=<MeanBackward1>)
09:49:48 - Training round: 69, worker: bob, avg_loss: tensor(0.0074, grad_fn=<MeanBackward1>)
09:49:49 - Training round: 69, worker: alice, avg_loss: tensor(0.1743, grad_fn=<MeanBackward1>)
09:49:50 - Starting training round 70/100
09:49:50 - Training round 70, calling fit on worker: charlie, lr = 0.025
09:49:50 - Training round 70, calling fit on worker: alice, lr = 0.025
09:49:50 - Training round 70, calling fit on worker: bob, lr = 0.025
09:49:52 - Training round: 70, worker: bob, avg_loss: tensor(0.0284, grad_fn=<MeanBackward1>)
09:49:53 -

09:51:08 - charlie: Test set: Average loss: 0.0039, Accuracy: 8232/10000 (82)
09:51:10 - Prediction hist: [ 994  1133  1046  1021  972  898  942  1028  966  1000]
09:51:10 - Federated model: Test set: Average loss: 0.0007, Accuracy: 9723/10000 (97)
09:51:10 - Starting training round 82/100
09:51:10 - Training round 82, calling fit on worker: charlie, lr = 0.019
09:51:10 - Training round 82, calling fit on worker: bob, lr = 0.019
09:51:10 - Training round 82, calling fit on worker: alice, lr = 0.019
09:51:12 - Training round: 82, worker: charlie, avg_loss: tensor(0.0246, grad_fn=<MeanBackward1>)
09:51:13 - Training round: 82, worker: bob, avg_loss: tensor(0.0089, grad_fn=<MeanBackward1>)
09:51:15 - Training round: 82, worker: alice, avg_loss: tensor(0.0548, grad_fn=<MeanBackward1>)
09:51:16 - Starting training round 83/100
09:51:16 - Training round 83, calling fit on worker: bob, lr = 0.019
09:51:16 - Training round 83, calling fit on worker: alice, lr = 0.019
09:51:16 - Training round 

09:52:28 - Training round: 95, worker: charlie, avg_loss: tensor(0.0128, grad_fn=<MeanBackward1>)
09:52:30 - Training round: 95, worker: bob, avg_loss: tensor(0.0417, grad_fn=<MeanBackward1>)
09:52:31 - Training round: 95, worker: alice, avg_loss: tensor(0.0148, grad_fn=<MeanBackward1>)
09:52:32 - Starting training round 96/100
09:52:32 - Training round 96, calling fit on worker: bob, lr = 0.015
09:52:32 - Training round 96, calling fit on worker: charlie, lr = 0.015
09:52:32 - Training round 96, calling fit on worker: alice, lr = 0.015
09:52:33 - Training round: 96, worker: charlie, avg_loss: tensor(0.0039, grad_fn=<MeanBackward1>)
09:52:35 - Training round: 96, worker: bob, avg_loss: tensor(0.0060, grad_fn=<MeanBackward1>)
09:52:36 - Training round: 96, worker: alice, avg_loss: tensor(0.0096, grad_fn=<MeanBackward1>)
09:52:38 - Starting training round 97/100
09:52:38 - Training round 97, calling fit on worker: charlie, lr = 0.014
09:52:38 - Training round 97, calling fit on worker: b

After 100 rounds of training we acheive an accuracy > 95% on the entire testing dataset. 
This is impressing, given that no worker has access to more than 4 digits.

# Congratulations!!! - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!

### Star PySyft on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.

- [Star PySyft](https://github.com/OpenMined/PySyft)

### Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! You can do so by filling out the form at [http://slack.openmined.org](http://slack.openmined.org)

### Join a Code Project!

The best way to contribute to our community is to become a code contributor! At any time you can go to PySyft GitHub Issues page and filter for "Projects". This will show you all the top level Tickets giving an overview of what projects you can join! If you don't want to join a project, but you would like to do a bit of coding, you can also look for more "one off" mini-projects by searching for GitHub issues marked "good first issue".

- [PySyft Projects](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3AProject)
- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)

### Donate

If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!

[OpenMined's Open Collective Page](https://opencollective.com/openmined)