# **Setting up environment**

In [None]:
!pip install --upgrade --quiet tensorflow-federated

In [None]:
import nest_asyncio
nest_asyncio.apply()

In [None]:
%load_ext tensorboard

In [None]:
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

tff.federated_computation(lambda: 'Hello, World!')()

# **Input Data**

## **Loading dataset**

In [None]:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

In [None]:
len(emnist_train.client_ids)                                                      # Number of clients ID for training

In [None]:
emnist_train.element_type_structure                                               # Structure of training dataset

### **Exploring data heterogeneity**

**Example data**

In [None]:
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])                                                 # Example data from client with ID 0

example_element = next(iter(example_dataset))                                     # The iter() function creates an object which can be iterated one element at a time.
                                                                                  # The next() function returns the next item in an iterator.
example_element['label'].numpy()                                                  # Label of example data

In [None]:
from matplotlib import pyplot as plt
plt.imshow(example_element['pixels'].numpy(), cmap='gray', aspect='equal')        # Pixel representation in dataset for example data
plt.grid(False)
_ = plt.show()

**MNIST digits from one client**

In [None]:
figure = plt.figure(figsize=(20, 4))
j = 0

for example in example_dataset.take(40):                                          # Taking 40 elements from example_dataset
  plt.subplot(4, 10, j+1)                                                         # Plotting in a grid of 4 x 10 (Similar to subplot of matlab)
  plt.imshow(example['pixels'].numpy(), cmap='gray', aspect='equal')              # Plotting gray scale pixel image with equal aspect
  plt.axis('off')
  j += 1

**Number of data elements of each client**

In [None]:
# Number of examples per layer for a sample of clients
f = plt.figure(figsize=(12, 7))
f.suptitle('Label Counts for a Sample of Clients')
for i in range(6):                                                                # Taking first 6 client's data
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])                                                 # Dataset of one client
  plot_data = collections.defaultdict(list)                                       # Defaultdict is a container like dictionaries. Unlike dictionary, it never raises a KeyError. It provides a default value for the key that does not exists.
  for example in client_dataset:                                                  # Taking each data element in client "i"
    label = example['label'].numpy()                                              # Find label of the element of one client specific dataset
    plot_data[label].append(label)                                                # Append label value () corresponding to the key [] (that is same as label)
                                                                                  # Append counts individually per label to make plots more colorful instead of one color per plot.
  plt.subplot(2, 3, i+1)
  plt.title('Client {}'.format(i))
  for j in range(10):                                                             # For each label in client "i"
    plt.hist(
        plot_data[j],
        density=False,
        bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])                                  # Creates histogram

**Mean pixel image for each label per client**

In [None]:
for i in range(5):                                                                # Taking first 5 client's data
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:                                                  # For each data in client "i"
    plot_data[example['label'].numpy()].append(example['pixels'].numpy())         # Appending pixel information of element corresponding to the key 
  f = plt.figure(i, figsize=(12, 5))
  f.suptitle("Client #{}'s Mean Image Per Label".format(i))
  for j in range(10):                                                             # For each label in client "i"
    mean_img = np.mean(plot_data[j], 0)                                           # Taking mean of the appended info. for label "j"
    plt.subplot(2, 5, j+1)
    plt.imshow(mean_img.reshape((28, 28)))                                        # Reshaping mean_img to original image shape, then plotting it
    plt.axis('off')

## **Preprocessing input data**

tf.data.Dataset: 
https://www.tensorflow.org/api_docs/python/tf/data/Dataset

In [None]:
NUM_CLIENTS = 10
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER = 10

def preprocess(dataset):

  def batch_format_fn(element):                                                   # Flatten a batch of pixel info and return the features as an OrderedDict.
    return collections.OrderedDict(                                               # OrderedDict preserves the order in which the keys are inserted
        x=tf.reshape(element['pixels'], [-1, 784]),                               # Reshaping 28x28 matrix info into 784(=28^2) vector
        y=tf.reshape(element['label'], [-1, 1]))                                  # Assigning pixel info to x and label info to y

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER).batch(                # Repeat over the data set is used to run over several epochs. repeat() data NUM_epochs times 
      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)                  # shuffle() method randomly shuffles the elements of this dataset. Fills a buffer with buffer_size elements, then randomly samples elements from this buffer, replacing the selected elements with new elements. For perfect shuffling, a buffer size greater than or equal to the full size of the dataset is required.
                                                                                  # batch() method combines consecutive elements of the dataset into batches.
                                                                                  # map:
                                                                                  # Creates a Dataset that prefetches elements from this dataset. Most dataset input pipelines should end with a call to prefetch. This allows later elements to be prepared while the current element is being processed. This often improves latency and throughput, at the cost of using additional memory to store prefetched elements.
                                                                                  ## methods execute in the order of writing

