# Test Tensorflow-federated (TFF) library

## Test #3 : custom training procedure

In [1]:
import os
import collections
import nest_asyncio
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_federated as tff

tf.config.set_visible_devices([tf.config.list_physical_devices('GPU')[0]], 'GPU')

nest_asyncio.apply()

print('Tensorflow version : {}'.format(tf.__version__))
print('Tensorflow-federated version : {}'.format(tff.__version__))
print('# GPUs : {}'.format(len(tf.config.list_logical_devices('GPU'))))

tff.federated_computation(lambda: 'Hello, World!')()

2022-11-15 11:10:46.155377: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-15 11:10:46.253961: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2022-11-15 11:10:46.278263: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Tensorflow version : 2.10.0
Tensorflow-federated version : 0.39.0
# GPUs : 1


2022-11-15 11:10:56.062093: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-15 11:10:56.464487: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 371 MB memory:  -> device: 0, name: Quadro RTX 5000, pci bus id: 0000:17:00.0, compute capability: 7.5


b'Hello, World!'

In [2]:
emnist_train, emnist_valid = tff.simulation.datasets.emnist.load_data()
print('Dataset length :\n  Train length : {}\n  Valid length : {}'.format(
    len(emnist_train.client_ids), len(emnist_valid.client_ids)
))
print('Data signature : {}'.format(emnist_train.element_type_structure))

Dataset length :
  Train length : 3383
  Valid length : 3383
Data signature : OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])


In [3]:
def preprocess(dataset, epochs = 50, batch_size = 32, shuffle_size = 1024, prefetch_size = 16):
    def batch_format_fn(data):
        return (
            tf.reshape(data['pixels'], [-1, 28 * 28]),
            tf.reshape(data['label'], [-1, 1])
        )
    
    return dataset.repeat(epochs).shuffle(shuffle_size, seed = 1).batch(batch_size).map(batch_format_fn).prefetch(prefetch_size)

def make_federated_data(client_data, ids = None, n = None):
    if ids is None: ids = client_data.client_ids[:n]
    return [
        preprocess(client_data.create_tf_dataset_for_client(x)) for x in ids
    ]

train_fed_data = make_federated_data(emnist_train, n = 25)
valid_fed_data = make_federated_data(emnist_valid, n = 10)
valid_data     = preprocess(emnist_valid.create_tf_dataset_from_all_clients(), epochs = 1, batch_size = 256)

print('# datasets : train {} - valid {}'.format(len(train_fed_data), len(valid_fed_data)))

# datasets : train 25 - valid 10


## Custom training process

### `initialize()` function

In [4]:
def build_model():
    model = tf.keras.models.Sequential([
        tf.keras.layers.Input(shape = (28 * 28, ), dtype = tf.int32),
        tf.keras.layers.Dense(32, activation = 'relu'),
        tf.keras.layers.Dense(10, activation = 'softmax')
    ], name = 'simple_mlp')
    model.build((None, 28 * 28))
    return model

def build_fed_model():
    model = build_model()
    return tff.learning.from_keras_model(
        model,
        input_spec = train_fed_data[0].element_spec,
        loss       = tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics    = [tf.keras.metrics.SparseCategoricalAccuracy()]
    )

@tff.tf_computation
def server_init():
    model = build_fed_model()
    return model.trainable_variables

@tff.federated_computation
def initialize_fn():
    return tff.federated_value(server_init(), tff.SERVER)

whimsy_model = build_fed_model()
input_spec   = tff.SequenceType(whimsy_model.input_spec)
weights_spec = server_init.type_signature.result

print(input_spec)
print(weights_spec)

<float32[?,784],int32[?,1]>*
<float32[784,32],float32[32],float32[32,10],float32[10]>


### `next(server_state, federated_data)` function

In [5]:
@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
    tf.nest.map_structure(
        lambda w, s_w: w.assign(s_w), model.trainable_variables, server_weights
    )
    
    for batch in dataset:
        with tf.GradientTape() as tape:
            outputs = model.forward_pass(batch)
        
        grads = tape.gradient(outputs.loss, model.trainable_variables)
        client_optimizer.apply_gradients(zip(grads, model.trainable_variables))
    
    return model.trainable_variables

@tf.function
def server_update(model, mean_client_weights):
    tf.nest.map_structure(
        lambda w, new_w: w.assign(new_w), model.trainable_variables, mean_client_weights
    )
    return model.trainable_variables


@tff.tf_computation(input_spec, weights_spec)
def client_update_fn(dataset, server_weights):
    model = build_fed_model()
    optimizer = tf.keras.optimizers.SGD(0.1)
    return client_update(model, dataset, server_weights, optimizer)

@tff.tf_computation(weights_spec)
def server_update_fn(mean_client_weights):
    model = build_fed_model()
    return server_update(model, mean_client_weights)


In [6]:
fed_server_type = tff.FederatedType(weights_spec, tff.SERVER)
fed_data_type   = tff.FederatedType(input_spec, tff.CLIENTS)

@tff.federated_computation(fed_server_type, fed_data_type)
def next_fn(server_state, federated_data):
    server_weights_at_client = tff.federated_broadcast(server_state)
    
    client_weights = tff.federated_map(
        client_update_fn, (federated_data, server_weights_at_client)
    )
    
    mean_client_weights = tff.federated_mean(client_weights)
    
    server_weights = tff.federated_map(
        server_update_fn, mean_client_weights
    )
    
    return server_weights

iterative_process = tff.templates.IterativeProcess(
    initialize_fn = initialize_fn,
    next_fn = next_fn
)
print(iterative_process.initialize.type_signature)
print(iterative_process.next.type_signature)

( -> <float32[784,32],float32[32],float32[32,10],float32[10]>@SERVER)
(<server_state=<float32[784,32],float32[32],float32[32,10],float32[10]>@SERVER,federated_data={<float32[?,784],int32[?,1]>*}@CLIENTS> -> <float32[784,32],float32[32],float32[32,10],float32[10]>@SERVER)


## Training and evaluation

In [7]:
def evaluate(server_state):
    model = build_model()
    model.compile(
        loss = tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
    )
    model.set_weights(server_state)
    return model.evaluate(valid_data)

In [8]:
state = iterative_process.initialize()
evaluate(state)



[2.9456522464752197, 0.09896649420261383]

In [9]:
state = iterative_process.next(state, train_fed_data)
evaluate(state)



[2.3014912605285645, 0.11224039644002914]

In [10]:
epochs = 10
for epoch in range(2, epochs + 1):
    print('Epoch {} / {}'.format(epoch, epochs + 1))
    state = iterative_process.next(state, train_fed_data)
evaluate(state)



[2.3019518852233887, 0.11224039644002914]