# **Setting up environment**

In [None]:
!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()

In [None]:
import collections
import attr
import functools
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

# **Preprocessing**

## Loading dataset

In [None]:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

## Flattening data

In [None]:
NUM_CLIENTS = 10
BATCH_SIZE = 20

def preprocess(dataset):
  def batch_format_fn(element):
    return (tf.reshape(element['pixels'], [-1, 784]), 
            tf.reshape(element['label'], [-1, 1]))
  return dataset.batch(BATCH_SIZE).map(batch_format_fn)                           # return a (features, label) tuple in a batch of BATCH_SIZE

In [None]:
client_ids = np.random.choice(emnist_train.client_ids, size=NUM_CLIENTS,          # Randomly selets "NUM_CLIENTS" from the list of clients in  without replacement
                              replace=False)                                      # Probability of selection of each element can also be passed as an argument

federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))  # Creates dataset for the selected clients
  for x in client_ids
]

# **Model**


## Defining Keras model

In [None]:
def create_keras_model():
  return tf.keras.models.Sequential([                                             # Signal hidden layer sequential keras model
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])

## Wrapping Keras model

In [None]:
def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=federated_train_data[0].element_spec,                            # Specification of input
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

# **Building Federated Learning Algorithm**

In [None]:
def initialize_fn():
  model = model_fn()
  return model.trainable_variables

In [None]:
def next_fn(server_weights, federated_dataset):

  server_weights_at_client = broadcast(server_weights)                            # Broadcast the server weights to the clients.
  
  client_weights = client_update(federated_dataset, server_weights_at_client)     # Each client computes their updated weights.
  
  mean_client_weights = mean(client_weights)                                      # The server averages these updates.
  
  server_weights = server_update(mean_client_weights)                             # The server updates its model.

  return server_weights

## **TensorFlow Blocks**

### **Client Update**


In [None]:
@tf.function
def client_update(model, dataset, server_weights, client_optimizer):              # Performs training (using the server model weights) on the client's dataset.
  
  client_weights = model.trainable_variables                                      # Initialize the client model with the current server weights.
  
  tf.nest.map_structure(lambda x, y: x.assign(y),                                 # Assign the server weights to the client model by assigning each element of server_weights to client_weights
                        client_weights, server_weights)
                                                                                  # Use the client_optimizer to update the local model.
  for batch in dataset:                                                           # For each bacth in input "Dataset"
    with tf.GradientTape() as tape:
      outputs = model.forward_pass(batch)                                         # Compute a forward pass on the batch of data

    grads = tape.gradient(outputs.loss, client_weights)                           # Compute the corresponding gradient of outputs.loss w.r.t. client_weights
    grads_and_vars = zip(grads, client_weights)                                   # Zips "gradients" and "client_weights"

    client_optimizer.apply_gradients(grads_and_vars)                              # Apply the gradient using a client optimizer.

  return client_weights

### **Server Update**

In [None]:
@tf.function
def server_update(model, mean_client_weights):                                    # Updates the server model weights as the average of the client model weights.
  model_weights = model.trainable_variables
  
  tf.nest.map_structure(lambda x, y: x.assign(y),                                 # Assign the mean client weights to the server model.
                        model_weights, mean_client_weights)                       # tf.nest.map_structure() applies func to each entry in structure and returns a new structure.
  return model_weights

## **TensorFlow Federeated Blocks**

### **`initialize_fn`**

In [None]:
@tff.tf_computation
def server_init():
  model = model_fn()
  return model.trainable_variables

In [None]:
@tff.federated_computation
def initialize_fn():
  return tff.federated_value(server_init(), tff.SERVER)

In [None]:
str(initialize_fn.type_signature)

### **`next_fn`**

In [None]:
whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)                       # Dataset type
print(str(tf_dataset_type))

model_weights_type = server_init.type_signature.result                            # Model weight type
print(str(model_weights_type))

###  **`client_update_fn`**


In [None]:
@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
  model = model_fn()
  client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
  return client_update(model, tf_dataset, server_weights, client_optimizer)

### **`server_update_fn`**

In [None]:
@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
  model = model_fn()
  return server_update(model, mean_client_weights)

In [None]:
federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

In [None]:
@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
  
  server_weights_at_client = tff.federated_broadcast(server_weights)              # Broadcast the server weights to the clients.
  
  client_weights = tff.federated_map(                                             # Each client computes their updated weights.
      client_update_fn, (federated_dataset, server_weights_at_client))
  
  mean_client_weights = tff.federated_mean(client_weights)                        # The server averages these updates.

  server_weights = tff.federated_map(server_update_fn, mean_client_weights)       # The server updates its model.

  return server_weights

In [None]:
federated_algorithm = tff.templates.IterativeProcess(
    initialize_fn=initialize_fn,
    next_fn=next_fn
)

In [None]:
print(str(federated_algorithm.initialize.type_signature))
print(str(federated_algorithm.next.type_signature))

# **Evaluating Algorithm**

In [None]:
central_emnist_test = emnist_test.create_tf_dataset_from_all_clients().take(1000) # Taking only 1000 samples
central_emnist_test = preprocess(central_emnist_test)                             # Preprocessing test dataset

## **Evaluation on test dataset**

In [None]:
def evaluate(server_state):
  keras_model = create_keras_model()                                              # Creates Keras mode
  keras_model.compile(                                                            # Configures the model for training
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]  
  )
  keras_model.set_weights(server_state)                                           # Sets the weights of model same as the server_state
  keras_model.evaluate(central_emnist_test)                                       # Returns the loss value & metrics values for the model in test mode.
                                                                                  # Computation is done in batch, if batch size not secified, 32 is default value

In [None]:
server_state = federated_algorithm.initialize()
evaluate(server_state)

In [None]:
for round in range(15):
  server_state = federated_algorithm.next(server_state, federated_train_data)

In [None]:
evaluate(server_state)

# **Challenge:**

1. Implement a version of `server_update` that updates the server weights to be the midpoint of model_weights and mean_client_weights. (Note: This kind of "midpoint" approach is analogous to recent work on the [Lookahead optimizer](https://arxiv.org/abs/1907.08610)!).  
2. Add [gradient clipping](https://towardsdatascience.com/what-is-gradient-clipping-b8e815cdfb48) to the `client_update` function.
3. Implement Federated Averaging with learning rate decay on the clients.
  
  We could have the server store and broadcast more data. For example, the server could also store the client learning rate, and make it decay over time! Note that this will require changes to the type signatures used in the `tff.tf_computation` calls above.

For ideas (including the answer to the harder challenge above) you can see the source-code for [`tff.learning.build_federated_averaging_process`](https://www.tensorflow.org/federated/api_docs/python/tff/learning/build_federated_averaging_process), or check out various [research projects](https://github.com/google-research/federated) using TFF.