In [None]:
preprocessed_example_dataset = preprocess(example_dataset)                        # preprocessed example_dataset 

sample_batch = tf.nest.map_structure(lambda x: x.numpy(),                         # map_structure applies func to each entry in structure and returns a new structure.
                                     next(iter(preprocessed_example_dataset)))    # lambda x: x.numpy will convert all the values passed in second argument into numpy values

sample_batch

## **Converting user's dataset into a list**

In [None]:
def make_federated_data(client_data, client_ids):                                 # Returns a list containing preprocessed data of each client specified in client_ids
  return [
      preprocess(client_data.create_tf_dataset_for_client(x))
      for x in client_ids
  ]

# **Choosing Clients**

In [None]:
sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]                           # Sampling first NUM_CLIENTS 

federated_train_data = make_federated_data(emnist_train, sample_clients)          # Converting the dataset into list of preprocessed data for selected clients

print('Number of client datasets: {l}'.format(l=len(federated_train_data)))
print('First dataset: {d}'.format(d=federated_train_data[0]))

# **Creating Model with Keras**

In [None]:
def create_keras_model():
  return tf.keras.models.Sequential([                                             # Sequential model is created
      tf.keras.layers.InputLayer(input_shape=(784,)),                             # Input layer consists of 784 inputs
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),                      # Layer with 10 nodes and weight matrix initialized to zeros. dense() implements the operation, output = activation(dot(input, kernel) + bias).
      tf.keras.layers.Softmax(),                                                  # softmaxed output
  ])

from_keras_model: https://www.tensorflow.org/federated/api_docs/python/tff/learning/from_keras_model

losses: https://www.tensorflow.org/api_docs/python/tf/keras/losses

metrics: https://www.tensorflow.org/api_docs/python/tf/keras/metrics

In [None]:
def model_fn():
  # We _must_ create a new model here, and _not_ capture it from an external scope. TFF will call this within different graph contexts.
  keras_model = create_keras_model()                                              # Model created using keras
  return tff.learning.from_keras_model(                                           # Builds a tff.learning.Model object from a tf.keras.Model
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,                       # Specifies the type of arguments the model expects. 
                                                                                  # element_spec() method gives type specification of an element of the dataset.  
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),                       # A single tf.keras.losses.Loss or a list of losses-per-output. . 
                                                                                  # SparseCategoricalCrossentropy() computes the crossentropy loss between the labels and predictions
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])                     # A metric is a function that is used to judge the performance of model.
                                                                                  # SparseCategoricalAccuracy: Compute the frequency with which y_pred matches y_true


# **Training Model on Federated Data**

build_federated_averaging_process: https://www.tensorflow.org/federated/api_docs/python/tff/learning/build_federated_averaging_process

Optimizers: https://www.tensorflow.org/api_docs/python/tf/keras/optimizers

In [None]:
iterative_process = tff.learning.build_federated_averaging_process(               # Builds an iterative process that performs federated averaging
    model_fn,                                                                     # A tff.learning.Model.
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),      # Optimizer for client
                                                                                  # Optimizing using stochastic gradient descent (SGD) with specified learning rate
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))       # Optimizer for server

TFF has constructed a pair of federated computations and packaged them into a `tff.templates.IterativeProcess` in which these computations are available as a pair of properties `initialize` and `next`.

### `Initialize`

In [None]:
str(iterative_process.initialize.type_signature)

In [None]:
state = iterative_process.initialize()                                            # Initializes computation to construct the server state.

### `Next`

In [None]:
state, metrics = iterative_process.next(state, federated_train_data)              # It's first argument is the current state (originally produced by `tff.templates.IterativeProcess.initialize`), and the first (or only) returned value is the updated state. 
print('round  1, metrics={}'.format(metrics))

In [None]:
NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):                                            # Running multiple rounds of server training.
  state, metrics = iterative_process.next(state, federated_train_data)            # Same users are used for each round as federated_train_data is same at each round
  print('round {:2d}, metrics={}'.format(round_num, metrics))

