## Author - Sidhanta Narayan Singhdeo
## CWID  - 10465272
## Course - AAI 800 (Special Problems in AI)
## Project Advisor - Prof. Hong Man

### Implement Custom Federated Averaging
*   Understand the general structure of federated learning algorithms.
*   Explore the *Federated Core* of TFF.
*   Use the Federated Core to implement Federated Averaging directly

In [1]:
#Patch asyncio to allow nested event loops

import nest_asyncio
nest_asyncio.apply()

In [2]:
#Load required Libraries
import tensorflow as tf
import tensorflow_federated as tff

### Loading the CIFAR100 data 


In [3]:
# emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
cifar_train, cifar_test = tff.simulation.datasets.cifar100.load_data(cache_dir=None)

### Start Preprocessing the Data
In order to feed the dataset into our model, the data is flattened, and each example is converted into a tuple of the form `(flattened_image_vector, label)`.

In [4]:
NUM_CLIENTS = 10
BATCH_SIZE = 20

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch of EMNIST data and return a (features, label) tuple."""
    return (tf.reshape(element['image'], [-1, 3072]), 
            tf.reshape(element['coarse_label'], [-1, 1]))

  return dataset.batch(BATCH_SIZE).map(batch_format_fn)

#Function to iterate over clients and preprocess.

In [5]:
client_ids = sorted(cifar_train.client_ids)[:NUM_CLIENTS]
federated_train_data = [preprocess(cifar_train.create_tf_dataset_for_client(x))
  for x in client_ids
]

### Keras Model Declaration

In [6]:
# model = Sequential()

# model.add(Conv2D(32, (3, 3), input_shape=(32, 32, 3)))
# model.add(LeakyReLU(alpha=0.1))
# BatchNormalization(axis=-1)
# model.add(Conv2D(32, (3, 3)))
# model.add(LeakyReLU(alpha=0.1))
# model.add(MaxPooling2D(pool_size=(2, 2)))

# BatchNormalization(axis=-1)
# model.add(Conv2D(64, (3, 3)))
# model.add(LeakyReLU(alpha=0.1))
# BatchNormalization(axis=-1)
# model.add(Conv2D(64, (3, 3)))
# model.add(LeakyReLU(alpha=0.1))
# model.add(MaxPooling2D(pool_size=(2, 2)))

# model.add(Flatten())

# BatchNormalization()
# model.add(Dense(512))
# model.add(LeakyReLU(alpha=0.1))
# BatchNormalization()
# model.add(Dropout(0.2))
# model.add(Dense(100))


# model.add(Activation('softmax'))

We have had tried to implement the CNN(above) in other models but there were issues that would delay the project

We can try peforming the same on the above CNN in the future, for now we go with a simpler version

In [7]:
def create_keras_model():
  initializer = tf.keras.initializers.GlorotNormal(seed=0)
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(3072,)),
      tf.keras.layers.Dense(10, kernel_initializer=initializer),
      tf.keras.layers.Softmax(),
  ])

In order to use this model in TFF, wrap the Keras model as a [`tff.learning.Model`]

In [8]:
def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=federated_train_data[0].element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

trainable_variables`: An iterable of the tensors corresponding to trainable layers.

non_trainable_variables`: An iterable of the tensors corresponding to non-trainable layers.

In [9]:
def initialize_fn():
  model = model_fn()
  return model.trainable_variables

This function looks good, but as you will see later, you will need to make a small modification to make it a "TFF computation".

Next, let's write a sketch of the `next_fn`.

In [10]:
def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = client_update(federated_dataset, server_weights_at_client)

  # The server averages these updates.
  mean_client_weights = mean(client_weights)

  # The server updates its model.
  server_weights = server_update(mean_client_weights)

  return server_weights

Let's focus on implementing these four components separately. First, let's focus on the parts that can be implemented in pure TensorFlow, namely the client and server update steps.


## TensorFlow Blocks 

### Client update

The `tff.learning.Model` can be used to do client training in essentially the same way you would train a TensorFlow model. 
We  use `tf.GradientTape` to compute the gradient on batches of data, then apply these gradient using a `client_optimizer`. This will only involve the trainable weights.


In [11]:
@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
  """Performs training (using the server model weights) on the client's dataset."""
  # Initialize the client model with the current server weights.
  client_weights = model.trainable_variables
  # Assign the server weights to the client model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        client_weights, server_weights)

  # Use the client_optimizer to update the local model.
  for batch in dataset:
    with tf.GradientTape() as tape:
      # Compute a forward pass on the batch of data
      outputs = model.forward_pass(batch)

    # Compute the corresponding gradient
    grads = tape.gradient(outputs.loss, client_weights)
    grads_and_vars = zip(grads, client_weights)

    # Apply the gradient using a client optimizer.
    client_optimizer.apply_gradients(grads_and_vars)

  return client_weights

