# 针对 聚类联邦学习 实验

In [None]:
import sys
import os
import random
import math
from typing import List
from sklearn.cluster import AgglomerativeClustering
import numpy as np
import torch

from fedlab.utils.serialization import SerializationTool
from fedlab.utils.aggregator import Aggregators

# 将项目根目录加入环境变量
PROJECT_DIR = os.path.dirname(os.getcwd())
sys.path.append(PROJECT_DIR)
print(PROJECT_DIR)

from utils import read_options
from client_utils import detail_evaluate

In [None]:
config, cluster_partitioner, model = read_options()
random.seed(config['seed'])

In [None]:
num_client = config['num_client']
num_classes = config['num_classes']

num_round = config['num_round']

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
client_sample_stream = [
    random.sample(
        range(num_client), max(1, int(num_client * 0.2))
    )
    for _ in range(num_round)
]

In [None]:
def model_init():
    client_models: List[torch.nn.Module] = [model() for _ in range(num_client)]
    client_optimizers: List[torch.optim.SGD] = [torch.optim.SGD(client_models[i].parameters(), lr=config['lr']) for i in range(num_client)]
    client_criteria: List[torch.nn.CrossEntropyLoss] = [torch.nn.CrossEntropyLoss() for _ in range(num_client)]

    global_model: torch.nn.Module = model()
    global_optimizer: torch.optim.SGD = torch.optim.SGD(global_model.parameters(), lr=config['lr'])
    global_criterion: torch.nn.CrossEntropyLoss = torch.nn.CrossEntropyLoss()
    
    return client_models, client_optimizers, client_criteria, global_model, global_optimizer, global_criterion

def train(cid: int, model: torch.nn.Module, optimizer: torch.optim.SGD, criterion: torch.nn.CrossEntropyLoss):
    model.to(device)
    model.train()
    train_loader = cluster_partitioner.get_dataloader(cid, config['local_bs'])
    for _ in range(config['local_ep']):
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

## Clustered Federated Learning: Model-Agnostic Distributed Multitask Optimization Under Privacy Constraints

CFL 算法过程测试

结果：判断是否二分的条件中存在人工给定的两个超参数, 目前使用原论文所给代码的原始超参数的情况下, 始终无法达到对应的 if 条件

In [None]:
def CFL_test():
    client_models, client_optimizers, client_criteria, global_model, _, _ = model_init()
    client_cluster = [list(range(num_client))]

    for current_round in range(num_round):
        
        if current_round == 0:
            for cid in range(num_client):
                client_models[cid].load_state_dict(global_model.state_dict())
        
        grads_list: List[torch.Tensor] = [None] * num_client
        grads_list = np.array(grads_list)
        selected_clients: List[int] = client_sample_stream[current_round]
        for client_id in selected_clients:
            
            local_model: torch.nn.Module = client_models[client_id]
            local_optimizer = client_optimizers[client_id]
            local_criterion = client_criteria[client_id]
            
            param_before = SerializationTool.serialize_model(local_model).detach()
            train(client_id, local_model, local_optimizer, local_criterion)
            param_after = SerializationTool.serialize_model(local_model).detach()
            
            grads = param_before - param_after
            grads_list[client_id] = grads
        
        similarity_matrix = torch.zeros((num_client, num_client))
        for i, grads_i in enumerate(grads_list):
            for j, grads_j in enumerate(grads_list[i + 1:], i + 1):
                if grads_i is not None and grads_j is not None:
                    similarity_score = torch.cosine_similarity(
                        grads_i, grads_j, dim=0, eps=1e-12
                    ).item()
                    similarity_matrix[i, j] = similarity_score
                    similarity_matrix[j, i] = similarity_score
        similarity_matrix = similarity_matrix.numpy()
        print("similarity_matrix: \n{}".format(similarity_matrix[selected_clients][:, selected_clients]))
        
        client_cluster_new = []
        for indices in client_cluster:
            max_norm = np.max([torch.norm(grads_list[i]).item() for i in indices if grads_list[i] is not None])
            mean_norm = torch.norm(torch.mean(torch.stack([grads_list[i] for i in indices if grads_list[i] is not None]), dim=0)).item()
            print("max_norm: {}, mean_norm: {}, len(cluster): {}, current round: {}".format(round(max_norm, 4), round(mean_norm, 4), len(indices), current_round))
            if mean_norm < 0.4 and max_norm > 1.6 and len(indices) > 2 and current_round > 20:
                clustering = AgglomerativeClustering(
                    affinity='precomputed', linkage='complete'
                ).fit(-similarity_matrix[indices][:, indices])
                cluster_1 = np.argwhere(clustering.labels_ == 0).flatten()
                cluster_2 = np.argwhere(clustering.labels_ == 1).flatten()
                print("cluster {} is split into {} and {}".format(indices, cluster_1, cluster_2))
                client_cluster_new += [cluster_1, cluster_2]
            else:
                client_cluster_new += [indices]

        client_cluster = client_cluster_new
        
        for cluster in client_cluster:
            average_grads = Aggregators.fedavg_aggregate([grads_list[cid] for cid in cluster if grads_list[cid] is not None])
            for cid in cluster:
                client_models[cid].to('cpu')
                SerializationTool.deserialize_model(client_models[cid], average_grads, mode='add')
        print("round {} is finished".format(current_round))

