In [None]:
!pip install --upgrade tensorflow-federated
!pip install nest_asyncio
import nest_asyncio
nest_asyncio.apply()

In [None]:
import sys
sys.path.append("../")

import tensorflow as tf
from tensorflow.keras import Model, callbacks
from tensorflow.keras.layers import Dense, Softmax
import tensorflow_federated as tff
from tensorflow_model_optimization.python.core.internal import tensor_encoding as te

from sklearn.model_selection import train_test_split
import collections

import numpy as np 
import pandas as pd
from sklearn.preprocessing import MinMaxScaler, StandardScaler
import random
import time
import statistics
from datetime import datetime
import mlflow 

from Reader import Reader
from model.FLModel import FLModel
from model.BNModel import BNModel
from Utils import Utils
from TFF_Utils import TFF_Utils

In [15]:
class Centralized_vs_Federated_Trainer:
    
    def __init__(self, file_name, number_of_classes):
        self.file = file_name
        self.NUM_CLASSES = number_of_classes
        self.BASE_DIR = "../"
        self.DATA_DIR = "data/"
        self.BATCH_SIZE = 64
        self.PREFETCH_BUFFER = 10
        self.SHUFFLE_BUFFER = 64
        self.CLIENTS = 10

        self.EPOCHS = 10 
        self.E = 1 # amount of local rounds       
        self.drop_index = True
        
        if "infected" in self.file:
            self.drop_index = False
            self.start_id = 1
            self.label_id = 9
        elif ("Cosphere" in self.file):
            self.start_id = 1
            self.label_id = 4
        else:
            self.start_id = 1
            self.label_id = 3

        self.k = 1
        self.use_bn = False
        self.use_tff = True
        self.learning_rate = 1e-1
        self.momentum = 0.9
        self.nesterov = False
        
    
    #mlflow setup
    def init_mlflow(self):
        mlflow.set_experiment(self.file)
        mlflow.end_run()
        mlflow.start_run()
        mlflow_experiment_id = mlflow.get_experiment_by_name(self.file).experiment_id
        mlflow_run_id = mlflow.active_run().info.run_id
        log_path = "mlruns/" + str(mlflow_experiment_id) + "/" + str(mlflow_run_id) + "/" + "artifacts" + "/"

        mlflow.log_param("local_rounds", self.k)
        mlflow.log_param("use_bn", self.use_bn)
        mlflow.log_param("use_tff", self.use_tff)
        if self.compression:
            mlflow.log_param("use_compression", self.compression)
            mlflow.log_param("quant_bits", self.quantization_bits)

        mlflow.log_param("lr", self.learning_rate)
        mlflow.log_param("momentum", self.momentum)
        mlflow.log_param("nesterov", self.nesterov)
        mlflow.log_param("batch_size", self.BATCH_SIZE)
        mlflow.log_param("clients", self.CLIENTS)
        mlflow.log_param("epochs", self.EPOCHS)
        mlflow.log_param("classes", self.NUM_CLASSES)
            
    def split_input_target(self, input, target):
        return input, target

    def create_dataset(self, x, y, use_tff = True):
        ds =  tf.data.Dataset.from_tensor_slices((x, y))

        if use_tff:
            return (
            ds.repeat(self.EPOCHS).shuffle(self.SHUFFLE_BUFFER)
            .map(self.split_input_target)).batch(self.BATCH_SIZE) 
        else:
            return ds.repeat(self.BATCH_SIZE * self.EPOCHS).shuffle(self.SHUFFLE_BUFFER).batch(self.BATCH_SIZE,drop_remainder = True) 

    def get_split(self, x, y):
        return train_test_split(x, y, test_size=0.2, random_state=42)

    def create_unfederated_dataset(self, x, features):
        former_shape = x[:, self.start_id:features].shape
        client_x = np.delete( x[:, self.start_id:features], self.label_id-1, 1 ).reshape(former_shape[0], former_shape[1]-1)
        client_x = self.scaler.transform(client_x)
        client_y = x[:, self.label_id].reshape(-1, 1)
        X_train, X_test, y_train, y_test = self.get_split(client_x, client_y)
        X_train, X_val, y_train, y_val = self.get_split(X_train, y_train)
        train_data = self.create_dataset(X_train, y_train, use_tff=False)
        test_data = self.create_dataset(X_test, y_test, use_tff=False)
        val_data = self.create_dataset(X_val, y_val, use_tff=False)
        return train_data, test_data, val_data
    
    def get_not_bn_idx(self, trainable_variables):
      new_trainable_weights = []
      train_var_idx = []
      for idx, bn_weights in enumerate(trainable_variables):
        if 'batch_normalization' not in bn_weights.name:
          train_var_idx.append(idx)
      return train_var_idx

    def get_bn_idx(self, trainable_variables):
      new_trainable_weights = []
      train_var_idx = []
      for idx, bn_weights in enumerate(trainable_variables):
        if 'batch_normalization' in bn_weights.name:
          train_var_idx.append(idx)
      return train_var_idx

    def get_weights_by_idx(self, trainable_variables, var_ids):
      new_weights = []  
      for i in var_ids:
        new_weights.append(trainable_variables[i])
      return new_weights

    def create_keras_model(self, input_dim):
        return tf.keras.models.Sequential([
          tf.keras.layers.InputLayer(input_shape=(input_dim,)),
          tf.keras.layers.Dense(500, activation=tf.nn.relu),
          tf.keras.layers.Dense(self.NUM_CLASSES, activation='softmax'),
        ])

    def create_keras_bn_model(self, input_dim):
        return tf.keras.models.Sequential([
          tf.keras.layers.BatchNormalization(input_shape=(input_dim,)),
          tf.keras.layers.Dense(500, activation=tf.nn.relu),
          tf.keras.layers.BatchNormalization(),
          tf.keras.layers.Dense(self.NUM_CLASSES, activation='softmax'),
        ])

    # Each time the next method is called, the server model is broadcast to each client using a broadcast function. 
    # For each client, one epoch of local training is performed via the tf.keras.optimizers.Optimizer.apply_gradients method of the client optimizer. 
    # Each client computes the difference between the client model after training and the initial broadcast model. 
    # These model deltas are then aggregated at the server using some aggregation function. 
    # The aggregate model delta is applied at the server by using the tf.keras.optimizers.Optimizer.apply_gradients method of the server optimizer.
    def model_fn(self):
      # We _must_ create a new model here, and _not_ capture it from an external
      # scope. TFF will call this within different graph contexts.
        if self.use_bn:
          print("Using model with batch normalization")
          keras_model = self.create_keras_bn_model(self.test_data[0].element_spec[0].shape[1])
        else:
          print("Using model without batch normalization")
          keras_model = self.create_keras_model(self.test_data[0].element_spec[0].shape[1])
        return tff.learning.from_keras_model(
          keras_model,
          input_spec = self.train_data[0].element_spec,
          loss = tf.keras.losses.SparseCategoricalCrossentropy(),
          metrics = [tf.keras.metrics.SparseCategoricalAccuracy(), 
                   tf.keras.metrics.SparseTopKCategoricalAccuracy(k=self.k)]) 


    def map_weights_ids(self):
      model = self.create_keras_bn_model(self.train_data[0].element_spec[0].shape[1])
      trainable_variables = model.trainable_variables
      all_weights_name = []
      train_weights_ids = []

      for layer in model.layers:
        for weight in layer.weights:
          all_weights_name.append(weight.name)

      for var in trainable_variables:
        for idx, name in enumerate(all_weights_name):
          if var.name == name:
            if 'batch_normalization' not in name:
              train_weights_ids.append(idx)

      return train_weights_ids

    def broadcast_encoder_fn(self, value):
      """Function for building encoded broadcast."""
      spec = tf.TensorSpec(value.shape, value.dtype)
      if value.shape.num_elements() > self.quantization_thresh:
        return te.encoders.as_simple_encoder(
            te.encoders.uniform_quantization(bits=self.quantization_bits), spec)
      else:
        return te.encoders.as_simple_encoder(te.encoders.identity(), spec)


    def mean_encoder_fn(self, tensor_spec):
      """Function for building a GatherEncoder."""
      spec = tf.TensorSpec(tensor_spec.shape, tensor_spec.dtype)
      if tensor_spec.shape.num_elements() > self.quantization_thresh:
        return te.encoders.as_gather_encoder(
            te.encoders.uniform_quantization(bits= self.quantization_bits), spec)
      else:
        return te.encoders.as_gather_encoder(te.encoders.identity(), spec)

    def run_federated(self):
        broadcast_process = None
        mean_factory = None
        
        if self.compression: 
            # Equivalent to:
            # compressed_mean = tff.learning.compression_aggregator(zeroing=False, clipping=False)
            
            # TODO try out Quantization bits = 6 or 7 bits
            # If resources permit doing a small grid search, we would recommend that you identify the value 
            # for which training becomes unstable or final model quality starts to degrade, and then increase that value by two
            
            # TODO Clients per round. Note that significantly increasing the number of clients per round can enable a smaller value for quantization_bits to work well, 
            # because the randomized inaccuracy introduced by quantization may be evened out by averaging over more client updates
            mean_factory = tff.aggregators.MeanFactory(
                value_sum_factory = tff.aggregators.EncodedSumFactory(self.mean_encoder_fn),
                weight_sum_factory = tff.aggregators.EncodedSumFactory(self.mean_encoder_fn)
                )   
            encoded_broadcast_process = (tff.learning.framework.build_encoded_broadcast_process_from_model(
                                         self.model_fn, self.broadcast_encoder_fn))
            

        # ---custom fed avg implementation for batch normalization-----
        # --------------------------START------------------------------
        @tff.tf_computation
        def server_init():
            model = self.model_fn()
            trainable_variables = model.trainable_variables
            non_bn_ids = self.get_not_bn_idx(trainable_variables)
            non_bn_weights = self.get_weights_by_idx(trainable_variables, non_bn_ids)
            return non_bn_weights

        @tf.function
        def server_update(model, mean_client_weights):
            """Updates the server model weights as the average of the client model weights."""
            updated_model_weights = []
            trainable_variables = model.trainable_variables

            non_bn_ids = self.get_not_bn_idx(trainable_variables)
            bn_ids = self.get_bn_idx(trainable_variables)
            non_bn_weights = self.get_weights_by_idx(trainable_variables, non_bn_ids)
            bn_weights = self.get_weights_by_idx(trainable_variables, bn_ids)

            tf.nest.map_structure(lambda x, y: x.assign(y),
                                    non_bn_weights, mean_client_weights)
            for i in range(len(trainable_variables)):
                if i in non_bn_ids:
                    j = non_bn_ids.index(i)
                    updated_model_weights.append(non_bn_weights[j])
                else: 
                    j = bn_ids.index(i)
                    updated_model_weights.append(bn_weights[j])
            return non_bn_weights


        @tf.function
        def client_update(model, dataset, server_weights, client_optimizer):
            updated_clients_weights = []
            trainable_variables = model.trainable_variables
            
            #get ids of non_batch normalization weights
            non_bn_ids = self.get_not_bn_idx(trainable_variables)
            mlflow.log_param("number of updatable non-BN weights", len(non_bn_ids))
            
            #get ids of batch normalization weights
            bn_ids = self.get_bn_idx(trainable_variables)
            
            non_bn_weights = self.get_weights_by_idx(trainable_variables, non_bn_ids)
            bn_weights = self.get_weights_by_idx(trainable_variables, bn_ids)
            
            # Assign the mean client weights to the server model.
            tf.nest.map_structure(lambda x, y: x.assign(y),
                                    non_bn_weights, server_weights)

            #update any weight which is not from a bn layer
            for i in range(len(trainable_variables)):
                if i in non_bn_ids:
                    j = non_bn_ids.index(i)
                    updated_clients_weights.append(non_bn_weights[j])
                    
                    #log shape of weight 
                    mlflow.log_param(f"{j} Shape of updatable non-BN weight", non_bn_weights[j].shape)
                else: 
                    j = bn_ids.index(i)
                    updated_clients_weights.append(bn_weights[j])

            client_weights = updated_clients_weights

            for epoch in range(self.E):
                for batch in dataset:
                    with tf.GradientTape() as tape:
                        outputs = model.forward_pass(batch)
                    grads = tape.gradient(outputs.loss, client_weights)
                    grads_and_vars = zip(grads, client_weights)
                    client_optimizer.apply_gradients(grads_and_vars)

            return non_bn_weights

        whimsy_model = self.model_fn()
        tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)
        model_weights_type = server_init.type_signature.result
        federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
        federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)
        print(federated_server_type)
        print(federated_dataset_type)

        @tff.tf_computation(tf_dataset_type, model_weights_type)
        def client_update_fn(tf_dataset, server_weights):
            model = self.model_fn()
            client_optimizer = tf.keras.optimizers.SGD(learning_rate= self.learning_rate, momentum=self.momentum, nesterov=self.nesterov)
            return client_update(model, tf_dataset, server_weights, client_optimizer)

        @tff.tf_computation(model_weights_type)
        def server_update_fn(mean_client_weights):
            model = self.model_fn()
            return server_update(model, mean_client_weights)

        @tff.federated_computation
        def initialize_fn():
            return tff.federated_value(server_init(), tff.SERVER)


        @tff.federated_computation(federated_server_type, federated_dataset_type)
        def next_fn(server_weights, federated_dataset):

        # Broadcast the server weights to the clients.
            server_weights_at_client = tff.federated_broadcast(server_weights)

            # Each client computes their updated weights.
            client_weights = tff.federated_map(
                client_update_fn, (federated_dataset, server_weights_at_client))

            # The server averages these updates.
            mean_client_weights = tff.federated_mean(client_weights)

            # # The server updates its model.
            server_weights = tff.federated_map(server_update_fn, mean_client_weights)
            return server_weights

        def evaluate(server_state, train_data, test_data, trainable_ids):
            server_weights = []
            acc_mean, loss_mean, k_acc_mean = [], [], []
            model = self.create_keras_bn_model(test_data[0].element_spec[0].shape[1])
            model.compile(
                    loss = tf.keras.losses.SparseCategoricalCrossentropy(),
                    metrics = [tf.keras.metrics.SparseCategoricalAccuracy(), 
                            tf.keras.metrics.SparseTopKCategoricalAccuracy(k=self.k)]  
            )
            weights = model.get_weights()

            for idx, weight_id in enumerate(trainable_ids):
                weights[weight_id] = np.array(server_state[idx])

            model.set_weights(weights)
            print("\t--Training--\t")

            for batch in self.train_data:
                loss, acc, k_acc = model.evaluate(batch, batch_size=self.BATCH_SIZE, verbose=0)
                loss_mean.append(loss)
                acc_mean.append(acc)
                k_acc_mean.append(k_acc)
            train_loss = statistics.mean(loss_mean)
            train_acc = statistics.mean(acc_mean)
            train_k_acc = statistics.mean(k_acc_mean)
            acc_mean, loss_mean, k_acc_mean = [], [], []

            print("\t--Testing--\t")
            for batch in self.test_data:
                loss, acc, top_k_acc = model.evaluate(batch, batch_size=self.BATCH_SIZE, verbose=0)
                loss_mean.append(loss)
                acc_mean.append(acc)
                k_acc_mean.append(k_acc)
            test_loss = statistics.mean(loss_mean)
            test_acc = statistics.mean(acc_mean)
            test_k_acc = statistics.mean(k_acc_mean)
            return {"Train_acc":train_acc, 
                    f"Train_{self.k}_acc":train_k_acc, 
                    "Train_loss":train_loss, 
                    "Test_acc":test_acc, 
                    f"Test_{self.k}_acc":test_k_acc,
                    "Test_loss":test_loss,}
        # -------------------END---------------------

        with tf.device('/gpu:0'):
            if self.use_bn:    #if batch normalization dont update the clients weights of the batch normalization layer. 
                #returns just server state-> 
                trainable_ids = self.map_weights_ids()
                fed_avg = tff.templates.IterativeProcess(
                                        initialize_fn=initialize_fn,
                                        next_fn=next_fn
                                        )

                state = fed_avg.initialize()
                #-----
                environment = self.tff_utils.set_sizing_environment()
                #-----

                for round_num in range(self.EPOCHS):  
                    state = fed_avg.next(state, self.train_data)
                  #---
                    size_info = environment.get_size_info()
                    broadcasted_bits = size_info.broadcast_bits[-1]
                    aggregated_bits = size_info.aggregate_bits[-1]
                    mlflow.log_metric('cumulative_broadcasted_bits', broadcasted_bits, round_num)
                    mlflow.log_metric('cumulative_aggregated_bits', aggregated_bits, round_num)
                    
                    print('round {:2d}, broadcasted_bits={}, aggregated_bits={}'
                          .format(round_num,  
                                  self.tff_utils.format_size(broadcasted_bits), 
                                  self.tff_utils.format_size(aggregated_bits)
                            )
                         )
                    #---
                    
                    metrics = evaluate(state, self.train_data, self.test_data, trainable_ids)
                    for name, value in metrics.items():
                        mlflow.log_metric(name, value, round_num)
                        print(round_num, name, value)

            else:
                fed_avg = tff.learning.build_federated_averaging_process(
                    self.model_fn,
                    client_optimizer_fn = lambda: tf.keras.optimizers.SGD(learning_rate= self.learning_rate, momentum=self.momentum, nesterov=self.nesterov), 
                    server_optimizer_fn = lambda: tf.keras.optimizers.SGD(learning_rate= 1.0),
                    broadcast_process = broadcast_process,
                    model_update_aggregation_factory = mean_factory
                    )

                state = fed_avg.initialize()
                
                #-----
                environment = self.tff_utils.set_sizing_environment()
                #-----
                
                for round_num in range(self.EPOCHS):
                    print(round_num)
                    state, metrics = fed_avg.next(state, self.train_data)
                    # Note: training metrics reported by the iterative training process 
                    #generally reflect the performance of the model at the beginning of the training round
                    
                    #---
                    size_info = environment.get_size_info()
                    broadcasted_bits = size_info.broadcast_bits[-1]
                    aggregated_bits = size_info.aggregate_bits[-1]
                    mlflow.log_metric('cumulative_broadcasted_bits', broadcasted_bits, round_num)
                    mlflow.log_metric('cumulative_aggregated_bits', aggregated_bits, round_num)
                    
                    print('round {:2d}, metrics={}, broadcasted_bits={}, aggregated_bits={}'
                          .format(round_num, metrics, 
                                  self.tff_utils.format_size(broadcasted_bits), 
                                  self.tff_utils.format_size(aggregated_bits)
                            )
                         )
                    #---
                    
                    for name, value in metrics['train'].items():
                        mlflow.log_metric(name, value, round_num)
                        print(round_num, name, value)

                    evaluation = tff.learning.build_federated_evaluation(self.model_fn)  
                    test_metrics = evaluation(state.model, self.test_data)
                    for name, value in test_metrics.items():
                        mlflow.log_metric(name, value, round_num)
                        print(round_num, name, value)

    def run_unfederated(self, ds_train, ds_test, ds_val, input_dim):

        early_stopping_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', 
                                    min_delta=0.01, 
                                    patience=2, 
                                    verbose=0, 
                                    mode='auto', 
                                    baseline=None, 
                                    restore_best_weights=True)
        if self.use_bn:
          model = BNModel(self.NUM_CLASSES)
        else:
          model = FLModel(self.NUM_CLASSES)
        model.compile(
                    optimizer= self.client_optimizer, 
                    loss= "sparse_categorical_crossentropy", 
                    metrics= [
                              self.sparseCategoricalAcc, 
                              self.sparseTopKCategoricalAccuracy
                              ]
                    )
        for epoch in range(self.EPOCHS):
            with tf.device('/gpu:0'):
                history = model.fit(
                                  ds_train,
                                  steps_per_epoch=64, 
                                  validation_data = ds_val, 
                                  verbose=0,
                                  callbacks = [early_stopping_callback])     

            loss = round(history.history["loss"][0], 8)
            acc = round(history.history["sparse_categorical_accuracy"][0], 8)
            k_acc =  round(history.history["sparse_top_k_categorical_accuracy"][0], 8)
            val_loss =  round(history.history['val_loss'][0], 8)
            val_acc = round(history.history['val_sparse_categorical_accuracy'][0], 8)
            val_k_acc =  round(history.history['val_sparse_top_k_categorical_accuracy'][0], 8)
                    
            with tf.device('/gpu:0'):
                    test_loss, test_acc, test_k_acc = model.evaluate(ds_test, batch_size=self.BATCH_SIZE, verbose=0)

            mlflow.log_metric("Loss/train", loss, epoch)
            mlflow.log_metric("Acc/train", acc, epoch)
            mlflow.log_metric("K_acc/train", k_acc, epoch)

            mlflow.log_metric("Loss/validation", val_loss, epoch)
            mlflow.log_metric("Acc/validation", val_acc, epoch)
            mlflow.log_metric("K_acc/validation", val_k_acc, epoch)

            mlflow.log_metric("Loss/test", test_loss, epoch)
            mlflow.log_metric("Acc/test", test_acc, epoch)
            mlflow.log_metric("K_acc/test", test_k_acc, epoch)

            print(
              f'Epoch: {epoch},\n'
              f'Train Loss:\t{loss}, '
              f'Train Accuracy:\t{acc}, '
              f'Train Top 5 Accuracy:\t{k_acc}\n'
              f'Validation Loss:\t{val_loss}, '
              f'Validation Accuracy:\t{val_acc}, '
              f'Validation Top 5 Accuracy:\t{val_k_acc}\n'
              f'Test Loss:\t{test_loss}, '
              f'Test Accuracy:\t{test_acc} '
              f'Test Top 5 Accuracy:\t{test_k_acc}'
              f'\n--------------------------------------------------------------------------------------------------------------------------\n'
            )

    def main(self, compare_centralized = False, compression = False):
        
        self.compression = compression
        if self.compression: 
            self.quantization_bits = 8 #default 8
            self.quantization_thresh = 20000 # default 20000
        
        self.entropy_loss = tf.keras.losses.SparseCategoricalCrossentropy()
        self.sparseCategoricalAcc = tf.keras.metrics.SparseCategoricalAccuracy()
        self.sparseTopKCategoricalAccuracy = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=self.k)
        self.client_optimizer =  tf.keras.optimizers.SGD(learning_rate= self.learning_rate, momentum=self.momentum, nesterov=self.nesterov)
        self.server_optimzer = tf.keras.optimizers.SGD(learning_rate= 1.0)
        
        #read dataset
        self.client_ids = None
        self.scaler = StandardScaler()
        utils = Utils()
        reader = Reader(self.BASE_DIR + self.DATA_DIR, self.file, self.drop_index)
        data = reader.get_data()

        #clean from client_ids and label_ids and pass to scaler
        if ("IID.csv" in self.file): 
            data = utils.create_clients(data, self.CLIENTS, strict = False)
            reader.set_features(reader.get_features() + 1)
            self.client_ids =  [i for i in range(0, self.CLIENTS)]

        cols = [i for i in range(0, reader.get_features())]
        del cols[0]
        if "app" in self.file.lower():
            del cols[2]
        elif "infected" in self.file.lower():
            del cols[8]
        elif "cosphere" in self.file.lower():
            del cols[3]

        features = reader.get_features()
        self.scaler.fit(data[:, cols])
        
        
