In [1]:
import nest_asyncio
nest_asyncio.apply()

In [2]:
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
from tensorflow_federated.python.program import value_reference

np.random.seed(0)

In [3]:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

In [4]:
from tensorflow_federated.proto.v0 import computation_pb2 as pb


NUM_CLIENTS = 10
NUM_EPOCHS = 5
SHUFFLE_BUFFER = 100

def preprocess(dataset):

  def map_fn(element):
    return [tf.reshape(element['pixels'], [-1, 784]),
        tf.reshape(element['label'], [-1, 1])]

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER, seed=1).map(map_fn)

In [5]:
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])
preprocessed_example_dataset = preprocess(example_dataset)

In [6]:
def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])
  
def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

In [7]:
model = create_keras_model()

input_spec = (tf.TensorSpec(shape=(1,784), dtype=tf.float32, name=None), 
              tf.TensorSpec(shape=(1,1), dtype=tf.int32, name=None))
print(input_spec)
print(preprocessed_example_dataset.element_spec)

functional_model = tff.learning.models.functional_model_from_keras(keras_model=model, loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(),input_spec=input_spec)

def tff_model_fn() -> tff.learning.Model:
    return tff.learning.models.model_from_functional(functional_model)

(TensorSpec(shape=(1, 784), dtype=tf.float32, name=None), TensorSpec(shape=(1, 1), dtype=tf.int32, name=None))
(TensorSpec(shape=(1, 784), dtype=tf.float32, name=None), TensorSpec(shape=(1, 1), dtype=tf.int32, name=None))


In [8]:
tf.config.optimizer.set_experimental_options({'disable_meta_optimizer': True})

tff.learning.models.save_functional_model(functional_model=functional_model, path='tmp_dir')
saved_functional_model = tff.learning.models.load_functional_model(
                "tmp_dir"
            )

def saved_tff_model_fn() -> tff.learning.Model:
    
    return tff.learning.models.model_from_functional(saved_functional_model)



INFO:tensorflow:Assets written to: tmp_dir/assets


INFO:tensorflow:Assets written to: tmp_dir/assets


In [9]:
import os
OUTPUT_DIR = 'some_dir'
train_output_managers = [tff.program.LoggingReleaseManager()]
evaluation_output_managers = [tff.program.LoggingReleaseManager()]
model_output_manager = tff.program.LoggingReleaseManager()

# # there is an issue with this, it causes and error for some reason
# summary_dir = os.path.join(OUTPUT_DIR, "summary")
# tensorboard_manager = tff.program.TensorBoardReleaseManager(summary_dir)
# train_output_managers.append(tensorboard_manager)

# # there is an issue with this, it causes and error for some reason
# csv_path = os.path.join(OUTPUT_DIR, "evaluation_metrics.csv")
# csv_manager = tff.program.CSVFileReleaseManager(csv_path)
# evaluation_output_managers.append(csv_manager)

# # there is an issue with this, it causes and error for some reason
# program_state_dir = os.path.join(OUTPUT_DIR, "program_state")
# program_state_manager = tff.program.FileProgramStateManager(program_state_dir)


In [10]:
#emnist_train, emnist_test
preprocessed_example_dataset

<MapDataset element_spec=(TensorSpec(shape=(1, 784), dtype=tf.float32, name=None), TensorSpec(shape=(1, 1), dtype=tf.int32, name=None))>

In [11]:
number_of_clients = 3 
total_rounds = 10
train_datasets = [preprocess(emnist_train.create_tf_dataset_for_client(i)) for i in emnist_train.client_ids]
test_datasets = [preprocess(emnist_test.create_tf_dataset_for_client(i)) for i in emnist_test.client_ids]
train_data_source = tff.program.DatasetDataSource(train_datasets)
evaluation_data_source = tff.program.DatasetDataSource(test_datasets)

In [15]:
import functools
from tensorflow_federated.python.program import value_reference


