In [None]:
!pip install --upgrade tensorflow-federated
!pip install nest_asyncio
import nest_asyncio
nest_asyncio.apply()

In [None]:
import sys
sys.path.append("../")

import tensorflow as tf
from tensorflow.keras import Model, callbacks
from tensorflow.keras.layers import Dense, Softmax
import tensorflow_federated as tff
import tensorflow.compat.v1 as tf
from tensorflow_model_optimization.python.core.internal import tensor_encoding as te


from sklearn.model_selection import train_test_split
import collections

import numpy as np 
import pandas as pd
from sklearn.preprocessing import MinMaxScaler, StandardScaler
import random
import time
import statistics
from datetime import datetime
import mlflow 

from Reader import Reader
from model.FLModel import FLModel
from model.BNModel import BNModel
from Utils import Utils
from TFF_Utils import TFF_Utils
from DefaultDenseQuantizeConfig import DefaultDenseQuantizeConfig

tf.enable_eager_execution() 

In [None]:
class Centralized_vs_Federated_Trainer:
    """
    A class that compares centralized training with TensorFlow Federated's FedAVG as well as its own implementation of FedBN.
    Additionally, lossy compression is compared in all federated settings.
    
    
    Parameters
    ----------
    file_name : str
        name of data file
    number_of_classes : int
        max value of class +label
    E : int
        number of local rounds
    compression : bool 
        flag for compression
    
    Attributes
    ----------
    file : string
        name of the input file containing the data
    BASE_DIR : str
        name of the directory containing all needed subdirectories for in- and output
    DATA_DIR : str
        name of the directory containing the input data
    NUM_CLASSES : int
        number of unique labels in the input file
    client_ids : list
        random int ids of the input datafile or passed input ids
    CLIENTS : int
        number of clients in total. when client_ids is passed this is set to len(client_ids)
    EPOCHS : int
    E : int 
        number of training epochs on the clients bevor a global update
    k : int
        specifies sparse top K aategorical accuracy
    drop_index : bool
        indicates if to delete the index column in the input cile
    start_id : int
        number of column where features starts
    label_id :
        column number of label
    use_bn : bool
        indicates if the FedBN should be used for training
    use_tff : bool
        indicates if zentralied or federated learing should be used
    compression : bool
        indicates if quantization should be used. only for federated learning
    
    BATCH_SIZE : int
    PREFETCH_BUFFER : int
    SHUFFLE_BUFFER : int
    learning_rate : float
    momentum : float
    nesterov : bool
    entropy_loss : tf.keras.losses.CategoricalCrossentropy
    sparseCategoricalAcc : tf.keras.metrics.SparseCategoricalAccuracy 
    sparseTopKCategoricalAccuracy : sparseTopKCategoricalAccuracy
    client_optimizer : tf.keras.optimizers.SGD
    server_optimizer : tf.keras.optimizers.SGD
    scaler : sklearn.preprocessing.StandardScaler
    train_data : tf.data.Dataset 
    test_data : tf.data.Dataset 
    
    Methods
    ----------
    get_summary_in_bytes( model)
    init_mlflow():
    split_input_target(input, target)
    get_split(x, y)
    create_dataset(x, y, use_tff = True)
    create_unfederated_dataset
    get_not_bn_idx
    get_bn_idx
    get_weights_by_idx
    create_keras_model
    create_keras_bn_model
    quantize(layer)
    dequantize(layer, min_range, max_range)
    map_quantization(layer)
    map_dequantization(weights, ranges, quant_idx)
    get_min(layer)
    get_max(layer)
    model_fn
    map_weights_ids
    broadcast_encoder_fn
    mean_encoder_fn
    
    run_federated()
    
        Submethods
        ----------
        server_init()
        server_update(model, mean_client_weights)
        server_update_fn(mean_client_weights)
        initialize_fn()
        next_fn(server_weights, federated_dataset)
        evaluate(server_state, train_data, test_data, trainable_ids)
        
    run_unfederated(ds_train, ds_test, ds_val, input_dim)  
    main(compare_centralized = False)
    
    """
    
    def __init__(self, file_name:str, number_of_classes:int, E:int, compression:bool = False):
        self.file = file_name
        self.NUM_CLASSES = number_of_classes
        self.BASE_DIR = "../"
        self.DATA_DIR = "data/"
        self.BATCH_SIZE = 64
        self.PREFETCH_BUFFER = 10
        self.SHUFFLE_BUFFER = 64
        self.CLIENTS = 10
        self.EPOCHS = 10
        self.E = E # amount of local rounds       
        self.drop_index = True
        self.client_ids = None
        
        
        if "infected" in self.file:
            self.drop_index = False
            self.start_id = 1
            self.label_id = 9
        elif ("Cosphere" in self.file):
            self.start_id = 1
            self.label_id = 4
        else:
            self.start_id = 1
            self.label_id = 3

        self.k = 10
        self.use_bn = True
        self.use_tff = True
        self.learning_rate = 1e-1
        self.momentum = 0.9
        self.nesterov = False
        
        self.compression = compression
        if self.compression: 
            self.quantization_bits = 8 #default 8
            self.quantization_thresh = 10000

        self.mode = 'MIN_FIRST'    
        
    def get_summary_in_bytes(self, model):
        """helper  function that outputs the summed size of all model layers in bits
        
        Parameters
        ----------
        model : keras.Model
        """
        for layer in model.layers:
            layer_size_in_bit = 0
            for weight in layer.weights:
                size = 0
                if len(weight.shape) > 1:
                    size = weight.shape[0] * weight.shape[1]
                    
                if len(weight.shape) == 1:
                    size = weight.shape[0]
                
                if weight.dtype == 'float32':
                    layer_size_in_bit += (size * 32)
                    
                elif weight.dtype == 'int8':
                    layer_size_in_bit += (size * 8)
                else:
                    print(f"unsupported datatype: {weight.dtype}")
                    
            print(layer, layer_size_in_bit)        
            
    #mlflow setup
    def init_mlflow(self): 
        """sets up an mlflow run"""
        mlflow.set_experiment(self.file)
        mlflow.end_run()
        mlflow.start_run()
        mlflow_experiment_id = mlflow.get_experiment_by_name(self.file).experiment_id
        mlflow_run_id = mlflow.active_run().info.run_id
        log_path = "mlruns/" + str(mlflow_experiment_id) + "/" + str(mlflow_run_id) + "/" + "artifacts" + "/"
        
        mlflow.log_param("run_id", mlflow_run_id)
        mlflow.log_param("local_rounds", self.k)
        mlflow.log_param("use_bn", self.use_bn)
        mlflow.log_param("use_tff", self.use_tff)
        if self.compression: 
            mlflow.log_param("use_compression", self.compression)
            mlflow.log_param("quant_bits", self.quantization_bits)
            mlflow.log_param("threshold", self.quantization_thresh)
            mlflow.set_tags({"day":"15.09.2020"})
            mlflow.log_param("mode", self.mode)

        mlflow.log_param("lr", self.learning_rate)
        mlflow.log_param("momentum", self.momentum)
        mlflow.log_param("nesterov", self.nesterov)
        mlflow.log_param("batch_size", self.BATCH_SIZE)
        mlflow.log_param("clients", self.CLIENTS)
        mlflow.log_param("epochs", self.EPOCHS)
        mlflow.log_param("classes", self.NUM_CLASSES)
            
    def split_input_target(self, input, target):
        """helper function for formatting the TFF dataset in the needed format
        
        Parameters
        ----------
        input : Tensor
            features
        target : Tensor
            label
            
        Returns
        -------
        tensor
        tensor
        """
        return input, target

    def create_dataset(self, x, y):
        """creates datasets in a specific format depending if zentralized or federated learning is used
        Parameters
        ----------
        x : numpy.array()
        y : numpy.array()
        
        Returns
        ---- 
        tf.data.Dataset
        """
        ds =  tf.data.Dataset.from_tensor_slices((x, y))

        if self.use_tff:
            return (
            ds.repeat(self.EPOCHS).shuffle(self.SHUFFLE_BUFFER)
            .map(self.split_input_target)).batch(self.BATCH_SIZE) 
        else:
            return ds.repeat(self.BATCH_SIZE * self.EPOCHS).shuffle(self.SHUFFLE_BUFFER).batch(self.BATCH_SIZE,drop_remainder = True) 

    def get_split(self, x, y):
        """Split arrays or matrices into random train and test subsets
        Parameters
        ----------
        x : numpy.array()
        y : numpy.array()
                    
        Returns
        -------
        List 
            containing train-test split of inputs.
        """
        return train_test_split(x, y, test_size=0.2, random_state=42)

    def create_unfederated_dataset(self, x, features):
        """ creates an scaled, zentralized train- and test-dataset
        
        Parameters
        ----------
        x : numpy.array()
            containing the whole dataset
        features : int
            number of input features
            
        Returns
        -------
        tf.data.Dataset
            train dataset
        tf.data.Dataset
            test dataset
        tf.data.Dataset
            validation dataset
        """
        former_shape = x[:, self.start_id:features].shape
        client_x = np.delete( x[:, self.start_id:features], self.label_id-1, 1 ).reshape(former_shape[0], former_shape[1]-1)
        client_x = self.scaler.transform(client_x)
        client_y = x[:, self.label_id].reshape(-1, 1)
        X_train, X_test, y_train, y_test = self.get_split(client_x, client_y)
        X_train, X_val, y_train, y_val = self.get_split(X_train, y_train)
        train_data = self.create_dataset(X_train, y_train)
        test_data = self.create_dataset(X_test, y_test)
        val_data = self.create_dataset(X_val, y_val)
        return train_data, test_data, val_data
    
    def get_not_bn_idx(self, trainable_variables):
        """
        returns the indicees of a list of trainable weights which are not batch normalization layers
        
        Parameters
        ----------
        trainable_variables : list 
            list of weight tensors
                    
        Returns
        -------
        list 
            list of indicees
        """
        
        new_trainable_weights = []
        train_var_idx = []
        for idx, bn_weights in enumerate(trainable_variables):
            if 'batch_normalization' not in bn_weights.name:
                train_var_idx.append(idx)
        return train_var_idx

    def get_bn_idx(self, trainable_variables):
        """ returns the indicees of a list of trainable weights which are batch normalization layers
        
        Parameters
        ----------
        trainable_variables: list
            list of weight tensors
            
        Returns
        -------
        list
            list of indicees
        """
        
        new_trainable_weights = []
        train_var_idx = []
        for idx, bn_weights in enumerate(trainable_variables):
            if 'batch_normalization' in bn_weights.name:
                train_var_idx.append(idx)
        return train_var_idx

    def get_weights_by_idx(self, trainable_variables, var_ids):
        """ returns a new list of selected weight tensors
        
        Parameters
        ----------
        trainable_variables : list
            list of weight tensors
        var_ids : list
            list of to be selected weight-ids
            
        Returns
        -------
        list 
            list of weight tensors
        """
        new_weights = []  
        for i in var_ids:
            new_weights.append(trainable_variables[i])
        return new_weights

    def create_keras_model(self, input_dim):
        """ returns an instance of a keras model for FedAVG
        
        Parameters
        ----------
        input_dim: int
            number of input connections
            
        Returns
        -------
        keras.model
        """
        return tf.keras.models.Sequential([
          tf.keras.layers.InputLayer(input_shape=(input_dim,)),
          tf.keras.layers.Dense(500, activation=tf.nn.relu),
          tf.keras.layers.Dense(self.NUM_CLASSES, activation='softmax'),
        ])

    def create_keras_bn_model(self, input_dim):
        """ returns an instance of a keras model for FedBN
        Parameters
        ----------
        input_dim: int
            number of input connections
                    
        Returns
        -------
        keras.model
        """
        return tf.keras.models.Sequential([
          tf.keras.layers.BatchNormalization(input_shape=(input_dim,)),
          tf.keras.layers.Dense(500, activation=tf.nn.relu),
          tf.keras.layers.BatchNormalization(),
          tf.keras.layers.Dense(self.NUM_CLASSES, activation='softmax'),
        ])
        
    def quantize(self, layer):
        """ performs quantization on a 32-bit float weight tensor. Returns an 8-bit-integer-tensor
        
        Parameters
        ----------
        layer: tensor
            weight tensor to be quantized
            
        Returns
        -------
        tensor
        """
        min_range = self.get_min(layer)
        max_range = self.get_max(layer)
        return tf.quantization.quantize(layer, min_range, max_range, tf.qint32, mode = self.mode)
    
    def dequantize(self, layer, min_range, max_range):
        """ performs dequantization on a 8-bit int weight tensor. Returns an 32-bit-float-tensor
        
        Parameters
        ----------
        layer: tensor
            weight tensor to be dequantized      
            
        Returns
        -------
        tensor
        """
        return tf.quantization.dequantize(layer, min_range, max_range, mode=self.mode, dtype=tf.dtypes.float32)
    
    def map_quantization(self, weights):
        """ Performs quantization, if the size of an weigth tensor is bigger than a threshold 
        
        Parameters
        ----------
        weights : list
            list of weight tensors     
            
        Returns
        -------
        list 
            of all weights including the quantized ones
        list 
            list of tupel containing min and max range for dequantization
        list
            list of indicees which weights to dequantize
        
        """
        quantized_weights = []
        ranges = []
        quant_idx = []
        for idx, weight in enumerate(weights):
            if len(weight.shape) > 1:
                size = weight.shape[0] * weight.shape[1]
            elif len(weight.shape) == 1:
                size = weight.shape[0]
            if size > self.quantization_thresh:
                quantized_weight, mi, ma = self.quantize(weight)
                quantized_weights.append(quantized_weight)
                ranges.append((mi, ma))
                quant_idx.append(idx)
            else: 
                quantized_weights.append(weight)
        print(f"\nquanization: \n\n{quantized_weights, quant_idx}\n\n----")
        return quantized_weights, ranges, quant_idx
    
    def map_dequantization(self, weights, ranges, quant_idx):
        """ Perfoms dequantization 
        
        Parameters
        ----------
        weights : list
            list of all weight tensors
        ranges : list 
            list of tupel with min-range and max-range
        quant_idx : list
            list of indices of weights to be dequantified
            
        Returns
        -------
            list of float32-tensors
        """
        dequantized_weights = []
        for idx, weight in enumerate(weights):
            if idx in quant_idx:
                range = ranges[0]
                ranges.pop(0)
                dequantized_weight = self.dequantize(weight, range[0], range[1])
                dequantized_weights.append(dequantized_weight)
            else: 
                dequantized_weights.append(weight)
        
        return dequantized_weights
    
    def get_min(self, layer):
        """ Returns the smallest value of a tensor
        
        Parameters
        ----------
        layer: tensor
        """
        return tf.math.reduce_min(layer, axis=None)
    
    def get_max(self, layer):
        """ Returns the biggest value of a tensor
        
        Parameters
        ---------
        layer: tensor
        """
        return tf.math.reduce_max(layer, axis=None)
        
 
    def model_fn(self):
        """ Create instance of a keras model depending on which federated strategie (FedAVG, FedBN) to use 
        
        Returns
        ----------
        keras.model
        """
        
      # Note: We must create a new model here, and not capture it from an external scope. 
      # TFF will call this within different graph contexts.

        if self.use_bn:
          keras_model = self.create_keras_bn_model(self.test_data[0].element_spec[0].shape[1])

        else:
          keras_model = self.create_keras_model(self.test_data[0].element_spec[0].shape[1])
 
        return tff.learning.from_keras_model(
          keras_model,
          input_spec = self.train_data[0].element_spec,
          loss = tf.keras.losses.SparseCategoricalCrossentropy(),
          metrics = [tf.keras.metrics.SparseCategoricalAccuracy(), 
                   tf.keras.metrics.SparseTopKCategoricalAccuracy(k=self.k)]) 


    def map_weights_ids(self):
        """ Filter the weights for FedBN which are not from batch normalization layers

        Returns
        -------
        list
        list of weight tensors
        """
        model = self.create_keras_bn_model(self.train_data[0].element_spec[0].shape[1])
        trainable_variables = model.trainable_variables
        all_weights_name = []
        train_weights_ids = []

        for layer in model.layers:
            for weight in layer.weights:
                all_weights_name.append(weight.name)

        for var in trainable_variables:
            for idx, name in enumerate(all_weights_name):
                if var.name == name:
                    if 'batch_normalization' not in name:
                        train_weights_ids.append(idx)

        return train_weights_ids

    def broadcast_encoder_fn(self, value):
        """Broadcasting step: Function for building encoded broadcast.

        Parameters
        ----------
        value : tensor
          tensor to be quantized if bigger than a threshold

        Returns
        -------
          if tensor bigger than threshold: returns instance of quantization encoder
          else: returns instance of default encoder without any specific compression

        """
        spec = tf.TensorSpec(value.shape, value.dtype)
        if value.shape.num_elements() > self.quantization_thresh:
            return te.encoders.as_simple_encoder(
                te.encoders.uniform_quantization(bits=self.quantization_bits), spec)
        else:
            return te.encoders.as_simple_encoder(te.encoders.identity(), spec)


    def mean_encoder_fn(self, tensor_spec):
        """Aggregation step: Function for building a GatherEncoder
        
        Parameters
        ----------
        value : tensor
          tensor to be quantized if bigger than a threshold

        Returns
        -------
          if tensor bigger than threshold: returns instance of quantization encoder
          else: returns instance of default encoder without any specific compression

        """
        spec = tf.TensorSpec(tensor_spec.shape, tensor_spec.dtype)
        if tensor_spec.shape.num_elements() > self.quantization_thresh:
            return te.encoders.as_gather_encoder(
                te.encoders.uniform_quantization(bits= self.quantization_bits), spec)
        else:
            return te.encoders.as_gather_encoder(te.encoders.identity(), spec)

    def run_federated(self):
        """ Handles the federated learning. 
        
        First, this contains some inner classes for dealing with FedBN. 
        Placement outside throws errors regarding scope. 
        
        Second, this runs the training and evaluation process of FedAVG or FedBN, depending of the value of'self.use_bn'
        
        """        
        broadcast_process = None
        mean_factory = None
        if self.compression and (not self.use_bn):  
            mean_factory = tff.aggregators.MeanFactory(
                value_sum_factory = tff.aggregators.EncodedSumFactory(self.mean_encoder_fn),
                weight_sum_factory = tff.aggregators.EncodedSumFactory(self.mean_encoder_fn))   
            
            broadcast_process = tff.learning.framework.build_encoded_broadcast_process_from_model(
                                         self.model_fn, self.broadcast_encoder_fn)
            

        # ---custom fed avg implementation for batch normalization-----
        # --------------------------START------------------------------
        @tff.tf_computation
        def server_init():
            """FedBN: creates an instance of the specified keras model and filters the non-batch-normalization-weights
     
            Returns
            -------
            list of all weights which are trainable and not from a batch norralization layer
            """
            
            model = self.model_fn()
            trainable_variables = model.trainable_variables
            non_bn_ids = self.get_not_bn_idx(trainable_variables)
            return self.get_weights_by_idx(trainable_variables, non_bn_ids)
        
        whimsy_model = self.model_fn()
        tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)
        model_weights_type = server_init.type_signature.result
        federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
        federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

        @tf.function
        def server_update(model, mean_client_weights):
            """FedBN: Updates the server weights with the average of the client model weights.
            
            if 'self.use_compression' is True: Quantizes and dequantizes the post-training-weights. 
            This way the information loss can be later analyzed
            
            Parameters
            ----------
            model : keras.model
            mean_client_weights : federated value with placement of the averaged client weights
                list of tensors
                
            Returns
            ------
            list of all weights which are trainable and not from a batch normralization layer
            """
            
            updated_model_weights = []
            trainable_variables = model.trainable_variables

            non_bn_ids = self.get_not_bn_idx(trainable_variables)
            bn_ids = self.get_bn_idx(trainable_variables)
            non_bn_weights = self.get_weights_by_idx(trainable_variables, non_bn_ids)
            bn_weights = self.get_weights_by_idx(trainable_variables, bn_ids)

            tf.nest.map_structure(lambda x, y: x.assign(y),
                                    non_bn_weights, mean_client_weights)
            for i in range(len(trainable_variables)):
                if i in non_bn_ids:
                    j = non_bn_ids.index(i)
                    updated_model_weights.append(non_bn_weights[j])
                else: 
                    j = bn_ids.index(i)
                    updated_model_weights.append(bn_weights[j])
                    
            #--------- quantize and dequantize---------
            if self.use_bn:
                quantized_weights, min_max, range_idx = self.map_quantization(non_bn_weights)
                return self.map_dequantization(quantized_weights, min_max, range_idx)
            #---------------------------
            
            return non_bn_weights


        @tf.function
        def client_update(model, dataset, server_weights, client_optimizer):
            """FedBN: Updates the clients weights with the new global server weights.
            
            if 'self.use_compression' is True: Quantizes and dequantizes the post-training-weights. 
            This way the information loss can be later analyzed
            
            Parameters
            ----------
            dataset: tf.data.dataset
                lokal client dataset
            server_weights : tensor
            client_optimizer: tf.keras.optimizers.SGD
                instance of the corresponding client optimizer
                
            Returns
            ------
            list of all weights which are trainable and not from a batch normalization layer
            """
            
            updated_clients_weights = []
            trainable_variables = model.trainable_variables
            
            #get ids of non_batch normalization weights
            non_bn_ids = self.get_not_bn_idx(trainable_variables)
            mlflow.log_param("number of updatable non-BN weights", len(non_bn_ids))
            
            #get ids of batch normalization weights
            bn_ids = self.get_bn_idx(trainable_variables)
            
            non_bn_weights = self.get_weights_by_idx(trainable_variables, non_bn_ids)
            bn_weights = self.get_weights_by_idx(trainable_variables, bn_ids)
            
            # Assign the mean client weights to the server model.
            tf.nest.map_structure(lambda x, y: x.assign(y),
                                    non_bn_weights, server_weights)

            #update any weight which is not from a bn layer
            for i in range(len(trainable_variables)):
                if i in non_bn_ids:
                    j = non_bn_ids.index(i)
                    updated_clients_weights.append(non_bn_weights[j])
                    
                    #log shape of weight 
                    #mlflow.log_param(f"{j} Shape of updatable non-BN weight", non_bn_weights[j].shape)
                else: 
                    j = bn_ids.index(i)
                    updated_clients_weights.append(bn_weights[j])

            client_weights = updated_clients_weights

            for epoch in range(self.E):
                for batch in dataset:
                    with tf.GradientTape() as tape:
                        outputs = model.forward_pass(batch)
                    grads = tape.gradient(outputs.loss, client_weights)
                    grads_and_vars = zip(grads, client_weights)
                    client_optimizer.apply_gradients(grads_and_vars)
                    
            #--------- quantize and dequantize---------
            if self.use_bn:
                quantized_weights, min_max, range_idx = self.map_quantization(non_bn_weights)
                return self.map_dequantization(quantized_weights, min_max, range_idx)
            #---------------------------

            return non_bn_weights

        @tff.tf_computation(tf_dataset_type, model_weights_type)
        def client_update_fn(tf_dataset, server_weights):
            """FedBN:Instantiates the necessary objects for a client step
            
            Parameters
            ----------
            tf_dataset : tf.data.dataset
            server_weights : federated_map
                Placement of the corresponding TF values
            
            Returns
            -------
            list of all weights which are trainable and not from a batch normalization layer
            """
            model = self.model_fn()
            client_optimizer = tf.keras.optimizers.SGD(learning_rate= self.learning_rate, momentum=self.momentum, nesterov=self.nesterov)
            return client_update(model, tf_dataset, server_weights, client_optimizer)

        @tff.tf_computation(model_weights_type)
        def server_update_fn(mean_client_weights):
            """FedBN:Instantiates the necessary objects for assigning the new server weights
            
            Parameters
            ----------
            mean_client_weights : federated_map
                Placement of the corresponding TF values
            
            Returns
            -------
            list of all weights which are trainable and not from a batch normalization layer
            """
            model = self.model_fn()
            return server_update(model, mean_client_weights)

        @tff.federated_computation
        def initialize_fn():
            """FedBN: Start the federated learning by braodcasting the initial server weights to all clients
            
            Returns
            -------
            Federated Value with placement of the server weights
            """
            return tff.federated_value(server_init(), tff.SERVER)


        @tff.federated_computation(federated_server_type, federated_dataset_type)
        def next_fn(server_weights, federated_dataset):
            """FedBN: Handels the lifecycle of the federated learning
            
            First the client weights are broadcasted to the clients.
            The clients train their local model with 'self.E' steps
            Aferwards the client weights are aggregated and averaged.
            Finally, the new server weights are updated.
            
            Note: All variables of federated operations are TFF Variables and not directly accessible. 
            
            Parameters
            ----------
            server_weights: list
                 list if server weights passed by 'server_update'
            federated_dataset
                list: of datasets
            
            Returns
            -------
            call to 'server_update_fn'
            """

           # Broadcast the server weights to the clients.
            server_weights_at_client = tff.federated_broadcast(server_weights)

            # Each client computes their updated weights.
            client_weights = tff.federated_map(
                client_update_fn, (federated_dataset, server_weights_at_client))

            # The server averages these updates.
            mean_client_weights = tff.federated_mean(client_weights)

            # # The server updates its model.
            server_weights = tff.federated_map(server_update_fn, mean_client_weights)
                             
            return server_weights

        def evaluate(server_state, train_data, test_data, trainable_ids):
            """FedBN: This methods runs the global model in testmode on both the training set and the test set
            
            Parameters
            ----------
            server_state : tff.learning.framework.ServerState
                 contains the currently global model
                 
            train_data: tf.data.dataset
                for all clients
            test_data: tf.data.dataset
                for all clients
            trainable_ids: list
                list of weight indicees. Just the non-batch-normalization-weights are passed the the new model instance
            
            Returns
            -------
            Dictionary of the specified metrics for training and testing
            """
            
            server_weights = []
            acc_mean, loss_mean, k_acc_mean = [], [], []
            model = self.create_keras_bn_model(test_data[0].element_spec[0].shape[1])
            model.compile(
                    loss = tf.keras.losses.SparseCategoricalCrossentropy(),
                    metrics = [tf.keras.metrics.SparseCategoricalAccuracy(), 
                            tf.keras.metrics.SparseTopKCategoricalAccuracy(k=self.k)]  
            )
            weights = model.get_weights()

            for idx, weight_id in enumerate(trainable_ids):
                weights[weight_id] = np.array(server_state[idx])

            model.set_weights(weights)
            print("\t--Training--\t")

            for batch in self.train_data:
                loss, acc, k_acc = model.evaluate(batch, batch_size=self.BATCH_SIZE, verbose=0)
                loss_mean.append(loss)
                acc_mean.append(acc)
                k_acc_mean.append(k_acc)
            train_loss = statistics.mean(loss_mean)
            train_acc = statistics.mean(acc_mean)
            train_k_acc = statistics.mean(k_acc_mean)
            acc_mean, loss_mean, k_acc_mean = [], [], []

            print("\t--Testing--\t")
            for batch in self.test_data:
                loss, acc, top_k_acc = model.evaluate(batch, batch_size=self.BATCH_SIZE, verbose=0)
                loss_mean.append(loss)
                acc_mean.append(acc)
                k_acc_mean.append(k_acc)
            test_loss = statistics.mean(loss_mean)
            test_acc = statistics.mean(acc_mean)
            test_k_acc = statistics.mean(k_acc_mean)
            return {"Train_acc":train_acc, 
                    f"Train_{self.k}_acc":train_k_acc, 
                    "Train_loss":train_loss, 
                    "Test_acc":test_acc, 
                    f"Test_{self.k}_acc":test_k_acc,
                    "Test_loss":test_loss,}
        # -------------------END---------------------
        
        with tf.device('/gpu:0'):
            #FedBN
            if self.use_bn:    
                trainable_ids = self.map_weights_ids()
                fed_avg = tff.templates.IterativeProcess(
                                        initialize_fn=initialize_fn,
                                        next_fn=next_fn
                                        )

                state = fed_avg.initialize()
                for round_num in range(self.EPOCHS):  
                    state = fed_avg.next(state, self.train_data)
                    
                    metrics = evaluate(state, self.train_data, self.test_data, trainable_ids)
                    for name, value in metrics.items():
                        mlflow.log_metric(name, value, round_num)
                        print(round_num, name, value)
                        
            #FedAVG
            else:
                fed_avg = tff.learning.build_federated_averaging_process(
                    self.model_fn,
                    client_optimizer_fn = lambda: tf.keras.optimizers.SGD(learning_rate= self.learning_rate, momentum=self.momentum, nesterov=self.nesterov), 
                    server_optimizer_fn = lambda: tf.keras.optimizers.SGD(learning_rate= 1.0),
                    broadcast_process = broadcast_process,
                    model_update_aggregation_factory = mean_factory
                    )

                state = fed_avg.initialize()
                
                #-----
                #just usable for FedAVG
                environment = self.tff_utils.set_sizing_environment()
                #-----
                
                for round_num in range(self.EPOCHS):
                    state, metrics = fed_avg.next(state, self.train_data)
                    # Note: training metrics reported by the iterative training process 
                    #generally reflect the performance of the model at the beginning of the training round
                    
                    #---
                    size_info = environment.get_size_info()
                    broadcasted_bits = size_info.broadcast_bits[-1]
                    aggregated_bits = size_info.aggregate_bits[-1]
                    mlflow.log_metric('cumulative_broadcasted_bits', broadcasted_bits, round_num)
                    mlflow.log_metric('cumulative_aggregated_bits', aggregated_bits, round_num)
                    print(broadcasted_bits, aggregated_bits)
                    
                    print('round {:2d}, metrics={}, broadcasted_bits={}, aggregated_bits={}'
                          .format(round_num, metrics, 
                                  self.tff_utils.format_size(broadcasted_bits), 
                                  self.tff_utils.format_size(aggregated_bits)
                            )
                         )
                    #---
                    
                    for name, value in metrics['train'].items():
                        print(round_num, name, value)

                    evaluation = tff.learning.build_federated_evaluation(self.model_fn)  
                    test_metrics = evaluation(state.model, self.test_data)
                    for name, value in test_metrics.items():
                        mlflow.log_metric(name, value, round_num)
                        print(round_num, name, value)

    def run_unfederated(self, ds_train, ds_test, ds_val, input_dim):
        """Handels centralized training with TensorFlow
        
        Parameters
        ----------
        ds_train : tf.data.dataset
        ds_test : tf.data.dataset
        ds_val : tf.data.dataset
        input_dim: int
            number of input connections
        
        """

        early_stopping_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', 
                                    min_delta=0.01, 
                                    patience=2, 
                                    verbose=0, 
                                    mode='auto', 
                                    baseline=None, 
                                    restore_best_weights=True)
        if self.use_bn:
            model = BNModel(self.NUM_CLASSES)
        else:
            model = FLModel(self.NUM_CLASSES)
        model.compile(
                    optimizer= self.client_optimizer, 
                    loss= "sparse_categorical_crossentropy", 
                    metrics= [
                              self.sparseCategoricalAcc, 
                              self.sparseTopKCategoricalAccuracy
                              ]
                    )
        for epoch in range(self.EPOCHS):
            with tf.device('/gpu:0'):
                history = model.fit(
                                  ds_train,
                                  steps_per_epoch=64, 
                                  validation_data = ds_val, 
                                  verbose=0,
                                  callbacks = [early_stopping_callback])     

            loss = round(history.history["loss"][0], 8)
            acc = round(history.history["sparse_categorical_accuracy"][0], 8)
            k_acc =  round(history.history["sparse_top_k_categorical_accuracy"][0], 8)
            val_loss =  round(history.history['val_loss'][0], 8)
            val_acc = round(history.history['val_sparse_categorical_accuracy'][0], 8)
            val_k_acc =  round(history.history['val_sparse_top_k_categorical_accuracy'][0], 8)
                    
            with tf.device('/gpu:0'):
                    test_loss, test_acc, test_k_acc = model.evaluate(ds_test, batch_size=self.BATCH_SIZE, verbose=0)

            mlflow.log_metric("Loss/train", loss, epoch)
            mlflow.log_metric("Acc/train", acc, epoch)
            mlflow.log_metric("K_acc/train", k_acc, epoch)

            mlflow.log_metric("Loss/validation", val_loss, epoch)
            mlflow.log_metric("Acc/validation", val_acc, epoch)
            mlflow.log_metric("K_acc/validation", val_k_acc, epoch)

            mlflow.log_metric("Loss/test", test_loss, epoch)
            mlflow.log_metric("Acc/test", test_acc, epoch)
            mlflow.log_metric("K_acc/test", test_k_acc, epoch)

            print(
              f'Epoch: {epoch},\n'
              f'Train Loss:\t{loss}, '
              f'Train Accuracy:\t{acc}, '
              f'Train Top 5 Accuracy:\t{k_acc}\n'
              f'Validation Loss:\t{val_loss}, '
              f'Validation Accuracy:\t{val_acc}, '
              f'Validation Top 5 Accuracy:\t{val_k_acc}\n'
              f'Test Loss:\t{test_loss}, '
              f'Test Accuracy:\t{test_acc} '
              f'Test Top 5 Accuracy:\t{test_k_acc}'
              f'\n--------------------------------------------------------------------------------------------------------------------------\n'
            )

    def main(self, compare_centralized = False):
        """Handels all possible setups on the same client ids so comparission of the results is possible
        
        Reads the desired input file and creates #'self.CLIENTS' random clients
        Runs loop on a range of clients
            Creates datasets from the subset of client_ids
            Setup:
            Run centralized learning
            Run FedAVG without compression
            Run FedBN without compression
            Loop on different compression thresholds:
                Run FedAVG with compression an a specified threshold
                Run FedBN with compression an a specified threshold
                
        Note:For each of theses runs there are logging files from mlflow
        
        Parameters
        ----------
        compare_centralized : bool
            Indicates if centralized training should be part of the setup
        """
               
        self.entropy_loss = tf.keras.losses.SparseCategoricalCrossentropy()
        self.sparseCategoricalAcc = tf.keras.metrics.SparseCategoricalAccuracy()
        self.sparseTopKCategoricalAccuracy = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=self.k)
        self.client_optimizer =  tf.keras.optimizers.SGD(learning_rate= self.learning_rate, momentum=self.momentum, nesterov=self.nesterov)
        self.server_optimzer = tf.keras.optimizers.SGD(learning_rate= 1.0)
        
        #read dataset
        self.scaler = StandardScaler()
        utils = Utils()
        reader = Reader(self.BASE_DIR + self.DATA_DIR, self.file, self.drop_index)
        data = reader.get_data()

        #clean from client_ids and label_ids and pass to scaler
        if ("IID.csv" in self.file): 
            data = utils.create_clients(data, self.CLIENTS, strict = False)
            reader.set_features(reader.get_features() + 1)
            self.client_ids =  [i for i in range(0, self.CLIENTS)]

        cols = [i for i in range(0, reader.get_features())]
        del cols[0]
        if "app" in self.file.lower():
            del cols[2]
        elif "infected" in self.file.lower():
            del cols[8]
        elif "cosphere" in self.file.lower():
            del cols[3]

        features = reader.get_features()
        self.scaler.fit(data[:, cols])
        
        for clients in [10]:
            self.CLIENTS = clients
            if ((self.file == "App_usage_trace.txt") or (self.file == "top_90_apps.csv") or ("infected" in self.file) or ("Cosphere" in self.file)): 
                # for subsets: renumber the client ids
                data  = utils.map_ids(data.copy())
                #create list of client ids
                num_of_users = int((np.amax(data[:, 0]) + 1))
                self.client_ids = list(range(0, num_of_users))
                random.shuffle(self.client_ids)
                self.client_ids = self.client_ids[:self.CLIENTS]
                self.client_ids = sorted(self.client_ids)
                print(f"All client-IDs: {self.client_ids}")
                
            if ("IID_2" in self.file): 
                self.client_ids =  [i for i in range(0, self.CLIENTS)]

            #federated training: create dataset per client
            self.train_data = []
            self.test_data = []

            for id in self.client_ids:
                
                indicees = data[:, 0] == id
                former_shape = data[indicees, self.start_id:features].shape

                #delete index 0 and 3 or 9, containing the label and the user id
                client_x = np.delete( data[indicees, self.start_id:features], self.label_id-1, 1 ).reshape(former_shape[0], former_shape[1]-1)
                #scale 
                client_x = self.scaler.transform(client_x)
                client_y = data[indicees, self.label_id].reshape(-1, 1)

                if len(client_x) > 1:
                    X_train, X_test, y_train, y_test = train_test_split(client_x, client_y, test_size=0.2, random_state=42)    
                    ds_train = self.create_dataset(X_train, y_train)
                    ds_test = self.create_dataset(X_test, y_test)
                    print("Client {}: Created  dataset".format(id))

                    self.train_data.append(ds_train)
                    self.test_data.append(ds_test)
                else:
                    print("Could not generate datasets for client {} as there is just one entry in X_train".format(id))
                    self.client_ids.remove(id)

            # Check format for TFF: needs to be in shape(None, dim)
            #like eg:
            # (TensorSpec(shape=(None, 3), dtype=tf.float64, name=None),
            #  TensorSpec(shape=(None, 1), dtype=tf.float64, name=None)

            #Check format for unfederated Learning: shape(batchsize, dim)
            print(self.train_data[0].element_spec)
            print(self.test_data[0].element_spec)

            self.tff_utils = TFF_Utils()

            #run without batch normalization
            self.use_tff = True
            self.use_bn = False
            self.compression = False
            self.init_mlflow()
            self.run_federated()
            
            #run with batch normalization
            self.use_bn = True
            self.init_mlflow()
            self.run_federated()
            
            #test compression
            for t in [10000, 1499, 500]:
                print(f"--------\nClient {clients} mit Threshold: {t}--------")
                self.quantization_thresh = t
                
                #run without batch normalization            
                self.compression = True
                self.use_bn = False
                self.init_mlflow()
                self.run_federated()
                
                #run with batch normalization
                self.use_bn = True
                self.init_mlflow()
                self.run_federated()

            #run zentralized
            if compare_centralized:

                self.use_tff = False
                self.use_bn = False
                self.compression = False
                self.init_mlflow()

                #get same client ids, as with tff
                mask = np.isin(data[:, 0], self.client_ids)
                x = data[mask].copy() 
                #create dataset for centralized training
                unfederated_train, unfederated_test, unfederated_val = self.create_unfederated_dataset(x, reader.get_features())
                self.run_unfederated(unfederated_train, unfederated_test, unfederated_val,  (reader.get_features()-2))

In [None]:
"""Start training with different datasets

"""

# name of different input files and unique labels
input_options = {    "0" : ("App_usage_trace.txt", 3996),
                     "1" : ("top_90_apps.csv" , 3996), 
                     "2" : ("10_infected.csv", 3),
                     "3" : ("num_classes_infected_shuffled.csv", 3),
                     "4" : ("Cosphere.csv", 9972),
                     "5" : ("Cosphere_cropped.csv", 9972),
                     "6" : ("top_90_apps_IID_2.csv", 3996)
                }
#specify which dataset to use
used_option = "1"
c_vs_fed_trainer = Centralized_vs_Federated_Trainer(file_name = input_options[used_option][0], 
                                                    number_of_classes = input_options[used_option][1],
                                                    E = 1, 
                                                    compression = True)
# run
c_vs_fed_trainer.main(compare_centralized=True)