# Building your own FL Algorithm
https://www.tensorflow.org/federated/tutorials/building_your_own_federated_learning_algorithm

source "venv/bin/activate"

In [1]:
import nest_asyncio
nest_asyncio.apply()

import tensorflow as tf
import tensorflow_federated as tff

2022-05-24 04:09:18.008318: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-05-24 04:09:18.008361: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


### Input data

In [2]:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()


2022-05-24 04:09:22.864326: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2022-05-24 04:09:22.864375: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)
2022-05-24 04:09:22.864400: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (debian): /proc/driver/nvidia/version does not exist
2022-05-24 04:09:22.864684: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
NUM_CLIENTS = 10
BATCH_SIZE = 20

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch of EMNIST data and return a (features, label) tuple."""
    return (tf.reshape(element['pixels'], [-1, 784]), 
            tf.reshape(element['label'], [-1, 1]))

  return dataset.batch(BATCH_SIZE).map(batch_format_fn)


In [4]:
client_ids = sorted(emnist_train.client_ids)[:NUM_CLIENTS]
federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
  for x in client_ids
]


### Preparing the model

In [5]:
def create_keras_model():
  initializer = tf.keras.initializers.GlorotNormal(seed=0)
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer=initializer),
      tf.keras.layers.Softmax(),
  ])


In [6]:
def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=federated_train_data[0].element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])


### Building FL algorithm

In [7]:
def initialize_fn():
  model = model_fn()
  return model.trainable_variables


In [8]:
def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = client_update(federated_dataset, server_weights_at_client)

  # The server averages these updates.
  mean_client_weights = mean(client_weights)

  # The server updates its model.
  server_weights = server_update(mean_client_weights)

  return server_weights


In [9]:
@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
  """Performs training (using the server model weights) on the client's dataset."""
  # Initialize the client model with the current server weights.
  client_weights = model.trainable_variables
  # Assign the server weights to the client model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        client_weights, server_weights)

  # Use the client_optimizer to update the local model.
  for batch in dataset:
    with tf.GradientTape() as tape:
      # Compute a forward pass on the batch of data
      outputs = model.forward_pass(batch)

    # Compute the corresponding gradient
    grads = tape.gradient(outputs.loss, client_weights)
    grads_and_vars = zip(grads, client_weights)

    # Apply the gradient using a client optimizer.
    client_optimizer.apply_gradients(grads_and_vars)

  return client_weights


In [10]:
@tf.function
def server_update(model, mean_client_weights):
  """Updates the server model weights as the average of the client model weights."""
  model_weights = model.trainable_variables
  # Assign the mean client weights to the server model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        model_weights, mean_client_weights)
  return model_weights


### Federated Core

In [11]:
federated_float_on_clients = tff.FederatedType(tf.float32, tff.CLIENTS)
str(federated_float_on_clients)


'{float32}@CLIENTS'

federated computations: code generated by tff.federated_computation is neither TensorFlow nor Python code; ; It is a specification of a distributed system in an internal platform-independent glue language

In [12]:
@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def get_average_temperature(client_temperatures):
  return tff.federated_mean(client_temperatures)

In [13]:
str(get_average_temperature.type_signature)


'({float32}@CLIENTS -> float32@SERVER)'

In [14]:
get_average_temperature([68.5, 70.3, 69.8])


69.53334

- TFF  computations are non-eager
- TFF computations cannot contain TF operations. Just federated operators. TensorFlow code must be confined to blocks decorated with tff.tf_computation

In [15]:
@tff.tf_computation(tf.float32)
def add_half(x):
  return tf.add(x, 0.5)


In [16]:
str(add_half.type_signature)


'(float32 -> float32)'

tff.federated_computation has explicit placements and tff.tf_computation doesnt.

We can use tff.tf_computation blocks in federated computations by specifying placements.

In [17]:
@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def add_half_on_clients(x):
  return tff.federated_map(add_half, x)

str(add_half_on_clients.type_signature)


'({float32}@CLIENTS -> {float32}@CLIENTS)'

tff.federated_map applies a given tff.tf_computation preserving the placement

### Summary
- TFF operates on federated values.
- Each federated value has a federated type, with a type (eg. tf.float32) and a placement (eg. tff.CLIENTS).
- Federated values can be transformed using federated computations, which must be decorated with tff.federated_computation and a federated type signature.
- TensorFlow code must be contained in blocks with tff.tf_computation decorators.
- These blocks can then be incorporated into federated computations.

## Own FL algorithm

In [18]:
@tff.tf_computation
def server_init():
  model = model_fn()
  return model.trainable_variables


In [19]:
@tff.federated_computation
def initialize_fn():
  return tff.federated_value(server_init(), tff.SERVER)


Turn our client_update into a tff.tf_computation that accepts a client datasets and server weights, and outputs an updated client weights tensor.

In [20]:
whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)
str(tf_dataset_type)



'<float32[?,784],int32[?,1]>*'

In [21]:
model_weights_type = server_init.type_signature.result
str(model_weights_type)



'<float32[784,10],float32[10]>'

tff.tf_computation for the client update

In [22]:
@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
  model = model_fn()
  client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
  return client_update(model, tf_dataset, server_weights, client_optimizer)


tff.federated_computation for the server update

In [23]:
@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
  model = model_fn()
  return server_update(model, mean_client_weights)


Last, but not least, we need to create the tff.federated_computation that brings this all together. 

In [24]:
federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)


In [25]:
@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = tff.federated_broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = tff.federated_map(
      client_update_fn, (federated_dataset, server_weights_at_client))

  # The server averages these updates.
  mean_client_weights = tff.federated_mean(client_weights)

  # The server updates its model.
  server_weights = tff.federated_map(server_update_fn, mean_client_weights)

  return server_weights


We now have a tff.federated_computation for both the algorithm initialization, and for running one step of the algorithm. To finish our algorithm, we pass these into tff.templates.IterativeProcess.

In [26]:
federated_algorithm = tff.templates.IterativeProcess(
    initialize_fn=initialize_fn,
    next_fn=next_fn
)
str(federated_algorithm.initialize.type_signature)



'( -> <float32[784,10],float32[10]>@SERVER)'

## Algorithm evaluation

We first create a centralized evaluation dataset, and then apply the same preprocessing we used for the training data.

In [27]:
central_emnist_test = emnist_test.create_tf_dataset_from_all_clients()
central_emnist_test = preprocess(central_emnist_test)


Write a function that accepts a server state, and uses Keras to evaluate on the test dataset.

In [37]:
def evaluate(server_state):
  keras_model = create_keras_model()
  keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]  
  )
  keras_model.set_weights(server_state)
  keras_model.evaluate(central_emnist_test)


 initialize our algorithm and evaluate on the test set

In [38]:
server_state = federated_algorithm.initialize()
evaluate(server_state)



In [39]:
for round in range(15):
  server_state = federated_algorithm.next(server_state, federated_train_data)
evaluate(server_state)


2022-05-24 04:13:24.829943: W tensorflow/core/data/root_dataset.cc:200] Optimization loop failed: CANCELLED: Operation was cancelled




(Add gradient clipping and learning rate decay on the clients; server stores it)