# **Displaying model metrics in TensorBoard**

In [None]:
#@test {"skip": true}
logdir = "/tmp/logs/scalars/training/"                                            # The directory to write an event file
summary_writer = tf.summary.create_file_writer(logdir)                            # Creates a summary file writer for the given log directory.
state = iterative_process.initialize()

In [None]:
#@test {"skip": true}
with summary_writer.as_default():                                                 # with statement is used in exception handling. It helps avoiding bugs and leaks by ensuring that a resource is properly released when the code using the resource is completely executed.
  for round_num in range(1, NUM_ROUNDS):                                          # NUM_ROUNDS of server training
    state, metrics = iterative_process.next(state, federated_train_data)          # Each round of training
    for name, value in metrics['train'].items():                                  # metrics['train']= OrderedDict([('sparse_categorical_accuracy', 0.6872428), ('loss', 1.0891807)])
      tf.summary.scalar(name, value, step=round_num)                              # Writes scalar "value" corresponding to "name" metric 

In [None]:
#@test {"skip": true}
!ls {logdir}
%tensorboard --logdir {logdir} --port=0

In [None]:
# @test {"skip": true}
# Uncomment and run this this cell to clean your directory of old output for
# future graphs from this directory. We don't run it by default so that if 
# you do a "Runtime > Run all" you don't lose your results.

#!rm -R /tmp/logs/scalars/*

# **Customizing the Model Implementation**

## **Defining model variables, forward pass, and metrics**


In [None]:
 MnistVariables = collections.namedtuple(                                         # function for creating tuple subclasses with named fields
    'MnistVariables', 'weights bias num_examples loss_sum accuracy_sum')          

### Function for creating variables


In [None]:
def create_mnist_variables():                                                     # Function to create the variables.
  return MnistVariables(
      weights=tf.Variable(                                                        # Defining variable for weights
          lambda: tf.zeros(dtype=tf.float32, shape=(784, 10)),                    # Initializing weights to be zero of size 784 x 10 and as float variable
          name='weights',                                                         # Assigning name to the variable
          trainable=True),                                                        # 
      bias=tf.Variable(                                                           # Defining variable for bias
          lambda: tf.zeros(dtype=tf.float32, shape=(10)),
          name='bias',
          trainable=True),
      num_examples=tf.Variable(0.0, name='num_examples', trainable=False),        # Defining variable for number of examples
      loss_sum=tf.Variable(0.0, name='loss_sum', trainable=False),                # Defining variable for loss
      accuracy_sum=tf.Variable(0.0, name='accuracy_sum', trainable=False))        # Defining variable for accuracy

### Function for predicting labels, computing loss and accuracy in a forward pass

In [None]:
def mnist_forward_pass(variables, batch):                                         # Forward pass method that computes loss and updates the cumulative statistics for a single batch of input data
  y = tf.nn.softmax(tf.matmul(batch['x'], variables.weights) + variables.bias)    # NN output, computes softmax activations.
  predictions = tf.cast(tf.argmax(y, 1), tf.int32)                                # NN prediction, assigns the location of vector y, corresponding to largest value and casts that value into 32 bit integer

  flat_labels = tf.reshape(batch['y'], [-1])                                      # Reshapes batch['y'] into a vector (second argument defines the shape of final tensor and [-1] means vector)
  loss = -tf.reduce_mean(                                                         # Computes mean of the tensor across columns
      tf.reduce_sum(tf.one_hot(flat_labels, 10) * tf.math.log(y), axis=[1]))      # Sums the dot product of one hot vector of length 10, with indices specified by "flat_labels" and the predicted labels
  accuracy = tf.reduce_mean(
      tf.cast(tf.equal(predictions, flat_labels), tf.float32))                    # tf.equal() returns a boolean value after comparing the values 

  num_examples = tf.cast(tf.size(batch['y']), tf.float32)                         # Number of examples assigned to the size of the output values

  variables.num_examples.assign_add(num_examples)                                 # Adds the new value to the varibles
  variables.loss_sum.assign_add(loss * num_examples)
  variables.accuracy_sum.assign_add(accuracy * num_examples)

  return loss, predictions

### Function for returning local metrics

In [None]:
def get_local_mnist_metrics(variables):
  return collections.OrderedDict(
      num_examples=variables.num_examples,
      loss=variables.loss_sum / variables.num_examples,
      accuracy=variables.accuracy_sum / variables.num_examples)

