In [1]:
import nest_asyncio
nest_asyncio.apply()

In [2]:
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

In [3]:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

In [4]:
from tensorflow_federated.proto.v0 import computation_pb2 as pb


NUM_CLIENTS = 10
NUM_EPOCHS = 5
SHUFFLE_BUFFER = 100

def preprocess(dataset):

  def map_fn(element):
    return [tf.reshape(element['pixels'], [-1, 784]),
        tf.reshape(element['label'], [-1, 1])]

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER, seed=1).map(map_fn)

In [5]:
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])
preprocessed_example_dataset = preprocess(example_dataset)

In [6]:
def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])
  
def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

In [7]:
model = create_keras_model()

input_spec = (tf.TensorSpec(shape=(1,784), dtype=tf.float32, name=None), 
              tf.TensorSpec(shape=(1,1), dtype=tf.int32, name=None))
print(input_spec)
print(preprocessed_example_dataset.element_spec)

functional_model = tff.learning.models.functional_model_from_keras(keras_model=model, loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(),input_spec=input_spec)

def tff_model_fn() -> tff.learning.Model:
    return tff.learning.models.model_from_functional(functional_model)

(TensorSpec(shape=(1, 784), dtype=tf.float32, name=None), TensorSpec(shape=(1, 1), dtype=tf.int32, name=None))
(TensorSpec(shape=(1, 784), dtype=tf.float32, name=None), TensorSpec(shape=(1, 1), dtype=tf.int32, name=None))


In [8]:
tff.learning.models.save_functional_model(functional_model=functional_model, path='tmp_dir')
saved_functional_model = tff.learning.models.load_functional_model(
                "tmp_dir"
            )
def saved_tff_model_fn() -> tff.learning.Model:
    
    return tff.learning.models.model_from_functional(saved_functional_model)



INFO:tensorflow:Assets written to: tmp_dir/assets


INFO:tensorflow:Assets written to: tmp_dir/assets


In [9]:
import os
OUTPUT_DIR = 'some_dir'
train_output_managers = [tff.program.LoggingReleaseManager()]
evaluation_output_managers = [tff.program.LoggingReleaseManager()]
model_output_manager = tff.program.LoggingReleaseManager()

# # there is an issue with this, it causes and error for some reason
# summary_dir = os.path.join(OUTPUT_DIR, "summary")
# tensorboard_manager = tff.program.TensorBoardReleaseManager(summary_dir)
# train_output_managers.append(tensorboard_manager)

# # there is an issue with this, it causes and error for some reason
# csv_path = os.path.join(OUTPUT_DIR, "evaluation_metrics.csv")
# csv_manager = tff.program.CSVFileReleaseManager(csv_path)
# evaluation_output_managers.append(csv_manager)

# # there is an issue with this, it causes and error for some reason
# program_state_dir = os.path.join(OUTPUT_DIR, "program_state")
# program_state_manager = tff.program.FileProgramStateManager(program_state_dir)


In [10]:
#emnist_train, emnist_test
preprocessed_example_dataset

<MapDataset element_spec=(TensorSpec(shape=(1, 784), dtype=tf.float32, name=None), TensorSpec(shape=(1, 1), dtype=tf.int32, name=None))>

In [11]:
number_of_clients = 3 
total_rounds = 10
train_datasets = [preprocess(emnist_train.create_tf_dataset_for_client(i)) for i in emnist_train.client_ids]
test_datasets = [preprocess(emnist_test.create_tf_dataset_for_client(i)) for i in emnist_test.client_ids]
train_data_source = tff.program.DatasetDataSource(train_datasets)
evaluation_data_source = tff.program.DatasetDataSource(test_datasets)

In [12]:
import functools
from tensorflow_federated.python.program import value_reference


