In [24]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
from torch import nn
from torch.utils.data import DataLoader, Dataset, TensorDataset
from tqdm import tqdm
import copy
import matplotlib.pyplot as plt
from collections import OrderedDict
import random
import torch.nn.functional as F
import torch.nn.functional as func
import collections
import config
from sklearn.model_selection import train_test_split
from collections import Counter
from config import *

In [25]:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True' # 解决由于多次加载 OpenMP 相关动态库而引起的冲突

In [26]:
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Tue Dec 17 17:06:55 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        Off | 00000000:21:00.0 Off |                  Off |
| 35%   45C    P2              74W / 350W |   2403MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce RTX 4090        Off | 00000000:E1:00.0 Off |  

In [27]:
# MobileNetV2（比lenet更复杂的CNN网络）网络中的线性瓶颈结构，原文中用于CIFAR-100任务
class LinearBottleNeck(nn.Module):

    def __init__(self, in_channels, out_channels, stride, t=6, class_num=100):
        super().__init__()

        self.residual = nn.Sequential(
            nn.Conv2d(in_channels, in_channels * t, 1),
            nn.BatchNorm2d(in_channels * t),
            nn.ReLU6(inplace=True),

            nn.Conv2d(in_channels * t, in_channels * t, 3, stride=stride, padding=1, groups=in_channels * t),
            nn.BatchNorm2d(in_channels * t),
            nn.ReLU6(inplace=True),

            nn.Conv2d(in_channels * t, out_channels, 1),
            nn.BatchNorm2d(out_channels)
        )

        self.stride = stride
        self.in_channels = in_channels
        self.out_channels = out_channels

    def forward(self, x):

        residual = self.residual(x)

        if self.stride == 1 and self.in_channels == self.out_channels:
            residual += x

        return residual

