In [None]:
import numpy as np
import os
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.preprocessing import LabelEncoder

from io import BytesIO
import requests

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.metrics import SparseCategoricalAccuracy
from tensorflow.keras import datasets, layers, models

import time
from collections import Counter, OrderedDict
import math
import pandas as pd

import matplotlib.pyplot as plt
import gc

class GarbageCollectorCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        gc.collect()

from fedartml import InteractivePlots, SplitAsFederatedData

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import flwr as fl
from typing import List, Tuple, Dict, Optional
from flwr.common import Metrics

from keras.datasets import cifar10
import threading
from abc import ABC, abstractmethod
from logging import INFO
from flwr.common.logger import log
import random
from flwr.server.client_proxy import ClientProxy
from flwr.server.criterion import Criterion
from logging import WARNING
from typing import Callable, Dict, List, Optional, Tuple, Union
from functools import reduce

from flwr.common import (
    FitRes,
    FitIns,
    MetricsAggregationFn,
    NDArrays,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)
from flwr.server.client_manager import ClientManager
from flwr.server.strategy.fedavg import FedAvg

from sklearn.cluster import OPTICS

In [None]:
def test_model(model, X_test, Y_test):
    cce = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = False)
    logits = model.predict(X_test, batch_size=64, verbose=3, callbacks=[GarbageCollectorCallback()])
    y_pred = tf.argmax(logits, axis=1)
    loss = cce(Y_test, logits).numpy()
    acc = accuracy_score(y_pred, Y_test)
    pre = precision_score(y_pred, Y_test, average='weighted',zero_division = 0)
    rec = recall_score(y_pred, Y_test, average='weighted',zero_division = 0)
    f1s = f1_score(y_pred, Y_test, average='weighted',zero_division = 0)
    return loss, acc, pre, rec, f1s

def from_FedArtML_to_Flower_format(clients_dict):
  list_x_train = []
  list_y_train = []
  client_names = list(clients_dict.keys())
  for client in client_names:
    each_client_train=np.array(clients_dict[client],dtype=object)
    feat=[]
    x_tra=np.array(each_client_train[:, 0])
    for row in x_tra:
      feat.append(row)
    feat=np.array(feat)
    y_tra=np.array(each_client_train[:, 1])
    list_x_train.append(feat)
    list_y_train.append(y_tra)

  return list_x_train, list_y_train