CFL_test()

## A Greedy Agglomerative Framework for Clustered Federated Learning

Federated Learning via Agglomerative Client Clustering (FLACC) 算法

效果：有一定的聚类效果, 但是效果依旧不明显

思考：能不能将聚合的判断条件类比到基于密度的聚合原理中, 利用其中所谓的临近节点等概念

In [None]:
def FLACC_test():
    client_models, client_optimizers, client_criteria, global_model, _, _ = model_init()
    similarity_matrix = np.full((num_client, num_client), -1.0, dtype=np.float32)
    memory_matrix = np.full((num_client, num_client), 0, dtype=np.int32)    
    client_cluster, model_cluster = [[i] for i in range(num_client)], [model() for i in range(num_client)]
    client_cluster_model = {i: model_cluster[i] for i in range(num_client)}
    
    no_merge_count = 0
    
    for current_round in range(num_round):
        grads_list: List[torch.Tensor] = [None] * num_client
        selected_clients = client_sample_stream[current_round]
        print("selected clients: {}".format(sorted(selected_clients)))
        if no_merge_count < 10:
            for cid in selected_clients:
                local_model: torch.nn.Module = client_models[cid]
                local_model.load_state_dict(global_model.state_dict())
                local_optimizer = client_optimizers[cid]
                local_criterion = client_criteria[cid]
                
                param_before = SerializationTool.serialize_model(local_model).detach()
                train(cid, local_model, local_optimizer, local_criterion)
                param_after = SerializationTool.serialize_model(local_model).detach()
                
                grads = param_before - param_after
                grads_list[cid] = grads
            
            for i, cid_i in enumerate(selected_clients):
                for j, cid_j in enumerate(selected_clients[i+1:], i+1):
                    if grads_list[cid_i] is not None and grads_list[cid_j] is not None:
                        similarity_score = torch.cosine_similarity(
                            grads_list[cid_i], grads_list[cid_j], dim=0, eps=1e-12
                        ).item()
                        similarity_matrix[cid_i, cid_j] = similarity_score
                        similarity_matrix[cid_j, cid_i] = similarity_score
                        
                        memory_matrix[cid_i, cid_j] = current_round
                        memory_matrix[cid_j, cid_i] = current_round
            
            max_min_similarity = float('-inf')
            cluster1_index, cluster2_index = -1, -1
            cross_max_similarity = None
            within_min_similarity = None
            cross_min_similarity = None
            for i, cluster_i in enumerate(client_cluster):
                for j, cluster_j in enumerate(client_cluster[i + 1: ], i + 1):
                    
                    if len(cluster_i) == 1 and len(cluster_j) == 1:
                        min_similarity = similarity_matrix[cluster_i[0], cluster_j[0]]
                    else:
                        mask = similarity_matrix[cluster_i][:, cluster_j] > -1.0
                        if np.any(mask):
                            min_similarity = np.min(similarity_matrix[cluster_i][:, cluster_j][mask])
                        else:
                            continue
                    
                    if min_similarity > max_min_similarity:
                        
                        if len(cluster_i) > 1 and len(cluster_j) > 1:
                            mask1 = similarity_matrix[cluster_i][:, cluster_i] > -1.0
                            mask2 = similarity_matrix[cluster_j][:, cluster_j] > -1.0
                            min_cluster_i = np.min(similarity_matrix[cluster_i][:, cluster_i][mask1]) if np.any(mask1) else float('inf')
                            min_cluster_j = np.min(similarity_matrix[cluster_j][:, cluster_j][mask2]) if np.any(mask2) else float('inf')
                            if math.isinf(min_cluster_i) and math.isinf(min_cluster_j):
                                continue
                            within_min_similarity = min(min_cluster_i, min_cluster_j)

                        max_min_similarity = min_similarity
                        cluster1_index, cluster2_index = i, j
                        cross_max_similarity =  np.max(similarity_matrix[cluster_i][:, cluster_j])
                        cross_min_similarity =  min_similarity
            
            print("cross_min_similarity: {}, cross_max_similarity: {}, within_min_similarity: {}".format(
                round(float(cross_min_similarity), 4), round(float(cross_max_similarity), 4), within_min_similarity
            ))
            if cross_min_similarity > 0 and (within_min_similarity is None or cross_max_similarity > within_min_similarity):
                print("merge {} and {}".format(client_cluster[cluster1_index], client_cluster[cluster2_index]))
                cluster1 = client_cluster[cluster1_index]
                cluster2 = client_cluster[cluster2_index]
                client_cluster.remove(cluster1)
                client_cluster.remove(cluster2)
                client_cluster.append(cluster1 + cluster2)
                
                model1 = model_cluster[cluster1_index]
                model2 = model_cluster[cluster2_index]
                model_cluster.remove(model1)
                model_cluster.remove(model2)
                model_param_lsit = [SerializationTool.serialize_model(client_models[cid]) for cid in cluster1 + cluster2]
                avg_param = Aggregators.fedavg_aggregate(model_param_lsit)
                cluster_model = model()
                SerializationTool.deserialize_model(cluster_model, avg_param)
                model_cluster.append(cluster_model)
                
                for cid in cluster1 + cluster2:
                    client_cluster_model[cid] = cluster_model

                no_merge_count = 0
            
            else:
                no_merge_count += 1
            
            out_memory_indices = np.where((current_round - memory_matrix) > 10)
            similarity_matrix[out_memory_indices] = -1
            
            selected_clients_param_list = [SerializationTool.serialize_model(client_models[cid]) for cid in selected_clients]
            selected_clients_avg_param = Aggregators.fedavg_aggregate(selected_clients_param_list)
            SerializationTool.deserialize_model(global_model, selected_clients_avg_param)
        
        else:
            for cid in selected_clients:
                local_model: torch.nn.Module = client_models[cid]
                local_model.load_state_dict(client_cluster_model[cid].state_dict())
                local_optimizer = client_optimizers[cid]
                local_criterion = client_criteria[cid]
                
                train(cid, local_model, local_optimizer, local_criterion)
            
            selected_clients_set = set(selected_clients)
            for index, cluster in enumerate(client_cluster):
                cluster_set = set(cluster)
                intersection = selected_clients_set.intersection(cluster_set)
                if len(intersection) > 0:
                    print("update cluster {} using clients {}".format(cluster, intersection))
                    param_list = [SerializationTool.serialize_model(client_models[cid]) for cid in intersection]
                    avg_param = Aggregators.fedavg_aggregate(param_list)
                    SerializationTool.deserialize_model(model_cluster[index], avg_param)

    # test
    global_eval = np.zeros(num_classes + 1)
    for cluster_cids, cluster_model in zip(client_cluster, model_cluster):
        test_loader = cluster_partitioner.get_cluster_dataloader(cluster_cids, config['local_bs'], "test")
        result = np.array(detail_evaluate(cluster_model, torch.nn.CrossEntropyLoss(), test_loader, num_classes))
        global_eval += result
    global_eval /= len(client_cluster)
    global_eval = np.round(global_eval, 4)
    print("global_eval: {}".format(global_eval.tolist()))
        