class MobileNetV2(nn.Module):

    def __init__(self, class_num=20):
        super().__init__()

        self.pre = nn.Sequential(
            nn.Conv2d(3, 32, 1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU6(inplace=True)
        )

        self.stage1 = LinearBottleNeck(32, 16, 1, 1)
        self.stage2 = self._make_stage(2, 16, 24, 2, 6)
        self.stage3 = self._make_stage(3, 24, 32, 2, 6)
        self.stage4 = self._make_stage(4, 32, 64, 2, 6)
        self.stage5 = self._make_stage(3, 64, 96, 1, 6)
        self.stage6 = self._make_stage(3, 96, 160, 1, 6)
        self.stage7 = LinearBottleNeck(160, 320, 1, 6)

        self.conv1 = nn.Sequential(
            nn.Conv2d(320, 1280, 1),
            nn.BatchNorm2d(1280),
            nn.ReLU6(inplace=True)
        )

        self.conv2 = nn.Conv2d(1280, class_num, 1)

    def forward(self, x):
        x = self.pre(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.stage5(x)
        x = self.stage6(x)
        x = self.stage7(x)
        x = self.conv1(x)
        x = F.adaptive_avg_pool2d(x, 1)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)

        return x

    def _make_stage(self, repeat, in_channels, out_channels, stride, t):

        layers = []
        layers.append(LinearBottleNeck(in_channels, out_channels, stride, t))

        while repeat - 1:
            layers.append(LinearBottleNeck(out_channels, out_channels, 1, t))
            repeat -= 1

        return nn.Sequential(*layers)

def mobilenetv2():
    return MobileNetV2()

In [28]:
def test_inference(model, test):
    """ Returns the test accuracy and loss.
    """
    tensor_x = torch.Tensor(test[0]).to(device)
    tensor_y = torch.Tensor(test[1]).to(device)
    test_dataset = TensorDataset(tensor_x, tensor_y)

    model.eval()
    loss, total, correct = 0.0, 0.0, 0.0

    criterion = nn.CrossEntropyLoss()
    testloader = DataLoader(test_dataset, batch_size=128,
                            shuffle=True)

    for batch_idx, (images, labels) in enumerate(testloader):
        with torch.no_grad():  # 在测试过程中不需要计算梯度，节省内存和加速计算
        # Inference
            outputs = model(images)
            batch_loss = criterion(outputs, labels.long())
            loss += batch_loss.item() * labels.size(0) # 计算损失值，更好反映模型输出概率分布与真实标签的差距

        # Prediction
            _, pred_labels = torch.max(outputs, 1)
            pred_labels = pred_labels.view(-1)
            correct += torch.sum(torch.eq(pred_labels, labels)).item()
            total += len(labels)
    #print(correct,"/",total)
    accuracy = correct/total
    return accuracy, loss

In [29]:
# 将CIFAR-100的100个类别转为20个类别（粒度更粗，降低任务复杂度）
def sparse2coarse(targets):
    """Convert Pytorch CIFAR100 sparse targets to coarse targets.

    Usage:
        trainset = torchvision.datasets.CIFAR100(path)
        trainset.targets = sparse2coarse(trainset.targets)
    """
    coarse_labels = np.array([ 4,  1, 14,  8,  0,  6,  7,  7, 18,  3,
                               3, 14,  9, 18,  7, 11,  3,  9,  7, 11,
                               6, 11,  5, 10,  7,  6, 13, 15,  3, 15,
                               0, 11,  1, 10, 12, 14, 16,  9, 11,  5,
                               5, 19,  8,  8, 15, 13, 14, 17, 18, 10,
                               16, 4, 17,  4,  2,  0, 17,  4, 18, 17,
                               10, 3,  2, 12, 12, 16, 12,  1,  9, 19,
                               2, 10,  0,  1, 16, 12,  9, 13, 15, 13,
                              16, 19,  2,  4,  6, 19,  5,  5,  8, 19,
                              18,  1,  2, 15,  6,  0, 17,  8, 14, 13])
    return coarse_labels[targets]

In [30]:
# 共有6w个图像，其中5w训练，1w测试
def CIFAR100():
    '''Return Cifar100
    '''
    train_dataset = torchvision.datasets.CIFAR100(root='../data/CIFAR-100',
                                            train=True,
                                            transform=transforms.ToTensor(),
                                            download=True)
    test_dataset = torchvision.datasets.CIFAR100(root='../data/CIFAR-100',
                                            train=False,
                                            transform=transforms.ToTensor(),
                                            download=True)
    total_img,total_label = [],[]
    for imgs,labels in train_dataset:
        total_img.append(imgs.numpy())
        total_label.append(labels)
    for imgs,labels in test_dataset:
        total_img.append(imgs.numpy())
        total_label.append(labels) 
    total_img = np.array(total_img)
    total_label = np.array(sparse2coarse(total_label))

    cifar = [total_img, total_label]
    return cifar


In [31]:
# 基于 Dirichlet 分布 来模拟non-IID。返回一个形状为 (client_num, class_num) 的概率矩阵，每一行代表一个客户端对各类别的概率分布。
def get_prob(non_iid, client_num, class_num = 20):
    # Modify：我之后加上的
    if data_random_fix:
        np.random.seed(seed_num)  # 固定种子，确保数据抽样一致
    
    return np.random.dirichlet(np.repeat(non_iid, class_num), client_num)

In [32]:
def create_data(prob, size_per_client, dataset, N=20):
    total_each_class = size_per_client * np.sum(prob, 0)
    data, label = dataset

    # Modify：我之后加上的
    if data_random_fix:
        np.random.seed(seed_num)  # 固定种子，确保数据抽样一致
        random.seed(seed_num)

    # 为每个类别随机采样数据
    all_class_set = []
    for i in range(N):
        size = total_each_class[i]
        sub_data = data[label == i]
        sub_label = label[label == i]

        rand_indx = np.random.choice(len(sub_data), size=int(size), replace=False).astype(int)
        sub2_data, sub2_label = sub_data[rand_indx], sub_label[rand_indx]
        all_class_set.append((sub2_data, sub2_label))

    index = [0] * N
    clients, test = [], []

    for m in range(prob.shape[0]):  # 遍历客户端
        labels, images = [], []  # 训练数据
        tlabels, timages = [], [] # 测试数据

        # TODO_241216：这里每个client的测试集和它的训练集分布相同，并且最后测试时，也是计算所有client中的准确率的平均值
        # TODO_241216：别的FL方法也是这样做的吗？我也要这样做吗？
        for n in range(N):
            # 80%用于训练，20%用于测试
            # 这里的int向下取整，会导致实际的数据量比计算略小
            start, end = index[n], index[n] + int(prob[m][n] * size_per_client * 0.8)
            test_start, test_end = end, index[n] + int(prob[m][n] * size_per_client)

            image, label = all_class_set[n][0][start:end], all_class_set[n][1][start:end]
            test_image, test_label = all_class_set[n][0][test_start:test_end], all_class_set[n][1][test_start:test_end]

            # 记录当前类别的数据分配进度
            index[n] += int(prob[m][n] * size_per_client)

            labels.extend(label)
            images.extend(image)

            tlabels.extend(test_label)
            timages.extend(test_image)

        clients.append((np.array(images), np.array(labels)))
        test.append((np.array(timages), np.array(tlabels)))

    return clients, test

In [33]:

# 合并所有客户端的测试数据 （上面讲测试数据分成了不同的客户端）
# 但并没有使用，用途不明
def comb_client_test_func(client_test_data):
    comb_client_test_image = []
    comb_client_test_label = []
    for i in range(client_num):
        comb_client_test_image.extend(list(client_test_data[i][0]))
        comb_client_test_label.extend(list(client_test_data[i][1]))
    
    # 将测试图片和标签合并为 numpy 数组
    comb_client_test_image = np.array(comb_client_test_image)
    comb_client_test_label = np.array(comb_client_test_label)
    
    label_count = Counter(comb_client_test_label)
    print("测试集类别分布：")
    for label, count in sorted(label_count.items()):
        print(f"类别 {label}: {count} 个样本")
    
    return [comb_client_test_image, comb_client_test_label]

In [34]:
# 从数据集中按类别均匀抽取子集，并按照指定的比例 percentage 进行缩减，同时对数据进行随机打乱
def select_subset(whole_set, percentage):
    
    # Modify：我之后加上的
    if data_random_fix:
        np.random.seed(seed_num)  # 固定种子，确保数据抽样一致
        random.seed(seed_num)
    
    a = whole_set[0]
    b = whole_set[1]
    if len(a) != len(b):
        raise ValueError("Both arrays should have the same length.")

    if not 0 <= percentage <= 1:
        raise ValueError("Percentage must be between 0 and 1.")

    unique_classes = np.unique(b)

    a_prime = []
    b_prime = []

    for cls in unique_classes:
        indices = np.where(b == cls)[0]
        subset_size = int(len(indices) * percentage)

        selected_indices = np.random.choice(indices, subset_size, replace=False)

        a_prime.extend(a[selected_indices])
        b_prime.extend(b[selected_indices])

    a_prime, b_prime = np.array(a_prime), np.array(b_prime)

    # Shuffle arrays to randomize the order of elements
    shuffle_indices = np.random.permutation(len(a_prime))
    a_prime, b_prime = a_prime[shuffle_indices], b_prime[shuffle_indices]

    return [a_prime, b_prime]

In [35]:
# 准备数据集
# 这部分是我加的

cifar = CIFAR100()
prob = get_prob(non_iid, client_num, class_num=20)
client_data, client_test_data = create_data(prob, size_per_client, cifar, N=20)


all_images = []
all_labels = []
for data in client_data:
    all_images.extend(data[0])
    all_labels.extend(data[1])
comb_client_data = [np.array(all_images), np.array(all_labels)]

# 输出cpmb_client_data情况
imgs, lbls = comb_client_data
lbls = np.array(lbls)
total_count = len(lbls)
unique_classes, counts = np.unique(lbls, return_counts=True)

# 创建一个长度为20的数组记录各类别计数，默认0
class_counts = [0]*20
for cls, cnt in zip(unique_classes, counts):
    class_counts[cls] = cnt

# 打印格式：Total: 总数 类别0计数 类别1计数 ... 类别19计数
print("Traning Client Total: {}".format(" ".join([str(total_count)] + [str(c) for c in class_counts])))


# 打印每个客户端训练数据情况（只输出前10个）
for i, (imgs, lbls) in enumerate(client_data[:10]):
    lbls = np.array(lbls)
    total_count = len(lbls)
    unique_classes, counts = np.unique(lbls, return_counts=True)
    # 创建一个长度为20的数组记录各类别计数，默认0
    class_counts = [0]*20
    for cls, cnt in zip(unique_classes, counts):
        class_counts[cls] = cnt
    # 打印格式：Client i: 总数 类别0计数 类别1计数 ... 类别19计数
    print("Client {}: {}".format(i, " ".join([str(total_count)] + [str(c) for c in class_counts])))
    # 打印前5个数据和标签
    # print("  前5个标签: ", lbls[:5])
    # print("  前5个数据形状: ", [imgs[j].shape for j in range(min(5, len(imgs)))])
    # print()
    

# 打印每个客户端测试数据情况（只输出前10个）
for i, (imgs, lbls) in enumerate(client_test_data[:10]):
    lbls = np.array(lbls)
    total_count = len(lbls)
    unique_classes, counts = np.unique(lbls, return_counts=True)
    class_counts = [0]*20
    for cls, cnt in zip(unique_classes, counts):
        class_counts[cls] = cnt
    # 打印格式：Client i Test: 总数 类别0计数 类别1计数 ... 类别19计数
    print("Client {} Test: {}".format(i, " ".join([str(total_count)] + [str(c) for c in class_counts])))
    # 打印前5个数据和标签
    # print("  前5个标签: ", lbls[:5])
    # print("  前5个数据形状: ", [imgs[j].shape for j in range(min(5, len(imgs)))])
    # print()

# 提前生成固定的服务器数据
# Modify: 这是我后来修改的
server_data = select_subset(comb_client_data, server_percentage)

s_imgs, s_lbls = server_data
s_lbls = np.array(s_lbls)
total_count = len(s_lbls)
unique_classes, counts = np.unique(s_lbls, return_counts=True)
class_counts = [0]*20
for cls, cnt in zip(unique_classes, counts):
    class_counts[cls] = cnt

# 输出格式: Server round: 总数 类别0计数 类别1计数 ... 类别19计数
print("Server {}: {}".format(round, " ".join([str(total_count)] + [str(c) for c in class_counts])))
# print("  前5个标签: ", lbls[:5])
# print("  前5个数据形状: ", [server_data[0][j].shape for j in range(min(5, len(server_data[0])))])

10
False
Files already downloaded and verified
Files already downloaded and verified
Traning Client Total: 30228 1777 1490 1282 1448 1526 1590 1675 1443 1436 1555 1438 1543 1522 1447 1545 1304 1528 1563 1581 1535
Client 0: 153 0 17 1 6 22 0 5 0 29 0 0 1 1 14 0 10 12 4 21 10
Client 1: 152 0 1 0 3 0 1 0 12 10 3 0 40 3 11 39 0 26 0 3 0
Client 2: 151 1 2 0 0 1 8 1 18 7 13 0 2 6 2 18 18 10 1 42 1
Client 3: 152 0 45 4 3 4 0 3 2 0 9 2 3 0 1 1 0 19 26 26 4
Client 4: 150 10 0 2 1 27 1 0 1 9 0 5 12 12 38 2 8 0 0 3 19
Client 5: 151 12 1 2 10 20 5 5 3 41 17 0 8 3 2 0 1 3 0 8 10
Client 6: 152 0 5 0 9 1 66 1 3 9 0 2 12 0 24 1 1 15 0 1 2
Client 7: 155 0 1 23 5 1 5 15 0 32 2 0 35 1 0 20 0 2 0 13 0
Client 8: 152 5 1 8 2 1 26 13 21 0 0 5 2 0 0 21 0 0 13 34 0
Client 9: 152 4 8 6 3 33 2 35 0 0 8 2 0 0 10 10 0 0 9 12 10
Client 0 Test: 38 0 4 0 1 5 0 1 0 7 0 1 1 0 4 0 2 3 1 6 2
Client 1 Test: 39 0 0 1 0 0 1 0 4 3 1 0 10 0 2 10 0 6 0 1 0
Client 2 Test: 39 0 0 1 0 0 3 0 4 2 4 0 0 2 1 4 5 3 0 10 0
Client 3 Tes

In [36]:
# 本地训练并更新权重，返回更新后的模型权重、平均训练损失以及第一个迭代的梯度信息
def update_weights(model_weight, dataset, learning_rate, local_epoch):
    model = mobilenetv2().to(device)
    model.load_state_dict(model_weight)

    model.train()
    epoch_loss = []
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()

    Tensor_set = TensorDataset(torch.Tensor(dataset[0]).to(device), torch.Tensor(dataset[1]).to(device))
    data_loader = DataLoader(Tensor_set, batch_size=128, shuffle=True)

    first_iter_gradient = None  # 初始化变量来保存第一个iter的梯度

    for iter in range(local_epoch):
        batch_loss = []
        for batch_idx, (images, labels) in enumerate(data_loader):
            model.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels.long())
            loss.backward()
            optimizer.step()
            batch_loss.append(loss.item()/images.shape[0])

            # 保存第一个iter的梯度
            if iter == 0 and batch_idx == 0:
                first_iter_gradient = {}
                for name, param in model.named_parameters():
                    first_iter_gradient[name] = param.grad.clone()
                # 保存 BatchNorm 层的 running mean 和 running variance
                for name, module in model.named_modules():
                    if isinstance(module, nn.BatchNorm2d):
                        first_iter_gradient[name + '.running_mean'] = module.running_mean.clone()
                        first_iter_gradient[name + '.running_var'] = module.running_var.clone()

        epoch_loss.append(sum(batch_loss)/len(batch_loss))

    return model.state_dict(), sum(epoch_loss) / len(epoch_loss), first_iter_gradient

