In [None]:
import numpy as np

import tensorflow as tf
import tensorflow_hub as hub

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print("Hub version: ", hub.__version__)
print(
    "GPU is", "available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE"
)

In [None]:
import tensorflow_datasets as tfds

train_data, test_data = tfds.load(name="imdb_reviews", split=["train", "test"])

In [None]:
from flex.data import Dataset

flex_data = Dataset.from_tfds_text_dataset(
    train_data, X_columns=["text"], label_columns=["label"]
)

In [None]:
from flex.data import FedDatasetConfig, FedDataDistribution

config = FedDatasetConfig(seed=0)
config.n_nodes = 2
config.replacement = False  # ensure that clients do not share any data
config.node_ids = ["client1", "client2"]  # Optional
flex_dataset = FedDataDistribution.from_config(
    centralized_data=flex_data, config=config
)

In [None]:
from flex.data import FedDataDistribution

flex_dataset = FedDataDistribution.iid_distribution(flex_data, n_nodes=2)
# Assign test data to server_id
server_id = "server"
flex_dataset[server_id] = Dataset.from_tfds_text_dataset(
    test_data, X_columns=["text"], label_columns=["label"]
)

In [None]:
from flex.pool.decorators import init_server_model
from flex.pool.decorators import collect_clients_weights
from flex.pool.decorators import aggregate_weights
from flex.pool.decorators import deploy_server_model
from flex.pool.decorators import set_aggregated_weights
from flex.model import FlexModel

In [None]:
from copy import deepcopy


# Defining the model
@init_server_model
def define_model(*args):
    # model = "https://tfhub.dev/google/nnlm-en-dim110-with-normalization/2"
    model = "https://tfhub.dev/google/nnlm-en-dim50/2"
    hub_layer = hub.KerasLayer(model, input_shape=[], dtype=tf.string, trainable=True)
    model = tf.keras.Sequential()
    model.add(hub_layer)
    model.add(tf.keras.layers.Dense(16, activation="relu"))
    model.add(tf.keras.layers.Dense(1))
    model.compile(
        optimizer="adam",
        loss=tf.losses.BinaryCrossentropy(from_logits=True),
        metrics=[tf.metrics.BinaryAccuracy(threshold=0.0, name="accuracy")],
    )
    server_flex_model = FlexModel()
    server_flex_model["model"] = model
    server_flex_model["loss"] = deepcopy(model.loss)
    server_flex_model["metrics"] = deepcopy(model.compiled_metrics._metrics)
    server_flex_model["optimizer"] = deepcopy(model.optimizer)
    return server_flex_model

In [None]:
from flex.pool import FlexPool

flex_pool = FlexPool.client_server_pool(
    fed_dataset=flex_dataset, server_id=server_id, init_func=define_model
)

In [None]:
clients = flex_pool.clients
server = flex_pool.servers
print(f"Server node is indentified by {server.actor_ids}")
print(f"Client nodes are identified by {clients.actor_ids}")

In [None]:
from flex.model import FlexModel


@deploy_server_model
def copy_model_to_clients(server_flex_model):
    client_flex_model = FlexModel()
    weights = server_flex_model["model"].get_weights()
    model = tf.keras.models.clone_model(server_flex_model["model"])
    model.set_weights(weights)
    model.compile(
        optimizer=server_flex_model["optimizer"],
        loss=server_flex_model["loss"],
        metrics=server_flex_model["metrics"],
    )
    client_flex_model["model"] = model
    return client_flex_model

In [None]:
server.map(copy_model_to_clients, clients)

In [None]:
def fit_tf(client_flex_model, client_data, *args, **kwargs):
    X_data, y_data = client_data.to_numpy()
    client_flex_model["model"].fit(X_data, y_data, *args, **kwargs)

In [None]:
clients.map(fit_tf, batch_size=512, epochs=2)

In [None]:
aggregator = flex_pool.aggregators
aggregator.actor_ids

In [None]:
@collect_clients_weights
def tensorflow_weights_collector(client_model):
    return client_model["model"].get_weights()

In [None]:
aggregator.map(tensorflow_weights_collector, clients)

In [None]:
@aggregate_weights
def fed_avg(agg_model):
    return np.mean(np.array(agg_model, dtype=object), axis=0)

In [None]:
aggregator.map(fed_avg)

In [None]:
@set_aggregated_weights
def set_aggregated_weights_tf(server_flex_model, aggregated_weights, *args, **kwargs):
    server_flex_model["model"].set_weights(aggregated_weights)

In [None]:
aggregator.map(set_aggregated_weights_tf, server)

In [None]:
server.map(copy_model_to_clients, clients)

In [None]:
def evaluate_model(flex_model: FlexModel, data: Dataset):
    X_test, y_test = data.to_numpy()
    results = flex_model["model"].evaluate(X_test, y_test, verbose=False)
    print(f"Results for node_id {flex_model.actor_id} : {results}")

In [None]:
server.map(evaluate_model)

In [None]:
clients.map(evaluate_model)

# Putting it all together

In [None]:
def train_n_rounds(n_rounds, batch_size, epochs):
    pool = FlexPool.client_server_pool(
        fed_dataset=flex_dataset, server_id=server_id, init_func=define_model
    )
    pool.servers.map(copy_model_to_clients, pool.clients)
    for i in range(n_rounds):
        print(f"\nRunning round: {i}\n")
        pool.clients.map(fit_tf, batch_size=batch_size, epochs=epochs)
        pool.clients.map(evaluate_model)
        pool.aggregators.map(tensorflow_weights_collector, pool.clients)
        pool.aggregators.map(fed_avg)
        pool.aggregators.map(set_aggregated_weights_tf, pool.servers)
        pool.servers.map(copy_model_to_clients, pool.clients)
        print(pool.servers.map(evaluate_model))

In [None]:
train_n_rounds(n_rounds=2, batch_size=512, epochs=5)

# END