# **Setting up environment**

In [None]:
#@test {"skip": true}
!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio
  
import nest_asyncio
nest_asyncio.apply()

In [None]:
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

# TODO(b/148678573,b/148685415): must use the reference context because it
# supports unbounded references and tff.sequence_* intrinsics.
tff.backends.reference.set_reference_context()

In [None]:
@tff.federated_computation
def hello_world():
  return 'Hello, World!'

hello_world()

# **Implementing Federated Averaging**


## **Preparing federated data sets**

We have data from 10 users, and each of the users contributes knowledge how to recognize a different digit. This is about as non-i.i.d.as it gets.

In [None]:
mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()                     # loading the standard MNIST data

In [None]:
[(x.dtype, x.shape) for x in mnist_train]                                         # The data comes as Numpy arrays, one with images and another with digit labels, both with the first dimension going over the individual examples.

A helper function that formats data in a way compatible with how we feed federated sequences into TFF computations, i.e., as a list of lists - the outer list ranging over the users (digits), the inner ones ranging over batches of data in each client's sequence. As is customary, each batch is structured as a pair of tensors named `x` and `y`, each with the leading batch dimension. 

In [None]:
NUM_EXAMPLES_PER_USER = 1000
BATCH_SIZE = 100

def get_data_for_digit(source, digit):
  output_sequence = []
  all_samples = [i for i, d in enumerate(source[1]) if d == digit]                # Enumerate() method adds a counter to an iterable and returns it in a form of enumerate object. This enumerate object can then be used directly in for loops or be converted into a list of tuples using list() method.
                                                                                  # all_samples contains counter (location) for all the samples with label "digit" as an enumerate
  for i in range(0, min(len(all_samples), NUM_EXAMPLES_PER_USER), BATCH_SIZE):    # Iterating over all samples with increament of BATCH_SIZE
    batch_samples = all_samples[i:i + BATCH_SIZE]                                 # Creating a list of samples in current batch
    output_sequence.append({                                                      # appends the information to "output_sequence"
        'x':
            np.array([source[0][i].flatten() / 255.0 for i in batch_samples],     # Under 'x', stores an array containing pixel information of all samples in batch_samples; flattens the 28x 28 matrix to a vector of 782 and normalizes values to lie in range on 0 to 1
                     dtype=np.float32),                                           # type casts all the value to float32
        'y':
            np.array([source[1][i] for i in batch_samples], dtype=np.int32)       # Under 'y', stores an array containg label information of all the samples in "batch_samples" and type cats it to int32
    })
  return output_sequence                                                          # output_sequence=[[batch1],[batch2],[batch3]...] for "digit"
                                                                                  # batch"i"=[pixel_vector1,pixel_vector2,...],[labe1, label2, ...]
federated_train_data = [get_data_for_digit(mnist_train, d) for d in range(10)]    # federated_train_data=[output_sequence for 0, output_sequence for 1,...]
                                                                                  # federated_train_data[a][b][c] contains batch_size elements where a=digit (client), b=batch number, c='x' or 'y'
federated_test_data = [get_data_for_digit(mnist_test, d) for d in range(10)]

In [None]:
federated_train_data[5][-1]['y']                                                  # `Y` tensor in the last batch of data contributed by the fifth client (corresponding to digit `5`).

In [None]:
# The image corresponding to the last element of that batch.
from matplotlib import pyplot as plt

plt.imshow(federated_train_data[5][-1]['x'][-1].reshape(28, 28), cmap='gray')
plt.grid(False)
plt.show()

## **On Combining TensorFlow and TFF**

In this tutorial, for compactness we immediately decorate functions that
introduce TensorFlow logic with `tff.tf_computation`. However, for more complex logic, this is not the pattern we recommend. Debugging TensorFlow can already be a challenge, and debugging TensorFlow after it has been fully serialized and then re-imported necessarily loses some metadata and limits interactivity, making debugging even more of a challenge.