In [37]:
# 计算模型权重的差异，并根据学习率 lr 对权重差异进行缩放
def weight_differences(n_w, p_w, lr):
    w_diff = copy.deepcopy(n_w)
    for key in w_diff.keys():
        if 'num_batches_tracked' in key:
            continue
        w_diff[key] = (p_w[key] - n_w[key]) * lr
    return w_diff

In [38]:
# 也是本地训练，不过引入了本文的权重修正机制
def update_weights_correction(model_weight, dataset, learning_rate, local_epoch, c_i, c_s):
    model = mobilenetv2().to(device)
    model.load_state_dict(model_weight)

    model.train()
    epoch_loss = []
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()

    Tensor_set = TensorDataset(torch.Tensor(dataset[0]).to(device), torch.Tensor(dataset[1]).to(device))
    data_loader = DataLoader(Tensor_set, batch_size=200, shuffle=True)

    for iter in range(local_epoch):
        batch_loss = []
        for batch_idx, (images, labels) in enumerate(data_loader):
            model.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels.long())
            loss.backward()
            optimizer.step()
            batch_loss.append(loss.sum().item()/images.shape[0])
        epoch_loss.append(sum(batch_loss)/len(batch_loss))
        corrected_graident = weight_differences(c_i, c_s, learning_rate)
        orginal_model_weight = model.state_dict()
        corrected_model_weight = weight_differences(corrected_graident, orginal_model_weight, 1)  # 这里缩放权重为1
        model.load_state_dict(corrected_model_weight)

    return model.state_dict(),  sum(epoch_loss) / len(epoch_loss)