### Function for aggregrating local metrics

In [None]:
@tff.federated_computation
def aggregate_mnist_metrics_across_clients(metrics):
  return collections.OrderedDict(
      num_examples=tff.federated_sum(metrics.num_examples),                       # Computes a sum at tff.SERVER of a value placed on the tff.CLIENTS.
      loss=tff.federated_mean(metrics.loss, metrics.num_examples),                # Computes a mean of value . It can also compute sum based on different weights. Read: https://www.tensorflow.org/federated/api_docs/python/tff/federated_mean
      accuracy=tff.federated_mean(metrics.accuracy, metrics.num_examples))

### Constructing an instance of `tff.learning.Model`

In [None]:
class MnistModel(tff.learning.Model):

  def __init__(self):
    self._variables = create_mnist_variables()                                    # Calling variable declaration function

  @property                                                                       # 
  def trainable_variables(self):                                                  # Returns a list of trainable variables
    return [self._variables.weights, self._variables.bias]                        

  @property
  def non_trainable_variables(self):                                              # Returns an empty list
    return []

  @property
  def local_variables(self):                                                      # Returns local metrics
    return [
        self._variables.num_examples, self._variables.loss_sum,
        self._variables.accuracy_sum
    ]

  @property
  def input_spec(self):                                                           # Returns specification of input
    return collections.OrderedDict(
        x=tf.TensorSpec([None, 784], tf.float32),                                 # tf.TensorSpec 
        y=tf.TensorSpec([None, 1], tf.int32))

  @tf.function
  def forward_pass(self, batch, training=True):
    del training                                                                  # deletes variable "training"
    loss, predictions = mnist_forward_pass(self._variables, batch)                
    num_exmaples = tf.shape(batch['x'])[0]
    return tff.learning.BatchOutput(                                              # A structure that holds the output of a tff.learning.Model.
        loss=loss, predictions=predictions, num_examples=num_exmaples)

  @tf.function
  def report_local_outputs(self):
    return get_local_mnist_metrics(self._variables)                               # Returns local metrics 

  @property
  def federated_output_computation(self):                                         # Returns aggregated model metrics
    return aggregate_mnist_metrics_across_clients

### Simulating federated training with the new model

In [None]:
iterative_process = tff.learning.build_federated_averaging_process(               # Builds an iterative process that performs federated averaging
    MnistModel,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02))

In [None]:
state = iterative_process.initialize()                                            # Initializes computation to construct the server state.

In [None]:
for round_num in range(1, 11):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))

In [None]:
#@test {"skip": true}
logdir = "/tmp/logs/scalars/training/"                                            # The directory to write an event file
summary_writer = tf.summary.create_file_writer(logdir)                            # Creates a summary file writer for the given log directory.
state = iterative_process.initialize()

In [None]:
#@test {"skip": true}
with summary_writer.as_default():                                                 # with statement is used in exception handling. It helps avoiding bugs and leaks by ensuring that a resource is properly released when the code using the resource is completely executed.
  for round_num in range(1, NUM_ROUNDS):                                          # NUM_ROUNDS of server training
    state, metrics = iterative_process.next(state, federated_train_data)          # Each round of training
    for name, value in metrics['train'].items():                                  # metrics['train']= OrderedDict([('sparse_categorical_accuracy', 0.6872428), ('loss', 1.0891807)])
      tf.summary.scalar(name, value, step=round_num)                              # Writes scalar "value" corresponding to "name" metric 

In [None]:
#@test {"skip": true}
!ls {logdir}
%tensorboard --logdir {logdir} --port=0

In [None]:
# @test {"skip": true}
# Uncomment and run this this cell to clean your directory of old output for
# future graphs from this directory. We don't run it by default so that if 
# you do a "Runtime > Run all" you don't lose your results.

#!rm -R /tmp/logs/scalars/*

### Evaluation

In [None]:
evaluation = tff.learning.build_federated_evaluation(MnistModel)

In [None]:
str(evaluation.type_signature)

In [None]:
train_metrics = evaluation(state.model, federated_train_data)
str(train_metrics)

### Evaluation of test dataset

In [None]:
federated_test_data = make_federated_data(emnist_test, sample_clients)
len(federated_test_data), federated_test_data[0]

In [None]:
test_metrics = evaluation(state.model, federated_test_data)
str(test_metrics)