#         #TODO start loop für verschiedene clients
        for clients in [2, 3]:
            self.CLIENTS = clients
            print(f"\n {clients} CLIENTS\n")

            if ((self.file == "App_usage_trace.txt") or (self.file == "top_90_apps.csv") or ("infected" in self.file) or ("Cosphere" in self.file)): 
                # for subsets: renumber the client ids
                data  = utils.map_ids(data.copy())
                #create list of client ids
                num_of_users = int((np.amax(data[:, 0]) + 1))
                self.client_ids = list(range(0, num_of_users))
                random.shuffle(self.client_ids)
                self.client_ids = self.client_ids[:self.CLIENTS]
                self.client_ids = sorted(self.client_ids)
                
            if ("IID_2" in self.file): 
                self.client_ids =  [i for i in range(0, self.CLIENTS)]

            #federated training: create dataset per client
            self.train_data = []
            self.test_data = []

            for id in self.client_ids:
                
                indicees = data[:, 0] == id
                print(len(indicees))
                print(data[:5])
                former_shape = data[indicees, self.start_id:features].shape

                #delete index 0 and 3 or 9, containing the label and the user id
                client_x = np.delete( data[indicees, self.start_id:features], self.label_id-1, 1 ).reshape(former_shape[0], former_shape[1]-1)
                #scale 
                client_x = self.scaler.transform(client_x)
                client_y = data[indicees, self.label_id].reshape(-1, 1)

                if len(client_x) > 1:
                    X_train, X_test, y_train, y_test = train_test_split(client_x, client_y, test_size=0.2, random_state=42)    
                    ds_train = self.create_dataset(X_train, y_train, self.use_tff)
                    ds_test = self.create_dataset(X_test, y_test, self.use_tff)
                    print("Client {}: Created  dataset".format(id))

                    self.train_data.append(ds_train)
                    self.test_data.append(ds_test)
                else:
                    print("Could not generate datasets for client {} as there is just one entry in X_train".format(id))
                    self.client_ids.remove(id)

            # Check format for TFF: needs to be in shape(None, dim)
            #like eg:
            # (TensorSpec(shape=(None, 3), dtype=tf.float64, name=None),
            #  TensorSpec(shape=(None, 1), dtype=tf.float64, name=None)

            #Check format for unfederated Learning: shape(batchsize, dim)
            print(self.train_data[0].element_spec)
            print(self.test_data[0].element_spec)

            self.tff_utils = TFF_Utils()


            #run without batch normalization
            self.init_mlflow()
            self.run_federated()

            #run with batch normalization
            self.use_bn = True
            self.init_mlflow()
            self.run_federated()

            if compare_centralized:

                self.use_tff = False
                self.use_bn = False
                self.init_mlflow()

                #get same client ids, as with tff
                mask = np.isin(data[:, 0], self.client_ids)
                x = data[mask].copy() 
                #create dataset for centralized training
                unfederated_train, unfederated_test, unfederated_val = self.create_unfederated_dataset(x, reader.get_features())
                self.run_unfederated(unfederated_train, unfederated_test, unfederated_val,  (reader.get_features()-2))

In [None]:
input_options = {    "0" : ("App_usage_trace.txt", 3996),
                     "1" : ("top_90_apps.csv" , 3996), #171
                     "2" : ("10_infected.csv", 3),
                     "3" : ("num_classes_infected_shuffled.csv", 3),
                     "4" : ("Cosphere.csv", 9972),
                     "5" : ("Cosphere_cropped.csv", 9972),
                     "6" : ("top_90_apps_IID_2.csv", 3996)
                }

used_option = "2"
c_vs_fed_trainer = Centralized_vs_Federated_Trainer(file_name = input_options[used_option][0], 
                                                    number_of_classes =input_options[used_option][1] )
c_vs_fed_trainer.main(compare_centralized=False, compression = True)