In [39]:
def average_weights(w):
    """
    Returns the average of the weights.
    """
    w_avg = copy.deepcopy(w[0])
    for key in w_avg.keys():
        if 'num_batches_tracked' in key:
            continue
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        w_avg[key] = torch.div(w_avg[key], len(w))
    return w_avg

In [40]:
# baseline: server-only
def server_only(initial_w, global_round, gamma, E):
    test_model = mobilenetv2().to(device)
    train_w = copy.deepcopy(initial_w)
    test_acc = []
    train_loss = []
    
    
    for round in tqdm(range(global_round)):
        # if gamma > 0.001:
        #     gamma = gamma * 0.99
        # Server side local training
        
        
        # 从comb中（论文中说明为全部训练数据）选择固定比率的server数据（并且是保证类别均衡的）
        # server_data = select_subset(comb_client_data, server_percentage)
                
        update_server_w, round_loss, _ = update_weights(train_w, server_data, gamma, E)
        train_w = update_server_w
        test_model.load_state_dict(train_w)
        train_loss.append(round_loss)
        # Test Accuracy
        test_a = 0
        for i in client_test_data:
            ac = test_inference(test_model,i)[0]
            test_a = test_a + ac
        test_a = test_a/len(client_test_data)
        test_acc.append(test_a)
        # print(test_a)
    return test_acc, train_loss