### Server Update 

We implement "vanilla" federated averaging, in which the server model weights are replaced by the average of the client model weights. 

Again, this only uses the trainable weights.

In [12]:
@tf.function
def server_update(model, mean_client_weights):
  """Updates the server model weights as the average of the client model weights."""
  model_weights = model.trainable_variables
  # Assign the mean client weights to the server model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        model_weights, mean_client_weights)
  return model_weights

## Federated computations

In [13]:
@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def get_average_temperature(client_temperatures):
  return tff.federated_mean(client_temperatures)

This `tff.federated_computation` accepts arguments of federated type `{float32}@CLIENTS`, and returns values of federated type `{float32}@SERVER`. Federated computations may also go from server to client, from client to client, or from server to server. Federated computations can also be composed like normal functions, as long as their type signatures match up.


## TensorFlow Federated blocks 

### Creating the initialization computation

The initialize function will be quite simple: You will create a model using `model_fn`. However, remember that you must separate out our TensorFlow code using `tff.tf_computation`.

In [15]:
@tff.tf_computation
def server_init():
  model = model_fn()
  return model.trainable_variables

We pass this directly into a federated computation using `tff.federated_value`.

In [16]:
@tff.federated_computation
def initialize_fn():
  return tff.federated_value(server_init(), tff.SERVER)

### Creating the `next_fn`

The client and server update code is now be used to write the actual algorithm. We will the `client_update` into a `tff.tf_computation` that accepts a client datasets and server weights, and outputs an updated client weights tensor.
We need the corresponding types to properly decorate our function.The type of the server weights can be extracted directly from our model.

In [17]:
whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)

In [18]:
#Check structure of data type
str(tf_dataset_type)

'<uint8[?,3072],int64[?,1]>*'

Model weights type can also be extracted by using our `server_init` function above.

In [19]:
model_weights_type = server_init.type_signature.result

In [20]:
str(model_weights_type)

'<float32[3072,10],float32[10]>'

We then create our `tff.tf_computation` for the client update.

In [21]:
@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
  model = model_fn()
  client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
  return client_update(model, tf_dataset, server_weights, client_optimizer)

The `tff.tf_computation` version of the server update is defined in a similar way

In [22]:
@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
  model = model_fn()
  return server_update(model, mean_client_weights)

We need to create the `tff.federated_computation` that brings this all together. This function will accept two *federated values*, one corresponding to the server weights (with placement `tff.SERVER`), and the other corresponding to the client datasets (with placement `tff.CLIENTS`).

We need to give them the proper placement using `tff.FederatedType` using the types defined above.

In [23]:
federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

Build the core of TFF comprising the below steps 

1. A server-to-client broadcast step.
2. A local client update step.
3. A client-to-server upload step.
4. A server update step.



In [24]:
@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = tff.federated_broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = tff.federated_map(
      client_update_fn, (federated_dataset, server_weights_at_client))
  
  # The server averages these updates.
  mean_client_weights = tff.federated_mean(client_weights)

  # The server updates its model.
  server_weights = tff.federated_map(server_update_fn, mean_client_weights)

  return server_weights

### Build Iterative Process

In [25]:
federated_algorithm = tff.templates.IterativeProcess(
    initialize_fn=initialize_fn,
    next_fn=next_fn
)

Check *type signature* of the `initialize` and `next` functions of our iterative process.

In [26]:
str(federated_algorithm.initialize.type_signature)

'( -> <float32[3072,10],float32[10]>@SERVER)'

This reflects the fact that `federated_algorithm.initialize` is a no-arg function that returns a single-layer model (with a 3072-by-10 weight matrix, and 10 bias units).

In [27]:
str(federated_algorithm.next.type_signature)

'(<server_weights=<float32[3072,10],float32[10]>@SERVER,federated_dataset={<uint8[?,3072],int64[?,1]>*}@CLIENTS> -> <float32[3072,10],float32[10]>@SERVER)'

Here, one can see that `federated_algorithm.next` accepts a server model and client data, and returns an updated server model.

## Evaluating the algorithm

In [28]:
central_emnist_test = cifar_test.create_tf_dataset_from_all_clients()
central_emnist_test = preprocess(central_emnist_test)

In [29]:
#use the weights forom last state in `set_weights`!
def evaluate(server_state):
  keras_model = create_keras_model()
  keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]  
  )
  keras_model.set_weights(server_state)
  keras_model.evaluate(central_emnist_test)

In [30]:
server_state = federated_algorithm.initialize()
evaluate(server_state)



In [31]:
#Perform for few more rounds
for round in range(15):
  server_state = federated_algorithm.next(server_state, federated_train_data)

In [32]:
#Evaluate Final state
evaluate(server_state)