async def tff_train_federated(
    initialize: tff.Computation,
    train: tff.Computation,
    train_data_source: tff.program.FederatedDataSource,
    evaluation: tff.Computation,
    evaluation_data_source: tff.program.FederatedDataSource,
    total_rounds: int,
    number_of_clients: int,
    train_output_managers,
    evaluation_output_managers,
    model_output_manager: tff.program.ReleaseManager,
    program_state_manager: tff.program.ProgramStateManager,
) -> None:
    tff.program.check_in_federated_context()

    # The program state manager is not working but maybe we could drop it
    if program_state_manager is not None:
        structure = initialize()
        program_state, version = await program_state_manager.load_latest(structure)
    else:
        program_state = None

    if program_state is not None:

        state, start_round = program_state
    else:
        state = initialize()
        start_round = 1

    # state = initialize()
    # start_round = 1
    async with tff.async_utils.ordered_tasks() as tasks:

        train_data_iterator = train_data_source.iterator()

        for round_number in range(start_round, total_rounds + 1):
            tasks.add_callable(
                functools.partial(
                    print, f"Running round {round_number} of training" 
                )
            )

            train_data = train_data_iterator.select(number_of_clients)
            output = train(state, train_data)
            state = output.state
            metrics = output.metrics

            if train_output_managers is not None:
                _, metrics_type = train.type_signature.result
                # tasks.add_all(
                #     *[m.release(metrics, metrics_type, round_number) for m in train_output_managers]
                # )
                materialized_value = await value_reference.materialize_value(metrics)

                tasks.add_callable(
                    functools.partial(
                        print, str(materialized_value) 
                    )
                )

            # This is not working
            # if program_state_manager is not None:
            #     program_state = (state, start_round)
            #     tasks.add(program_state_manager.save(program_state, round_number))

        evaluation_data_iterator = evaluation_data_source.iterator()
        evaluation_data = evaluation_data_iterator.select(number_of_clients)
        evaluation_metrics = evaluation(state, evaluation_data)

        # if evaluation_output_managers is not None:
        #     evaluation_metrics_type = evaluation.type_signature.result
        #     tasks.add_all(*[
        #         m.release(evaluation_metrics, evaluation_metrics_type, round_number)
        #         for m in train_output_managers
        #     ])

        # if model_output_manager is not None:
        #     state_type, _ = train.type_signature.result
        #     tasks.add(model_output_manager.release(state, state_type))



In [13]:
# KERAS MODEL 
import asyncio
context = tff.backends.native.create_local_async_python_execution_context()
context = tff.program.NativeFederatedContext(context)
tff.framework.set_default_context(context)

iterative_process = tff.learning.algorithms.build_unweighted_fed_avg(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))
initialize = iterative_process.initialize
train = iterative_process.next
evaluation = tff.learning.build_federated_evaluation(model_fn)
import logging
logger = logging.getLogger()

logger.setLevel(logging.INFO)
logging.info("Test")

asyncio.run(
        tff_train_federated(
            initialize=initialize,
            train=train,
            train_data_source=train_data_source,
            evaluation=evaluation,
            evaluation_data_source=evaluation_data_source,
            total_rounds=total_rounds,
            number_of_clients=number_of_clients,
            train_output_managers=train_output_managers,
            evaluation_output_managers=evaluation_output_managers,
            model_output_manager=model_output_manager,
            program_state_manager=None,
            # program_state_manager=program_state_manager,
        )
    )

INFO:root:Test
  element = _create_structure_of_coro_references(element_coro, element_type)
  element = _create_structure_of_coro_references(element_coro, element_type)
  element_coro = _get_item(shared_awaitable, index)
  return _create_structure_of_coro_references(coro, type_signature.member)


Running round 1 of training
OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.101123594), ('loss', 12.856089), ('num_examples', 1335), ('num_batches', 1335)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_count', ())])), ('finalizer', ())])
Running round 2 of training
OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.12541254), ('loss', 9.309718), ('num_examples', 1515), ('num_batches', 1515)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_count', ())])), ('finalizer', ())])
Running round 3 of training
OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.16557863), ('loss', 8.364625), ('num_examples', 1685), ('num_batches', 1685)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_count', ())])), ('finalizer', ())])
Running round 4 of training
O

  result = coro.send(None)
  result = coro.send(None)
  result = coro.send(None)


In [14]:
# FUNCTIONAL KERAS MODEL 
import asyncio
context = tff.backends.native.create_local_async_python_execution_context()
context = tff.program.NativeFederatedContext(context)
tff.framework.set_default_context(context)

iterative_process = tff.learning.algorithms.build_unweighted_fed_avg(
    tff_model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))
initialize = iterative_process.initialize
train = iterative_process.next
evaluation = tff.learning.build_federated_evaluation(tff_model_fn)

