In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [2]:
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

tf.compat.v1.enable_v2_behavior()

np.random.seed(0)

tff.federated_computation(lambda: 'Hello, World!')()

b'Hello, World!'

In [31]:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

In [4]:
emnist_train.client_ids[:10]

['f0000_14',
 'f0001_41',
 'f0005_26',
 'f0006_12',
 'f0008_45',
 'f0011_13',
 'f0014_19',
 'f0016_39',
 'f0017_07',
 'f0022_10']

In [5]:
emnist_train.element_type_structure

OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)),
             ('pixels',
              TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])

In [6]:
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

example_element = next(iter(example_dataset))

example_element['pixels'].numpy().shape

(28, 28)

In [7]:
%matplotlib inline
from matplotlib import pyplot as plt
plt.imshow(example_element['pixels'].numpy(), cmap='gray', aspect='equal')
plt.grid(False)
plt.show()

<Figure size 640x480 with 1 Axes>

In [8]:
NUM_CLIENTS = 65
NUM_EPOCHS = 5
BATCH_SIZE = 64
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER=10

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch `pixels` and return the features as an `OrderedDict`."""
    return collections.OrderedDict(
        x=tf.reshape(element['pixels'],[-1, 28,28,1]),#tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1]))

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER).batch(
      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

In [37]:
print(example_dataset)
print(example_dataset.repeat(2))

<TensorSliceDataset shapes: OrderedDict([(label, ()), (pixels, (28, 28))]), types: OrderedDict([(label, tf.int32), (pixels, tf.float32)])>
<RepeatDataset shapes: OrderedDict([(label, ()), (pixels, (28, 28))]), types: OrderedDict([(label, tf.int32), (pixels, tf.float32)])>


In [9]:
preprocessed_example_dataset = preprocess(example_dataset)
preprocessed_example_dataset
sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                     next(iter(preprocessed_example_dataset)))
sample_batch['y'].shape


(64, 1)

In [10]:
def make_federated_data(client_data, client_ids):
  return [
      preprocess(client_data.create_tf_dataset_for_client(x))
      for x in client_ids
  ]

In [68]:
for i in emnist_train.client_ids[:5]:
    print(list(emnist_train._h5_file["examples"][i].items())[1])

('pixels', <HDF5 dataset "pixels": shape (93, 28, 28), type "<f4">)
('pixels', <HDF5 dataset "pixels": shape (109, 28, 28), type "<f4">)
('pixels', <HDF5 dataset "pixels": shape (73, 28, 28), type "<f4">)
('pixels', <HDF5 dataset "pixels": shape (100, 28, 28), type "<f4">)
('pixels', <HDF5 dataset "pixels": shape (105, 28, 28), type "<f4">)


In [69]:
sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]

federated_train_data = make_federated_data(emnist_train, sample_clients)

print('Number of client datasets: {l}'.format(l=len(federated_train_data)))
print('First dataset: {d}'.format(d=federated_train_data[0]))
print('Second dataset: {d}'.format(d=federated_train_data[1]))

Number of client datasets: 20
First dataset: <PrefetchDataset shapes: OrderedDict([(x, (None, 28, 28, 1)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>
Second dataset: <PrefetchDataset shapes: OrderedDict([(x, (None, 28, 28, 1)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>


In [70]:
def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape = (28,28,1)),
      tf.keras.layers.Conv2D(filters = 32, kernel_size = (5,5)),
      tf.keras.layers.MaxPool2D(),
      tf.keras.layers.Conv2D(filters = 64, kernel_size = (5,5)),
      tf.keras.layers.MaxPool2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(512, activation = 'relu'),
      tf.keras.layers.Dense(10),
      tf.keras.layers.Softmax(),
  ])

#       tf.keras.layers.Input(shape=(784,)),
#       tf.keras.layers.Dense(10, kernel_initializer='zeros'),
#       tf.keras.layers.Softmax(),

In [71]:
type(preprocessed_example_dataset.element_spec)
preprocessed_example_dataset.element_spec

OrderedDict([('x',
              TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name=None)),
             ('y', TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))])

In [72]:
def model_fn():
  # We _must_ create a new model here, and _not_ capture it from an external
  # scope. TFF will call this within different graph contexts.
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])


In [79]:
keras_model = create_keras_model()
keras_model.summary()

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_4 (Conv2D)            (None, 24, 24, 32)        832       
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 12, 12, 32)        0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 8, 8, 64)          51264     
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (None, 4, 4, 64)          0         
_________________________________________________________________
flatten_2 (Flatten)          (None, 1024)              0         
_________________________________________________________________
dense_4 (Dense)              (None, 512)               524800    
_________________________________________________________________
dense_5 (Dense)              (None, 10)               

In [74]:
iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

In [75]:
str(iterative_process.initialize.type_signature)

'( -> <model=<trainable=<float32[5,5,1,32],float32[32],float32[5,5,32,64],float32[64],float32[1024,512],float32[512],float32[512,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<>,model_broadcast_state=<>>@SERVER)'

In [80]:
state = iterative_process.initialize()

In [81]:
state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))

round  1, metrics=<sparse_categorical_accuracy=0.12074883282184601,loss=2.304137706756592>


In [82]:
NUM_ROUNDS = 200
for round_num in range(2, NUM_ROUNDS):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))

round  2, metrics=<sparse_categorical_accuracy=0.16006240248680115,loss=2.2709944248199463>
round  3, metrics=<sparse_categorical_accuracy=0.18668746948242188,loss=2.254620313644409>
round  4, metrics=<sparse_categorical_accuracy=0.20842432975769043,loss=2.239936113357544>
round  5, metrics=<sparse_categorical_accuracy=0.24035361409187317,loss=2.222548007965088>
round  6, metrics=<sparse_categorical_accuracy=0.2627145051956177,loss=2.207554817199707>
round  7, metrics=<sparse_categorical_accuracy=0.27041080594062805,loss=2.1893606185913086>
round  8, metrics=<sparse_categorical_accuracy=0.2860114276409149,loss=2.168400526046753>
round  9, metrics=<sparse_categorical_accuracy=0.31419655680656433,loss=2.138439416885376>
round 10, metrics=<sparse_categorical_accuracy=0.3407176434993744,loss=2.1103899478912354>
round 11, metrics=<sparse_categorical_accuracy=0.36224648356437683,loss=2.073728084564209>
round 12, metrics=<sparse_categorical_accuracy=0.37670305371284485,loss=2.041407823562622>