# Federated Learning - Model Centric MNIST Example: Train FL Model


In the "[01-Create Plan](../model-centric/02-ExecutePlan.ipynb)" notebook we created the model, training plan, and averaging plan, and then hosted all of them in PyGrid.

Such hosted FL model can be now trained using client libraries, SwiftSyft, KotlinSyft, syft.js.

In this notebook, we'll use FL Client included in the PySyft to do the training.

### Credits:
- Original authors: 

 - Vova Manannikov - Github: [@vvmnnnkv](https://github.com/vvmnnnkv)


- Reviewers: 
 - Patrick Cason - Github: [@cereallcerny](https://github.com/cereallarceny)


- New Content tested and enriched by: 
 - Juan M. Aunon - Twitter: [@jm_aunon](https://twitter.com/jm_aunon) - Github: [@jmaunon](https://github.com/jmaunon)
 

In [1]:
%load_ext autoreload
%autoreload 2
import warnings
warnings.filterwarnings("ignore")

import torch as th
from torchvision import datasets, transforms

import numpy as np
import urllib3
import time

import syft as sy
from syft.federated.fl_client import FLClient
from syft.federated.fl_job import FLJob
from syft.grid.clients.model_centric_fl_client import ModelCentricFLClient

urllib3.disable_warnings()
sy.make_hook(globals())

Setting up Sandbox...
Done!


In [2]:
private_key = """
-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEAzQMcI09qonB9OZT20X3Z/oigSmybR2xfBQ1YJ1oSjQ3YgV+G
FUuhEsGDgqt0rok9BreT4toHqniFixddncTHg7EJzU79KZelk2m9I2sEsKUqEsEF
lMpkk9qkPHhJB5AQoClOijee7UNOF4yu3HYvGFphwwh4TNJXxkCg69/RsvPBIPi2
9vXFQzFE7cbN6jSxiCtVrpt/w06jJUsEYgNVQhUFABDyWN4h/67M1eArGA540vyd
kYdSIEQdknKHjPW62n4dvqDWxtnK0HyChsB+LzmjEnjTJqUzr7kM9Rzq3BY01DNi
TVcB2G8t/jICL+TegMGU08ANMKiDfSMGtpz3ZQIDAQABAoIBAD+xbKeHv+BxxGYE
Yt5ZFEYhGnOk5GU/RRIjwDSRplvOZmpjTBwHoCZcmsgZDqo/FwekNzzuch1DTnIV
M0+V2EqQ0TPJC5xFcfqnikybrhxXZAfpkhtU+gR5lDb5Q+8mkhPAYZdNioG6PGPS
oGz8BsuxINhgJEfxvbVpVNWTdun6hLOAMZaH3DHgi0uyTBg8ofARoZP5RIbHwW+D
p+5vd9x/x7tByu76nd2UbMp3yqomlB5jQktqyilexCIknEnfb3i/9jqFv8qVE5P6
e3jdYoJY+FoomWhqEvtfPpmUFTY5lx4EERCb1qhWG3a7sVBqTwO6jJJBsxy3RLIS
Ic0qZcECgYEA6GsBP11a2T4InZ7cixd5qwSeznOFCzfDVvVNI8KUw+n4DOPndpao
TUskWOpoV8MyiEGdQHgmTOgGaCXN7bC0ERembK0J64FI3TdKKg0v5nKa7xHb7Qcv
t9ccrDZVn4y/Yk5PCqjNWTR3/wDR88XouzIGaWkGlili5IJqdLEvPvUCgYEA4dA+
5MNEQmNFezyWs//FS6G3lTRWgjlWg2E6BXXvkEag6G5SBD31v3q9JIjs+sYdOmwj
kfkQrxEtbs173xgYWzcDG1FI796LTlJ/YzuoKZml8vEF3T8C4Bkbl6qj9DZljb2j
ehjTv5jA256sSUEqOa/mtNFUbFlBjgOZh3TCsLECgYAc701tdRLdXuK1tNRiIJ8O
Enou26Thm6SfC9T5sbzRkyxFdo4XbnQvgz5YL36kBnIhEoIgR5UFGBHMH4C+qbQR
OK+IchZ9ElBe8gYyrAedmgD96GxH2xAuxAIW0oDgZyZgd71RZ2iBRY322kRJJAdw
Xq77qo6eXTKpni7grjpijQKBgDHWRAs5DVeZkTwhoyEW0fRfPKUxZ+ZVwUI9sxCB
dt3guKKTtoY5JoOcEyJ9FdBC6TB7rV4KGiSJJf3OXAhgyP9YpNbimbZW52fhzTuZ
bwO/ZWC40RKDVZ8f63cNsiGz37XopKvNzu36SJYv7tY8C5WvvLsrd/ZxvIYbRUcf
/dgBAoGBAMdR5DXBcOWk3+KyEHXw2qwWcGXyzxtca5SRNLPR2uXvrBYXbhFB/PVj
h3rGBsiZbnIvSnSIE+8fFe6MshTl2Qxzw+F2WV3OhhZLLtBnN5qqeSe9PdHLHm49
XDce6NV2D1mQLBe8648OI5CScQENuRGxF2/h9igeR4oRRsM1gzJN
-----END RSA PRIVATE KEY-----
""".strip()

public_key = """
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAzQMcI09qonB9OZT20X3Z
/oigSmybR2xfBQ1YJ1oSjQ3YgV+GFUuhEsGDgqt0rok9BreT4toHqniFixddncTH
g7EJzU79KZelk2m9I2sEsKUqEsEFlMpkk9qkPHhJB5AQoClOijee7UNOF4yu3HYv
GFphwwh4TNJXxkCg69/RsvPBIPi29vXFQzFE7cbN6jSxiCtVrpt/w06jJUsEYgNV
QhUFABDyWN4h/67M1eArGA540vydkYdSIEQdknKHjPW62n4dvqDWxtnK0HyChsB+
LzmjEnjTJqUzr7kM9Rzq3BY01DNiTVcB2G8t/jICL+TegMGU08ANMKiDfSMGtpz3
ZQIDAQAB
-----END PUBLIC KEY-----
""".strip()

Creating authentication token.

In [3]:
import jwt
auth_token = jwt.encode({}, private_key, algorithm='RS256').decode('ascii')

print(auth_token)

eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.e30.Cn_0cSjCw1QKtcYDx_mYN_q9jO2KkpcUoiVbILmKVB4LUCQvZ7YeuyQ51r9h3562KQoSas_ehbjpz2dw1Dk24hQEoN6ObGxfJDOlemF5flvLO_sqAHJDGGE24JRE4lIAXRK6aGyy4f4kmlICL6wG8sGSpSrkZlrFLOVRJckTptgaiOTIm5Udfmi45NljPBQKVpqXFSmmb3dRy_e8g3l5eBVFLgrBhKPQ1VbNfRK712KlQWs7jJ31fGpW2NxMloO1qcd6rux48quivzQBCvyK8PV5Sqrfw_OMOoNLcSvzePDcZXa2nPHSu3qQIikUdZIeCnkJX-w0t8uEFG3DfH1fVA


Define `on_accepted`, `on_rejected`, `on_error` handlers.

The main training loop is located inside `on_accepted` routine.

In [4]:
import os
import json
from collections import defaultdict
def batch_data(data, batch_size, seed):
    
    '''
    data is a dict := {'x': [numpy array], 'y': [numpy array]} (on one client)
    returns x, y, which are both numpy array of length: batch_size
    '''
    data_x = data['x']
    data_y = data['y']

    # loop through mini-batches
    for i in range(0, len(data_x), batch_size):
        end = len(data_x) if i + batch_size > len(data_x) else i + batch_size
        batched_x = data_x[i:end]
        batched_y = data_y[i:end]
        yield (batched_x, batched_y)


def read_dir(data_dir):
    clients = []
    groups = []
    data = defaultdict(lambda : None)

    files = os.listdir(data_dir)
    files = [f for f in files if f.endswith('.json')]
    for f in files:
        file_path = os.path.join(data_dir,f)
        with open(file_path, 'r') as inf:
            cdata = json.load(inf)
        clients.extend(cdata['users'])
        if 'hierarchies' in cdata:
            groups.extend(cdata['hierarchies'])
        data.update(cdata['user_data'])

    clients = list(sorted(data.keys()))
    return clients, groups, data

In [5]:
def read_data(train_data_dir, test_data_dir):
    '''parses data in given train and test data directories

    assumes:
    - the data in the input directories are .json files with 
        keys 'users' and 'user_data'
    - the set of train set users is the same as the set of test set users
    
    Return:
        clients: list of client ids
        groups: list of group ids; empty list if none found
        train_data: dictionary of train data
        test_data: dictionary of test data
    '''
    train_clients, train_groups, train_data = read_dir(train_data_dir)
    test_clients, test_groups, test_data = read_dir(test_data_dir)
    assert train_clients == test_clients
    assert train_groups == test_groups

    return train_clients, train_groups, train_data, test_data

In [6]:
import os
# train_data_dir = os.path.join('.', 'data', 'femnist', 'train')
# test_data_dir = os.path.join('.', 'data', 'femnist', 'test')
train_data_dir = os.path.join('.', 'data_custom', 'train')
test_data_dir = os.path.join('.', 'data_custom', 'test')
users, groups, train_data, test_data = read_data(train_data_dir, test_data_dir)
# print(len(train_data['f4071_32']['x']))
print(train_data.keys())
# users = ['f4074_12', 'f4071_32']
# users = ['f4071_32']

dict_keys(['f0016_39', 'custom1'])


In [7]:
# batch_size = 20

# for batched_x, batched_y in batch_data(train_data['f4071_32'], batch_size, seed=1234):
#     print(len(batched_x))

# #     input_data = self.process_x(batched_x)
# #     target_data = self.process_y(batched_y)

test_X = th.tensor(np.array(test_data['custom1']['x']), dtype=th.float)
test_y = th.nn.functional.one_hot(th.tensor(test_data['custom1']['y'], dtype=th.int64), 62)

In [8]:
cycles_log = []
training_plan = None
testing_plan = None
model_params = None
status = {
    "ended": False
}

# Called when client is accepted into FL cycle
def on_accepted(job: FLJob):
    global training_plan
    global testing_plan
    global model_params
    
    print(f"Accepted into cycle {len(cycles_log) + 1}!")

    cycle_params = job.client_config
    batch_size = cycle_params["batch_size"]
    lr = cycle_params["lr"]
    max_updates = cycle_params["max_updates"]
#     print(job)
#     print(batch_size)
#     mnist_dataset = th.utils.data.DataLoader(
#         datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor()),
#         batch_size=batch_size,
#         drop_last=True,
#         shuffle=True,
#     )

    training_plan = job.plans["training_plan"]
    testing_plan = job.plans["evaluate_model_plan"]
    model_params = job.model.tensors()
    status['ended'] = True
    return
#     print(model_params[0][:5])
    losses = []
    accuracies = []

#     for batch_idx, (X, y) in enumerate(mnist_dataset):
#         X = X.view(batch_size, -1)
    
    print(' '.join([str(model_param.sum()) for model_param in model_params]))
    
    test_acc,test_loss = testing_plan.torchscript(test_X, test_y, th.tensor(len(test_y)), model_params)
    print(f'test_acc {test_acc} test_loss {test_loss}')
    for i in range(max_updates):
        for batched_x, batched_y in batch_data(train_data['custom1'], batch_size, seed=1234):
            X = th.tensor(np.array(batched_x), dtype=th.float)
            y = th.tensor(np.array(batched_y), dtype=th.int64)

            y_oh = th.nn.functional.one_hot(y, 62)
            loss, acc, logits, target, *model_params = training_plan.torchscript(
                X, y_oh, th.tensor(batch_size), th.tensor(lr), model_params
            )
#             [print(model_param.sum()) for model_param in model_params]
            
#             print(len(updated_params))
#             print(len(model_params))
            
#             model_params = updated_params
#   
    print(' '.join([str(model_param.sum()) for model_param in model_params]))
    
    test_acc,test_loss = testing_plan.torchscript(test_X, test_y, th.tensor(len(test_y)), model_params)
    print(f'test_acc_AFTER {test_acc} test_loss_AFTER {test_loss}')

    
    job.report(model_params)
    cycles_log.append((test_loss, test_acc))
    
# Called when the client is rejected from cycle
def on_rejected(job: FLJob, timeout):
    if timeout is None:
        print(f"Rejected from cycle without timeout (this means FL training is done)")
    else:
        print(f"Rejected from cycle with timeout: {timeout}")
    status["ended"] = True

# Called when error occured
def on_error(job: FLJob, error: Exception):
    print(f"Error: {error}")
    status["ended"] = True

In [9]:
''.join(['sj', 'ks'])

'sjks'

We use same PyGrid Node where the model was hosted, the model name/version of hosted model.

In [10]:
# PyGrid Node address
# gridAddress = "ws://alice:5000"
gridAddress = "ws://localhost:5000"

# Hosted model name/version
model_name = "mnist"
model_version = "1.0"

Let's define routine that creates FL client and starts the FL process.

In [11]:
#client.grid_worker.get_connection_speed(client.worker_id)

In [12]:
def new_job(self, model_name, model_version) -> FLJob:
        if self.worker_id is None:
            auth_response = self.grid_worker.authenticate(
                self.auth_token, model_name, model_version
            )
            self.worker_id = auth_response["data"]["worker_id"]

        job = FLJob(
            fl_client=self,
            grid_worker=self.grid_worker,
            model_name=model_name,
            model_version=model_version,
        )
        return job

In [13]:
def create_client_and_run_cycle():
    client = FLClient(url=gridAddress, auth_token=auth_token, verbose=True)
    authResponse = client.grid_worker.authenticate(client.auth_token,model_name,model_version)
#     printprint(acc.item())(authResponse)
    client.worker_id = authResponse["data"]["worker_id"]
    job = client.new_job(model_name, model_version)

    # Set event handlers
    job.add_listener(job.EVENT_ACCEPTED, on_accepted)
    job.add_listener(job.EVENT_REJECTED, on_rejected)
    job.add_listener(job.EVENT_ERROR, on_error)

    # Shoot!
    job.start()


Now we're ready to start FL training.

We're going to run multiple "workers" until the FL model is fully done and see the progress.

As we create & authenticate new client each time,
this emulates multiple different workers requesting a cycle and working on it.

In [14]:
while not status["ended"]:
    create_client_and_run_cycle()
    print('\n\n')
    time.sleep(1)

Accepted into cycle 1!





In [15]:
model_params_copy = model_params.copy()

In [None]:
import math
batch_size = 10
lr = 0.0003
def run_epoch(model_params_copy):
    test_acc,test_loss = testing_plan.torchscript(test_X, test_y, th.tensor(len(test_y)), model_params_copy)
    print(f'test_acc {test_acc} test_loss {test_loss}')
    total_batches = 1
    for i in range(math.ceil(total_batches)):
        for batched_x, batched_y in batch_data(train_data['custom1'], batch_size, seed=1234):
            X = th.tensor(np.array(batched_x), dtype=th.float)
            y = th.tensor(np.array(batched_y), dtype=th.int64)

            y_oh = th.nn.functional.one_hot(y, 62)
            loss, acc, logits, target, *model_params_copy = training_plan.torchscript(
                X, y_oh, th.tensor(batch_size), th.tensor(lr), model_params_copy
            )
    return model_params_copy, test_acc, test_loss

test_accs = []
test_losses = []
for i in range(50):
    model_params_copy, test_acc, test_loss = run_epoch(model_params_copy)
    test_accs.append(test_acc)
    test_losses.append(test_loss)



test_acc 0.009999999776482582 test_loss 10.483942985534668
test_acc 0.4350000023841858 test_loss 2.0518555641174316
test_acc 0.5174999833106995 test_loss 1.7015289068222046
test_acc 0.5550000071525574 test_loss 1.5512460470199585
test_acc 0.5724999904632568 test_loss 1.479295015335083
test_acc 0.5849999785423279 test_loss 1.4390881061553955
test_acc 0.5849999785423279 test_loss 1.4247289896011353


Let's plot loss and accuracy statistics recorded from each worker.
Each such worker's statistics is drawn with different color.

It's visible that loss/accuracy improvement occurs after each `max_diffs` reports to PyGrid,
because PyGrid updates the model and creates new checkpoint after
receiving `max_diffs` updates from FL clients.

In [43]:
!pip install matplotlib




In [16]:
print(cycles_log)

[]


In [17]:

import numpy as np
import pandas as pd

import visualization_utils

from baseline_constants import (
    ACCURACY_KEY,
    BYTES_READ_KEY,
    BYTES_WRITTEN_KEY,
    CLIENT_ID_KEY,
    LOCAL_COMPUTATIONS_KEY,
    NUM_ROUND_KEY,
    NUM_SAMPLES_KEY)


In [18]:
def get_accuracy_vs_round_number(stat_metrics, weighted=False):
    if weighted:
        accuracies = stat_metrics.groupby(NUM_ROUND_KEY).apply(_weighted_mean, ACCURACY_KEY, NUM_SAMPLES_KEY)
        accuracies = accuracies.reset_index(name=ACCURACY_KEY)

    else:
        accuracies = stat_metrics.groupby(NUM_ROUND_KEY, as_index=False).mean()
        stds = stat_metrics.groupby(NUM_ROUND_KEY, as_index=False).std()
    
    percentile_10 = stat_metrics.groupby(NUM_ROUND_KEY, as_index=False).apply(lambda x: x.quantile(0.10)) #.quantile(10)
    percentile_90 = stat_metrics.groupby(NUM_ROUND_KEY, as_index=False).apply(lambda x: x.quantile(0.90)) #.quantile(90)
    
#     print(accuracies)
    return accuracies, percentile_10, percentile_90

def get_loss_vs_round_number(stat_metrics, weighted=False):
    if weighted:
        accuracies = stat_metrics.groupby(NUM_ROUND_KEY).apply(_weighted_mean, 'loss', NUM_SAMPLES_KEY)
        accuracies = accuracies.reset_index(name='loss')

    else:
        accuracies = stat_metrics.groupby(NUM_ROUND_KEY, as_index=False).mean()
        stds = stat_metrics.groupby(NUM_ROUND_KEY, as_index=False).std()
    
    percentile_10 = stat_metrics.groupby(NUM_ROUND_KEY, as_index=False).apply(lambda x: x.quantile(0.10)) #.quantile(10)
    percentile_90 = stat_metrics.groupby(NUM_ROUND_KEY, as_index=False).apply(lambda x: x.quantile(0.90)) #.quantile(90)
    
#     print(accuracies)
    return accuracies, percentile_10, percentile_90


def _weighted_mean(df, metric_name, weight_name):
    d = df[metric_name]
    w = df[weight_name]
    
    try:
        return (w * d).sum() / w.sum()
    except ZeroDivisionError:
        return np.nan
fpath = "data/"
SHOW_WEIGHTED = True # show weighted accuracy instead of unweighted accuracy
PLOT_CLIENTS = True

stat_file_testbed = fpath + 'metrics_stat_testbed.csv' # change to None if desired
stat_file = fpath + 'metrics_stat.csv' # change to None if desired
sys_file = fpath + 'metrics_sys.csv' # change to None if desired

fstat_metrics_testbed= visualization_utils.load_data(stat_file_testbed)
fstat_metrics= visualization_utils.load_data(stat_file)

faccuracies_testbed,_, _ = get_accuracy_vs_round_number(fstat_metrics_testbed, True)
faccuracies, _, _ = get_accuracy_vs_round_number(fstat_metrics, True)

loss_testbed,_, _ = get_loss_vs_round_number(fstat_metrics_testbed, True)
loss, _, _ = get_loss_vs_round_number(fstat_metrics, True)

FileNotFoundError: [Errno 2] No such file or directory: 'data/metrics_stat_testbed.csv'

In [None]:
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(8, 6))
losses = []
accuracies = []
for i, cycle_log in enumerate(cycles_log):
    losses.append(cycle_log[0].item())
    accuracies.append(cycle_log[1].item())
    
# axs[0].plot(range(0, len(losses))[:500], losses[:500], label='PyGrid Node ')
# axs[0].plot(loss_testbed[NUM_ROUND_KEY][:500], loss_testbed['loss'][:500], label='Testbed')
# axs[0].plot(loss[NUM_ROUND_KEY][:500], loss['loss'][:500], label='LEAF')
# axs[0].legend(loc='best')

    

# plt.plot(range(0, len(accuracies))[:500], accuracies[:500], label='PyGrid Node ')
plt.plot(faccuracies_testbed[NUM_ROUND_KEY][:500], faccuracies_testbed[ACCURACY_KEY][:500], label='PyTorch')
plt.plot(faccuracies[NUM_ROUND_KEY][:500], faccuracies[ACCURACY_KEY][:500], label='Tensorflow')
plt.legend(loc='best')

plt.ylabel('Accuracy')
plt.xlabel('Round Number')


plt.savefig('combined_pytorch_leaf_testbed_accuracy.png')
#     print(f"Cycle {i + 1}:\tLoss: {np.mean(losses)}\tAcc: {np.mean(accuracies)}")

In [None]:
print(losses)

In [None]:
print(accuracies)

In [None]:
fig = plt.figure(figsize=(8, 6))
 
# plt.plot(range(0, len(losses))[:500], losses[:500], label='PyGrid Node ')
plt.plot(loss_testbed[NUM_ROUND_KEY][:500], loss_testbed['loss'][:500], label='PyTorch')
plt.plot(loss[NUM_ROUND_KEY][:500], loss['loss'][:500], label='TensorFlow')
plt.legend(loc='best')

plt.ylabel('Loss')
plt.xlabel('Round Number')


plt.savefig('combined_pytorch_leaf_testbed_loss.png')
#     print(f"Cycle {i + 1}:\tLoss: {np.mean(losses)}\tAcc: {np.mean