In [1]:
import nest_asyncio
nest_asyncio.apply()

In [2]:
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

In [3]:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

In [4]:
from tensorflow_federated.proto.v0 import computation_pb2 as pb


NUM_CLIENTS = 10
NUM_EPOCHS = 5
SHUFFLE_BUFFER = 100

def preprocess(dataset):

  def map_fn(element):
    return collections.OrderedDict(
        x=tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1]))

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER, seed=1).map(map_fn)

In [5]:
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])
preprocessed_example_dataset = preprocess(example_dataset)

In [6]:
class TestDataBackend(tff.framework.DataBackend):

  async def materialize(self, data, type_spec):
    client_id = int(data.uri[-1])
    client_dataset = emnist_train.create_tf_dataset_for_client(
        emnist_train.client_ids[client_id])
    return preprocess(client_dataset)

In [7]:
def ex_fn(
    device: tf.config.LogicalDevice) -> tff.framework.DataExecutor:
  return tff.framework.DataExecutor(
      tff.framework.EagerTFExecutor(device),
      data_backend=TestDataBackend())
factory = tff.framework.local_executor_factory(leaf_executor_fn=ex_fn)
ctx = tff.framework.ExecutionContext(executor_fn=factory)
tff.framework.set_default_context(ctx)

In [8]:
def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])
  
def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

In [9]:
iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

state = iterative_process.initialize()

In [10]:
element_type = tff.types.StructWithPythonType(
    preprocessed_example_dataset.element_spec,
    container_type=collections.OrderedDict)
dataset_type = tff.types.SequenceType(element_type)
dataset_type_proto = tff.framework.serialize_type(dataset_type)
# Sampling the first five clients in the EMNIST dataset.
arguments = [
  pb.Computation(data=pb.Data(uri=f'uri://{i}'), type=dataset_type_proto)
  for i in range(5)
]
data_handle = tff.framework.DataDescriptor(None, arguments, 
  tff.FederatedType(dataset_type, tff.CLIENTS), len(arguments))

In [11]:
state, metrics = iterative_process.next(state, data_handle)
print('round 1, metrics={}'.format(metrics))

round 1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.11625), ('loss', 12.6826515), ('num_examples', 2400), ('num_batches', 2400)]))])


In [12]:
NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
  state, metrics = iterative_process.next(state, data_handle)
  print('round {:2d}, metrics={}'.format(round_num, metrics))

round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.12375), ('loss', 10.283684), ('num_examples', 2400), ('num_batches', 2400)]))])
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.17791666), ('loss', 7.744127), ('num_examples', 2400), ('num_batches', 2400)]))])
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.275), ('loss', 5.888526), ('num_examples', 2400), ('num_batches', 2400)]))])
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.38708332), ('loss', 4.557935), ('num_exam

In [13]:
preprocessed_example_dataset.element_spec

OrderedDict([('x', TensorSpec(shape=(1, 784), dtype=tf.float32, name=None)),
             ('y', TensorSpec(shape=(1, 1), dtype=tf.int32, name=None))])

In [28]:
model = create_keras_model()
print(model)
input_spec = [tf.TensorSpec(shape=(1,784), dtype=tf.float32, name=None), tf.TensorSpec(shape=(1,1), dtype=tf.int32, name=None)]
functional_model = tff.learning.models.functional_model_from_keras(keras_model=model, loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(),input_spec=input_spec)
print(functional_model.input_spec[0])
tff.learning.models.save_functional_model(functional_model=functional_model, path='tmp.txt')


<keras.engine.sequential.Sequential object at 0x7f3124301ac0>
TensorSpec(shape=(1, 784), dtype=tf.float32, name=None)




INFO:tensorflow:Assets written to: tmp.txt/assets


INFO:tensorflow:Assets written to: tmp.txt/assets


In [29]:
import shutil
shutil.make_archive('tmp.zip', 'zip', 'tmp.txt')

'/home/teo/PySyft/notebooks/PySyTFF/tmp.zip.zip'

In [None]:
import zipfile
with zipfile.ZipFile('tmp.zip', 'r') as zip_ref:
    zip_ref.extractall(directory_to_extract_to)

In [31]:
functional_model_reloaded = tff.learning.models.load_functional_model('tmp.txt')
functional_model_reloaded

<tensorflow_federated.python.learning.models.serialization._LoadedFunctionalModel at 0x7f3144578850>