FLACC_test()

## 基于 FLACC 的改进
使用二分进一步验证划分的可靠性

In [None]:
def evaluate_clusters(client_cluster, model_cluster):
    print("evaluate_clusters:")
    num_classes = config['num_classes']
    global_eval = np.zeros(num_classes + 1)
    for cluster_cids, cluster_model in zip(client_cluster, model_cluster):
        print("cluster_cids: {}".format(cluster_cids))
        test_loader = cluster_partitioner.get_cluster_dataloader(cluster_cids, config['local_bs'], "test")
        result = np.array(detail_evaluate(cluster_model, torch.nn.CrossEntropyLoss(), test_loader, num_classes))
        global_eval += result
    global_eval /= len(client_cluster)
    global_eval = np.round(global_eval, 4)
    print("global_eval: {}".format(global_eval.tolist()))
    return global_eval

def create_cluster_model(cluster: List[int], client_models: List[torch.nn.Module]):
    if len(cluster) == 1:
        return client_models[cluster[0]]
    model_param_lsit = [SerializationTool.serialize_model(client_models[cid]) for cid in cluster]
    avg_param = Aggregators.fedavg_aggregate(model_param_lsit)
    cluster_model = model()
    SerializationTool.deserialize_model(cluster_model, avg_param)
    return cluster_model

