# Test Tensorflow-federated (TFF) library

## Test #6 : complete training and (intra-client) evaluation

In [None]:
import os
import collections
import nest_asyncio
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_federated as tff

tf.config.set_visible_devices([tf.config.list_physical_devices('GPU')[0]], 'GPU')

nest_asyncio.apply()

print('Tensorflow version : {}'.format(tf.__version__))
print('Tensorflow-federated version : {}'.format(tff.__version__))
print('# GPUs : {}'.format(len(tf.config.list_logical_devices('GPU'))))

tff.federated_computation(lambda: 'Hello, World!')()

2022-11-23 13:51:57.796278: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-23 13:51:57.890471: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2022-11-23 13:51:57.915200: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [None]:
import collections
import tensorflow as tf

from models.interfaces.base_fl_model import BaseFLModel
from models.interfaces.base_image_model import BaseImageModel

class MNISTFLClassifier(BaseImageModel, BaseFLModel):
    def __init__(self, input_size = (28, 28, 1), n_labels = 10, ** kwargs):
        self._init_image(input_size = input_size, ** kwargs)
        self._init_fl(** kwargs)
        
        self.n_labels = n_labels
        
        super().__init__(** kwargs)
    
    def _build_model(self):
        super()._build_model(model = {
            'architecture_name' : 'perceptron',
            'input_shape' : self.input_size,
            'units'       : 32,
            'n_dense'     : 1,
            'activation'  : 'relu',
            'drop_rate'   : 0.,
            'output_shape' : self.n_labels,
            'final_bias'   : True,
            'final_activation' : 'softmax'
        })
    
    @property
    def output_signature(self):
        return tf.TensorSpec(shape = (None, 1), dtype = tf.int32)
    
    def __str__(self):
        des = super().__str__()
        des += self._str_image()
        des += self._str_fl()
        des += '- # labels : {}\n'.format(self.n_labels)
        return des

    def compile(self, loss = 'sparse_categorical_crossentropy', metrics = ['sparse_categorical_accuracy'], ** kwargs):
        super().compile(loss = loss, metrics = metrics, ** kwargs)
    
    def preprocess_data(self, data):
        return (
            tf.expand_dims(data['pixels'], axis = -1),
            tf.cast(tf.reshape(data['label'], [-1, 1]), tf.int32)
        )
    
    def get_dataset_config(self, * args, ** kwargs):
        kwargs['batch_before_map'] = True
        return super().get_dataset_config(* args, ** kwargs)
    
    def get_config(self, * args, ** kwargs):
        config = super().get_config(* args, ** kwargs)
        config.update({
            ** self.get_config_image(),
            ** self.get_config_fl(),
            'n_labels' : self.n_labels
        })
        return config
    
    
model = MNISTFLClassifier(nom = 'test_fl_6-8')
model.compile()
print(model)
model.summary()

In [None]:
emnist_train, emnist_valid = tff.simulation.datasets.emnist.load_data()
print('Dataset length :\n  Train length : {}\n  Valid length : {}'.format(
    len(emnist_train.client_ids), len(emnist_valid.client_ids)
))
print('Data signature : {}'.format(emnist_train.element_type_structure))

## Training

In [None]:
from loggers import set_level

set_level('info', 'models')
h = model.train_fl(
    emnist_train, validation_data = emnist_valid,
    epochs      = 50,
    batch_size  = 32,
    train_times = 5,
    n_clients   = 100,
    server_optimizer_fn = {'name' : 'SGD', 'lr' : 0.5},
    client_optimizer_fn = {'name' : 'SGD', 'lr' : 0.1},
    use_experimental_simulation_loop = False,
    verbose = 0
)
print(h)
h.plot()

In [None]:
h.plot(filename = 'example_history_2.png')