# Test Tensorflow-federated (TFF) library

## Test #1-2 : MNIST classification with `BaseFLModel` interface

In [1]:
import os
import collections
import nest_asyncio
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_federated as tff

tf.config.set_visible_devices([tf.config.list_physical_devices('GPU')[0]], 'GPU')

nest_asyncio.apply()

print('Tensorflow version : {}'.format(tf.__version__))
print('Tensorflow-federated version : {}'.format(tff.__version__))
print('# GPUs : {}'.format(len(tf.config.list_logical_devices('GPU'))))

tff.federated_computation(lambda: 'Hello, World!')()

2022-11-23 13:09:54.970174: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-23 13:09:55.067456: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2022-11-23 13:09:55.091681: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Tensorflow version : 2.10.0
Tensorflow-federated version : 0.39.0
# GPUs : 1


2022-11-23 13:10:06.171857: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-23 13:10:06.551768: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 14783 MB memory:  -> device: 0, name: Quadro RTX 5000, pci bus id: 0000:17:00.0, compute capability: 7.5


b'Hello, World!'

In [2]:
import collections
import tensorflow as tf

from models.interfaces.base_fl_model import BaseFLModel
from models.interfaces.base_image_model import BaseImageModel

class MNISTFLClassifier(BaseImageModel, BaseFLModel):
    def __init__(self, input_size = (28, 28, 1), n_labels = 10, ** kwargs):
        self._init_image(input_size = input_size, ** kwargs)
        self._init_fl(** kwargs)
        
        self.n_labels = n_labels
        
        super().__init__(** kwargs)
    
    def _build_model(self):
        super()._build_model(model = {
            'architecture_name' : 'perceptron',
            'input_shape' : self.input_size,
            'units'       : 32,
            'n_dense'     : 1,
            'activation'  : 'relu',
            'bnorm'       : 'never',
            'output_shape' : self.n_labels,
            'final_bias'   : True,
            'final_activation' : 'softmax'
        })
    
    @property
    def output_signature(self):
        return tf.TensorSpec(shape = (None, 1), dtype = tf.int32)
    
    def __str__(self):
        des = super().__str__()
        des += self._str_image()
        des += self._str_fl()
        des += '- # labels : {}\n'.format(self.n_labels)
        return des

    def compile(self, loss = 'sparse_categorical_crossentropy', metrics = ['sparse_categorical_accuracy'], ** kwargs):
        super().compile(loss = loss, metrics = metrics, ** kwargs)
    
    def preprocess_data(self, data):
        return (
            tf.expand_dims(data['pixels'], axis = -1),
            tf.cast(tf.reshape(data['label'], [-1, 1]), tf.int32)
        )
    
    def get_dataset_config(self, * args, ** kwargs):
        kwargs['batch_before_map'] = True
        return super().get_dataset_config(* args, ** kwargs)
    
    def get_config(self, * args, ** kwargs):
        config = super().get_config(* args, ** kwargs)
        config.update({
            ** self.get_config_image(),
            ** self.get_config_fl(),
            'n_labels' : self.n_labels
        })
        return config
    
    
model = MNISTFLClassifier()
print(model)
model.summary()


Sub model model
- Inputs 	: (None, 28, 28, 1)
- Outputs 	: (None, 10)
- Number of layers 	: 6
- Number of parameters 	: 0.025 Millions
- Model not compiled

Already trained on 0 epochs (0 steps)

- Image size : (28, 28, 1)
- Normalization style : None
- # labels : 10



Model: "multi_layer_perceptron"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten (Flatten)           (None, 784)               0         
                                                                 
 dense_0 (Dense)             (None, 32)                25120     
                                                                 
 activation (Activation)     (None, 32)                0         
                                                                 
 dropout (Dropout)           (None, 32)                0         
                                                                 
 classification_layer (Dense  (None,

In [3]:
emnist_train, emnist_valid = tff.simulation.datasets.emnist.load_data()
print('Dataset length :\n  Train length : {}\n  Valid length : {}'.format(
    len(emnist_train.client_ids), len(emnist_valid.client_ids)
))
print('Data signature : {}'.format(emnist_train.element_type_structure))

Dataset length :
  Train length : 3383
  Valid length : 3383
Data signature : OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])


## Initialization

In [4]:
model.compile()

config = model._get_fl_train_config(
    emnist_train,
    validation_data = emnist_valid,
    
    n_train_clients = 50,
    n_valid_clients = 25,
    
    batch_size = 64
)

config

