In [None]:
from copy import deepcopy
import numpy as np

import tensorflow as tf
import tensorflow_datasets as tfds

import matplotlib.pyplot as plt

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print("GPU is", "available" if tf.config.list_physical_devices('GPU') else "NOT AVAILABLE")

In [None]:
from flex.data import FedDataDistribution, Dataset
import numpy as np

flex_dataset, test_data =  FedDataDistribution.FederatedEMNIST(return_test=True)

def remove_writer_from_y_data(client_dataset: Dataset):
    new_x_data = client_dataset.X_data
    new_y_data = np.asarray([y[0] for y in client_dataset.y_data])  # each label is a tuple (mnist_label, writer), so we remove the writer
    return Dataset(new_x_data, new_y_data)

flex_dataset = flex_dataset.map(remove_writer_from_y_data)

In [None]:
from flex.pool import init_server_model
from flex.pool import FlexPool, FlexModel

def build_server_model(*args):
    flex_model = FlexModel()

    flex_model.model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    #tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
    ])

    flex_model.model.compile(optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])

    return flex_model

p = FlexPool.client_server_architecture(flex_dataset, init_func=build_server_model)

clients = p.clients
servers = p.servers
aggregators = p.aggregators

print(f"Number of nodes in the pool {len(p)}: {len(servers)} servers plus {len(clients)} clients. The server is also an aggregator")

We also implement the possibility of select a subsample of the clients in the training process.

In [None]:
#Filter clients
filtered_pool = p.filter(node_dropout=0.9)

clients = filtered_pool.clients
total_clients = p.clients
servers = p.servers
aggregators = p.aggregators


print(f"Number of nodes in the pool {len(p)}: {len(servers)} servers plus {len(clients)} clients. The server is also an aggregator")
print(f"Number of nodes in the pool {len(p)}: {len(servers)} servers plus {len(total_clients)} clients. The server is also an aggregator")


In [None]:
import copy

from flex.pool import deploy_server_model

@deploy_server_model
def copy_server_model_to_clients(server_flex_model: FlexModel):
    flex_model = FlexModel()
    for k, v in server_flex_model.items():
        flex_model[k] = copy.deepcopy(v)

    return flex_model

servers.map(copy_server_model_to_clients, clients)

In [None]:
from flex.data import Dataset

def train(client_flex_model: FlexModel, client_data: Dataset):
    client_flex_model.model.fit(client_data.X_data, client_data.y_data)

clients.map(train)

In [None]:
from flex.pool import collect_clients_weights

@collect_clients_weights
def get_clients_weights(client_flex_model: FlexModel):
    return client_flex_model.model.get_weights()

aggregators.map(get_clients_weights, clients)

In [None]:
from flex.pool import aggregate_weights
import numpy as np

@aggregate_weights
def aggregate(list_of_weights: list):
    return np.mean(np.array(list_of_weights), axis=0)

# Aggregate weights
aggregators.map(aggregate)

In [None]:
from flex.pool import set_aggregated_weights

@set_aggregated_weights
def set_agreggated_weights_to_server(server_flex_model: FlexModel, aggregated_weights):
    server_flex_model.model.set_weights(aggregated_weights) 

# Set aggregated weights in the server model
aggregators.map(set_agreggated_weights_to_server, servers)

In [None]:
from flex.pool import evaluate_server_model

from sklearn import metrics

def score_model(labels, preds):
    preds = np.argmax(preds, axis=1)
    accuracy = metrics.accuracy_score(labels, preds)
    recall = metrics.recall_score(labels, preds, average='macro')
    precision = metrics.precision_score(labels, preds, average='macro')
    f1 = metrics.f1_score(labels, preds, average='macro')

    results = accuracy, recall, precision, f1
    
    print("-------------------------------------")
    print("Accuracy: {},\nRecall: {},\nPrecision: {},\nMacro F1: {}".format(*results))
    print("-----------------------------------")
    

@evaluate_server_model
def evaluate_global_model(server_flex_model: FlexModel, test_data=None):
    preds = server_flex_model.model.predict(test_data.X_data)
    score_model(test_data.y_data, preds)

servers.map(evaluate_global_model, test_data=test_data)

### Run the federated learning experiment for a few rounds

Now, we can summarize the steps provided above and run the federated experiment for multiple rounds:

In [None]:
n_rounds = 5
p = FlexPool.client_server_architecture(
    flex_dataset, init_func=build_server_model
)

filtered_pool = p.filter(clients_dropout=0.8)

servers = p.servers
clients = filtered_pool.clients
aggregators = p.aggregators

for i in range(n_rounds):
    servers.map(copy_server_model_to_clients, clients)
    clients.map(train)
    aggregators.map(get_clients_weights, clients)
    aggregators.map(aggregate)
    aggregators.map(set_agreggated_weights_to_server, servers)

print("Federated model scores")
servers.map(evaluate_global_model, test_data=test_data)