In [None]:
from copy import deepcopy
import numpy as np

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds

import matplotlib.pyplot as plt

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print("Hub version: ", hub.__version__)
print("GPU is", "available" if tf.config.list_physical_devices('GPU') else "NOT AVAILABLE")

In [None]:
train_data, test_data = tfds.load(name="imdb_reviews", split=["train", "test"], 
                                    batch_size=-1, as_supervised=True)

In [None]:
from flex.data import FlexDataObject

flex_data = FlexDataObject.from_tfds_dataset(train_data)

In [None]:
from flex.data import FlexDatasetConfig, FlexDataDistribution

config = FlexDatasetConfig(seed=0)
config.n_clients = 2
config.replacement = False # ensure that clients do not share any data
config.client_names = ['client1', 'client2'] # Optional
flex_dataset = FlexDataDistribution.from_config(cdata=flex_data, config=config)

In [None]:
from flex.data import FlexDataDistribution

flex_dataset = FlexDataDistribution.iid_distribution(flex_data, n_clients=2)

# Primitive Functions

In [None]:
from flex.pool.primitive_functions import initialize_server_model
from flex.pool.primitive_functions import deploy_global_model_to_clients
from flex.pool.primitive_functions import deploy_model_to_clients
from flex.pool.primitive_functions import collect_weights
from flex.pool.primitive_functions import aggregate_weights
from flex.pool.primitive_functions import evaluate_model
from flex.pool.primitive_functions import train

In [None]:
# Defining the model

def define_model(*args):
    model = "https://tfhub.dev/google/nnlm-en-dim50-with-normalization/2"
    hub_layer = hub.KerasLayer(model, input_shape=[], dtype=tf.string, trainable=True)
    model = tf.keras.Sequential()
    model.add(hub_layer)
    model.add(tf.keras.layers.Dense(16, activation='relu'))
    model.add(tf.keras.layers.Dense(1))
    model.compile(optimizer='adam',
                    loss=tf.losses.BinaryCrossentropy(from_logits=True),
                    metrics=[tf.metrics.BinaryAccuracy(threshold=0.0, name='accuracy')])
    return model

In [None]:
from flex.pool import FlexPool

flex_pool = FlexPool.client_server_architecture(fed_dataset=flex_dataset, init_func=initialize_server_model, model=define_model, verbose=1, model_params=[])

In [None]:
clients = flex_pool.clients
server = flex_pool.servers
print(f"Server node is indentified by {server.actor_ids}")
print(f"Client nodes are identified by {clients.actor_ids}")

In [None]:
server.map(deploy_model_to_clients, clients, verbose=1)

In [None]:
clients.map(train, batch_size=512, epochs=2)

In [None]:
aggregator = flex_pool.aggregators


In [None]:
clients.map(collect_weights, aggregator)

In [None]:
def fed_avg(agg_model):
    return np.mean(np.array(agg_model), axis=0)

In [None]:
aggregator.map(aggregate_weights, verbose=1, func_aggregate=fed_avg)

In [None]:
server.map(deploy_global_model_to_clients, clients)

In [None]:
test_examples, test_labels = test_data

In [None]:
server.map(evaluate_model, test_examples=test_examples, test_labels=test_labels)

In [None]:
clients.map(evaluate_model, test_examples=test_examples, test_labels=test_labels)