async def tff_train_federated(
    initialize: tff.Computation,
    train: tff.Computation,
    train_data_source: tff.program.FederatedDataSource,
    evaluation: tff.Computation,
    evaluation_data_source: tff.program.FederatedDataSource,
    total_rounds: int,
    number_of_clients: int,
    train_output_managers,
    evaluation_output_managers,
    model_output_manager: tff.program.ReleaseManager,
    program_state_manager: tff.program.ProgramStateManager,
) -> None:
    tff.program.check_in_federated_context()

    # The program state manager is not working but maybe we could drop it
    if program_state_manager is not None:
        structure = initialize()
        program_state, version = await program_state_manager.load_latest(structure)
    else:
        program_state = None

    if program_state is not None:

        state, start_round = program_state
    else:
        state = initialize()
        start_round = 1

    # state = initialize()
    # start_round = 1
    async with tff.async_utils.ordered_tasks() as tasks:

        train_data_iterator = train_data_source.iterator()

        for round_number in range(start_round, total_rounds + 1):
            tasks.add_callable(
                functools.partial(
                    print, f"Running round {round_number} of training" 
                )
            )

            train_data = train_data_iterator.select(number_of_clients)
            output = train(state, train_data)
            state = output.state
            metrics = output.metrics

            # if train_output_managers is not None:
            #     _, metrics_type = train.type_signature.result
                # tasks.add_all(
                #     *[m.release(metrics, metrics_type, round_number) for m in train_output_managers]
                # )
                # materialized_value = await value_reference.materialize_value(metrics)

                # tasks.add_callable(
                #     functools.partial(
                #         print, str(materialized_value) 
                #     )
                # )

            # This is not working
            # if program_state_manager is not None:
            #     program_state = (state, start_round)
            #     tasks.add(program_state_manager.save(program_state, round_number))

        # evaluation_data_iterator = evaluation_data_source.iterator()
        # evaluation_data = evaluation_data_iterator.select(number_of_clients)
        # evaluation_metrics = evaluation(state, evaluation_data)

        # if evaluation_output_managers is not None:
        #     evaluation_metrics_type = evaluation.type_signature.result
        #     tasks.add_all(*[
        #         m.release(evaluation_metrics, evaluation_metrics_type, round_number)
        #         for m in train_output_managers
        #     ])

        if model_output_manager is not None:
            state_type, _ = train.type_signature.result
            tasks.add(model_output_manager.release(state, state_type))
    state = await value_reference.materialize_value(state)
    
    return state


In [16]:
# FUNCTIONAL KERAS MODEL FROM A SAVED MODEL
import asyncio
context = tff.backends.native.create_local_async_python_execution_context()
context = tff.program.NativeFederatedContext(context)
tff.framework.set_default_context(context)

iterative_process = tff.learning.algorithms.build_unweighted_fed_avg(
    saved_tff_model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))
initialize = iterative_process.initialize
train = iterative_process.next
evaluation = tff.learning.build_federated_evaluation(saved_tff_model_fn)

state = asyncio.run(
        tff_train_federated(
            initialize=initialize,
            train=train,
            train_data_source=train_data_source,
            evaluation=evaluation,
            evaluation_data_source=evaluation_data_source,
            total_rounds=total_rounds,
            number_of_clients=number_of_clients,
            train_output_managers=train_output_managers,
            evaluation_output_managers=evaluation_output_managers,
            model_output_manager=model_output_manager,
            program_state_manager=None,
            # program_state_manager=program_state_manager,
        )
    )

  metrics = output.metrics
  metrics = output.metrics


Running round 1 of training
Running round 2 of training
Running round 3 of training
Running round 4 of training
Running round 5 of training
Running round 6 of training
Running round 7 of training
Running round 8 of training
Running round 9 of training
Running round 10 of training


  state = asyncio.run(


In [17]:
from tensorflow_federated.python.learning.model_utils import ModelWeights 

new_weights = ModelWeights(list(state.global_model_weights.trainable),list(state.global_model_weights.non_trainable))

In [26]:

new_saved_functional_model = tff.learning.models.load_functional_model(
                "tmp_dir"
            )
new_model = tff.learning.models.model_from_functional(new_saved_functional_model)
# new_model = create_keras_model()
state.global_model_weights.assign_weights_to(new_model)
state.global_model_weights.non_trainable

()