# Tutorial: Asynchronous federated learning on MNIST

This notebook will go through the steps to run a federated learning via websocket workers in an asynchronous way using [TrainConfig](https://github.com/OpenMined/PySyft/blob/dev/examples/tutorials/advanced/Federated%20Learning%20with%20TrainConfig/Introduction%20to%20TrainConfig.ipynb). We will use federated averaging to join the remotely trained models.

Authors:
- Silvia - GitHub [@midokura-silvia](https://github.com/midokura-silvia)
- Marianne Monteiro - Twitter [@hereismari](https://twitter.com/hereismari) - Github [@mari-linhares
](https://github.com/mari-linhares)

In [1]:
%load_ext autoreload
%autoreload 2

import inspect

## Federated Learning setup

For a Federated Learning setup with TrainConfig we need different participants:

* _Workers_ that own datasets.

* An entity that knows the workers and the dataset name that lives in each worker. We'll call this a _scheduler_.

Each worker is represented by two parts, a proxy local to the scheduler (websocket client worker) and the remote instance that holds the data and performs the computations. The remote part is called a websocket server worker.

## Preparation: Start the websocket workers
So first, we need to create the remote workers. For this, you need to run in a terminal (not possible from the notebook):

```bash
python start_websocket_servers.py
```

#### What's going on?

The script will instantiate three workers, Alice, Bob and Charlie and prepare their local data. 
Each worker is set up to have a subset of the MNIST training dataset. 
Alice holds all images corresponding to the digits 0-3, 
Bob holds all images corresponding to the digits 4-6 and 
Charlie holds all images corresponding to the digits 7-9.

| Worker      | Digits in local dataset | Number of samples |
| ----------- | ----------------------- | ----------------- |
| Alice       | 0-3                     | 24754             |
| Bob         | 4-6                     | 17181             |
| Charlie     | 7-9                     | 18065             |


In [2]:
# uncomment the following to see the code of the function that starts a worker
# import run_websocket_server

# print(inspect.getsource(run_websocket_server.start_websocket_server_worker))

Before continuing let's first need to import dependencies, setup needed arguments and configure logging.

In [3]:
# Dependencies
import sys
import asyncio

import syft as sy
from syft.workers import WebsocketClientWorker
from syft.frameworks.torch.federated import utils

import torch
from torchvision import datasets, transforms

import run_websocket_client as rwc
if torch.__version__>= "1.0.2":
    raise ValueError(f"This tutorial currently does not support torch versions >= 1.0.2, you have version {torch.__version__}")

In [4]:
# Hook torch
hook = sy.TorchHook(torch)

In [5]:
# Arguments
args = rwc.define_and_get_arguments(args=[])
use_cuda = args.cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
print(args)

Namespace(batch_size=32, cuda=False, federate_after_n_batches=10, lr=0.1, save_model=False, seed=1, test_batch_size=128, training_rounds=40, verbose=False)


In [6]:
# Configure logging
import logging

logger = logging.getLogger("run_websocket_client")

if not len(logger.handlers):
    FORMAT = "%(asctime)s - %(message)s"
    DATE_FMT = "%H:%M:%S"
    formatter = logging.Formatter(FORMAT, DATE_FMT)
    handler = logging.StreamHandler()
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    logger.propagate = False
LOG_LEVEL = logging.DEBUG
logger.setLevel(LOG_LEVEL)

Now let's instantiate the websocket client workers, our local proxies to the remote workers.
Note that **this step will fail, if the websocket server workers are not running**.

In [7]:
kwargs_websocket = {"host": "localhost", "hook": hook, "verbose": args.verbose}
alice = WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket)
bob = WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket)
charlie = WebsocketClientWorker(id="charlie", port=8779, **kwargs_websocket)

worker_instances = [alice, bob, charlie]
print(worker_instances)

[<WebsocketClientWorker id:alice #objects local:0 #objects remote: 1>, <WebsocketClientWorker id:bob #objects local:0 #objects remote: 1>, <WebsocketClientWorker id:charlie #objects local:0 #objects remote: 1>]


## Setting up training

### Model
Let's instantiate the machine learning model. It is a small neural network with 2 convolutional and two fully connected layers. 
It uses ReLU activations and max pooling.

In [8]:
print(inspect.getsource(rwc.Net))

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)



In [9]:
model = rwc.Net().to(device)
print(model)

Net(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)


### Test Data

The training data lives in each of the devices, but fefore starting the training, let's load the MNIST test data

In [10]:
test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "./data",
            train=False,
            download=True,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
            ),
        ),
        batch_size=args.test_batch_size,
        shuffle=False,
        drop_last=False,
    )
