# FLEXible tutorial: MNIST classification using Tensorflow

FLEXible is a library to federate models. We offer the tools to load and federate data or to load federated data, and the tools to create a federated environment. The user can define the model and the *communication primitives* to train the model in a federated environment, but we already offer some simple functions that let the user to build an fast and easy experiment. This primitives can be expressed in the following steps:
- initialization: Initialize the model in the server.
- deploy model: Deploy the model to the clients.
- training: Define the train function.
- collect the weights: Collect the weights of the clients params to aggregate them later.
- aggregate the weights: Use an aggregation method to aggregte the collected weights.
- deploy model: Deploy the model with the updated weights to the clients.
- evaluate: Define the evaluate function.

In this notebook, we show how to use the defined primitive functions, letting the user the implementation of some key functions:
- Define the model to train: It's necessary to tell server and clients which model will be trained.
- Aggregator method: In this notebook we will implement FedAvg as the aggregation function.

Note that the primitive functions that we offer are basics functions, as we assume how the federated learning training will be. If you want to do a more customizable training loop, please check the notebook flex_text_classifiication_tensorflow_demo, as we show there how to implement the primitive functions from scrach. We will follow this [tutorial](https://www.tensorflow.org/tutorials/quickstart/beginner?hl=es-419) from Tensorfllow 2.0 Guide for begginers. 

In [10]:
import tensorflow as tf

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print("GPU is", "available" if tf.config.list_physical_devices('GPU') else "NOT AVAILABLE")

Version:  2.10.0
Eager mode:  True
GPU is NOT AVAILABLE


In [11]:
from flex.data import FlexDataDistribution

flex_dataset, test_data =  FlexDataDistribution.load_femnist(return_test=True)


# Primitive Functions

In [12]:
from flex.pool.flex_primitives import init_server_model_tf
from flex.pool.flex_primitives import deploy_server_model_tf
from flex.pool.flex_primitives import collect_clients_weights_tf
from flex.pool.flex_primitives import train_tf
from flex.pool.flex_primitives import set_aggregated_weights_tf
from flex.pool.flex_primitives import evaluate_server_model_tf
from flex.pool.flex_aggregators import fed_avg

In [13]:
# Defining the model

def define_model(**kargs):
    model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
    ])

    model.compile(optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])

    return model

In [14]:
model = define_model()

In [15]:
from flex.pool import FlexPool

flex_pool = FlexPool.client_server_architecture(fed_dataset=flex_dataset, init_func=init_server_model_tf, model=define_model())

In [16]:
clients = flex_pool.clients
server = flex_pool.servers
print(f"Server node is indentified by {server.actor_ids}")
print(f"Client nodes are identified by {clients.actor_ids}")

Server node is indentified by ['server']
Client nodes are identified by ['client_450', 'client_2113', 'client_3232', 'client_3861', 'client_396', 'client_1029', 'client_2048', 'client_3149', 'client_3502', 'client_266', 'client_2506', 'client_2560', 'client_3510', 'client_3184', 'client_1733', 'client_3203', 'client_2458', 'client_1265', 'client_997', 'client_840', 'client_1520', 'client_2130', 'client_675', 'client_2355', 'client_491', 'client_1817', 'client_1782', 'client_1361', 'client_274', 'client_542', 'client_1773', 'client_1670', 'client_2431', 'client_382', 'client_360', 'client_1401', 'client_1280', 'client_757', 'client_423', 'client_3167', 'client_1143', 'client_2234', 'client_1167', 'client_613', 'client_1302', 'client_3158', 'client_663', 'client_747', 'client_1351', 'client_394', 'client_1461', 'client_850', 'client_1859', 'client_3198', 'client_1563', 'client_3806', 'client_2137', 'client_0', 'client_2002', 'client_3328', 'client_715', 'client_1905', 'client_1772', 'cli

In [17]:
server.map(deploy_server_model_tf, clients)

In [18]:
clients.map(train_tf, batch_size=512, epochs=1)



In [19]:
aggregator = flex_pool.aggregators


In [20]:
aggregator.map(collect_clients_weights_tf, clients)

In [21]:
aggregator.map(fed_avg)

  return np.mean(np.array(aggregated_weights_as_list), axis=0)


In [22]:
aggregator.map(set_aggregated_weights_tf, server)

In [23]:
server.map(deploy_server_model_tf, clients)

In [25]:
from flex.data import FlexDataObject

test_examples, test_labels = test_data.X_data, test_data.y_data

In [26]:
server.map(evaluate_server_model_tf, test_data=test_examples, test_labels=test_labels)



[[108.77302551269531, 0.19272500276565552]]

In [27]:
def train_n_rounds(n_rounds, batch_size, epochs):  
    pool = FlexPool.client_server_architecture(fed_dataset=flex_dataset, init_func=init_server_model_tf, model=define_model())
    pool.servers.map(deploy_server_model_tf, pool.clients)
    for i in range(n_rounds):
        print(f"\nRunning round: {i}\n")
        pool.clients.map(train_tf, batch_size=batch_size, epochs=epochs)
        pool.aggregators.map(collect_clients_weights_tf, pool.clients)
        pool.aggregators.map(fed_avg)
        pool.aggregators.map(set_aggregated_weights_tf, pool.servers)
        pool.servers.map(deploy_server_model_tf, pool.clients)
        pool.servers.map(evaluate_server_model_tf, test_data=test_examples, test_labels=test_labels)

In [28]:
train_n_rounds(n_rounds=2, batch_size=512, epochs=5)


Running round: 0

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Epoch 1/5
Epoch 2/5
Epoch 3/5
E