In [None]:
from copy import deepcopy
import numpy as np

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds

import matplotlib.pyplot as plt

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print("Hub version: ", hub.__version__)
print("GPU is", "available" if tf.config.list_physical_devices('GPU') else "NOT AVAILABLE")

In [None]:
train_data, test_data = tfds.load(name="imdb_reviews", split=["train", "test"], 
                                    batch_size=-1, as_supervised=True)

In [None]:
from flex.data import FlexDataObject

flex_data = FlexDataObject.from_tfds_dataset(train_data)

In [None]:
from flex.data import FlexDatasetConfig, FlexDataDistribution

config = FlexDatasetConfig(seed=0)
config.n_clients = 2
config.replacement = False # ensure that clients do not share any data
config.client_names = ['client1', 'client2'] # Optional
flex_dataset = FlexDataDistribution.from_config(cdata=flex_data, config=config)

In [None]:
from flex.data import FlexDataDistribution

flex_dataset = FlexDataDistribution.iid_distribution(flex_data, n_clients=2)

In [None]:
from flex.pool.flex_primitives import deploy_model_to_clients
from flex.pool.flex_primitives import evaluate_model

from flex.pool.flex_decorators import init_server_model_decorator
from flex.pool.flex_decorators import collector_decorator
from flex.pool.flex_decorators import aggregator_decorator_tf
from flex.pool.flex_decorators import train_decorator
from flex.pool.flex_decorators import deploy_decorator

In [None]:
# Defining the model
@init_server_model_decorator
def define_model(*args):
    model = "https://tfhub.dev/google/nnlm-en-dim110-with-normalization/2"
    hub_layer = hub.KerasLayer(model, input_shape=[], dtype=tf.string, trainable=True)
    model = tf.keras.Sequential()
    model.add(hub_layer)
    model.add(tf.keras.layers.Dense(16, activation='relu'))
    model.add(tf.keras.layers.Dense(1))
    model.compile(optimizer='adam',
                    loss=tf.losses.BinaryCrossentropy(from_logits=True),
                    metrics=[tf.metrics.BinaryAccuracy(threshold=0.0, name='accuracy')])
    return model

In [None]:
from flex.pool import FlexPool

flex_pool = FlexPool.client_server_architecture(fed_dataset=flex_dataset, init_func=define_model, verbose=1, model_params=[])

In [None]:
clients = flex_pool.clients
server = flex_pool.servers
print(f"Server node is indentified by {server.actor_ids}")
print(f"Client nodes are identified by {clients.actor_ids}")

In [None]:
server.map(deploy_model_to_clients, clients, verbose=1)

In [None]:
@train_decorator
def fit_tf(model, X_data, y_data, *args, **kwargs):
    model.fit(X_data, y_data, *args, **kwargs)

In [None]:
clients.map(fit_tf, batch_size=512, epochs=2)

In [None]:
aggregator = flex_pool.aggregators
aggregator.actor_ids

In [None]:
@collector_decorator
def tensorflow_weights_collector(client_model):
    return client_model.get_weights()

In [None]:
clients.map(tensorflow_weights_collector, aggregator)

In [None]:
@aggregator_decorator_tf
def fed_avg(agg_model):
    return np.mean(np.array(agg_model), axis=0)

In [None]:
aggregator.map(fed_avg, verbose=1)

In [None]:
@deploy_decorator
def deploy_global_model(server_model, clients_models):
    aggregated_weights = server_model.get_weights()
    for client_model in clients_models:
        clients_models[client_model]["model"].set_weights(aggregated_weights)

In [None]:
server.map(deploy_global_model, clients)

In [None]:
test_examples, test_labels = test_data

In [None]:
server.map(evaluate_model, test_examples=test_examples, test_labels=test_labels)

In [None]:
clients.map(evaluate_model, test_examples=test_examples, test_labels=test_labels)

# Putting it all together

In [None]:
def train_n_rounds(n_rounds, batch_size, epochs):
    pool = FlexPool.client_server_architecture(fed_dataset=flex_dataset, init_func=define_model, verbose=1, model_params=[])
    pool.servers.map(deploy_model_to_clients, pool.clients, verbose=1)
    for i in range(n_rounds):
        print(f"\nRunning round: {i}\n")
        pool.clients.map(fit_tf, batch_size=batch_size, epochs=epochs, verbose=1)
        pool.clients.map(evaluate_model, test_examples=test_examples, test_labels=test_labels)
        pool.clients.map(tensorflow_weights_collector, pool.aggregators, verbose=1)
        pool.aggregators.map(fed_avg, verbose=1)
        pool.servers.map(deploy_global_model, pool.clients)
        pool.servers.map(evaluate_model, test_examples=test_examples, test_labels=test_labels)

In [None]:
train_n_rounds(n_rounds=4, batch_size=512, epochs=10)

# END