{'model_fn': <function models.interfaces.base_fl_model.BaseFLModel.get_model_fn.<locals>.vanilla_model_fn()>,
 'loss_fn': <function models.interfaces.base_fl_model.BaseFLModel.get_loss_fn.<locals>.<lambda>()>,
 'metrics_fn': <function models.interfaces.base_fl_model.BaseFLModel.get_metrics_fn.<locals>.<lambda>()>,
 'server_optimizer_fn': <function models.interfaces.base_fl_model.BaseFLModel.get_optimizer_fn.<locals>.<lambda>()>,
 'client_optimizer_fn': <function models.interfaces.base_fl_model.BaseFLModel.get_optimizer_fn.<locals>.<lambda>()>,
 'reconstruction_optimizer_fn': <function models.interfaces.base_fl_model.BaseFLModel.get_optimizer_fn.<locals>.<lambda>()>,
 'train_ids': ['f1491_42',
  'f3914_30',
  'f0724_37',
  'f1546_05',
  'f3912_15',
  'f1382_07',
  'f2225_86',
  'f0932_44',
  'f1728_02',
  'f2165_54',
  'f2258_85',
  'f1234_26',
  'f1299_42',
  'f2520_64',
  'f3255_43',
  'f0072_36',
  'f3267_41',
  'f3135_37',
  'f1760_27',
  'f0553_03',
  'f0402_38',
  'f3159_16',
  'f

In [5]:
train_fed_data = config['x']
valid_fed_data = config['validation_data']

iterative_process = tff.learning.algorithms.build_weighted_fed_avg(
    config['model_fn'],
    server_optimizer_fn = config['server_optimizer_fn'],
    client_optimizer_fn = config['client_optimizer_fn']
)

print(iterative_process.initialize.type_signature.formatted_representation())

( -> <
  global_model_weights=<
    trainable=<
      float32[784,32],
      float32[32],
      float32[32,10],
      float32[10]
    >,
    non_trainable=<>
  >,
  distributor=<>,
  client_work=<>,
  aggregator=<
    value_sum_process=<>,
    weight_sum_process=<>
  >,
  finalizer=<
    int64,
    float32[784,32],
    float32[32],
    float32[32,10],
    float32[10],
    float32[784,32],
    float32[32],
    float32[32,10],
    float32[10]
  >
>@SERVER)


## Training

In [6]:
state = iterative_process.initialize()

In [7]:
result = iterative_process.next(state, train_fed_data)
state, metrics = result.state, result.metrics
print('Round 1, metrics : {}'.format(metrics))

Round 1, metrics : OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.09654094), ('loss', 2.7454767), ('num_examples', 5117), ('num_batches', 98)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])


In [9]:
epochs = 25
for epoch in range(2, epochs + 1):
    result = iterative_process.next(state, train_fed_data)
    state, metrics = result.state, result.metrics
    print('Epoch {} : {}'.format(epoch, metrics['client_work']['train']))


Epoch 2 : OrderedDict([('sparse_categorical_accuracy', 0.11842877), ('loss', 2.292113), ('num_examples', 5117), ('num_batches', 98)])
Epoch 3 : OrderedDict([('sparse_categorical_accuracy', 0.12096932), ('loss', 2.2899954), ('num_examples', 5117), ('num_batches', 98)])
Epoch 4 : OrderedDict([('sparse_categorical_accuracy', 0.124682434), ('loss', 2.2850769), ('num_examples', 5117), ('num_batches', 98)])
Epoch 5 : OrderedDict([('sparse_categorical_accuracy', 0.12038304), ('loss', 2.2873216), ('num_examples', 5117), ('num_batches', 98)])
Epoch 6 : OrderedDict([('sparse_categorical_accuracy', 0.14012116), ('loss', 2.2858226), ('num_examples', 5117), ('num_batches', 98)])
Epoch 7 : OrderedDict([('sparse_categorical_accuracy', 0.1397303), ('loss', 2.282901), ('num_examples', 5117), ('num_batches', 98)])


2022-11-23 13:11:24.700261: W tensorflow/core/data/root_dataset.cc:266] Optimization loop failed: CANCELLED: Operation was cancelled


Epoch 8 : OrderedDict([('sparse_categorical_accuracy', 0.1475474), ('loss', 2.2817528), ('num_examples', 5117), ('num_batches', 98)])
Epoch 9 : OrderedDict([('sparse_categorical_accuracy', 0.13875318), ('loss', 2.2813594), ('num_examples', 5117), ('num_batches', 98)])
Epoch 10 : OrderedDict([('sparse_categorical_accuracy', 0.12526871), ('loss', 2.2836208), ('num_examples', 5117), ('num_batches', 98)])
Epoch 11 : OrderedDict([('sparse_categorical_accuracy', 0.12038304), ('loss', 2.283988), ('num_examples', 5117), ('num_batches', 98)])
Epoch 12 : OrderedDict([('sparse_categorical_accuracy', 0.12820011), ('loss', 2.2795305), ('num_examples', 5117), ('num_batches', 98)])
Epoch 13 : OrderedDict([('sparse_categorical_accuracy', 0.14207543), ('loss', 2.280367), ('num_examples', 5117), ('num_batches', 98)])
Epoch 14 : OrderedDict([('sparse_categorical_accuracy', 0.1389486), ('loss', 2.2772293), ('num_examples', 5117), ('num_batches', 98)])
Epoch 15 : OrderedDict([('sparse_categorical_accuracy'