# FLEXible tutorial: MNIST classification using Tensorflow

FLEXible is a library to federate models. We offer the tools to load and federate data or to load federated data, and the tools to create a federated environment. The user can define the model and the *communication primitives* to train the model in a federated environment, but we already offer decorators so that an advancer user can implement its own federated workflow. We design python decorators to handle the following federated learning flows:
- initialization: Initialize the model in the server.
- deploy model: Deploy the model to the clients.
- training: Define the train function.
- collect the weights: Collect the weights of the clients params to aggregate them later.
- aggregate the weights: Use an aggregation method to aggregte the collected weights.
- deploy model: Deploy the model with the updated weights to the clients.
- evaluate: Define the evaluate function.

In this notebook, we show how to use decorators, to implement advanced federated learning concepts.

If these tools are not low-level enough, try creating your own decorators or use directly FLEXible at low-level [here](./flex_text_classification_tensorflow_demo.ipynb) .

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf
print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print("GPU is", "available" if tf.config.list_physical_devices('GPU') else "NOT AVAILABLE")

In [None]:
from flex.datasets import load

flex_dataset, test_data = load("federated_emnist", return_test=True, split="digits")

`@init_server_model` is a decorator designed to perform the initialization of the server model in a client-server architecture. It has no requirements for specific arguments in the function that uses it.

In [None]:
from flex.pool import init_server_model
from flex.pool import FlexPool
from flex.model import FlexModel
from copy import deepcopy

@init_server_model
def build_server_model():
    server_flex_model = FlexModel()

    server_flex_model["model"] = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    server_flex_model["model"].compile(optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])

    # Required to copy this model in later stages of the FL training process
    server_flex_model["optimizer"] = deepcopy(server_flex_model["model"].optimizer)
    server_flex_model["loss"] = deepcopy(server_flex_model["model"].loss)
    server_flex_model["metrics"] = deepcopy(server_flex_model["model"].compiled_metrics._metrics)

    return server_flex_model

flex_pool = FlexPool.client_server_pool(flex_dataset, init_func=build_server_model)

clients = flex_pool.clients
servers = flex_pool.servers
aggregators = flex_pool.aggregators

print(f"Number of nodes in the pool {len(flex_pool)}: {len(servers)} server plus {len(clients)} clients. The server is also an aggregator")

We also implement the possibility of select a subsample of the clients in the training process.

In [None]:
#Select clients
clients_per_round=20
selected_clients_pool = clients.select(clients_per_round)
selected_clients = selected_clients_pool.clients

print(f"Server node is indentified by key \"{servers.actor_ids[0]}\"")
print(f"Selected {len(selected_clients.actor_ids)} client nodes of a total of {len(clients.actor_ids)}")

`@deploy_server_model` is a decorator designed to copy the model from the server to the clients at each federated learning round. The function that uses it, must have at least one argument, which is the FlexModel object that stores the model at the server.

In [None]:
from flex.pool import deploy_server_model

@deploy_server_model
def copy_server_model_to_clients(server_flex_model: FlexModel):
    flex_model = FlexModel()

    flex_model["model"] = tf.keras.models.clone_model(server_flex_model["model"])
    weights = server_flex_model["model"].get_weights()
    flex_model["model"].set_weights(weights)

    flex_model["model"].compile(
        optimizer=server_flex_model["optimizer"],
        loss=server_flex_model["loss"],
        metrics=server_flex_model["metrics"],
    )
    return flex_model


servers.map(copy_server_model_to_clients, selected_clients)

Suprisingly, there is no decorator for the training process as it can be imnplemented directly.

In [None]:
from flex.data import Dataset

def train(client_flex_model: FlexModel, client_data: Dataset):
    X, y = client_data.to_numpy()
    client_flex_model["model"].fit(X, 
                                    y,
                                    epochs=5,
                                    batch_size=512,
                                    verbose=False)

selected_clients.map(train)

`@collect_clients_weights` as it name says, it collects weights from a set of clients, the function that uses it must have at least one argument, the FlexModel from each client, and it is expected to return the weights of her model.

In [None]:
from flex.pool import collect_clients_weights

@collect_clients_weights
def get_clients_weights(client_flex_model: FlexModel):
    return client_flex_model["model"].get_weights()

aggregators.map(get_clients_weights, selected_clients)

`@aggregate_weights` simplifies the process of aggregating and the function using it expects at least one argument, a list that contains the weights collected in the step before using `@collect_clients_weights`.

In [None]:
from flex.pool import aggregate_weights
import tensorly as tl

tl.set_backend('tensorflow')

@aggregate_weights
def aggregate_with_fedavg(list_of_weights: list):
    agg_weights = []
    for layer_index in range(len(list_of_weights[0])):
        weights_per_layer = tl.stack([weights[layer_index] for weights in list_of_weights])
        agg_layer = tl.mean(weights_per_layer, axis=0)
        agg_weights.append(agg_layer)
    return agg_weights

# Aggregate weights
aggregators.map(aggregate_with_fedavg)

`@set_aggregated_weights` is designed as a setter, and it sets the aggregated weights from the aggregator to the server. The function that uses it expects at least two arguments, the FlexModel at the server and the aggregated weights as returned in the last step.

In [None]:
from flex.pool import set_aggregated_weights

@set_aggregated_weights
def set_agreggated_weights_to_server(server_flex_model: FlexModel, aggregated_weights):
    server_flex_model["model"].set_weights(aggregated_weights) 

aggregators.map(set_agreggated_weights_to_server, servers)

`@evaluate_server_model` is coded to test the server model using external data. The function that uses it must have at least one argument, the FlexModel at the server.

In [None]:
from flex.pool import evaluate_server_model

@evaluate_server_model
def evaluate_global_model(server_flex_model: FlexModel, test_data=None):
    X, y = test_data.to_numpy()
    return server_flex_model["model"].evaluate(X, y, verbose=False)

metrics = servers.map(evaluate_global_model, test_data=test_data)
print(metrics[0])

### Run the federated learning experiment for a few rounds

Now, we can summarize the steps provided above and run the federated experiment for multiple rounds:

In [None]:
def train_n_rounds(n_rounds, clients_per_round=20):  
    pool = FlexPool.client_server_pool(fed_dataset=flex_dataset, init_func=build_server_model)
    for i in range(n_rounds):
        print(f"\nRunning round: {i+1} of {n_rounds}")
        selected_clients_pool = pool.clients.select(clients_per_round)
        selected_clients = selected_clients_pool.clients
        print(f"Selected clients for this round: {len(selected_clients)}")
        # Deploy the server model to the selected clients
        pool.servers.map(copy_server_model_to_clients, selected_clients)
        # Each selected client trains her model
        selected_clients.map(train)
        # The aggregador collects weights from the selected clients and aggregates them
        pool.aggregators.map(get_clients_weights, selected_clients)
        pool.aggregators.map(aggregate_with_fedavg)
        # The aggregator send its aggregated weights to the server
        pool.aggregators.map(set_agreggated_weights_to_server, pool.servers)
        metrics = pool.servers.map(evaluate_global_model, test_data=test_data)
        loss, acc = metrics[0]
        print(f"Server: Test acc: {acc:.4f}, test loss: {loss:.4f}")

In [None]:
train_n_rounds(5)