In [41]:
def fedavg(initial_w, global_round, eta, K, M):
    test_model = mobilenetv2().to(device)
    train_w = copy.deepcopy(initial_w)
    test_acc = []
    train_loss = []
    for round in tqdm(range(global_round)):
        local_weights, local_loss = [], []
        # Client side local training
        # if eta > 0.001:
        #     eta = eta * 0.99
        sampled_client = random.sample(range(client_num), M)
        for i in sampled_client:
            update_client_w, client_round_loss, _ = update_weights(train_w, client_data[i], eta, K)
            local_weights.append(update_client_w)
            local_loss.append(client_round_loss)

        train_w = average_weights(local_weights)

        # Test Accuracy
        test_model.load_state_dict(train_w)
        loss_avg = sum(local_loss)/ len(local_loss)
        train_loss.append(loss_avg)
        test_a = 0
        for i in client_test_data:
            ac = test_inference(test_model,i)[0]
            test_a = test_a + ac
        test_a = test_a/len(client_test_data)
        test_acc.append(test_a)
#         print(test_a)
    return test_acc, train_loss


In [42]:
def hybridFL(initial_w, global_round, eta, K, M):
    """
    HybridFL算法：FedAvg改进，服务器也作为一个普通客户端参与训练。
    
    参数:
    - initial_w: 初始模型权重
    - global_round: 全局训练轮数
    - eta: 学习率
    - K: 本地训练轮数
    - M: 每轮采样的客户端数量
    """
    test_model = mobilenetv2().to(device)  # 初始化测试模型
    train_w = copy.deepcopy(initial_w)     # 当前全局权重
    test_acc = []                          # 保存每轮测试精度
    train_loss = []                        # 保存每轮训练损失
    
    for round in tqdm(range(global_round)):
        local_weights, local_loss = [], []  # 存储每个客户端/服务器的权重和损失

        # 随机采样 M 个客户端
        sampled_client = random.sample(range(client_num), M)

        # 客户端本地训练
        for i in sampled_client:
            update_client_w, client_round_loss, _ = update_weights(train_w, client_data[i], eta, K)
            local_weights.append(update_client_w)
            local_loss.append(client_round_loss)

        # 服务器参与训练
        update_server_w, server_round_loss, _ = update_weights(train_w, server_data, eta, K)
        local_weights.append(update_server_w)   # 将服务器权重加入列表
        local_loss.append(server_round_loss)    # 将服务器损失加入列表

        # 权重聚合
        train_w = average_weights(local_weights)

        # 评估模型性能
        test_model.load_state_dict(train_w)
        loss_avg = sum(local_loss) / len(local_loss)
        train_loss.append(loss_avg)
        
        test_a = 0
        for i in client_test_data:  # 遍历所有客户端测试数据
            ac = test_inference(test_model, i)[0]
            test_a += ac
        test_a = test_a / len(client_test_data)
        test_acc.append(test_a)
        
        # # 打印每轮的结果
        # print(f"Round {round + 1}: Test Accuracy = {test_a:.4f}, Train Loss = {loss_avg:.4f}")
    
    return test_acc, train_loss


In [43]:
def CLG_SGD(initial_w, global_round, eta, gamma, K, E, M):
    test_model = mobilenetv2().to(device)
    train_w = copy.deepcopy(initial_w)
    test_acc = []
    train_loss = []
    
    for round in tqdm(range(global_round)):
        # 学习率衰减，这里默认注释掉了
        # if eta > 0.001:
        #     eta = eta * 0.99
        # if gamma > 0.001:
        #     gamma = gamma * 0.99
        local_weights, local_loss = [], []
        # Client side local training
        # 从总共client_num客户端中选择M个训练
        sampled_client = random.sample(range(client_num), M)
        for i in sampled_client:
            update_client_w, client_round_loss, _ = update_weights(train_w, client_data[i], eta, K)
            local_weights.append(update_client_w)
            local_loss.append(client_round_loss)
        train_w = average_weights(local_weights)
        # Server side local training
        
        
        # 从comb中（论文中说明为全部训练数据）选择固定比率的server数据（并且是保证类别均衡的）
        # TODO_241216:这里是每一轮都重新选择数据（但保证类别比例是一样的，都是按照comb中的比例），我的场景中可以这样吗？
        # server_data = select_subset(comb_client_data, server_percentage)
        
        update_server_w, round_loss, _ = update_weights(train_w, server_data, gamma, E)
        train_w = update_server_w
        local_loss.append(round_loss)

        # Test Accuracy
        test_model.load_state_dict(train_w)
        loss_avg = sum(local_loss)/ len(local_loss)
        train_loss.append(loss_avg)   # 计算所有客户端和服务器一起的平均损失

        test_a = 0
        # 遍历客户端测试数据，计算平均准确率
        for i in client_test_data:
            ac = test_inference(test_model,i)[0]
            test_a = test_a + ac
        test_a = test_a/len(client_test_data)
        test_acc.append(test_a)
