# Tutorial: Federated learning with websockets and federated averaging

This notebook will go through the steps to run a federated learning via websocket workers. We will use federated averaging to join the remotely trained models. 

Authors:
- midokura-silvia

## Preparation: start the websocket server workers
Each worker is represented by two parts, a local handle (websocket client worker) and the remote instance that holds the data and performs the computations. The remote part is calle a websocket server worker.

So first, we need to create the remote workers. For this, you need to run in a terminal (not possible from the notebook) the script start_websocket_servers.py

## Setting up the websocket client workers

We first need to perform the imports and setup some arguments and variables

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import syft as sy
from syft.workers import WebsocketClientWorker
import torch
from torchvision import datasets, transforms

from syft.frameworks.torch.federated import utils

In [3]:
import run_websocket_client as rwc

In [4]:
args = rwc.define_and_get_arguments(args=[])
use_cuda = args.cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
print(args)

Namespace(batch_size=64, cuda=False, epochs=2, federate_after_n_batches=50, lr=0.01, save_model=False, seed=1, test_batch_size=1000, use_virtual=False, verbose=False)


Now let's instantiate the websocket client workers, our local access point to the remote workers.
Note that this step will fail, if the websocket server workers are not running.

In [5]:
hook = sy.TorchHook(torch)

kwargs_websocket = {"host": "localhost", "hook": hook, "verbose": args.verbose}
alice = WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket)
bob = WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket)
charlie = WebsocketClientWorker(id="charlie", port=8779, **kwargs_websocket)

workers = [alice, bob, charlie]
print(workers)


[<WebsocketClientWorker id:alice #tensors:0>, <WebsocketClientWorker id:bob #tensors:0>, <WebsocketClientWorker id:charlie #tensors:0>]


## Prepare and distribute the training data

We will use the MNIST dataset and distribute the data randomly onto the workers. 
This is not realistic for a federated training setup, where the data would normally already be available at the remote workers.

We instantiate two FederatedDataLoaders, one for the train and one for the test set of the MNIST dataset.

If you run into BrokenPipe errors

In [6]:
federated_train_loader = sy.FederatedDataLoader(
    datasets.MNIST(
        "../data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ).federate(tuple(workers)),
    batch_size=args.batch_size,
    shuffle=True,
    iter_per_worker=True
)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../data",
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ),
    batch_size=args.test_batch_size,
    shuffle=True
)


Next, we need to instantiate the machine learning model. It is a small neural network with 2 convolutional and two fully connected layers. 
It uses ReLU activations and max pooling.

In [7]:
model = rwc.Net().to(device)
print(model)
#forward method of the model:
#    def forward(self, x):
#        x = F.relu(self.conv1(x))
#        x = F.max_pool2d(x, 2, 2)
#        x = F.relu(self.conv2(x))
#        x = F.max_pool2d(x, 2, 2)
#        x = x.view(-1, 4 * 4 * 50)
#        x = F.relu(self.fc1(x))
#        x = self.fc2(x)
#        return F.log_softmax(x, dim=1)

Net(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)


Now we are ready to define the training loop. We will perform training over a given number of batches separately on each worker and then calculate the federated average of the resulting model.

In [8]:
def train(model, device, federated_train_loader, lr, federate_after_n_batches):
    model.train()

    nr_batches = federate_after_n_batches

    models = {}
    loss_values = {}

    iter(federated_train_loader)  # initialize iterators
    batches = rwc.get_next_batches(federated_train_loader, nr_batches)
    counter = 0

    while True:
        print("Starting training round, batches [{}, {}]".format(counter, counter + nr_batches))
        data_for_all_workers = True
        for worker in batches:
            curr_batches = batches[worker]
            if curr_batches:
                models[worker], loss_values[worker] = rwc.train_on_batches(
                    worker, curr_batches, model, device, lr
                )
            else:
                data_for_all_workers = False
        counter += nr_batches
        if not data_for_all_workers:
            logger.debug("At least one worker ran out of data, stopping.")
            break

        model = utils.federated_avg(models)
        batches = rwc.get_next_batches(federated_train_loader, nr_batches)
    return model

In [9]:
import logging
FORMAT = "%(asctime)s %(levelname)s %(filename)s(l:%(lineno)d) - %(message)s"
LOG_LEVEL = logging.DEBUG
logging.basicConfig(format=FORMAT, level=LOG_LEVEL)
logger = logging.getLogger("main")

## Let's start the training

Now we are ready to start the federated training.

In [10]:
for epoch in range(1, args.epochs + 1):
    print("Starting epoch {}/{}".format(epoch, args.epochs))
    model = train(model, device, federated_train_loader, args.lr, args.federate_after_n_batches)
    rwc.test(model, device, test_loader)

Starting epoch 1/2




Starting training round, batches [0, 50]




Starting training round, batches [50, 100]




Starting training round, batches [100, 150]




Starting training round, batches [150, 200]




Starting training round, batches [200, 250]




Starting training round, batches [250, 300]




Starting training round, batches [300, 350]


2019-04-02 13:28:16,078 DEBUG <ipython-input-8-997890e4f55a>(l:26) - At least one worker ran out of data, stopping.


Starting training round, batches [350, 400]


2019-04-02 13:28:17,982 INFO run_websocket_client.py(l:158) - 

2019-04-02 13:28:17,982 INFO run_websocket_client.py(l:162) - Test set: Average loss: 0.3572, Accuracy: 8959/10000 (90%)



Starting epoch 2/2




Starting training round, batches [0, 50]




Starting training round, batches [50, 100]




Starting training round, batches [100, 150]




Starting training round, batches [150, 200]




Starting training round, batches [200, 250]




Starting training round, batches [250, 300]




Starting training round, batches [300, 350]


2019-04-02 13:30:40,336 DEBUG <ipython-input-8-997890e4f55a>(l:26) - At least one worker ran out of data, stopping.


Starting training round, batches [350, 400]


2019-04-02 13:30:42,234 INFO run_websocket_client.py(l:158) - 

2019-04-02 13:30:42,234 INFO run_websocket_client.py(l:162) - Test set: Average loss: 0.2389, Accuracy: 9261/10000 (93%)



# Congratulations!!! - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!

### Star PySyft on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.

- [Star PySyft](https://github.com/OpenMined/PySyft)

### Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! You can do so by filling out the form at [http://slack.openmined.org](http://slack.openmined.org)

### Join a Code Project!

The best way to contribute to our community is to become a code contributor! At any time you can go to PySyft GitHub Issues page and filter for "Projects". This will show you all the top level Tickets giving an overview of what projects you can join! If you don't want to join a project, but you would like to do a bit of coding, you can also look for more "one off" mini-projects by searching for GitHub issues marked "good first issue".

- [PySyft Projects](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3AProject)
- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)

### Donate

If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!

[OpenMined's Open Collective Page](https://opencollective.com/openmined)