# Test Tensorflow-federated (TFF) library

## Test #2-2 : MNIST classification with `BaseFLModel` interface + local layers

In [1]:
import os
import collections
import nest_asyncio
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_federated as tff

tf.config.set_visible_devices([tf.config.list_physical_devices('GPU')[0]], 'GPU')

nest_asyncio.apply()

print('Tensorflow version : {}'.format(tf.__version__))
print('Tensorflow-federated version : {}'.format(tff.__version__))
print('# GPUs : {}'.format(len(tf.config.list_logical_devices('GPU'))))

tff.federated_computation(lambda: 'Hello, World!')()

2022-11-21 13:42:08.364097: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-21 13:42:08.459099: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2022-11-21 13:42:08.484084: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Tensorflow version : 2.10.0
Tensorflow-federated version : 0.39.0
# GPUs : 1


2022-11-21 13:42:15.632412: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-21 13:42:16.012069: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 14783 MB memory:  -> device: 0, name: Quadro RTX 5000, pci bus id: 0000:17:00.0, compute capability: 7.5


b'Hello, World!'

In [2]:
import collections
import tensorflow as tf

from models.interfaces.base_fl_model import BaseFLModel
from models.interfaces.base_image_model import BaseImageModel

class MNISTFLClassifier(BaseImageModel, BaseFLModel):
    def __init__(self, input_size = (28, 28, 1), n_labels = 10, ** kwargs):
        self._init_image(input_size = input_size, ** kwargs)
        self._init_fl(** kwargs)
        
        self.n_labels = n_labels
        
        super().__init__(** kwargs)
    
    def _build_model(self):
        super()._build_model(model = {
            'architecture_name' : 'perceptron',
            'input_shape' : self.input_size,
            'units'       : 32,
            'n_dense'     : 1,
            'activation'  : 'relu',
            'bnorm'       : 'never',
            'output_shape' : self.n_labels,
            'final_bias'   : True,
            'final_activation' : 'softmax'
        })
    
    @property
    def output_signature(self):
        return tf.TensorSpec(shape = (None, 1), dtype = tf.int32)
    
    @property
    def local_layers(self):
        return self.layers[-2:]
    
    @property
    def global_layers(self):
        return self.layers[:-2]
    
    def __str__(self):
        des = super().__str__()
        des += self._str_image()
        des += self._str_fl()
        des += '- # labels : {}\n'.format(self.n_labels)
        return des

    def compile(self, loss = 'sparse_categorical_crossentropy', metrics = ['sparse_categorical_accuracy'], ** kwargs):
        super().compile(loss = loss, metrics = metrics, ** kwargs)
    
    def preprocess_data(self, data):
        return (
            tf.expand_dims(data['pixels'], axis = -1),
            tf.cast(tf.reshape(data['label'], [-1, 1]), tf.int32)
        )
    
    def get_dataset_config(self, * args, ** kwargs):
        kwargs['batch_before_map'] = True
        return super().get_dataset_config(* args, ** kwargs)
    
    def get_config(self, * args, ** kwargs):
        config = super().get_config(* args, ** kwargs)
        config.update({
            ** self.get_config_image(),
            ** self.get_config_fl(),
            'n_labels' : self.n_labels
        })
        return config
    
    
model = MNISTFLClassifier()
model.summary()
print(model)



Model: "multi_layer_perceptron"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten (Flatten)           (None, 784)               0         
                                                                 
 dense_0 (Dense)             (None, 32)                25120     
                                                                 
 activation (Activation)     (None, 32)                0         
                                                                 
 dropout (Dropout)           (None, 32)                0         
                                                                 
 classification_layer (Dense  (None, 10)               330       
 )                                                               
                                                                 
 activation_1 (Activation)   (None, 10)                0         
                                          

In [3]:
emnist_train, emnist_valid = tff.simulation.datasets.emnist.load_data()
print('Dataset length :\n  Train length : {}\n  Valid length : {}'.format(
    len(emnist_train.client_ids), len(emnist_valid.client_ids)
))
print('Data signature : {}'.format(emnist_train.element_type_structure))

Dataset length :
  Train length : 3383
  Valid length : 3383
Data signature : OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])


## Initialization

In [4]:
model.compile()

config = model._get_fl_train_config(
    emnist_train,
    validation_data = emnist_valid,
    
    n_train_clients = 25,
    n_valid_clients = 10,
    
    batch_size = 64
)

config

