# FLEXible tutorial: MNIST classification using Tensorflow

FLEXible is a library to federate models. We offer the tools to load and federate data or to load federated data, and the tools to create a federated environment. The user can define the model and the *communication primitives* to train the model in a federated environment, but we already offer some simple functions that let the user to build an fast and easy experiment. This primitives can be expressed in the following steps:
- initialization: Initialize the model in the server.
- deploy model: Deploy the model to the clients.
- training: Define the train function.
- collect the weights: Collect the weights of the clients params to aggregate them later.
- aggregate the weights: Use an aggregation method to aggregte the collected weights.
- deploy model: Deploy the model with the updated weights to the clients.
- evaluate: Define the evaluate function.

In this notebook, we show how to use the defined primitive functions, letting the user the implementation of some key functions:
- Define the model to train: It's necessary to tell server and clients which model will be trained.
- Aggregator method: In this notebook we will implement FedAvg as the aggregation function.

Note that the primitive functions that we offer are basics functions, as we assume how the federated learning training will be. If you want to do a more customizable training loop, please check the notebook "Federated MNIST TF example with flexible decorators", as we show there how to implement the primitive functions from scrach. We will follow this [tutorial](https://www.tensorflow.org/tutorials/quickstart/beginner?hl=es-419) from Tensorfllow 2.0 Guide for begginers. 

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf
print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print("GPU is", "available" if tf.config.list_physical_devices('GPU') else "NOT AVAILABLE")

In [None]:
from flex.datasets import load

flex_dataset, test_data = load("federated_emnist", return_test=True, split="digits")

# Primitive Functions

In [None]:
from flex.pool import FlexPool
from flex.pool import init_server_model_tf

# Defining the model
def get_model():
    model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
    ])

    model.compile(optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])
    return model

flex_pool = FlexPool.client_server_pool(fed_dataset=flex_dataset, init_func=init_server_model_tf, model=get_model())

clients = flex_pool.clients
server = flex_pool.servers
aggregator = flex_pool.aggregators

print(f"Number of nodes in the pool {len(flex_pool)}: {len(server)} server plus {len(clients)} clients. The server is also an aggregator")

In [None]:
#Select clients
clients_per_round=20
selected_clients_pool = clients.select(clients_per_round)
selected_clients = selected_clients_pool.clients

print(f"Server node is indentified by key \"{server.actor_ids[0]}\"")
print(f"Selected {len(selected_clients.actor_ids)} client nodes of a total of {len(clients.actor_ids)}")

In [None]:
from flex.pool import deploy_server_model_tf

server.map(deploy_server_model_tf, selected_clients)

In [None]:
from flex.pool import train_tf

selected_clients.map(train_tf, batch_size=512, epochs=1, verbose=False)

In [None]:
from flex.pool import collect_clients_weights_tf

aggregator = flex_pool.aggregators
aggregator.map(collect_clients_weights_tf, selected_clients)

In [None]:
from flex.pool.aggregators import fed_avg

aggregator.map(fed_avg)

In [None]:
from flex.pool import set_aggregated_weights_tf

aggregator.map(set_aggregated_weights_tf, server)

In [None]:
from flex.pool import deploy_server_model_tf

server.map(deploy_server_model_tf, selected_clients)

In [None]:
from flex.pool import evaluate_server_model_tf

test_examples, test_labels = test_data.to_numpy()
metrics = server.map(evaluate_server_model_tf, test_data=test_examples, test_labels=test_labels)
loss, acc = metrics[0]
print(f"Test acc: {acc:.4f}, test loss: {loss:.4f}")

# Run the federated learning experiment for a few rounds

Now, we can summarize the steps provided above and run the federated experiment for multiple rounds:

In [None]:
def train_n_rounds(n_rounds, batch_size, epochs, clients_per_round=20):  
    pool = FlexPool.client_server_pool(fed_dataset=flex_dataset, init_func=init_server_model_tf, model=get_model())
    for i in range(n_rounds):
        print(f"\nRunning round: {i+1} of {n_rounds}")
        selected_clients_pool = pool.clients.select(clients_per_round)
        selected_clients = selected_clients_pool.clients
        print(f"Selected clients for this round: {len(selected_clients)}")
        # Deploy the server model to the selected clients
        pool.servers.map(deploy_server_model_tf, selected_clients)
        # Each selected client trains her model
        selected_clients.map(train_tf, batch_size=batch_size, epochs=epochs, verbose=False)
        # The aggregador collects weights from the selected clients and aggregates them
        pool.aggregators.map(collect_clients_weights_tf, selected_clients)
        pool.aggregators.map(fed_avg)
        # The aggregator send its aggregated weights to the server
        pool.aggregators.map(set_aggregated_weights_tf, pool.servers)
        metrics = pool.servers.map(evaluate_server_model_tf, test_data=test_examples, test_labels=test_labels)
        loss, acc = metrics[0]
        print(f"Server: Test acc: {acc:.4f}, test loss: {loss:.4f}")

In [None]:
train_n_rounds(n_rounds=5, batch_size=512, epochs=5, clients_per_round=20)