#         print(test_a)
    return test_acc, train_loss

In [44]:
def Fed_C(initial_w, global_round, eta, gamma, K, E, M):
    test_model = mobilenetv2().to(device)
    train_w = copy.deepcopy(initial_w)
    test_acc = []
    train_loss = []
    
    
    for round in tqdm(range(global_round)):
        # if eta > 0.001:
        #     eta = eta * 0.99
        # if gamma > 0.001:
        #     gamma = gamma * 0.99
        local_weights, local_loss = [], []
        g_i_list = []
        # server_data = select_subset(comb_client_data, server_percentage)
        
        
        # 计算Server gradient
        _, _, g_s = update_weights(train_w, server_data, gamma, 1)

        # 计算Client gradient
        sampled_client = random.sample(range(client_num), M)
        for i in sampled_client:
            _, _, g_i = update_weights(train_w, client_data[i], eta, 1)
            g_i_list.append(g_i)


        # Client side local training
        for i in range(len(sampled_client)):
            update_client_w, client_round_loss = update_weights_correction(train_w, client_data[sampled_client[i]], eta, K, g_i_list[i], g_s)
            local_weights.append(update_client_w)
            local_loss.append(client_round_loss)
        train_w = average_weights(local_weights)
        # Server side local training
        update_server_w, round_loss, _ = update_weights(train_w, server_data, gamma, E)
        train_w = update_server_w
        local_loss.append(round_loss)

        # Test Accuracy
        test_model.load_state_dict(train_w)
        loss_avg = sum(local_loss)/ len(local_loss)
        train_loss.append(loss_avg)

        test_a = 0
        for i in client_test_data:
            ac = test_inference(test_model,i)[0]
            test_a = test_a + ac
        test_a = test_a/len(client_test_data)
        test_acc.append(test_a)
    return test_acc, train_loss


In [45]:
def Fed_S(initial_w, global_round, eta, gamma, K, E, M):
    test_model = mobilenetv2().to(device)
    train_w = copy.deepcopy(initial_w)
    test_acc = []
    train_loss = []
    
    
    for round in tqdm(range(global_round)):
        # if eta > 0.001:
        #     eta = eta * 0.99
        # if gamma > 0.001:
        #     gamma = gamma * 0.99
        local_weights, local_loss = [], []
        g_i_list = []
        # Server gradient
        # server_data = select_subset(comb_client_data, server_percentage)
        _, _, g_s = update_weights(train_w, server_data, gamma, 1)

        # Client gradient
        sampled_client = random.sample(range(client_num), M)
        for i in sampled_client:
            _, _, g_i = update_weights(train_w, client_data[i], eta, 1)
            g_i_list.append(g_i)


        # Client side local training
        for i in range(len(sampled_client)):
            update_client_w, client_round_loss, _ = update_weights(train_w, client_data[sampled_client[i]], eta, K)
            local_weights.append(update_client_w)
            local_loss.append(client_round_loss)
        train_w = average_weights(local_weights)

        # Server aggregation correction
        g_i_average = average_weights(g_i_list)
        correction_g = weight_differences(g_i_average, g_s, K*eta)
        train_w = weight_differences(correction_g, copy.deepcopy(train_w), 1)


        # Server side local training
        update_server_w, round_loss, _ = update_weights(train_w, server_data, gamma, E)
        train_w = update_server_w
        local_loss.append(round_loss)

        # Test Accuracy
        test_model.load_state_dict(train_w)
        loss_avg = sum(local_loss)/ len(local_loss)
        train_loss.append(loss_avg)

        test_a = 0
        for i in client_test_data:
            ac = test_inference(test_model,i)[0]
            test_a = test_a + ac
        test_a = test_a/len(client_test_data)
        test_acc.append(test_a)
    return test_acc, train_loss

In [46]:
# 初始化模型与参数
# 这部分是我补充的


init_model = mobilenetv2().to(device)
initial_w = init_model.state_dict()

# Servfer-only训练
test_acc, train_loss = server_only(initial_w, global_round, gamma, E)

# 打印训练过程中的结果
print("Server only 训练完成！")
print("各轮平均测试精度:", test_acc)
print("各轮平均训练损失:", train_loss)
print("最终测试精度:", test_acc[-1] if len(test_acc) > 0 else "无数据")


init_model = mobilenetv2().to(device)
initial_w = init_model.state_dict()