def remove_clients_from_cluster(client_cluster, model_cluster, similarity_matrix, client_models):
    client_cluster_new = []
    model_cluster_new = []
    for index, cluster in enumerate(client_cluster):
        if len(cluster) < 2:
            client_cluster_new.append(cluster)
            model_cluster_new.append(model_cluster[index])
            continue
        cluster_similarity_matrix = similarity_matrix[cluster, :][:, cluster]
        mask = np.where((cluster_similarity_matrix > -1.0) & (cluster_similarity_matrix < 0.0))
        if not np.any(mask):
            client_cluster_new.append(cluster)
            model_cluster_new.append(model_cluster[index])
            continue
        outlier_indices = np.unique(mask)
        print("remove {} from {}".format(outlier_clients, cluster))

        outlier_clients = []
        reminder_clients = []
        for cid in cluster:
            if cid in outlier_indices:
                outlier_clients.append(cid)
            else:
                reminder_clients.append(cid)
        if len(reminder_clients) > 0:
            client_cluster_new.append(reminder_clients)
            model_cluster_new.append(create_cluster_model(reminder_clients, client_models))
        for cid in outlier_clients:
            client_cluster_new.append([cid])
            model_cluster_new.append(client_models[cid])
        

        # min_similarity = np.min(cluster_similarity_matrix[mask])
        # if min_similarity < 0:
        #     clustering = AgglomerativeClustering(
        #         affinity="precomputed", linkage="complete"
        #     ).fit(-cluster_similarity_matrix)
        #     cluster_1 = [cluster[i] for i in np.argwhere(clustering.labels_ == 0).flatten()]
        #     cluster_2 = [cluster[i] for i in np.argwhere(clustering.labels_ == 1).flatten()]
        #     if len(cluster_1) == 0 or len(cluster_2) == 0:
        #         client_cluster_new.append(cluster)
        #         model_cluster_new.append(model_cluster[index])
        #     else:
        #         print("split {} into {} and {}".format(cluster, cluster_1, cluster_2))
        #         model_1 = create_cluster_model(cluster_1, client_models)
        #         model_2 = create_cluster_model(cluster_2, client_models)
        #         client_cluster_new += [cluster_1, cluster_2]
        #         model_cluster_new += [model_1, model_2]
        # else:
        #     client_cluster_new.append(cluster)
        #     model_cluster_new.append(model_cluster[index])
    return client_cluster_new, model_cluster_new