Therefore, **we strongly recommend writing complex TF logic as stand-alone
Python functions** (that is, without `tff.tf_computation` decoration). This way the TensorFlow logic can be developed and tested using TF best practices and tools (like eager mode), before serializing the computation for TFF (e.g., by invoking `tff.tf_computation` with a Python function as the argument).

## **Defining a loss function**

### Defining the type of input as a TFF named tuple. 

In [None]:
BATCH_SPEC = collections.OrderedDict(
    x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),                         # Since the size of data batches may vary, batch dimension is set to 
                                                                                  #`None` to indicate that the size of this dimension is unknown.
    y=tf.TensorSpec(shape=[None], dtype=tf.int32))
BATCH_TYPE = tff.to_type(BATCH_SPEC)

str(BATCH_TYPE)

### Defining the model parameters
Parameters as a TFF named tuple of *weights* and *bias*.

In [None]:
MODEL_SPEC = collections.OrderedDict(
    weights=tf.TensorSpec(shape=[784, 10], dtype=tf.float32),
    bias=tf.TensorSpec(shape=[10], dtype=tf.float32))
MODEL_TYPE = tff.to_type(MODEL_SPEC)

print(MODEL_TYPE)

### Loss for the given model

Note the usage of `@tf.function` decorator inside the `@tff.tf_computation` decorator. This allows to write TF using Python like semantics even though were inside a `tf.Graph` context created by the `tff.tf_computation` decorator.

In [None]:
# NOTE: `forward_pass` is defined separately from `batch_loss` so that it can 
# be later called from within another tf.function. Necessary because a
# @tf.function decorated method cannot invoke a @tff.tf_computation.

@tf.function
def forward_pass(model, batch):
  predicted_y = tf.nn.softmax(                                                    # softmax on output
      tf.matmul(batch['x'], model['weights']) + model['bias'])                    # output=Wx+b
  return -tf.reduce_mean(
      tf.reduce_sum(
          tf.one_hot(batch['y'], 10) * tf.math.log(predicted_y), axis=[1]))       

@tff.tf_computation(MODEL_TYPE, BATCH_TYPE)
def batch_loss(model, batch):
  return forward_pass(model, batch)

str(batch_loss.type_signature)

### Constructing an initial model

In [None]:
initial_model = collections.OrderedDict(                                          # initaillizing all weights ans bais to 0
    weights=np.zeros([784, 10], dtype=np.float32),
    bias=np.zeros([10], dtype=np.float32))

sample_batch = federated_train_data[5][-1]

batch_loss(initial_model, sample_batch)

The arguments of the call to `batch_loss` aren't simply passed to the body of that function.

**What happens when we invoke `batch_loss`?**

The Python body of `batch_loss` has already been traced and serialized  in the above cell where it was defined.  TFF acts as the caller to `batch_loss`
at the computation definition time, and as the target of invocation at the time `batch_loss` is invoked. In both roles, TFF serves as the bridge between TFF's abstract type system and Python representation types. At the invocation time, TFF will accept most standard Python container types (`dict`, `list`, `tuple`, `collections.namedtuple`, etc.) as concrete representations of abstract TFF tuples. Also, although as noted above, TFF computations formally only accept a single parameter, you can use the familiar Python call syntax with positional and/or keyword arguments in case where the type of the parameter is a tuple - it works as expected.

### Gradient descent on a single batch

Function to perform a single step of gradient descent. 

Note how in defining this function, we use `batch_loss` as a subcomponent. You can invoke a computation constructed with `tff.tf_computation` inside the body of another computation, though typically this is not necessary - as noted above, because serialization looses some debugging information, it is often preferable for more complex computations to write and test all the TensorFlow without the `tff.tf_computation` decorator.

