# TFF on Handwritten Digit Classfication

In [2]:
import collections
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

## Exploring MNIST data & preprocessing

In [4]:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

Downloading emnist_all.sqlite.lzma: 100%|██████████| 170507172/170507172 [00:35<00:00, 4091150.12it/s]


In [None]:
from tensorflow_federated.python.simulation.datasets import emnist
# Number of clients
len(emnist_train.client_ids)

In [None]:
emnist_train.element_type_structure

In [150]:
# Creates a dataset containing sample from a partiuclar client
sample_dataset = emnist_train.create_tf_dataset_for_client(emnist_train.client_ids[10])
sample_digit = next(iter(sample_dataset))
# Retrieves label
print(sample_digit['label'].numpy())

4


In [None]:
from matplotlib import pyplot as plt 
plt.imshow(sample_digit['pixels'].numpy(), cmap="gray", aspect="equal")
plt.show()

In [None]:
# Exploring samples from one client (Impossible in production)
# Non-iid
figure = plt.figure(figsize=(20, 4))
j = 0

for sample in sample_dataset.take(40):
  plt.subplot(4, 10, j+1)
  plt.imshow(sample_digit['pixels'].numpy(), cmap="gray", aspect="equal")
  plt.axis("off")
  j += 1

In [None]:
plt.savefig("sample_client_data.png")

In [None]:
# Number of examples per layer for a sample of clients
f = plt.figure(figsize=(12, 7))
f.suptitle('Label Counts for a Sample of Clients')
client_id = 1000
for i in range(6):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[client_id])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    # Append counts individually per label to make plots
    # more colorful instead of one color per plot.
    label = example['label'].numpy()
    plot_data[label].append(label)
  plt.subplot(2, 3, i+1)
  plt.title('Client {}'.format(client_id))
  for j in range(10):
    plt.hist(
        plot_data[j],
        density=False,
        bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
  client_id += 400

In [None]:
# Visualizing mean image per client for each label
k = 100
for i in range(5):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[k])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    plot_data[example['label'].numpy()].append(example['pixels'].numpy())
  f = plt.figure(i, figsize=(12, 5))
  f.suptitle("Client #{}'s Mean Image Per Label".format(k))
  for j in range(10):
    mean_img = np.mean(plot_data[j], 0)
    plt.subplot(2, 5, j+1)
    plt.imshow(mean_img.reshape((28, 28)))
    plt.axis('off')
    k += 50

In [113]:
NUM_CLIENTS = 10
NUM_EPOCHS = 10
BATCH_SIZE = 64
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER = 10

"""
Helper function to preprocess input data to model
28x28 images into 784-dimensional tensor
"""
def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch `pixels` and return the features as an `OrderedDict`."""
    return collections.OrderedDict(
        #x=tf.reshape(element['pixels'], [-1, 784]),
        x=element['pixels'],
        y=tf.reshape(element['label'], [-1, 1]))

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER, seed=1).batch(
      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

In [114]:
"""
Helper function that will return a list containting tf.data.Dataset for each
client.
"""
def make_federated_data(client_data, client_ids):
  return [
      preprocess(client_data.create_tf_dataset_for_client(x))
      for x in client_ids
  ]

## Build Keras model

In [131]:
from tensorflow.keras.layers import Dense, Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import RMSprop, SGD

In [117]:
def create_keras_model():
  return Sequential(
      [
      Conv2D(filters=32, kernel_size=(3,3), activation="relu", 
             input_shape=(28, 28, 1)),
      MaxPooling2D(pool_size=(2,2)), 
      Conv2D(filters=64, kernel_size=(3,3), activation="relu"),
      MaxPooling2D(pool_size=(2,2)),
      Flatten(),
      Dense(64, activation="relu", ),
      Dense(10, activation="softmax")
  ])

In [None]:
create_keras_model().summary()

In [118]:
"""
Keras model needs to be wrapped in tff.learning.Model()
"""
def model_fn(preprocessed_example_dataset=federated_train_data[0]):
  
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

## Training with Federated Averaging

In [133]:
iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: SGD(learning_rate=0.01, momentum=0.9),
    server_optimizer_fn=lambda: SGD(learning_rate=0.8, momentum=0.9))

In [134]:
# Construct server state
state = iterative_process.initialize()

In [136]:
NUM_ROUNDS = 40
NUM_CLIENTS = 10
TOTAL_CLIENTS = len(emnist_train.client_ids)
for round_num in range(1, NUM_ROUNDS+1):

  # Random smapling of client
  start = np.random.randint(0, TOTAL_CLIENTS)
  
  # Select a subset of NUM_CLIENTS for each round of FL lifecycle
  sample_clients = emnist_train.client_ids[start:start+NUM_CLIENTS]
  federated_train_data = make_federated_data(emnist_train, sample_clients)
  
  # FL Lifecycle (Single round of training)
  # next signature: SERVER_STATE, FEDERATED_DATA => SEVER STATE, TRAINING METRICS
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))


round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.9729651), ('loss', 0.08845906), ('num_examples', 10320), ('num_batches', 165)]))])
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.95959693), ('loss', 0.15162002), ('num_examples', 10420), ('num_batches', 167)]))])
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.97955304), ('loss', 0.09288146), ('num_examples', 8950), ('num_batches', 145)]))])
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.97519684), ('loss', 0.11913582

In [None]:
# Examining federated computations
# initialize: State of FA
print(iterative_process.initialize.type_signature.formatted_representation())

## Federated Evaluation

In [137]:
# Evaluating the global model
evaluation = tff.learning.build_federated_evaluation(model_fn)

In [None]:
print(evaluation.type_signature.formatted_representation())

In [139]:
# SIGNATURE: SERVER_MODEL, FEDERATED_DATA -> TRAINING_METRICS
train_metrics = evaluation(state.model, federated_train_data)

In [140]:
str(train_metrics)

"OrderedDict([('eval', OrderedDict([('sparse_categorical_accuracy', 0.9761249), ('loss', 0.0812723), ('num_examples', 10890), ('num_batches', 175)]))])"

In [141]:
federated_test_data = make_federated_data(emnist_test, sample_clients)

len(federated_test_data)

10

In [142]:
test_metrics = evaluation(state.model, federated_test_data)
str(test_metrics)

"OrderedDict([('eval', OrderedDict([('sparse_categorical_accuracy', 0.9612403), ('loss', 0.10330613), ('num_examples', 1290), ('num_batches', 28)]))])"

# Single Client Handwritten Digit Classification

In [153]:
from tensorflow import keras
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

In [154]:
# Data Preprocessing
from tensorflow.keras.utils import to_categorical
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255

In [152]:
y_test[:10]

array([7, 2, 1, 0, 4, 1, 4, 9, 5, 9], dtype=uint8)

In [157]:
from tensorflow_federated.python import learning
from tensorflow_federated.python.learning.optimizers import optimizer
# Building the model and training it
model = create_keras_model()
model.compile(optimizer=SGD(learning_rate=0.01, momentum=0.9), 
              loss=keras.losses.SparseCategoricalCrossentropy(), 
              metrics=keras.metrics.SparseCategoricalAccuracy())
model.fit(x_train, y_train, epochs=NUM_ROUNDS // 2, batch_size=BATCH_SIZE)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.callbacks.History at 0x7f2ad65d20d0>

In [158]:
# Testing the model 
model.evaluate(x_test, y_test)



[0.036732666194438934, 0.9907000064849854]