In [20]:
import tensorflow as tf

tf.compat.v1.enable_v2_behavior()

In [21]:
import tensorflow_datasets as tfds

first_50_percent = tfds.Split.TRAIN.subsplit(tfds.percent[:50])
last_50_percent = tfds.Split.TRAIN.subsplit(tfds.percent[-50:])

alice_dataset = tfds.load("fashion_mnist", split=first_50_percent)
bob_dataset = tfds.load("fashion_mnist", split=last_50_percent)

In [22]:
import collections

BATCH_SIZE = 32


def cast(element):
    """Casts an image's pixels into float32."""
    out = {}
    out["image"] = tf.image.convert_image_dtype(element["image"], dtype=tf.float32)
    out["label"] = element["label"]
    return out


def flatten(element):
    """Flattens an image in preparation for the neural network."""
    return collections.OrderedDict(
        [
            ("x", tf.reshape(element["image"], [-1])),
            ("y", tf.reshape(element["label"], [1])),
        ]
    )


def preprocess(dataset):
    """Preprocesses images to be fed into neural network."""
    return dataset.map(cast).map(flatten).batch(BATCH_SIZE)

In [23]:
preprocessed_alice_dataset = preprocess(alice_dataset)
preprocessed_bob_dataset = preprocess(bob_dataset)

In [24]:
federated_data = [preprocessed_alice_dataset, preprocessed_bob_dataset]

In [25]:
from tensorflow.python.keras.optimizer_v2 import gradient_descent

LEARNING_RATE = 0.02


def custom_loss_function(y_true, y_pred):
    """Custom loss function."""
    return tf.reduce_mean(
        tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred)
    )


def create_compiled_keras_model():
    """Compiles the keras model."""
    model = tf.keras.models.Sequential(
        [
            tf.keras.layers.Dense(
                10,
                activation=tf.nn.softmax,
                kernel_initializer="zeros",
                input_shape=(784,),
            )
        ]
    )
    model.compile(
        loss=custom_loss_function,
        optimizer=gradient_descent.SGD(learning_rate=LEARNING_RATE),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
    )
    return model

In [26]:
batch_of_samples = tf.contrib.framework.nest.map_structure(
    lambda x: x.numpy(), iter(preprocessed_alice_dataset).next()
)


def model_instance():
    """Instantiates the keras model."""
    keras_model = create_compiled_keras_model()
    return tff.learning.from_compiled_keras_model(keras_model, batch_of_samples)

In [27]:
from tensorflow_federated import python as tff

federated_learning_iterative_process = tff.learning.build_federated_averaging_process(
    model_instance
)
state = federated_learning_iterative_process.initialize()
state, performance = federated_learning_iterative_process.next(state, federated_data)

In [28]:
performance