In [None]:
@tff.tf_computation(MODEL_TYPE, BATCH_TYPE, tf.float32)
def batch_train(initial_model, batch, learning_rate):                             # Define a group of model variables and set them to `initial_model`.
  model_vars = collections.OrderedDict([                                          # Must be defined outside the @tf.function.
      (name, tf.Variable(name=name, initial_value=value))                         # model_vars is an ordered dict. with same parameters as in initial_model and assigning same initial values
      for name, value in initial_model.items()])
  
  optimizer = tf.keras.optimizers.SGD(learning_rate)

  @tf.function
  def _train_on_batch(model_vars, batch):                                         # Performs one step of gradient descent using loss from `batch_loss`.
    with tf.GradientTape() as tape:                                               # "tape" will watch the trainable parameters
      loss = forward_pass(model_vars, batch)                                      
    grads = tape.gradient(loss, model_vars)                                       # "grads" will contain gradient of "loss" with respect to "model_vars"
    optimizer.apply_gradients(                                                    # Applies gradient to variables, arguments are gradient and variables
        zip(tf.nest.flatten(grads), tf.nest.flatten(model_vars)))                 # tf.nest.flatten() returns a flat list from a given nested structure, that is, if input is  [[a,b],[c],[d,e,f]] then output is [a,b,c,d,e,f]
    return model_vars

  return _train_on_batch(model_vars, batch)

In [None]:
str(batch_train.type_signature)

When you invoke a Python function decorated with `tff.tf_computation` within the
body of another such function, the logic of the inner TFF computation is
embedded (essentially, inlined) in the logic of the outer one. As noted above,
if you are writing both computations, it is likely preferable to make the inner
function (`batch_loss` in this case) a regular Python or `tf.function` rather
than a `tff.tf_computation`. However, here we illustrate that calling one
`tff.tf_computation` inside another basically works as expected. This may be
necessary if, for example, you do not have the Python code defining
`batch_loss`, but only its serialized TFF representation.

Now, let's apply this function a few times to the initial model to see whether
the loss decreases.

In [None]:
model = initial_model
losses = []
for _ in range(5):
  model = batch_train(model, sample_batch, 0.1)
  losses.append(batch_loss(model, sample_batch))

In [None]:
losses

### Gradient descent on a sequence of local data

`local_train` that consumes the entire sequence of all batches from one
user instead of just a single batch. The new computation will need to now
consume `tff.SequenceType(BATCH_TYPE)` instead of `BATCH_TYPE`.

In [None]:
LOCAL_DATA_TYPE = tff.SequenceType(BATCH_TYPE)                                    # Definng data type similar to that of BATCH_TYPE

@tff.federated_computation(MODEL_TYPE, tf.float32, LOCAL_DATA_TYPE)
def local_train(initial_model, learning_rate, all_batches):
  @tff.federated_computation(MODEL_TYPE, BATCH_TYPE)
  def batch_fn(model, batch):                                                     # Mapping function to apply to each batch.
    return batch_train(model, batch, learning_rate)                               # batch_train trains one batch one time

  return tff.sequence_reduce(all_batches, initial_model, batch_fn)                # https://www.tensorflow.org/federated/api_docs/python/tff/sequence_reduce
                                                                                  # repeated application of function "batch_fn" on each element of "all_batches" to get reduced model parameters 

In [None]:
str(local_train.type_signature)

There are quite a few details buried in this short section of code, let's go
over them one by one.

*   We could have implemented this logic entirely in TensorFlow, relying on `tf.data.Dataset.reduce` to process the sequence similarly to how
we've done it earlier, we've opted this time to express the logic in the glue
language, as a `tff.federated_computation`. We've used the federated operator
`tff.sequence_reduce` to perform the reduction.

  The operator `tff.sequence_reduce` is used similarly to
`tf.data.Dataset.reduce` but for the use inside federated computations. It is a template operator with a formal parameter 3-tuple:
  *   *sequence* of `T`-typed elements
  *   the initial state of  the reduction of some type `U`
  *   the *reduction operator* of type `(<U,T> -> U)` that alters the
state of the reduction by processing a single element. 

  The result is the final state of the reduction, after processing all elements in a sequential order. 

