# FLEXible tutorial: MNIST classification using Tensorflow

FLEXible is a library to federate models. We offer the tools to load and federate data or to load federated data, and the tools to create a federated environment. The user can define the model and the *communication primitives* to train the model in a federated environment, but we already offer some simple functions that let the user to build an fast and easy experiment. This primitives can be expressed in the following steps:
- initialization: Initialize the model in the server.
- deploy model: Deploy the model to the clients.
- training: Define the train function.
- collect the weights: Collect the weights of the clients params to aggregate them later.
- aggregate the weights: Use an aggregation method to aggregte the collected weights.
- deploy model: Deploy the model with the updated weights to the clients.
- evaluate: Define the evaluate function.

In this notebook, we show how to use the defined primitive functions, letting the user the implementation of some key functions:
- Define the model to train: It's necessary to tell server and clients which model will be trained.
- Aggregator method: In this notebook we will implement FedAvg as the aggregation function.

Note that the primitive functions that we offer are basics functions, as we assume how the federated learning training will be. If you want to do a more customizable training loop, please check the notebook flex_text_classification_tensorflow_demo, as we show there how to implement the primitive functions from scrach. We will follow this [tutorial](https://www.tensorflow.org/tutorials/quickstart/beginner?hl=es-419) from Tensorfllow 2.0 Guide for begginers. 

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf
print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print("GPU is", "available" if tf.config.list_physical_devices('GPU') else "NOT AVAILABLE")

In [None]:
from flex.datasets import load

flex_dataset, test_data = load("federated_emnist", return_test=True, split="digits")

In [None]:
from flex.pool import init_server_model
from flex.pool import FlexPool
from flex.model import FlexModel
from copy import deepcopy

@init_server_model
def build_server_model():
    server_flex_model = FlexModel()

    server_flex_model["model"] = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    server_flex_model["model"].compile(optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])

    # Required to copy this model in later stages of the FL training process
    server_flex_model["optimizer"] = deepcopy(server_flex_model["model"].optimizer)
    server_flex_model["loss"] = deepcopy(server_flex_model["model"].loss)
    server_flex_model["metrics"] = deepcopy(server_flex_model["model"].compiled_metrics._metrics)

    return server_flex_model

flex_pool = FlexPool.client_server_architecture(flex_dataset, init_func=build_server_model)

clients = flex_pool.clients
servers = flex_pool.servers
aggregators = flex_pool.aggregators

print(f"Number of nodes in the pool {len(flex_pool)}: {len(servers)} server plus {len(clients)} clients. The server is also an aggregator")

We also implement the possibility of select a subsample of the clients in the training process.

In [None]:
#Filter clients
clients_per_round=20
node_dropout = 1-(clients_per_round/len(clients))
selected_clients_pool = clients.filter(node_dropout=node_dropout)
selected_clients = selected_clients_pool.clients

print(f"Server node is indentified by key \"{servers.actor_ids[0]}\"")
print(f"Selected {len(selected_clients.actor_ids)} client nodes of a total of {len(clients.actor_ids)}")

In [None]:
from flex.pool import deploy_server_model

@deploy_server_model
def copy_server_model_to_clients(server_flex_model: FlexModel):
    flex_model = FlexModel()

    flex_model["model"] = tf.keras.models.clone_model(server_flex_model["model"])
    weights = server_flex_model["model"].get_weights()
    flex_model["model"].set_weights(weights)

    flex_model["model"].compile(
        optimizer=server_flex_model["optimizer"],
        loss=server_flex_model["loss"],
        metrics=server_flex_model["metrics"],
    )
    return flex_model


servers.map(copy_server_model_to_clients, selected_clients)

In [None]:
from flex.data import Dataset

def train(client_flex_model: FlexModel, client_data: Dataset):
    client_flex_model["model"].fit(client_data.X_data, 
                                    client_data.y_data,
                                    epochs=5,
                                    batch_size=512,
                                    verbose=False)

selected_clients.map(train)

In [None]:
from flex.pool import collect_clients_weights

@collect_clients_weights
def get_clients_weights(client_flex_model: FlexModel):
    return client_flex_model.model.get_weights()

aggregators.map(get_clients_weights, selected_clients)

In [None]:
from flex.pool import aggregate_weights
import numpy as np

@aggregate_weights
def aggregate_with_fedavg(list_of_weights: list):
    return np.mean(np.array(list_of_weights, dtype=object), axis=0)

# Aggregate weights
aggregators.map(aggregate_with_fedavg)

In [None]:
from flex.pool import set_aggregated_weights

@set_aggregated_weights
def set_agreggated_weights_to_server(server_flex_model: FlexModel, aggregated_weights):
    server_flex_model.model.set_weights(aggregated_weights) 

aggregators.map(set_agreggated_weights_to_server, servers)

In [None]:
from flex.pool import evaluate_server_model

@evaluate_server_model
def evaluate_global_model(server_flex_model: FlexModel, test_data=None):
    loss, acc = server_flex_model.model.evaluate(test_data.X_data, test_data.y_data, verbose=False)
    print(f"Test acc {acc:.4f}, test loss {loss:.4f}")

servers.map(evaluate_global_model, test_data=test_data)

### Run the federated learning experiment for a few rounds

Now, we can summarize the steps provided above and run the federated experiment for multiple rounds:

In [None]:
def train_n_rounds(n_rounds, clients_per_round=20):  
    pool = FlexPool.client_server_architecture(fed_dataset=flex_dataset, init_func=build_server_model)
    for i in range(n_rounds):
        print(f"\nRunning round: {i+1} of {n_rounds+1}")
        node_dropout = 1-(clients_per_round/len(pool.clients))
        selected_clients_pool = pool.clients.filter(node_dropout=node_dropout)
        selected_clients = selected_clients_pool.clients
        print(f"Selected clients for this round: {len(selected_clients)}")
        # Deploy the server model to the selected clients
        pool.servers.map(copy_server_model_to_clients, selected_clients)
        # Each selected client trains her model
        selected_clients.map(train)
        # The aggregador collects weights from the selected clients and aggregates them
        pool.aggregators.map(get_clients_weights, selected_clients)
        pool.aggregators.map(aggregate_with_fedavg)
        # The aggregator send its aggregated weights to the server
        pool.aggregators.map(set_agreggated_weights_to_server, pool.servers)
        pool.servers.map(evaluate_global_model, test_data=test_data)

In [None]:
train_n_rounds(5)