def get_model():
    model = Sequential([
        tf.keras.layers.Conv2D(6, kernel_size=5, strides=1,  activation='relu', input_shape=(32,32,3), padding='same'), #C1
        tf.keras.layers.AveragePooling2D(), #S1
        tf.keras.layers.Conv2D(16, kernel_size=5, strides=1, activation='relu', padding='valid'), #C2
        tf.keras.layers.AveragePooling2D(), #S2
        tf.keras.layers.Conv2D(120, kernel_size=5, strides=1, activation='relu', padding='valid'), #C3
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(84, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    model.compile(optimizer=SGD(learning_rate = 0.01), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

class FlowerClient(fl.client.NumPyClient):
    def __init__(self, model, x_train, y_train, entropy) -> None:
        self.model = model
        self.x_train, self.y_train = x_train, y_train
        self.entropy = entropy

    def get_parameters(self, config):
        return self.model.get_weights()

    def fit(self, parameters, config):
        self.model.set_weights(parameters)
        self.model.fit(self.x_train, self.y_train, epochs=1,verbose=3, batch_size = 64, callbacks=[GarbageCollectorCallback()])
        return self.model.get_weights(), len(self.x_train), {'entropy': self.entropy}

    def evaluate(self, parameters, config):
        return loss, len(self.x_test), {"accuracy": acc}

def plot_metric_from_history(
    hist: None,
    save_plot_path: None,
    metric_type: None,
    metric: None,
) -> None:

    metric_dict = (
        hist.metrics_centralized
        if metric_type == "centralized"
        else hist.metrics_distributed
    )
    rounds, values = zip(*metric_dict[metric])
    plt.plot(np.asarray(rounds), np.asarray(values), color=colors[5], linewidth=5, label='Test')
    plt.legend(fontsize=45)
    plt.xlabel('Communication round', fontsize=40)
    plt.ylabel(metric, fontsize=50)
    plt.title(metric, fontsize=60)
    plt.xticks(fontsize=30)
    plt.yticks(fontsize=30)
    plt.ylim(0, 1)

def retrieve_global_metrics(
    hist: None,
    metric_type: None,
    metric: None,
    best_metric: None,
) -> None:

    metric_dict = (
        hist.metrics_centralized
        if metric_type == "centralized"
        else hist.metrics_distributed
    )
    rounds, values = zip(*metric_dict[metric])
    if best_metric:
      metric_return = max(values)
    else:
      metric_return = values[-1]
    return metric_return

In [None]:
class SimpleClientManager(ClientManager):
    def __init__(self, cluster_labels, entropies) -> None:
        self.clients: Dict[str, ClientProxy] = {}
        self._cv = threading.Condition()
        self.seed = 0
        self.cluster_labels = cluster_labels
        self.num_cluster = len(cluster_labels)
        self.entropies = entropies

    def __len__(self) -> int:
        return len(self.clients)

    def num_available(self) -> int:
        return len(self)

    def wait_for(self, num_clients: int, timeout: int = 86400) -> bool:
        with self._cv:
            return self._cv.wait_for(
                lambda: len(self.clients) >= num_clients, timeout=timeout
            )

    def register(self, client: ClientProxy) -> bool:
        if client.cid in self.clients:
            return False

        self.clients[client.cid] = client
        with self._cv:
            self._cv.notify_all()
        return True

    def unregister(self, client: ClientProxy) -> None:
        if client.cid in self.clients:
            del self.clients[client.cid]
            with self._cv:
                self._cv.notify_all()

    def all(self) -> Dict[str, ClientProxy]:
        return self.clients

    def sample(
        self,
        num_clients: int,
        min_num_clients: Optional[int] = None,
        criterion: Optional[Criterion] = None,
    ) -> List[ClientProxy]:
        if min_num_clients is None:
            min_num_clients = num_clients
        self.wait_for(min_num_clients)

        available_cids = list(self.clients)
        if criterion is not None:
            available_cids = [
                cid for cid in available_cids if criterion.select(self.clients[cid])
            ]
        sampled_cids = []

        if num_clients == 1:
            sampled_cids = random.sample(available_cids, num_clients)

        else:
            sample_choices = []
            sum_entropy = []
            for i in range(5):
                ss = []
                sample_cluster = random.sample(self.cluster_labels, num_clients)
                for cluster in sample_cluster:

                    client = random.choice(cluster)
                    ss.append(client)
                sum_ent = sum(np.exp(self.entropies[cid]) for cid in ss)
                sum_entropy.append(sum_ent)
                sample_choices.append(ss)
            sampled_cids = sample_choices[np.argmax(sum_entropy)]
            sampled_cids = [str(cid) for cid in sampled_cids]
        return [self.clients[cid] for cid in sampled_cids]

In [None]:
def aggregate(results: List[Tuple[NDArrays, float]]) -> NDArrays:
    num_examples_total = sum(num_examples for (_, num_examples) in results)
    weighted_weights = [
        [layer * num_examples for layer in weights] for weights, num_examples in results
    ]
    weights_prime: NDArrays = [
        reduce(np.add, layer_updates) / num_examples_total
        for layer_updates in zip(*weighted_weights)
    ]
    return weights_prime

class FedImp(FedAvg):
    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        if not results:
            return None, {}
        if not self.accept_failures and failures:
            return None, {}
        weights_results = [
                (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples*np.exp(fit_res.metrics["entropy"]/0.7))
                for _, fit_res in results
            ]
        aggregated_ndarrays = aggregate(weights_results)

        parameters_aggregated = ndarrays_to_parameters(aggregated_ndarrays)
        metrics_aggregated = {}
        if self.fit_metrics_aggregation_fn:
            fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
            metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
        elif server_round == 1:
            log(WARNING, "No fit_metrics_aggregation_fn provided")

        return parameters_aggregated, metrics_aggregated

In [None]:
random_state = 0
colors = ["#00cfcc","#e6013b","#007f88","#00cccd","#69e0da","darkblue","#FFFFFF"]
local_nodes_glob = 50
Alpha = 1

In [None]:
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()

train_images = train_images / 255
test_images = test_images / 255
train_labels, test_labels = np.concatenate(train_labels), np.concatenate(test_labels)

my_federater = SplitAsFederatedData(random_state = random_state)

clients_glob_dic, list_ids_sampled_dic, miss_class_per_node, distances = my_federater.create_clients(image_list = train_images, label_list = train_labels,
                                                             num_clients = local_nodes_glob, prefix_cli='client', method = "dirichlet", alpha = Alpha)

clients_glob = clients_glob_dic['with_class_completion']
list_ids_sampled = list_ids_sampled_dic['with_class_completion']

list_x_train, list_y_train = from_FedArtML_to_Flower_format(clients_dict=clients_glob)

In [None]:
def calculate_entropy(y_train):
      counts = Counter(y_train)
      entropy = 0.0
      counts = list(counts.values())
      counts = [0 if value is None else value for value in counts]
      for value in counts:
          entropy += -value/sum(counts) * math.log(value/sum(counts), 10) if value != 0 else 0
      return entropy

entropies = [calculate_entropy(np.array(list_y_train[int(cid)],dtype=int)) for cid in range(len(list_y_train))]
print(entropies)

[0.8140706749336993, 0.7814997270035137, 0.8007472529181611, 0.8869961625657947, 0.7223111043410904, 0.9456407216451947, 0.8371830227020712, 0.6400505782434818, 0.8553229668681934, 0.8474513248299018, 0.8562247907372218, 0.8689586586794313, 0.6883632928438981, 0.8735449107572161, 0.8693551902686224, 0.6109504488631721, 0.7927503114997031, 0.7350948658039949, 0.7047959241521686, 0.8127940558408161, 0.9088224052763686, 0.8523460377551788, 0.8718265328486052, 0.8099585578440226, 0.8238575186118464, 0.8938044653703818, 0.760999301480173, 0.6635587676972069, 0.7941163902013841, 0.755376368485666, 0.8498132703507625, 0.7354360457254969, 0.9231674957066108, 0.8467628816855375, 0.7859911355340558, 0.8541979540120902, 0.9562638223990053, 0.7284278698166272, 0.707804706099066, 0.8053255114202941, 0.735337802090503, 0.785640571122668, 0.8170941481445381, 0.861795218164332, 0.8474287977345513, 0.7866107201447122, 0.878673082139249, 0.8383190241982307, 0.7932174109028591, 0.7041119561001181]


In [None]:
from sklearn.cluster import OPTICS

def split_clusters(labels):
    clusters = {}
    for i, label in enumerate(labels):
        if label not in clusters:
            clusters[label] = [i]
        else:
            clusters[label].append(i)
    return clusters

def hellinger_distance(p, q):
    return np.sqrt(0.5 * ((np.sqrt(p) - np.sqrt(q)) ** 2).sum())

def compute_hellinger_distance_matrix(distributions):
    n = len(distributions)
    distances = np.zeros((n, n))
    for i in range(n):
        for j in range(i+1, n):
            distances[i, j] = hellinger_distance(distributions[i], distributions[j])
            distances[j, i] = distances[i, j]
    return distances

def to_prob_dist(data):
    return data / np.sum(data, axis=1, keepdims=True)

def softmax(x):
    exp_x = np.exp(x)
    sum_exp_x = np.sum(exp_x)
    return exp_x / sum_exp_x

def kmeans_no_small_clusters(data):
    counts = [dict(sorted(Counter(d).items())) for d in list_y_train]

    counts = [list(c.values()) for c in counts]
    counts = [0 if value is None else value for value in counts]
    counts = to_prob_dist(counts)
    distance_matrix = compute_hellinger_distance_matrix(counts)

    clustering = OPTICS(min_samples=2,
                  metric="precomputed").fit(distance_matrix)
    labels = split_clusters(clustering.labels_)
    labels = dict(sorted(labels.items(), key=lambda item: len(item[1])))
    new_dict = labels.copy()
    del new_dict[-1]

    new_key = max(new_dict.keys()) + 1
    for index, value in enumerate(labels[-1]):
        while new_key in new_dict:
            new_key += 1
        new_dict[new_key] = [value]
        new_key += 1
    return new_dict

labels = kmeans_no_small_clusters(list_y_train).values()

print(labels)

def sortf(item):
    return entropies[item]

new_label = []
for label in labels:
    label = sorted(label, key = sortf, reverse=True)
    new_label.append(label)

print(new_label)

dict_values([[5, 36], [8, 13], [9, 48], [16, 21], [22, 23], [25, 35], [45, 49], [10, 33, 34], [14, 43, 46], [29, 31, 39, 40], [0], [1], [2], [3], [4], [6], [7], [11], [12], [15], [17], [18], [19], [20], [24], [26], [27], [28], [30], [32], [37], [38], [41], [42], [44], [47]])
[[36, 5], [13, 8], [9, 48], [21, 16], [22, 23], [25, 35], [45, 49], [10, 33, 34], [46, 14, 43], [39, 29, 31, 40], [0], [1], [2], [3], [4], [6], [7], [11], [12], [15], [17], [18], [19], [20], [24], [26], [27], [28], [30], [32], [37], [38], [41], [42], [44], [47]]


In [None]:
def evaluate_DNN_CL(
    server_round: int,
    parameters: fl.common.NDArrays,
    config: Dict[str, fl.common.Scalar],
) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
    net = get_model()
    net.set_weights(parameters)
    loss, accuracy, precision, recall, f1score  = test_model(net, test_images, test_labels)
    print(f"@@@@@@ Server-side evaluation loss {loss} / accuracy {accuracy} / f1score {f1score} @@@@@@")
    return loss, {"accuracy": accuracy,"precision": precision,"recall": recall,"f1score": f1score}

In [None]:
comms_round = 1000

def client_fn(cid: str) -> fl.client.Client:
    # Define model
    model = get_model()

    x_train_cid = np.array(list_x_train[int(cid)],dtype=float)
    y_train_cid = np.array(list_y_train[int(cid)],dtype=int)
    return FlowerClient(model, x_train_cid, y_train_cid, entropies[int(cid)])

strategy=FedImp(
        fraction_fit=0.2,  
        fraction_evaluate=0,  
        min_fit_clients=10,
        min_available_clients = 50,
        evaluate_fn=evaluate_DNN_CL,
)

clientmanager = SimpleClientManager(cluster_labels = new_label, entropies=entropies)

history = fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=local_nodes_glob,
    config=fl.server.ServerConfig(num_rounds=comms_round),
    strategy=strategy,
    client_manager = clientmanager,
    client_resources = {'num_cpus': 1, 'num_gpus': 0},
)

start_time = time.time()

commun_metrics_history = fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=local_nodes_glob,
    config=fl.server.ServerConfig(num_rounds=comms_round),
    strategy=strategy,
)

training_time = time.time() - start_time

global_acc_test = retrieve_global_metrics(commun_metrics_history,"centralized","accuracy",False)

global_pre_test = retrieve_global_metrics(commun_metrics_history,"centralized","precision",False)

global_rec_test = retrieve_global_metrics(commun_metrics_history,"centralized","recall",False)

global_f1s_test = retrieve_global_metrics(commun_metrics_history,"centralized","f1score",False)

print("\n\nFINAL RESULTS: ===========================================================================================================================================================================================")
print('Test: commun_round: {} | global_acc: {:} | global_pre: {} | global_rec: {} | global_f1s: {}'.format(comms_round, global_acc_test, global_pre_test, global_rec_test, global_f1s_test))
print("Training time: %s seconds" % (training_time))

In [None]:
metrics_show = ["accuracy","precision","recall","f1score"]

# Define dimensions for plot
f, axs = plt.subplots(1,len(metrics_show),figsize=(70,15))

# Loop over the communication round history and metrics
for i in range(len(metrics_show)):
  plt.subplot(1, len(metrics_show), i + 1)
  plot_metric_from_history(commun_metrics_history,"any","centralized",metrics_show[i])