*   We have again used one computation (`batch_train`) as a
component within another (`local_train`), but not directly. We **can't use it as a reduction operator because it takes an additional parameter - the learning rate.**

  To resolve this, we define an embedded federated computation `batch_fn` that binds to the `local_train`'s parameter `learning_rate` in its body. It is allowed for a child computation defined this way to capture a formal parameter of its parent as long as the child computation is not invoked outside the body of its parent. You can think of this pattern as an equivalent of `functools.partial` in Python.

  The practical implication of capturing `learning_rate` this way is, of course, that the same learning rate value is used across all batches.

In [None]:
locally_trained_model = local_train(initial_model, 0.1, federated_train_data[5])
print(locally_trained_model)

## **Evaluation**

### Local evaluation

In [None]:
@tff.federated_computation(MODEL_TYPE, LOCAL_DATA_TYPE)
def local_eval(model, all_batches):
                                                                                  # Computes sum of losses of all the batch 
  return tff.sequence_sum(                                                        # Computes a sum of elements in a sequence.
      tff.sequence_map(                                                           # Maps a TFF sequence "all_batches" pointwise using given function
          tff.federated_computation(lambda b: batch_loss(model, b), BATCH_TYPE),  # computes loss of each batch for the given "model"
          all_batches))

In [None]:
str(local_eval.type_signature)

In [None]:
@tff.federated_computation(MODEL_TYPE, LOCAL_DATA_TYPE)
def local_eval(model, all_batches):
  # TODO(b/120157713): Replace with `tff.sequence_average()` once implemented.
                                                                                  # Computes sum of losses of all the batch 
  return tff.sequence_sum(                                                        # Computes a sum of elements in a sequence.
      tff.sequence_map(                                                           # Maps a TFF sequence "all_batches" pointwise using given function
          tff.federated_computation(lambda b: batch_loss(model, b), BATCH_TYPE),  # computes loss of each batch for the given "model"
          all_batches))

1.   We have used two new federated operators for processing sequences:
  *  `tff.sequence_map` that takes a *mapping function* `T->U` and a *sequence* of `T`, and emits a sequence of `U` obtained by applying the mapping function pointwise
  *   `tff.sequence_sum` that just adds all the elements

  Note that **we could have again used `tff.sequence_reduce`, but this wouldn't be the best choice - the reduction process is, by definition, sequential, whereas the mapping and sum can be computed in parallel.** 

2.   Just as in `local_train`, the component function we need
(`batch_loss`) takes more parameters than what the federated operator
(`tff.sequence_map`) expects, so we again define a partial, this time inline by directly wrapping a `lambda` as a `tff.federated_computation`. Using wrappers inline with a function as an argument is the recommended way to use
`tff.tf_computation` to embed TensorFlow logic in TFF.

In [None]:
print('initial_model loss =', local_eval(initial_model,
                                         federated_train_data[5]))
print('locally_trained_model loss =',
      local_eval(locally_trained_model, federated_train_data[5]))

In [None]:
print('initial_model loss =', local_eval(initial_model,
                                         federated_train_data[0]))
print('locally_trained_model loss =',
      local_eval(locally_trained_model, federated_train_data[0]))

Loss decreased the case of client 5 but not for client 0.

### Federated Evaluation
A pair of TFF types definitions for the model that originates at the server, and the data that remains on the clients.

In [None]:
SERVER_MODEL_TYPE = tff.type_at_server(MODEL_TYPE)
CLIENT_DATA_TYPE = tff.type_at_clients(LOCAL_DATA_TYPE)

Distribute the model to clients, let each client invoke
local evaluation on its local portion of data, and then average out the loss.

In [None]:
@tff.federated_computation(SERVER_MODEL_TYPE, CLIENT_DATA_TYPE)
def federated_eval(model, data):
  return tff.federated_mean(
      tff.federated_map(local_eval, [tff.federated_broadcast(model), data]))      # Maps the function parallely
                                                                                  # Local evaluation at each client
                                                                                  # tff.federated_broadcast() broadcasts a federated value from the tff.SERVER to the tff.CLIENTS.

