In [None]:
!pip install tensorflow
!pip install tensorflow-federated

In [None]:
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)
tff.federated_computation(lambda: "Hello World!")()

In [None]:
# Load the dataset used in the Federated Learning process
# Use a federated version of MNIST dataset; non-iid data as expected in federated env; each user has his unique style of writing
# This dataset loading is used to create the datasets for each of the clients in the federated learning process
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data() # load the simulation data together with the client ids; this interface works only in simulation mode

In [None]:
len(emnist_train.client_ids) # number of clients

In [None]:
# Explore the size of the data by looking at the shape of some client data
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0]
)

example_dataset.element_spec # Gives the structure of one entry from the MNIST dataset

In [None]:
# Select an element from the dataset of one of our own simulated clients
example_element = next(iter(example_dataset))

# Print the label of this element
example_element['label'].numpy()

In [None]:
from matplotlib import pyplot as plt
from IPython.display import display, HTML, IFrame
from skimage import io

In [None]:
# Plot the first 40 images from the first client dataset
# In reality, this is impossible

fig = plt.figure(figsize=(20,4))
j = 0

for digit in example_dataset.take(40):
  plt.subplot(4, 10, j + 1)
  plt.imshow(digit['pixels'].numpy(), cmap='gray', aspect='equal')
  plt.axis('off')
  j +=1

In [None]:
# Plot the distribution of digits for the first six clients
# This works just because we work in a simulated environment; in reality, only the user can see his data

f = plt.figure(figsize=(12, 7))
f.suptitle("Label Counts for a Sample of Clients")

for i in range(6):
  ds = emnist_train.create_tf_dataset_for_client(emnist_train.client_ids[i])
  k = collections.defaultdict(list)
  for e in ds:
    k[e['label'].numpy()].append(e['label'].numpy())
  plt.subplot(2, 3, i + 1)
  plt.title("Client {}".format(i))
  for j in range(10):
    plt.hist(k[j], density=False, bins=[x for x in range(11)])

In [None]:
# Plot the digit mean images for each client to see how each local model learn about each digit
for i in range(5):
  ds = emnist_train.create_tf_dataset_for_client(emnist_train.client_ids[i])
  k = collections.defaultdict(list)
  for digit in ds:
    k[digit['label'].numpy()].append(digit['pixels'].numpy())
  f = plt.figure(i, figsize=(12, 5))
  f.suptitle("Client #{}'s Mean Image Per Label".format(i))
  for j in range(10):
    mn_img = np.mean(k[j], 0)
    plt.subplot(2, 5, j + 1)
    plt.imshow(mn_img.reshape(28, 28))
    plt.axis("off")

In [None]:
# Define the parameters of the Federated Learning setup
NUM_CLIENTS = 10
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER = 10

In [None]:
# Preprocess the input data
def preprocess(dataset):

  def batch_format_fn(element):
    return collections.OrderedDict(
        x=tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1])
    )
  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER, seed=1).batch(
      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

In [None]:
# Verify if the preprocessing function works by looking at the first batch,
# converting all the tensorflow tensors inside the batch to numpy arrays

preprocessed_example_dataset = preprocess(example_dataset)

# Applies recursively the lambda function to each item in the nested structure
# a.k.a to all the tensors inside of other tensors in the batch
sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                     next(iter(preprocessed_example_dataset)))

sample_batch

In [None]:
# To be able to feed data to each client in the TFF simulation environment, we
# create a list with each user dataset

def make_federated_data(client_data, client_ids):
  return [
      preprocess(client_data.create_tf_dataset_for_client(x))
      for x in client_ids
  ]

In [None]:
# To choose the clients for our Federated Setup, we sample a random subset from
# our clients set

sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]

federated_train_data = make_federated_data(emnist_train, sample_clients)

print("Number of client datasets: {}".format(len(federated_train_data)))
print("First dataset: {}".format(federated_train_data[0]))

In [None]:
# Define the Keras NN used in image classification
def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])

In [None]:
# Wrap the Keras model in order to be used with Tensorflow Federated
def model_fn():
  keras_model = create_keras_model()
  return tff.learning.models.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
  )

In [None]:
# Construct the Federated Averaging algorithm to train the model on federated data
training_process = tff.learning.algorithms.build_weighted_fed_avg(
    model_fn, # use the constructor so the model construction is controlled by TFF
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02), # used to compute local model updates on each client
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0) # applies the averaged update to the global model on the server
)
training_process # A learning process that is iterative and implements FedAvg

In [None]:
# Outputs the state of the Federated Averaging process on the server
print(training_process.initialize.type_signature.formatted_representation())

In [None]:
# Construct the server state and initialize the global model parameters
train_state = training_process.initialize()
train_state

In [None]:
# Run a single round (Global Model from server -> Client -> Training on Local Data -> Collect and Average the Model Updates -> Global Model Update at the Server)
results = training_process.next(train_state, federated_train_data)
# Capture the new state after the first round
train_state = results.state
print('Train state: {}'.format(train_state))
# Capture the performance of the Federated Learning after one round
train_metrics = results.metrics
print('Round 1, metrics {}'.format(train_metrics))

In [None]:
# Run for more rounds
NUM_ROUNDS = 10
for round in range(2, NUM_ROUNDS):
  results = training_process.next(train_state, federated_train_data)
  train_state = results.state
  print('Round {}, train state: {}'.format(round, train_state))
  train_metrics = results.metrics
  print('Round {}, train metrics: {}'.format(round, train_metrics) + '\n')