In [None]:
import numpy as np
import os
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, roc_curve
from collections import Counter, OrderedDict

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.metrics import SparseCategoricalAccuracy
from tensorflow.keras.optimizers import SGD
from tensorflow.keras import datasets, layers, models

import pandas as pd

import matplotlib.pyplot as plt

import gc

class GarbageCollectorCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        gc.collect()

from fedartml import InteractivePlots, SplitAsFederatedData

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import flwr as fl

from typing import Callable, Dict, List, Optional, Tuple, Union
from flwr.common import Metrics
import random
from keras.datasets import cifar10, mnist
from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    MetricsAggregationFn,
    NDArrays,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)
from flwr.common.logger import log
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.criterion import Criterion
from flwr.server.strategy.aggregate import aggregate
from flwr.server.strategy.fedavg import FedAvg

In [None]:
def test_model(model, X_test, Y_test):
    cce = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = False)
    logits = model.predict(X_test, batch_size=64, verbose=3, callbacks=[GarbageCollectorCallback()])
    y_pred = tf.argmax(logits, axis=1)
    loss = cce(Y_test, logits).numpy()
    acc = accuracy_score(y_pred, Y_test)
    pre = precision_score(y_pred, Y_test, average='weighted',zero_division = 0)
    rec = recall_score(y_pred, Y_test, average='weighted',zero_division = 0)
    f1s = f1_score(y_pred, Y_test, average='weighted',zero_division = 0)
    return loss, acc, pre, rec, f1s

def from_FedArtML_to_Flower_format(clients_dict):
  list_x_train = []
  list_y_train = []
  client_names = list(clients_dict.keys())
  for client in client_names:
    each_client_train=np.array(clients_dict[client],dtype=object)
    feat=[]
    x_tra=np.array(each_client_train[:, 0])
    for row in x_tra:
      feat.append(row)
    feat=np.array(feat)
    y_tra=np.array(each_client_train[:, 1])
    list_x_train.append(feat)
    list_y_train.append(y_tra)

  return list_x_train, list_y_train