In [None]:
def improve_test():
    client_models, client_optimizers, client_criteria, global_model, _, _ = model_init()
    similarity_matrix = np.full((num_client, num_client), -1.0, dtype=np.float32)
    memory_matrix = np.full((num_client, num_client), 0, dtype=np.int32)
    client_cluster = [[i] for i in range(num_client)]
    model_cluster = [model() for i in range(num_client)]
    client_cluster_model = {i: model_cluster[i] for i in range(num_client)}

    no_merge_count = 0

    for current_round in range(num_round):
        grads_list: List[torch.Tensor] = [None] * num_client
        selected_clients = client_sample_stream[current_round]
        if no_merge_count < 10:
            for cid in selected_clients:
                local_model: torch.nn.Module = client_models[cid]
                local_model.load_state_dict(global_model.state_dict())
                local_optimizer = client_optimizers[cid]
                local_criterion = client_criteria[cid]
                
                param_before = SerializationTool.serialize_model(local_model).detach()
                train(cid, local_model, local_optimizer, local_criterion)
                param_after = SerializationTool.serialize_model(local_model).detach()
                
                grads = param_before - param_after
                grads_list[cid] = grads
            
            for i, cid_i in enumerate(selected_clients):
                for j, cid_j in enumerate(selected_clients[i+1:], i+1):
                    if grads_list[cid_i] is not None and grads_list[cid_j] is not None:
                        similarity_score = torch.cosine_similarity(
                            grads_list[cid_i], grads_list[cid_j], dim=0, eps=1e-12
                        ).item()
                        similarity_matrix[cid_i, cid_j] = similarity_score
                        similarity_matrix[cid_j, cid_i] = similarity_score

                        memory_matrix[cid_i, cid_j] = current_round
                        memory_matrix[cid_j, cid_i] = current_round
                        
            # 从 cluster 中删除 client (二分)
            client_cluster, model_cluster = remove_clients_from_cluster(
                client_cluster, model_cluster, similarity_matrix, client_models
            )
            
            # 合并 cluster
            max_min_similarity = float("-inf")
            cluster1_index, cluster2_index = -1, -1
            cross_max_similarity = None
            within_min_similarity = None
            cross_min_similarity = None
            for i, cluster_i in enumerate(client_cluster):
                for j, cluster_j in enumerate(client_cluster[i+1:], i+1):
                    
                    if len(cluster_i) == 1 and len(cluster_j) == 1:
                        min_similarity = similarity_matrix[cluster_i[0], cluster_j[0]]
                    else:
                        mask = similarity_matrix[cluster_i][:, cluster_j] > -1.0
                        if not np.any(mask):
                            continue
                        min_similarity = np.min(similarity_matrix[cluster_i][:, cluster_j][mask])
                    
                    if min_similarity <= max_min_similarity:
                        continue

                    if len(cluster_i) > 1 and len(cluster_j) > 1:
                        mask1 = similarity_matrix[cluster_i][:, cluster_i] > -1.0
                        mask2 = similarity_matrix[cluster_j][:, cluster_j] > -1.0
                        min_cluster_i = np.min(similarity_matrix[cluster_i][:, cluster_i][mask1]) if np.any(mask1) else float('inf')
                        min_cluster_j = np.min(similarity_matrix[cluster_j][:, cluster_j][mask2]) if np.any(mask2) else float('inf')
                        if math.isinf(min_cluster_i) and math.isinf(min_cluster_j):
                            continue
                        within_min_similarity = min(min_cluster_i, min_cluster_j)

                    max_min_similarity = min_similarity
                    cluster1_index, cluster2_index = i, j
                    cross_max_similarity =  np.max(similarity_matrix[cluster_i][:, cluster_j])
                    cross_min_similarity =  min_similarity
                
            print("cross_min_similarity: {}, cross_max_similarity: {}, within_min_similarity: {}".format(
                round(float(cross_min_similarity), 4), round(float(cross_max_similarity), 4), within_min_similarity
            ))
            if cross_min_similarity > 0 and (within_min_similarity is None or cross_max_similarity > within_min_similarity):
                print("merge {} and {}".format(client_cluster[cluster1_index], client_cluster[cluster2_index]))
                cluster1 = client_cluster[cluster1_index]
                cluster2 = client_cluster[cluster2_index]
                client_cluster.remove(cluster1)
                client_cluster.remove(cluster2)
                client_cluster.append(cluster1 + cluster2)
                
                model1 = model_cluster[cluster1_index]
                model2 = model_cluster[cluster2_index]
                model_cluster.remove(model1)
                model_cluster.remove(model2)
                cluster_model = create_cluster_model(cluster1 + cluster2, client_models)
                model_cluster.append(cluster_model)
                
                for cid in cluster1 + cluster2:
                    client_cluster_model[cid] = cluster_model

                no_merge_count = 0
            
            else:
                no_merge_count += 1

            out_memory_indices = np.where((current_round - memory_matrix) > 10)
            similarity_matrix[out_memory_indices] = -1
            
            selected_clients_param_list = [SerializationTool.serialize_model(client_models[cid]) for cid in selected_clients]
            selected_clients_avg_param = Aggregators.fedavg_aggregate(selected_clients_param_list)
            SerializationTool.deserialize_model(global_model, selected_clients_avg_param)
        
        else:
            for cid in selected_clients:
                local_model: torch.nn.Module = client_models[cid]
                local_model.load_state_dict(client_cluster_model[cid].state_dict())
                local_optimizer = client_optimizers[cid]
                local_criterion = client_criteria[cid]
                
                train(cid, local_model, local_optimizer, local_criterion)
            
            selected_clients_set = set(selected_clients)
            for index, cluster in enumerate(client_cluster):
                cluster_set = set(cluster)
                intersection = selected_clients_set.intersection(cluster_set)
                if len(intersection) > 0:
                    print("update cluster {} using clients {}".format(cluster, intersection))
                    param_list = [SerializationTool.serialize_model(client_models[cid]) for cid in intersection]
                    avg_param = Aggregators.fedavg_aggregate(param_list)
                    SerializationTool.deserialize_model(model_cluster[index], avg_param)

improve_test()