asyncio.run(
        tff_train_federated(
            initialize=initialize,
            train=train,
            train_data_source=train_data_source,
            evaluation=evaluation,
            evaluation_data_source=evaluation_data_source,
            total_rounds=total_rounds,
            number_of_clients=number_of_clients,
            train_output_managers=train_output_managers,
            evaluation_output_managers=evaluation_output_managers,
            model_output_manager=model_output_manager,
            program_state_manager=None,
            # program_state_manager=program_state_manager,
        )
    )

Running round 1 of training
OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('loss', 11.786694)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_count', ())])), ('finalizer', ())])
Running round 2 of training
OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('loss', 9.795738)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_count', ())])), ('finalizer', ())])
Running round 3 of training
OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('loss', 5.985588)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_count', ())])), ('finalizer', ())])
Running round 4 of training
OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('loss', 5.388061)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_count', ())])), ('finalizer', ())])
Running round 5 of training
OrderedDict([('distributor', ()), ('client_work

In [16]:
# FUNCTIONAL KERAS MODEL FROM A SAVED MODEL
import asyncio
context = tff.backends.native.create_local_async_python_execution_context()
context = tff.program.NativeFederatedContext(context)
tff.framework.set_default_context(context)
import logging
logger = logging.getLogger()

logger.setLevel(logging.WARN)
iterative_process = tff.learning.algorithms.build_unweighted_fed_avg(
    saved_tff_model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))
initialize = iterative_process.initialize
train = iterative_process.next
evaluation = tff.learning.build_federated_evaluation(saved_tff_model_fn)

asyncio.run(
        tff_train_federated(
            initialize=initialize,
            train=train,
            train_data_source=train_data_source,
            evaluation=evaluation,
            evaluation_data_source=evaluation_data_source,
            total_rounds=total_rounds,
            number_of_clients=number_of_clients,
            train_output_managers=train_output_managers,
            evaluation_output_managers=evaluation_output_managers,
            model_output_manager=model_output_manager,
            program_state_manager=None,
            # program_state_manager=program_state_manager,
        )
    )

Running round 1 of training


InvalidArgumentError: Attempting to add a duplicate function with name: __inference_flat_forward_pass_934 where the previous and current definitions differ. Previous definition: signature {
  name: "__inference_flat_forward_pass_934"
  input_arg {
    name: "identity"
    type: DT_FLOAT
  }
  input_arg {
    name: "identity_1"
    type: DT_FLOAT
  }
  input_arg {
    name: "batch_input"
    type: DT_FLOAT
  }
  input_arg {
    name: "batch_input_1"
    type: DT_INT32
  }
  output_arg {
    name: "identity_2"
    type: DT_FLOAT
  }
  output_arg {
    name: "identity_3"
    type: DT_FLOAT
  }
  output_arg {
    name: "identity_4"
    type: DT_INT32
  }
}
node_def {
  name: "PartitionedCall"
  op: "PartitionedCall"
  input: "identity"
  input: "identity_1"
  input: "batch_input"
  input: "batch_input_1"
  attr {
    key: "Tin"
    value {
      list {
        type: DT_FLOAT
        type: DT_FLOAT
        type: DT_FLOAT
        type: DT_INT32
      }
    }
  }
  attr {
    key: "Tout"
    value {
      list {
        type: DT_FLOAT
        type: DT_FLOAT
        type: DT_INT32
      }
    }
  }
  attr {
    key: "_collective_manager_ids"
    value {
      list {
      }
    }
  }
  attr {
    key: "_read_only_resource_inputs"
    value {
      list {
      }
    }
  }
  attr {
    key: "config"
    value {
      s: ""
    }
  }
  attr {
    key: "config_proto"
    value {
      s: "\n\007\n\003CPU\020\001\n\007\n\003GPU\020\0012\005*\0010J\0008\001\202\001\000"
    }
  }
  attr {
    key: "executor_type"
    value {
      s: ""
    }
  }
  attr {
    key: "f"
    value {
      func {
        name: "__inference_forward_pass_922"
      }
    }
  }
}
node_def {
  name: "Identity_2"
  op: "Identity"
  input: "PartitionedCall:output:0"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
node_def {
  name: "Identity_3"
  op: "Identity"
  input: "PartitionedCall:output:1"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
node_def {
  name: "Identity_4"
  op: "Identity"
  input: "PartitionedCall:output:2"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
ret {
  key: "identity_2"
  value: "Identity_2:output:0"
}
ret {
  key: "identity_3"
  value: "Identity_3:output:0"
}
ret {
  key: "identity_4"
  value: "Identity_4:output:0"
}
attr {
  key: "_construction_context"
  value {
    s: "kEagerRuntime"
  }
}
arg_attr {
  key: 0
  value {
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 784
            }
            dim {
              size: 10
            }
          }
        }
      }
    }
    attr {
      key: "_user_specified_name"
      value {
        s: "Identity"
      }
    }
  }
}
arg_attr {
  key: 1
  value {
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 10
            }
          }
        }
      }
    }
    attr {
      key: "_user_specified_name"
      value {
        s: "Identity_1"
      }
    }
  }
}
arg_attr {
  key: 2
  value {
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
            dim {
              size: 784
            }
          }
        }
      }
    }
    attr {
      key: "_user_specified_name"
      value {
        s: "batch_input"
      }
    }
  }
}
arg_attr {
  key: 3
  value {
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
            dim {
              size: 1
            }
          }
        }
      }
    }
    attr {
      key: "_user_specified_name"
      value {
        s: "batch_input"
      }
    }
  }
}
 and current definition: signature {
  name: "__inference_flat_forward_pass_934"
  input_arg {
    name: "identity"
    type: DT_FLOAT
  }
  input_arg {
    name: "identity_1"
    type: DT_FLOAT
  }
  input_arg {
    name: "batch_input"
    type: DT_FLOAT
  }
  input_arg {
    name: "batch_input_1"
    type: DT_INT32
  }
  output_arg {
    name: "identity_2"
    type: DT_FLOAT
  }
  output_arg {
    name: "identity_3"
    type: DT_FLOAT
  }
  output_arg {
    name: "identity_4"
    type: DT_INT32
  }
}
node_def {
  name: "PartitionedCall"
  op: "PartitionedCall"
  input: "identity"
  input: "identity_1"
  input: "batch_input"
  input: "batch_input_1"
  attr {
    key: "Tin"
    value {
      list {
        type: DT_FLOAT
        type: DT_FLOAT
        type: DT_FLOAT
        type: DT_INT32
      }
    }
  }
  attr {
    key: "Tout"
    value {
      list {
        type: DT_FLOAT
        type: DT_FLOAT
        type: DT_INT32
      }
    }
  }
  attr {
    key: "_collective_manager_ids"
    value {
      list {
      }
    }
  }
  attr {
    key: "_read_only_resource_inputs"
    value {
      list {
      }
    }
  }
  attr {
    key: "config"
    value {
      s: ""
    }
  }
  attr {
    key: "config_proto"
    value {
      s: "\n\007\n\003CPU\020\001\n\007\n\003GPU\020\0012\005*\0010J\0008\001R\005R\003\230\001\001\202\001\000"
    }
  }
  attr {
    key: "executor_type"
    value {
      s: ""
    }
  }
  attr {
    key: "f"
    value {
      func {
        name: "__inference_forward_pass_922"
      }
    }
  }
}
node_def {
  name: "Identity_2"
  op: "Identity"
  input: "PartitionedCall:output:0"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
node_def {
  name: "Identity_3"
  op: "Identity"
  input: "PartitionedCall:output:1"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
node_def {
  name: "Identity_4"
  op: "Identity"
  input: "PartitionedCall:output:2"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
ret {
  key: "identity_2"
  value: "Identity_2:output:0"
}
ret {
  key: "identity_3"
  value: "Identity_3:output:0"
}
ret {
  key: "identity_4"
  value: "Identity_4:output:0"
}
attr {
  key: "_construction_context"
  value {
    s: "kEagerRuntime"
  }
}
arg_attr {
  key: 0
  value {
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 784
            }
            dim {
              size: 10
            }
          }
        }
      }
    }
    attr {
      key: "_user_specified_name"
      value {
        s: "Identity"
      }
    }
  }
}
arg_attr {
  key: 1
  value {
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 10
            }
          }
        }
      }
    }
    attr {
      key: "_user_specified_name"
      value {
        s: "Identity_1"
      }
    }
  }
}
arg_attr {
  key: 2
  value {
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
            dim {
              size: 784
            }
          }
        }
      }
    }
    attr {
      key: "_user_specified_name"
      value {
        s: "batch_input"
      }
    }
  }
}
arg_attr {
  key: 3
  value {
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
            dim {
              size: 1
            }
          }
        }
      }
    }
    attr {
      key: "_user_specified_name"
      value {
        s: "batch_input"
      }
    }
  }
}