def get_model():
    model = Sequential([
        tf.keras.layers.Conv2D(6, kernel_size=5, strides=1,  activation='relu', input_shape=(32,32,3), padding='same'), #C1
        tf.keras.layers.AveragePooling2D(), #S1
        tf.keras.layers.Conv2D(16, kernel_size=5, strides=1, activation='relu', padding='valid'), #C2
        tf.keras.layers.AveragePooling2D(), #S2
        tf.keras.layers.Conv2D(120, kernel_size=5, strides=1, activation='relu', padding='valid'), #C3
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(84, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    model.compile(optimizer=SGD(learning_rate = 0.01), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

class FlowerClient(fl.client.NumPyClient):
    def __init__(self, model, x_train, y_train, cid) -> None:
        self.model = model
        self.x_train, self.y_train = x_train, y_train
        self.cid = int(cid)
        entropies[self.cid]*=0.995

    def get_parameters(self, config):
        return self.model.get_weights()
    
    def fit(self, parameters, config):
        self.model.set_weights(parameters)
        self.model.compile(optimizer=SGD(learning_rate = config["learning_rate"]), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
        self.model.fit(self.x_train, self.y_train, epochs=1,verbose=3, batch_size = 64, callbacks=[GarbageCollectorCallback()])
        return self.model.get_weights(), len(self.x_train), {"entropy": entropies[self.cid]}

    def evaluate(self, parameters, config):
        return loss, len(self.x_test), {"accuracy": acc}

In [None]:
random_state = 0
colors = ["#00cfcc","#e6013b","#007f88","#00cccd","#69e0da","darkblue","#FFFFFF"]
local_nodes_glob = 40
Alpha1 = 100
Alpha2 = 0.1

In [None]:
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()

train_images = train_images / 255
test_images = test_images / 255
train_labels, test_labels = np.concatenate(train_labels), np.concatenate(test_labels)

train1, train2, test1, test2 = train_test_split(train_images, train_labels, test_size=0.25, random_state=random_state)

my_federater = SplitAsFederatedData(random_state = random_state)

clients_glob_dic1, list_ids_sampled_dic1, miss_class_per_node1, distances1 = my_federater.create_clients(image_list = train1, label_list = test1,
                                                             num_clients = 30, prefix_cli='client', method = "dirichlet", alpha = Alpha1)

clients_glob1 = clients_glob_dic1['with_class_completion']
list_ids_sampled1 = list_ids_sampled_dic1['with_class_completion']

clients_glob_dic2, list_ids_sampled_dic2, miss_class_per_node2, distances2 = my_federater.create_clients(image_list = train2, label_list = test2,
                                                             num_clients = 10, prefix_cli='client', method = "dirichlet", alpha = Alpha2)

clients_glob2 = clients_glob_dic2['with_class_completion']
list_ids_sampled2 = list_ids_sampled_dic2['with_class_completion']

list_x_train1, list_y_train1 = from_FedArtML_to_Flower_format(clients_dict=clients_glob1)
list_x_train2, list_y_train2 = from_FedArtML_to_Flower_format(clients_dict=clients_glob2)

list_x_train, list_y_train = (list_x_train1+list_x_train2), (list_y_train1+list_y_train2)

In [None]:
import math
def calculate_entropy(y_train):
      counts = Counter(y_train)
      entropy = 0.0
      counts = list(counts.values())
      counts = [0 if value is None else value for value in counts]
      for value in counts:
          entropy += -value/sum(counts) * math.log(value/sum(counts), 10) if value != 0 else 0
      return entropy

entropies = [calculate_entropy(np.array(list_y_train[int(cid)],dtype=int)) for cid in range(len(list_y_train))]
tem = (1-(np.std(entropies)+0.01)/(np.mean(entropies)+0.01))
print(entropies)
print(tem)

In [None]:
def weighted_srs_wr(elements, weights, sample_size):
    assert len(elements) == len(weights), "Danh sách phần tử và trọng số phải có cùng độ dài."
    assert sample_size <= len(elements), "Kích thước mẫu phải nhỏ hơn hoặc bằng số phần tử."

    elements_copy = elements[:]
    weights_copy = [len(elements[i])/(weights[i]*40) for i in range(len(weights))]

    sample = []

    for _ in range(sample_size):
        total_weight = sum(weights_copy)
        random_choice = random.uniform(0, total_weight)
        cumulative_weight = 0
        for i, weight in enumerate(weights_copy):
            cumulative_weight += weight
            if cumulative_weight >= random_choice:
                sample.append(elements_copy[i])
                elements_copy.pop(i)
                weights_copy.pop(i)
                weights[i] += 1
                weights[i] = len(elements[i])/(weights[i]*40)
                break
    return sample

In [None]:
import threading

class SimpleClientManager(ClientManager):
    def __init__(self, cluster_labels, entropies) -> None:
        self.clients: Dict[str, ClientProxy] = {}
        self._cv = threading.Condition()
        self.cluster_labels = cluster_labels
        self.num_cluster = len(cluster_labels)
        self.entropies = entropies
        self.choosen = [1 for _ in range(len(cluster_labels))]

    def __len__(self) -> int:
        return len(self.clients)

    def num_available(self) -> int:
        return len(self)

    def wait_for(self, num_clients: int, timeout: int = 86400) -> bool:
        with self._cv:
            return self._cv.wait_for(
                lambda: len(self.clients) >= num_clients, timeout=timeout
            )

    def register(self, client: ClientProxy) -> bool:
        if client.cid in self.clients:
            return False

        self.clients[client.cid] = client
        with self._cv:
            self._cv.notify_all()
        return True

    def unregister(self, client: ClientProxy) -> None:
        if client.cid in self.clients:
            del self.clients[client.cid]
            with self._cv:
                self._cv.notify_all()

    def all(self) -> Dict[str, ClientProxy]:
        return self.clients

    def sample(
        self,
        num_clients: int,
        min_num_clients: Optional[int] = None,
        criterion: Optional[Criterion] = None,
    ) -> List[ClientProxy]:
        if min_num_clients is None:
            min_num_clients = num_clients
        self.wait_for(min_num_clients)

        available_cids = list(self.clients)
        if criterion is not None:
            available_cids = [
                cid for cid in available_cids if criterion.select(self.clients[cid])
            ]
        sampled_cids = []

        if num_clients == 1:
            sampled_cids = random.sample(available_cids, num_clients)

        else:
            sample_choices = []
            sum_entropy = []
            sample_cluster = weighted_srs_wr(self.cluster_labels, self.choosen, num_clients)
            for cluster in sample_cluster:
                self.choosen[self.cluster_labels.index(cluster)]+=1
            for i in range(1):
                ss = []
                sum_ent = 0
                for cluster in sample_cluster:
                    client = random.choice(cluster)
                    ss.append(client)
                    sum_ent += entropies[client]
                sum_entropy.append(sum_ent)
                sample_choices.append(ss)
            sampled_cids = sample_choices[np.argmax(sum_entropy)]
            sampled_cids = [str(cid) for cid in sampled_cids]
        return [self.clients[cid] for cid in sampled_cids]

In [None]:
class FedImp(FedAvg):
    def __init__(
        self,
        *,
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 1.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
        evaluate_fn: Optional[
            Callable[
                [int, NDArrays, Dict[str, Scalar]],
                Optional[Tuple[float, Dict[str, Scalar]]],
            ]
        ] = None,
        on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        accept_failures: bool = True,
        initial_parameters: Optional[Parameters] = None,
        fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        tem =0.7
    
    ) -> None:
        super().__init__()
        self.fraction_fit = fraction_fit
        self.fraction_evaluate = fraction_evaluate
        self.min_fit_clients = min_fit_clients
        self.min_evaluate_clients = min_evaluate_clients
        self.min_available_clients = min_available_clients
        self.evaluate_fn = evaluate_fn
        self.on_fit_config_fn = on_fit_config_fn
        self.on_evaluate_config_fn = on_evaluate_config_fn
        self.accept_failures = accept_failures
        self.initial_parameters = initial_parameters
        self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn
        self.evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn
        self.learning_rate = 0.1
        self.decay = 0.995
        self.tem = tem

        
    def configure_fit(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, FitIns]]:
        config = {"learning_rate": self.learning_rate} 
        
        if self.on_fit_config_fn is not None:
            # Custom fit config function provided
            config = self.on_fit_config_fn(server_round)
        fit_ins = FitIns(parameters, config)
        self.learning_rate*=self.decay
        # Sample clients
        sample_size, min_num_clients = self.num_fit_clients(
            client_manager.num_available()
        )
        clients = client_manager.sample(
            num_clients=sample_size, min_num_clients=min_num_clients
        )

        return [(client, fit_ins) for client in clients]
    
    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        if not results:
            return None, {}
        if not self.accept_failures and failures:
            return None, {}
        weights_results = [
                (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples*np.exp(fit_res.metrics["entropy"]/self.tem))
            
                for _, fit_res in results
            ]

        aggregated_ndarrays = aggregate(weights_results)
        parameters_aggregated = ndarrays_to_parameters(aggregated_ndarrays)
        metrics_aggregated = {}
        return parameters_aggregated, metrics_aggregated

In [None]:
from sklearn.cluster import OPTICS

def split_clusters(labels):
    clusters = {}
    for i, label in enumerate(labels):
        if label not in clusters:
            clusters[label] = [i]
        else:
            clusters[label].append(i)
    return clusters

def hellinger_distance(p, q):
    return np.sqrt(0.5 * ((np.sqrt(p) - np.sqrt(q)) ** 2).sum())

def compute_hellinger_distance_matrix(distributions):
    n = len(distributions)
    distances = np.zeros((n, n))
    for i in range(n):
        for j in range(i+1, n):
            distances[i, j] = hellinger_distance(distributions[i], distributions[j])
            distances[j, i] = distances[i, j]
    return distances

def to_prob_dist(data):
    return data / np.sum(data, axis=1, keepdims=True)

def kmeans_no_small_clusters(data):
    counts = [dict(sorted(Counter(d).items())) for d in list_y_train]

    counts = [list(c.values()) for c in counts]
    counts = [0 if value is None else value for value in counts]
    counts = to_prob_dist(counts)
    distance_matrix = compute_hellinger_distance_matrix(counts)

    clustering = OPTICS(min_samples=2,
                  metric="precomputed").fit(distance_matrix)
    labels = split_clusters(clustering.labels_)
    labels = dict(sorted(labels.items(), key=lambda item: len(item[1])))
    new_dict = labels.copy()
    del new_dict[-1]

    new_key = max(new_dict.keys()) + 1
    for index, value in enumerate(labels[-1]):
        while new_key in new_dict:
            new_key += 1
        new_dict[new_key] = [value]
        new_key += 1
    return new_dict

labels = kmeans_no_small_clusters(list_y_train).values()

print(labels)

new_label = []
for label in labels:
    new_label.append(label)

print(new_label)

In [None]:
def evaluate_DNN_CL(
    server_round: int,
    parameters: fl.common.NDArrays,
    config: Dict[str, fl.common.Scalar],
) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
    net = get_model()
    net.set_weights(parameters)
    loss, accuracy, precision, recall, f1score  = test_model(net, test_images, test_labels)
    return loss, {"accuracy": accuracy,"precision": precision,"recall": recall,"f1score": f1score}

In [None]:
comms_round = 1000

def client_fn(cid: str) -> fl.client.Client:
    # Define model
    model = get_model()

    x_train_cid = np.array(list_x_train[int(cid)],dtype=float)
    y_train_cid = np.array(list_y_train[int(cid)],dtype=int)
    return FlowerClient(model, x_train_cid, y_train_cid, cid)

strategy=FedImp(
        fraction_fit=0.2,  
        fraction_evaluate=0,  
        min_fit_clients=8,
        min_available_clients = 40,
        evaluate_fn=evaluate_DNN_CL,
)

clientmanager = SimpleClientManager(cluster_labels = new_label, entropies=entropies)

history = fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=local_nodes_glob,
    config=fl.server.ServerConfig(num_rounds=comms_round),
    strategy=strategy,
    client_manager = clientmanager,
    client_resources = {'num_cpus': 1, 'num_gpus': 0},
)

In [None]:
import json
with open(f'HICFL30.txt', 'w') as f:
    # Write some content to the file
    json.dump(history.metrics_centralized["accuracy"], f)

In [None]:
print(history.metrics_centralized["precision"][-1])
print(history.metrics_centralized["recall"][-1])
print(history.metrics_centralized["f1score"][-1])