# FLEXible tutorial: Text classification using Tensorflow

FLEXible is a library to federate models. We offer the tools to load and federate data or to load federated data, and the tools to create a federated environment. The user can define the model and the *communication primitives* to train the model in a federated environment, but we already offer some simple functions that let the user to build an fast and easy experiment. This primitives can be expressed in the following steps:
- initialization: Initialize the model in the server.
- deploy model: Deploy the model to the clients.
- training: Define the train function.
- collect the weights: Collect the weights of the clients params to aggregate them later.
- aggregate the weights: Use an aggregation method to aggregte the collected weights.
- deploy model: Deploy the model with the updated weights to the clients.
- evaluate: Define the evaluate function.

In this notebook, we show how to use the defined primitive functions, letting the user the implementation of some key functions:
- Define the model to train: It's necessary to tell server and clients which model will be trained.
- Aggregator method: In this notebook we will implement FedAvg as the aggregation function.

Note that the primitive functions that we offer are basics functions, as we assume how the federated learning training will be. If you want to do a more customizable training loop, please check the notebook flex_text_classifiication_tensorflow_demo, as we show there how to implement the primitive functions from scrach. We will follow this [tutorial](https://www.tensorflow.org/hub/tutorials/tf2_text_classification#build_the_model) from the TensorFlow tutorials for text classification. 

In [None]:
import numpy as np

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print("Hub version: ", hub.__version__)
print("GPU is", "available" if tf.config.list_physical_devices('GPU') else "NOT AVAILABLE")

In [None]:
train_data, test_data = tfds.load(name="imdb_reviews", split=["train", "test"])

In [None]:
from flex.data import Dataset

flex_data = Dataset.from_tfds_text_dataset(train_data, X_columns='text', label_columns='label')

In [None]:
from flex.data import FedDatasetConfig, FedDataDistribution

config = FedDatasetConfig(seed=0)
config.n_clients = 2
config.replacement = False # ensure that clients do not share any data
config.client_names = ['client1', 'client2'] # Optional
flex_dataset = FedDataDistribution.from_config(cdata=flex_data, config=config)

In [None]:
from flex.data import FedDataDistribution

flex_dataset = FedDataDistribution.iid_distribution(flex_data, n_clients=2)

# Primitive Functions

In [None]:
from flex.pool.primitives import init_server_model_tf
from flex.pool.primitives import deploy_server_model_tf
from flex.pool.primitives import collect_clients_weights_tf
from flex.pool.primitives import train_tf
from flex.pool.primitives import set_aggregated_weights_tf
from flex.pool.primitives import evaluate_server_model_tf
from flex.pool.aggregators import fed_avg

In [None]:
# Defining the model

def define_model(**kargs):
    model = "https://tfhub.dev/google/nnlm-en-dim50/2"
    hub_layer = hub.KerasLayer(model, input_shape=[], dtype=tf.string, trainable=True)
    model = tf.keras.Sequential()
    model.add(hub_layer)
    model.add(tf.keras.layers.Dense(16, activation='relu'))
    model.add(tf.keras.layers.Dense(1))
    model.compile(optimizer='adam',
                    loss=tf.losses.BinaryCrossentropy(from_logits=True),
                    metrics=[tf.metrics.BinaryAccuracy(threshold=0.0, name='accuracy')])
    return model

In [None]:
model = define_model()

In [None]:
from flex.pool import FlexPool

flex_pool = FlexPool.client_server_architecture(fed_dataset=flex_dataset, init_func=init_server_model_tf, model=define_model())

In [None]:
clients = flex_pool.clients
server = flex_pool.servers
print(f"Server node is indentified by {server.actor_ids}")
print(f"Client nodes are identified by {clients.actor_ids}")

In [None]:
server.map(deploy_server_model_tf, clients)

In [None]:
clients.map(train_tf, batch_size=512, epochs=1)

In [None]:
aggregator = flex_pool.aggregators


In [None]:
aggregator.map(collect_clients_weights_tf, clients)

In [None]:
aggregator.map(fed_avg)

In [None]:
aggregator.map(set_aggregated_weights_tf, server)

In [None]:
server.map(deploy_server_model_tf, clients)

In [None]:
test_data = Dataset.from_tfds_text_dataset(test_data, X_columns='text', label_columns='label')
test_examples, test_labels = test_data.X_data, test_data.y_data

In [None]:
server.map(evaluate_server_model_tf, test_data=test_examples, test_labels=test_labels)

# Putting it all together

In [None]:
def train_n_rounds(n_rounds, batch_size, epochs):  
    pool = FlexPool.client_server_architecture(fed_dataset=flex_dataset, init_func=init_server_model_tf, model=define_model())
    pool.servers.map(deploy_server_model_tf, pool.clients)
    for i in range(n_rounds):
        print(f"\nRunning round: {i}\n")
        pool.clients.map(train_tf, batch_size=batch_size, epochs=epochs)
        pool.aggregators.map(collect_clients_weights_tf, pool.clients)
        pool.aggregators.map(fed_avg)
        pool.aggregators.map(set_aggregated_weights_tf, pool.servers)
        pool.servers.map(deploy_server_model_tf, pool.clients)
        pool.servers.map(evaluate_server_model_tf, test_data=test_examples, test_labels=test_labels)

In [None]:
train_n_rounds(n_rounds=2, batch_size=512, epochs=5)

# END