(data, target) = test_loader.__iter__().next()

#### Making the model serializable

In order to send the model to the workers we need the model to be serializable, for this we use [`jit`](https://pytorch.org/docs/stable/jit.html).

In [11]:
traced_model = torch.jit.trace(model, data)

### Let's start the training

Now we are ready to start the federated training. We will perform training over a given number of batches separately on each worker and then calculate the federated average of the resulting model.

Every 10th training round we will evaluate the performance of the models returned by the workers and of the model obtained by federated averaging. 

The performance will be given both as the accuracy (ratio of correct predictions) and as the histograms of predicted digits. This is of interest, as each worker only owns a subset of the digits. Therefore, in the beginning each worker will only predict their numbers and only know about the other numbers via the federated averaging process.

The training is done in an asynchronous manner. This means that the scheduler just tell the workers to train and does not block to wait for the result of the training before talking to the next worker.

The parameters of the training are given in the arguments. 
Each worker will train on a given number of batches, given by the value of federate_after_n_batches.
The training batch size and learning rate are also configured. 

In [12]:
print("Federate_after_n_batches: " + str(args.federate_after_n_batches))
print("Batch size: " + str(args.batch_size))
print("Initial learning rate: " + str(args.lr))

Federate_after_n_batches: 10
Batch size: 32
Initial learning rate: 0.1


In [13]:
learning_rate = args.lr

for curr_round in range(1, args.training_rounds + 1):
    global traced_model
    
    logger.info("Starting training round %s/%s", curr_round, args.training_rounds)

    # For each of the workers we ask the model to train with their part of their data
    # in a async fashion and then we gather the results
    results = await asyncio.gather(
        *[
            rwc.fit_model_on_worker(
                worker=worker,
                traced_model=traced_model,
                batch_size=args.batch_size,
                curr_round=curr_round,
                max_nr_batches=args.federate_after_n_batches,
                lr=learning_rate,
            )
            for worker in worker_instances
        ]
    )
    
    models = {}
    loss_values = {}

    # Run evaluation every 10 rounds
    run_evaluation = (curr_round) % 10 == 1 or curr_round == args.training_rounds
    if run_evaluation:
        rwc.evaluate_models_on_test_data(test_loader, results)
    
    # Store models and loss_values for each worker
    for worker_id, worker_model, worker_loss in results:
        if worker_model is not None:
            models[worker_id] = worker_model
            loss_values[worker_id] = worker_loss

    # Average model
    avg_model = utils.federated_avg(models)
    if run_evaluation:
        rwc.evaluate_model("Federated model", avg_model, "cpu", test_loader)

    # decay learning rate
    learning_rate = max(0.98 * learning_rate, args.lr * 0.01)
    
    # Use averaged model in the next round
    traced_model = avg_model

# Save model
if args.save_model:
    torch.save(model.state_dict(), "mnist_cnn.pt")

11:41:42 - Starting training round 1/40
11:41:42 - Training round 1, calling fit on worker: alice, lr = 0.100
11:41:42 - Training round 1, calling fit on worker: bob, lr = 0.100
11:41:42 - Training round 1, calling fit on worker: charlie, lr = 0.100
11:41:43 - Training round: 1, worker: alice, avg_loss: tensor(2.3841, grad_fn=<MeanBackward1>)
11:41:44 - Training round: 1, worker: charlie, avg_loss: tensor(1.7610, grad_fn=<MeanBackward1>)
11:41:45 - Training round: 1, worker: bob, avg_loss: tensor(1.4272, grad_fn=<MeanBackward1>)
11:41:50 - Prediction hist.: [ 473  4436  5091  0  0  0  0  0  0  0]
11:41:50 - alice: Test set: Average loss: 0.0184, Accuracy: 2261/10000 (22.61)
11:41:52 - Prediction hist.: [ 0  0  0  0  8661  0  1339  0  0  0]
11:41:52 - bob: Test set: Average loss: 0.0231, Accuracy: 1614/10000 (16.14)
11:41:55 - Prediction hist.: [ 0  0  0  0  0  0  0  0  0  10000]
11:41:55 - charlie: Test set: Average loss: 0.0224, Accuracy: 1009/10000 (10.09)
11:41:58 - Prediction hist.

11:43:10 - Starting training round 14/40
11:43:10 - Training round 14, calling fit on worker: alice, lr = 0.077
11:43:10 - Training round 14, calling fit on worker: bob, lr = 0.077
11:43:11 - Training round 14, calling fit on worker: charlie, lr = 0.077
11:43:11 - Training round: 14, worker: charlie, avg_loss: tensor(0.3328, grad_fn=<MeanBackward1>)
11:43:13 - Training round: 14, worker: bob, avg_loss: tensor(0.1389, grad_fn=<MeanBackward1>)
11:43:14 - Training round: 14, worker: alice, avg_loss: tensor(0.2955, grad_fn=<MeanBackward1>)
11:43:15 - Starting training round 15/40
11:43:15 - Training round 15, calling fit on worker: alice, lr = 0.075
11:43:16 - Training round 15, calling fit on worker: bob, lr = 0.075
11:43:16 - Training round 15, calling fit on worker: charlie, lr = 0.075
11:43:16 - Training round: 15, worker: charlie, avg_loss: tensor(0.1620, grad_fn=<MeanBackward1>)
11:43:18 - Training round: 15, worker: bob, avg_loss: tensor(0.1267, grad_fn=<MeanBackward1>)
11:43:19 - T

11:44:33 - Starting training round 28/40
11:44:33 - Training round 28, calling fit on worker: alice, lr = 0.058
11:44:33 - Training round 28, calling fit on worker: bob, lr = 0.058
11:44:33 - Training round 28, calling fit on worker: charlie, lr = 0.058
11:44:34 - Training round: 28, worker: bob, avg_loss: tensor(0.0163, grad_fn=<MeanBackward1>)
11:44:35 - Training round: 28, worker: alice, avg_loss: tensor(0.0108, grad_fn=<MeanBackward1>)
11:44:36 - Training round: 28, worker: charlie, avg_loss: tensor(0.1363, grad_fn=<MeanBackward1>)
11:44:38 - Starting training round 29/40
11:44:38 - Training round 29, calling fit on worker: alice, lr = 0.057
11:44:38 - Training round 29, calling fit on worker: bob, lr = 0.057
11:44:38 - Training round 29, calling fit on worker: charlie, lr = 0.057
11:44:39 - Training round: 29, worker: bob, avg_loss: tensor(0.0456, grad_fn=<MeanBackward1>)
11:44:40 - Training round: 29, worker: alice, avg_loss: tensor(0.2410, grad_fn=<MeanBackward1>)
11:44:42 - Tra

11:46:03 - Prediction hist.: [ 1021  1127  1033  997  926  892  963  1036  963  1042]
11:46:03 - Federated model: Test set: Average loss: 0.0013, Accuracy: 9536/10000 (95.36)


After 40 rounds of training we achieve an accuracy > 95% on the entire testing dataset. 
This is impressing, given that no worker has access to more than 4 digits!!

# Congratulations!!! - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!

### Star PySyft on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.

- [Star PySyft](https://github.com/OpenMined/PySyft)

### Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! You can do so by filling out the form at [http://slack.openmined.org](http://slack.openmined.org)

### Join a Code Project!

The best way to contribute to our community is to become a code contributor! At any time you can go to PySyft GitHub Issues page and filter for "Projects". This will show you all the top level Tickets giving an overview of what projects you can join! If you don't want to join a project, but you would like to do a bit of coding, you can also look for more "one off" mini-projects by searching for GitHub issues marked "good first issue".

- [PySyft Projects](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3AProject)
- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)

### Donate

If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!

[OpenMined's Open Collective Page](https://opencollective.com/openmined)