# fedavg训练
test_acc, train_loss = fedavg(initial_w, global_round, eta, K, M)

# 打印训练过程中的结果
print("fedavg训练完成！")
print("各轮平均测试精度:", test_acc)
print("各轮平均训练损失:", train_loss)
print("最终测试精度:", test_acc[-1] if len(test_acc) > 0 else "无数据")



init_model = mobilenetv2().to(device)
initial_w = init_model.state_dict()

# hybridfl训练
test_acc, train_loss = hybridFL(initial_w, global_round, eta, K, M)

# 打印训练过程中的结果
print("hrbridFL训练完成！")
print("各轮平均测试精度:", test_acc)
print("各轮平均训练损失:", train_loss)
print("最终测试精度:", test_acc[-1] if len(test_acc) > 0 else "无数据")


init_model = mobilenetv2().to(device)
initial_w = init_model.state_dict()

# CLG_SGD训练
test_acc, train_loss = CLG_SGD(initial_w, global_round, eta, gamma, K, E, M)

# 打印训练过程中的结果
print("CLG_SGD 训练完成！")
print("各轮平均测试精度:", test_acc)
print("各轮平均训练损失:", train_loss)
print("最终测试精度:", test_acc[-1] if len(test_acc) > 0 else "无数据")


init_model = mobilenetv2().to(device)
initial_w = init_model.state_dict()

# Fed_C训练
test_acc, train_loss = Fed_C(initial_w, global_round, eta, gamma, K, E, M)

# 打印训练过程中的结果
print("Fed_C 训练完成！")
print("各轮平均测试精度:", test_acc)
print("各轮平均训练损失:", train_loss)
print("最终测试精度:", test_acc[-1] if len(test_acc) > 0 else "无数据")


init_model = mobilenetv2().to(device)
initial_w = init_model.state_dict()

# Fed_S训练
test_acc, train_loss = Fed_S(initial_w, global_round, eta, gamma, K, E, M)

# 打印训练过程中的结果
print("Fed_S 训练完成！")
print("各轮平均测试精度:", test_acc)
print("各轮平均训练损失:", train_loss)
print("最终测试精度:", test_acc[-1] if len(test_acc) > 0 else "无数据")

  0%|          | 0/100 [00:00<?, ?it/s]

100%|██████████| 100/100 [01:45<00:00,  1.05s/it]