1.   *let each client invoke local evaluation on its local portion of data* 

  `local_eval` has a type signature of the form `(<MODEL_TYPE, LOCAL_DATA_TYPE> ->float32)`. The federated operator `tff.federated_map` is a template that accepts as a parameter a 2-tuple:
  *   the *mapping function* of some type `T->U`
  *   a federated value of type `{T}@CLIENTS` (i.e., with member constituents of the same type as the parameter of the mapping function)
  *   returns a result of type `{U}@CLIENTS`.

  The second argument should be of a federated type `{<MODEL_TYPE,
LOCAL_DATA_TYPE>}@CLIENTS`, i.e., in the nomenclature of the preceding sections, it should be a federated tuple. Each client should hold a full set of arguments for `local_eval` as a member consituent. Instead, we're feeding it a 2-element
Python `list`. What's happening here?

Indeed, this is an example of an *implicit type cast* in TFF. Implicit casting is used scarcily at this point, but we plan to make it more pervasive in TFF as a way to minimize boilerplate.

The implicit cast that's applied in this case is the equivalence between federated tuples of the form `{<X,Y>}@Z`, and tuples of federated values `<{X}@Z,{Y}@Z>`. While formally, these two are different type signatures, looking at it from the programmers's perspective, each device in `Z` holds two units of data `X` and `Y`. What happens here is not unlike `zip` in Python, and indeed, we offer an operator `tff.federated_zip` that allows you to perform such
conversions explicity. When the `tff.federated_map` encounters a tuple as a second argument, it simply invokes `tff.federated_zip` for you.

Given the above, you should now be able to recognize the expression
`tff.federated_broadcast(model)` as representing a value of TFF type
`{MODEL_TYPE}@CLIENTS`, and `data` as a value of TFF type
`{LOCAL_DATA_TYPE}@CLIENTS` (or simply `CLIENT_DATA_TYPE`), the two getting filtered together through an implicit `tff.federated_zip` to form the second argument to `tff.federated_map`.

In [None]:
print('initial_model loss =', federated_eval(initial_model,
                                             federated_train_data))
print('locally_trained_model loss =',
      federated_eval(locally_trained_model, federated_train_data))

Indeed, as expected, the loss has increased. In order to improve the model for all users, we'll need to train in on everyone's data.

## **Federated training**

The simplest way to implement federated training is to locally train, and then average the models. 

In [None]:
SERVER_FLOAT_TYPE = tff.type_at_server(tf.float32)


@tff.federated_computation(SERVER_MODEL_TYPE, SERVER_FLOAT_TYPE,
                           CLIENT_DATA_TYPE)
def federated_train(model, learning_rate, data):                                  # Locally trains data at each client and takes their average
  return tff.federated_mean(
      tff.federated_map(local_train, [
          tff.federated_broadcast(model),
          tff.federated_broadcast(learning_rate), data
      ]))

Note that in the full-featured implementation of Federated Averaging provided by
`tff.learning`, rather than averaging the models, we prefer to average model
deltas, for a number of reasons, e.g., the ability to clip the update norms,
for compression, etc.

In [None]:
model = initial_model
learning_rate = 0.1
for round_num in range(5):                                                        # running a few rounds of federated training
  model = federated_train(model, learning_rate, federated_train_data)
  learning_rate = learning_rate * 0.9                                             # Decreasing learning rate after each iteration
  loss = federated_eval(model, federated_train_data)
  print('round {}, loss={}'.format(round_num, loss))

### Testing

In [None]:
print('initial_model test loss =',
      federated_eval(initial_model, federated_test_data))
print('trained_model test loss =', federated_eval(model, federated_test_data))

# **Points to remember**


*  `tf.reduce_sum`, `tf.reduce_mean` is used for finding sum, mean in tensorf flow where computations and final values are stored at same place.
* `tff.federated_sum`, `tff.federated_mean` do the same for the values input from client but final values are stored at server.



*   How can we find the time taken be each local training?
*   https://github.com/tensorflow/federated/blob/master/tensorflow_federated/python/learning/federated_averaging.py



