# FLEXible tutorial: Text classification using Tensorflow

FLEXible is a library to federate models. We offer the tools to load and federate data or to load federated data, and the tools to create a federated environment. The user must define the model and the *communication primitives* to train the model in a federated environment. This primitives can be expressed in the following steps:
- initialization: Initialize the model in the server.
- deplot model: Deploy the model to the clients.
- training: Define the train function.
- collect the weights: Collect the weights of the clients params to aggregate them later.
- aggregate the weights: Use an aggregation method to aggregte the collected weights.
- deploy model: Deploy the model with the updated weights to the clients.
- evaluate: Define the evaluate function.

In this notebook, we show how to implement this primitives and how to use FLEXible in orther to federate a model using TensorFlow. In this way, we will train a model using multiple clients, but without sharing any data between clients. We will follow this [tutorial](https://www.tensorflow.org/hub/tutorials/tf2_text_classification#build_the_model) from the TensorFlow tutorials for text classification. 

## Setup

In [None]:
from copy import deepcopy
import numpy as np

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds

import matplotlib.pyplot as plt

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print("Hub version: ", hub.__version__)
print("GPU is", "available" if tf.config.list_physical_devices('GPU') else "NOT AVAILABLE")

In [None]:
from flex.data import FlexDataObject, FlexDataset, FlexDatasetConfig, FlexDataDistribution
from flex.pool import FlexPool, FlexModel

In [None]:
print(tf.__version__)

In [None]:
print(tfds.__version__)

## Download the IMBD dataset

As used in the tutorial from TensorFlow, we will use the IMBD dataset. This dataset contains reviews about movies, and the *sentiment* associated to them.

In [None]:
train_data, test_data = tfds.load(name="imdb_reviews", split=["train", "test"], 
                                    batch_size=-1, as_supervised=True)

train_examples, train_labels = tfds.as_numpy(train_data)
test_examples, test_labels = tfds.as_numpy(test_data)

In [None]:
print(f"Training entries: {len(train_examples)}, test entries: {len(test_examples)}")

# Create the FlexDataObject

As we are using a centrilized dataset, we have to federate it. To federate the data we need to create a basic data object for FLEXible that is called **FlexDataObject**. To create a  **FlexDataObject** we need to have the data as *numpy.arrays*.

In [None]:
flex_data = FlexDataObject(X_data=train_examples, y_data=train_labels)

To ensure that we created the **FlexDataObject**, we can validate it before federating it, but this step will be done later anyways. The validate function does not return anything, it raises error if there is a problem with the data.

In [None]:
flex_data.validate()

## Create the FlexDataset

Once we hace the FlexDataObject, we can federate the data. We can federate the data in multiple ways, to know more about how to do this, check the [tutorial for FlexDataset](https://github.com/FLEXible-FL/FLEX-framework/blob/main/notebooks/flex_dataset_demo.ipynb). In this example we will use the **FlexDataDistribution** to create an iid_distribution using the function *from_config*, that it's the one recommended for creating multiple ways of federating the data. 

In [None]:
config = FlexDatasetConfig(seed=0)
config.n_clients = 2
config.replacement = False # ensure that clients do not share any data
config.client_names = ['client1', 'client2']
# config.weights = [0.2] * config.n_clients # each client has only 20% of its assigned class
config.weights = None
flex_dataset = FlexDataDistribution.from_config(cdata=flex_data, config=config)

Also, we could just use the function *iid_distribution* from FlexDataDistribution, that uses the same configuration that we've just used.

In [None]:
# flex_dataset = FlexDataDistribution.iid_distribution(flex_data, n_clients=2)

## Create the architecture

### Generating the clients and the model to train.

Once we've federated the dataset, we have to create the FlexPool. The FlexPool class simulates a real-time scenario for federated learning, so it is in charge of the communications across the actors. The class FlexPool will assign to each actor a role (client, aggregator, server), so they can communicate during the training phase.

Please, check the notebook about the actors (TODO: Hacer notebook actores y sus relaciones) to know more about the actors and their relationships in FLEXible.

To create a Pool of actors, we need to have a federated dataset, like we've just done, and the model to initialize in the server side, because the server will send the model to the clients so they can train the model. As we have the federated dataset (flex_dataset), we will now create the model.

In this case, we will use a model from the tensorflow hub, so we dont have worry about the preprocessing for the text.

In [None]:
def initialize_server_model(flex_model, *args, **kwargs):
    print("Initializing model server.")
    # model = "https://tfhub.dev/google/nnlm-en-dim50/2" # Not working right now, but it's a lower model.
    model = "https://tfhub.dev/google/nnlm-en-dim128-with-normalization/2"
    hub_layer = hub.KerasLayer(model, input_shape=[], dtype=tf.string, trainable=True)
    model = tf.keras.Sequential()
    model.add(hub_layer)
    model.add(tf.keras.layers.Dense(16, activation='relu'))
    model.add(tf.keras.layers.Dense(1))
    model.compile(optimizer='adam',
                    loss=tf.losses.BinaryCrossentropy(from_logits=True),
                    metrics=[tf.metrics.BinaryAccuracy(threshold=0.0, name='accuracy')])
    flex_model['model'] = model

Note that we have compiled the model in the initialize function. This is recommended so we can use the model in the server for further evaluation.

In this tutorial we will follow the client-server architecture offered in the FlexDataDistribution.

In [None]:
flex_pool = FlexPool.client_server_architecture(fed_dataset=flex_dataset, init_func=initialize_server_model)

### Deploy model to clients

We have to create the function that will deploy the model to the clients. 

In [None]:
def deploy_model_to_clients(server_model, clients_model, *args, **kwargs):
    print("Initializing model at client.")
    for client_id in clients_model:
        clients_model[client_id] = deepcopy(server_model)

To work in an easier way, FlexPool let the use to have organized pools, such as clients, aggregators or servers. This helps to understand how we are connecting the actors.

In [None]:
clients = flex_pool.clients
server = flex_pool.servers

To apply all the primitives, such as the deploy step, we will use the **map** function from *FlexPool*. The map function works in the following way: the pool that calls the function map, is the one that will send a message to the destiny pool. If we don't specify it to any pool, no destiny pool, it will "send" the message to the same pool that it's calling the map function. This is needed if we want to tell the clients to train/evaluate the model.

In [None]:
server.map(deploy_model_to_clients, clients)

In [None]:
clients._actors.keys() # Check the clients that will participate in the training of the federated model.

### Train the clients models

One the model is deployed on the clients, is time to create the training function. As you can see, we use the *fit* function from the TensorFlow model, so we don't need to create it, as we may need in PyTorch.

In [None]:
def train(client_model, data, *args, **kwargs):
    print("Training model at client.")
    model = client_model['model']
    X_data = data.X_data
    y_data = data.y_data
    history = model.fit(X_data, y_data, epochs=kwargs['epochs'], batch_size=kwargs['batch_size'],
                verbose=1)

Now we will train the model in the clients side. We will use the *map function* to tell the clients to train the model, and, to do so, we just need to use this function from the clients pool.

In [None]:
clients.map(train, batch_size=512, epochs=10)


### Aggregate the models

Now that we have trained the model we have to aggregate the weights. To do so, clients will send the weights to the aggregator, and she will perform the aggregation told. For the tutorial, we will implement the FevAvg aggregation mechanism.

First, we select the aggregator

In [None]:
aggregator = flex_pool.aggregators



Before applying the FedAvg aggregation method, we have to collect all the parameters (or weights) from the clients models.

In [None]:
def collect_weights(client_model, aggregator_model, **kwargs):
    # Here the server and the aggregator are the same, so we need to take the ID from the server
    # to select the model.
    # As the server has a unique ID, we don't know the ID from the server till it's created, so we
    # need to take the ID in this way.
    if 'weights' not in aggregator_model["server"].keys():
        print("Aggregating weights.")
        aggregator_model["server"]['weights'] = []

    aggregator_model["server"]['weights'].append(client_model['model'].get_weights())

In [None]:
clients.map(collect_weights, aggregator)

Now we can aggregate the weights using the FedAvg method. Now that the aggregator has the aggregated weights, she should send it to the server, but, as server and aggregator are the same in our architecture, we will put this step with the fedavg method.

In [None]:
def fedavg_aggregation(agg_model, *args):
    # agg_model["weights"] = np.mean(np.array(agg_model['weights']), axis=0)
    agg_model["model"].set_weights(np.mean(np.array(agg_model['weights']), axis=0))
    del agg_model["weights"]

In [None]:
aggregator.map(fedavg_aggregation)

### Deploy and evaluate the model.

Now it's turn from the server to update the weights from the clients models and then evaluate the model.

In [None]:
def deploy_global_model_to_clients(server_model, clients_models, *args, **kwargs):
    print("Deploying the global model on the clients.")
    aggregated_weights = server_model['model'].get_weights()
    for client_model in clients_models:
        clients_models[client_model]['model'].set_weights(aggregated_weights)

In [None]:
server.map(deploy_global_model_to_clients, clients)

And now, we can evaluate the model with the test set that we prepared at the begining of the notebook.

In [None]:
def evaluate_model(model, data, *args, **kwargs):
    model = model['model']
    if data is not None:
        print("Evaluating model at client.")
        results_local = model.evaluate(data.X_data, data.y_data)
        print(f"Results at client on client's data: {results_local}")
    else:
        print("Evaluating model at server")
    results = model.evaluate(kwargs['test_examples'], kwargs['test_labels'])
    print(f"Results on test data: {results}")

In [None]:
server.map(evaluate_model, test_examples=test_examples, test_labels=test_labels)

In [None]:
clients.map(evaluate_model, test_examples=test_examples, test_labels=test_labels)

# Putting it all together

You just have trained a model for 1 round using FLEXible. Now, you could set up all together in a function and iterate for multiple rounds.

In [None]:
def train_n_rounds(n_rounds, batch_size, epochs):
    pool = FlexPool.client_server_architecture(fed_dataset=flex_dataset, init_func=initialize_server_model)
    pool.servers.map(deploy_model_to_clients, pool.clients)
    for i in range(n_rounds):
        print(f"\nRunning round: {i}\n")
        pool.clients.map(train, batch_size=batch_size, epochs=epochs)
        pool.clients.map(evaluate_model, test_examples=test_examples, test_labels=test_labels)
        pool.clients.map(collect_weights, pool.aggregators)
        pool.aggregators.map(aggregate_weights)
        pool.servers.map(deploy_global_model_to_clients, pool.clients)
        pool.servers.map(evaluate_model, test_examples=test_examples, test_labels=test_labels)

In [None]:
train_n_rounds(n_rounds=4, batch_size=512, epochs=10)

### END
Congratulations, now you know how to train a model using FLEXible for multiples rounds. Remember that it's important to first deploy/initialize the model on the clients, so you can run the rounds without problem!