Server only 训练完成！
各轮平均测试精度: [0.051321207209223285, 0.07065446713833504, 0.08132523626374187, 0.07808900201600659, 0.06066246275048179, 0.09995242988167347, 0.09419918188635905, 0.0889215587149879, 0.10124859812935971, 0.12836012757608717, 0.13476190945474378, 0.13689548453519837, 0.13278557602390317, 0.11921261716998273, 0.1119176960323629, 0.11854880776249095, 0.13355632591087618, 0.126508432379652, 0.10489293154666161, 0.10674947088316929, 0.1020418970236527, 0.11097783710336317, 0.13245024320939316, 0.14275119447809154, 0.13324867607028315, 0.1485055149079688, 0.14441928303323348, 0.1036288656048447, 0.14609107211119687, 0.1335654797295428, 0.12809083789805556, 0.13459022435871845, 0.13467406205877924, 0.1430563821854608, 0.1468702590040395, 0.15036390653557102, 0.1507943689518456, 0.12810086282706573, 0.1292130491815761, 0.13542941963439303, 0.13783568270407748, 0.14296703858779553, 0.13812855042473726, 0.13378420299273644, 0.1219911272055083, 0.13085083771002268, 0.105339719551074

100%|██████████| 100/100 [04:37<00:00,  2.77s/it]


fedavg训练完成！
各轮平均测试精度: [0.0495069447487313, 0.049310456570278355, 0.057000228374489474, 0.04503536199658987, 0.06786833844015079, 0.05851623730684539, 0.07060765565905569, 0.11663307173764821, 0.09354851331488781, 0.1123673301773242, 0.11383894662944644, 0.127842802055668, 0.13516116022423294, 0.10490938008293281, 0.08846518355548266, 0.12240233075226493, 0.10628212633056709, 0.1119781575612171, 0.15461525131882162, 0.11640766639410166, 0.128201801577003, 0.11233377093823917, 0.12150088930811073, 0.1394102576881481, 0.12139291414860112, 0.12920232993412667, 0.15744837330665898, 0.15832157626880275, 0.12897370574022957, 0.13487779157327545, 0.13299627607898853, 0.13990295077309567, 0.15825538538078823, 0.15724061217609506, 0.1174989885905042, 0.15176421837453785, 0.16027532600051766, 0.15839904113372247, 0.1353648819267269, 0.16421094717648135, 0.1610197186066554, 0.18259923571576472, 0.1915946457102679, 0.13230555026828192, 0.14125469159669962, 0.13264232589824052, 0.15589146048414776, 

100%|██████████| 100/100 [04:58<00:00,  2.98s/it]


hrbridFL训练完成！
各轮平均测试精度: [0.04741424121514503, 0.05559922580665109, 0.04881679539094752, 0.04807227784439549, 0.06455100950079228, 0.10187515302090853, 0.08884738914420671, 0.07824645374692757, 0.10894948014535949, 0.13914332097885837, 0.15376996298528808, 0.12250669499553726, 0.1395941687510895, 0.1322224136987995, 0.14327155258486407, 0.1319636890159212, 0.147028086623539, 0.16512483896940697, 0.1524205112909025, 0.11463166161048427, 0.1442770288395381, 0.17843431109132496, 0.14869077674956543, 0.15524257494580204, 0.16274725606701956, 0.19205664113573012, 0.18850089675922949, 0.16427613118658577, 0.17185218178730277, 0.18510141684809922, 0.19419353942160464, 0.16984112312350652, 0.17643714387750775, 0.18460573693904656, 0.1688610682156558, 0.19862759088436355, 0.20765864643339482, 0.20101286380594582, 0.1685401689362267, 0.2107580498138233, 0.2004957575903966, 0.21077734385756364, 0.21701228434575456, 0.19228671652201348, 0.185327252187538, 0.19958097894794832, 0.19152357451326746, 0

100%|██████████| 100/100 [04:58<00:00,  2.98s/it]


CLG_SGD 训练完成！
各轮平均测试精度: [0.05109638361879235, 0.06323004287864196, 0.12176141015060533, 0.12827690110862477, 0.14606381481800906, 0.16879372952144006, 0.13642248544046084, 0.15615884564486285, 0.1510988602518253, 0.1812600124392597, 0.19451626995377544, 0.171365737786712, 0.1792969236768233, 0.19071875988984976, 0.1844938674057251, 0.18273410121081907, 0.1892109254419646, 0.19785577235010013, 0.19819995112553396, 0.1840203878341246, 0.2026926432283851, 0.20555804075035644, 0.2031645278735949, 0.2096209327137581, 0.19881761110440915, 0.21261680104084643, 0.22703357744613295, 0.23017716814129924, 0.21604492002635506, 0.22769735956770984, 0.23904051805906812, 0.2351788824349425, 0.22482808098623677, 0.2548487022212245, 0.24849554584939082, 0.2510210564897962, 0.25372188626743886, 0.2507870138558782, 0.25384685300168025, 0.2616656522140669, 0.24116794527234262, 0.26506905737780523, 0.2691175427880954, 0.23743445603818916, 0.26016644561269436, 0.2679540060369874, 0.26209089526637364, 0.2648

100%|██████████| 100/100 [07:25<00:00,  4.45s/it]


Fed_C 训练完成！
各轮平均测试精度: [0.05109638361879235, 0.09275787258845809, 0.09555794114292629, 0.07419276604152852, 0.11729794801989693, 0.08479980684008652, 0.11470389343408188, 0.11170224560214331, 0.1445872109165276, 0.12941011716698703, 0.15438648844335534, 0.14315900045495222, 0.15177498624366245, 0.17242636646326517, 0.17447906640223848, 0.16863243262568564, 0.1897973457444528, 0.18370671368884628, 0.1859520535270899, 0.17397713939474005, 0.20324265502186642, 0.2051722024485442, 0.21401632610522575, 0.2140653629779894, 0.13130312781316966, 0.2024689148327347, 0.15869206809576084, 0.2092460664734487, 0.22085133824254866, 0.20735704821413228, 0.21074072047639825, 0.22183291404229838, 0.18321236069428806, 0.2065631734150482, 0.22850771635381073, 0.22204475374386476, 0.22716622161384378, 0.21361360500839038, 0.22521419631196604, 0.23269691009691962, 0.22458463893593517, 0.2322894425402263, 0.23109153236505298, 0.21753445567491217, 0.2283690893999242, 0.22803295308685342, 0.23789115403976802, 

100%|██████████| 100/100 [06:27<00:00,  3.88s/it]

Fed_S 训练完成！
各轮平均测试精度: [0.051321207209223285, 0.06803169238889226, 0.10042575763885499, 0.10132348573782118, 0.09962111318576444, 0.10900401256087207, 0.14390348320345764, 0.12919548045026816, 0.14506491384769643, 0.1361086021725518, 0.12406543859987068, 0.1362936964287569, 0.1455499250695871, 0.14104294305775023, 0.1489386496037165, 0.17437976015089252, 0.1568148575811723, 0.18065803528387483, 0.17993707853628127, 0.17932434553082666, 0.19601073034787236, 0.1711604500016671, 0.1853167596262088, 0.20317377844184337, 0.20700515843174125, 0.1964356718113247, 0.20625164985508448, 0.2113757153703827, 0.20814715842889625, 0.2147264669134762, 0.21707778189630667, 0.21153885874661235, 0.21421303868871913, 0.21284359395234603, 0.20003939284965938, 0.2098357022100865, 0.2162566525592915, 0.22727053111352064, 0.20194919346150791, 0.2178997384647032, 0.22923347294345056, 0.2321312477223902, 0.23235846986461212, 0.22396406081795067, 0.2334544573574001, 0.21925563681758253, 0.2213766134017164, 0.230