{'model_fn': functools.partial(<function BaseFLModel.get_model_fn.<locals>.reconstruction_model_fn at 0x7f8dda9dc5e0>),
 'loss_fn': <function models.interfaces.base_fl_model.BaseFLModel.get_loss_fn.<locals>.<lambda>()>,
 'metrics_fn': <function models.interfaces.base_fl_model.BaseFLModel.get_metrics_fn.<locals>.<lambda>()>,
 'server_optimizer_fn': <function models.interfaces.base_fl_model.BaseFLModel.get_optimizer_fn.<locals>.<lambda>()>,
 'client_optimizer_fn': <function models.interfaces.base_fl_model.BaseFLModel.get_optimizer_fn.<locals>.<lambda>()>,
 'reconstruction_optimizer_fn': <function models.interfaces.base_fl_model.BaseFLModel.get_optimizer_fn.<locals>.<lambda>()>,
 'train_ids': ['f1491_42',
  'f3914_30',
  'f0724_37',
  'f1546_05',
  'f3912_15',
  'f1382_07',
  'f2225_86',
  'f0932_44',
  'f1728_02',
  'f2165_54',
  'f2258_85',
  'f1234_26',
  'f1299_42',
  'f2520_64',
  'f3255_43',
  'f0072_36',
  'f3267_41',
  'f3135_37',
  'f1760_27',
  'f0553_03',
  'f0402_38',
  'f3159

In [5]:
train_fed_data = config['x']
valid_fed_data = config['validation_data']

training_process = tff.learning.reconstruction.build_training_process(
    model_fn   = config['model_fn'],
    loss_fn    = config['loss_fn'],
    metrics_fn = config['metrics_fn'],
    server_optimizer_fn = config['server_optimizer_fn'],
    client_optimizer_fn = config['client_optimizer_fn'],
    reconstruction_optimizer_fn = config['reconstruction_optimizer_fn']
)

evaluation_process = tff.learning.reconstruction.build_federated_evaluation(
    model_fn   = config['model_fn'],
    loss_fn    = config['loss_fn'],
    metrics_fn = config['metrics_fn'],
    reconstruction_optimizer_fn = config['reconstruction_optimizer_fn']
)
print(training_process.initialize.type_signature.formatted_representation())

( -> <
  model=<
    trainable=<
      float32[784,32],
      float32[32]
    >,
    non_trainable=<>
  >,
  optimizer_state=<
    int64,
    float32[784,32],
    float32[32],
    float32[784,32],
    float32[32]
  >,
  delta_aggregate_state=<
    value_sum_process=<>,
    weight_sum_process=<>
  >,
  model_broadcast_state=<>
>@SERVER)


## Training

In [6]:
state = training_process.initialize()

In [7]:
state, metrics = training_process.next(state, train_fed_data)
print('Round 1, metrics : {}'.format(metrics))

2022-11-21 13:42:25.567437: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:25.567502: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:25.614119: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:25.614146: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:25.686341: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:25.687404: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:25.799977: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:25.800010: I tensorflow/

Round 1, metrics : OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.10477002), ('loss', 3.426327)]))])


In [8]:
epochs = 10
for epoch in range(2, epochs + 1):
    state, metrics = training_process.next(state, train_fed_data)
    print('Epoch {} : {}'.format(epoch, metrics['train']))


2022-11-21 13:42:27.894270: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:27.894293: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:27.894544: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:27.894560: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:27.895150: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:27.895164: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:27.895247: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:27.895301: I tensorflow/

Epoch 2 : OrderedDict([('sparse_categorical_accuracy', 0.11968504), ('loss', 2.8183522)])


2022-11-21 13:42:28.581564: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:28.581590: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:28.581736: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:28.581753: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:28.582019: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:28.582036: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:28.582397: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:28.582410: I tensorflow/

Epoch 3 : OrderedDict([('sparse_categorical_accuracy', 0.10997643), ('loss', 2.6411107)])


2022-11-21 13:42:29.283453: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:29.283479: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:29.283497: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:29.283514: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:29.283778: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:29.283804: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:29.284034: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:29.284048: I tensorflow/

Epoch 4 : OrderedDict([('sparse_categorical_accuracy', 0.093296476), ('loss', 2.6238232)])


2022-11-21 13:42:29.976071: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:29.976096: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:29.976221: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:29.976240: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:29.976591: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:29.976608: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:29.976774: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:29.976786: I tensorflow/

Epoch 5 : OrderedDict([('sparse_categorical_accuracy', 0.095348835), ('loss', 2.4971082)])


2022-11-21 13:42:30.935881: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:30.935918: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:30.936109: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:30.936134: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:30.936440: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:30.936492: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:30.936748: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:30.936788: I tensorflow/

Epoch 6 : OrderedDict([('sparse_categorical_accuracy', 0.0992647), ('loss', 2.4235313)])


2022-11-21 13:42:31.618189: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:31.618221: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:31.618245: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:31.618257: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:31.618455: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:31.618487: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:31.618883: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:31.618917: I tensorflow/

Epoch 7 : OrderedDict([('sparse_categorical_accuracy', 0.08245756), ('loss', 2.4012825)])


2022-11-21 13:42:32.313102: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:32.313129: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:32.313546: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:32.313561: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:32.313842: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:32.313898: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:32.314230: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:32.314246: I tensorflow/

Epoch 8 : OrderedDict([('sparse_categorical_accuracy', 0.086108856), ('loss', 2.3372757)])


2022-11-21 13:42:33.011761: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:33.011791: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:33.012015: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:33.012070: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:33.012087: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:33.012132: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:33.012378: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:33.012420: I tensorflow/

Epoch 9 : OrderedDict([('sparse_categorical_accuracy', 0.108063176), ('loss', 2.332305)])


2022-11-21 13:42:33.690957: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:33.690982: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:33.691171: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:33.691206: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:33.691380: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:33.691406: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:33.691797: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2022-11-21 13:42:33.691811: I tensorflow/

Epoch 10 : OrderedDict([('sparse_categorical_accuracy', 0.12409347), ('loss', 2.3102055)])
