In [45]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
from torch import nn
from torch.utils.data import DataLoader, Dataset, TensorDataset
from tqdm import tqdm
import copy
import matplotlib.pyplot as plt
from collections import OrderedDict
import random
import torch.nn.functional as F
import torch.nn.functional as func
import collections
from sklearn.model_selection import train_test_split
from collections import Counter
from config import *

In [46]:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True' # 解决由于多次加载 OpenMP 相关动态库而引起的冲突

In [47]:
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Tue Dec 17 14:41:17 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        Off | 00000000:21:00.0 Off |                  Off |
| 30%   25C    P8              13W / 350W |   3343MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce RTX 4090        Off | 00000000:E1:00.0 Off |  

In [48]:
# MobileNetV2（比lenet更复杂的CNN网络）网络中的线性瓶颈结构，原文中用于CIFAR-100任务
class LinearBottleNeck(nn.Module):

    def __init__(self, in_channels, out_channels, stride, t=6, class_num=100):
        super().__init__()

        self.residual = nn.Sequential(
            nn.Conv2d(in_channels, in_channels * t, 1),
            nn.BatchNorm2d(in_channels * t),
            nn.ReLU6(inplace=True),

            nn.Conv2d(in_channels * t, in_channels * t, 3, stride=stride, padding=1, groups=in_channels * t),
            nn.BatchNorm2d(in_channels * t),
            nn.ReLU6(inplace=True),

            nn.Conv2d(in_channels * t, out_channels, 1),
            nn.BatchNorm2d(out_channels)
        )

        self.stride = stride
        self.in_channels = in_channels
        self.out_channels = out_channels

    def forward(self, x):

        residual = self.residual(x)

        if self.stride == 1 and self.in_channels == self.out_channels:
            residual += x

        return residual

class MobileNetV2(nn.Module):

    def __init__(self, class_num=20):
        super().__init__()

        self.pre = nn.Sequential(
            nn.Conv2d(3, 32, 1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU6(inplace=True)
        )

        self.stage1 = LinearBottleNeck(32, 16, 1, 1)
        self.stage2 = self._make_stage(2, 16, 24, 2, 6)
        self.stage3 = self._make_stage(3, 24, 32, 2, 6)
        self.stage4 = self._make_stage(4, 32, 64, 2, 6)
        self.stage5 = self._make_stage(3, 64, 96, 1, 6)
        self.stage6 = self._make_stage(3, 96, 160, 1, 6)
        self.stage7 = LinearBottleNeck(160, 320, 1, 6)

        self.conv1 = nn.Sequential(
            nn.Conv2d(320, 1280, 1),
            nn.BatchNorm2d(1280),
            nn.ReLU6(inplace=True)
        )

        self.conv2 = nn.Conv2d(1280, class_num, 1)

    def forward(self, x):
        x = self.pre(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.stage5(x)
        x = self.stage6(x)
        x = self.stage7(x)
        x = self.conv1(x)
        x = F.adaptive_avg_pool2d(x, 1)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)

        return x

    def _make_stage(self, repeat, in_channels, out_channels, stride, t):

        layers = []
        layers.append(LinearBottleNeck(in_channels, out_channels, stride, t))

        while repeat - 1:
            layers.append(LinearBottleNeck(out_channels, out_channels, 1, t))
            repeat -= 1

        return nn.Sequential(*layers)

def mobilenetv2():
    return MobileNetV2()

In [49]:
def test_inference(model, test):
    """ Returns the test accuracy and loss.
    """
    tensor_x = torch.Tensor(test[0]).to(device)
    tensor_y = torch.Tensor(test[1]).to(device)
    test_dataset = TensorDataset(tensor_x, tensor_y)

    model.eval()
    loss, total, correct = 0.0, 0.0, 0.0

    criterion = nn.CrossEntropyLoss()
    testloader = DataLoader(test_dataset, batch_size=128,
                            shuffle=True)

    for batch_idx, (images, labels) in enumerate(testloader):
        with torch.no_grad():  # 在测试过程中不需要计算梯度，节省内存和加速计算
        # Inference
            outputs = model(images)
            batch_loss = criterion(outputs, labels.long())
            loss += batch_loss.item() * labels.size(0) # 计算损失值，更好反映模型输出概率分布与真实标签的差距

        # Prediction
            _, pred_labels = torch.max(outputs, 1)
            pred_labels = pred_labels.view(-1)
            correct += torch.sum(torch.eq(pred_labels, labels)).item()
            total += len(labels)
    #print(correct,"/",total)
    accuracy = correct/total
    return accuracy, loss

In [50]:
# 将CIFAR-100的100个类别转为20个类别（粒度更粗，降低任务复杂度）
def sparse2coarse(targets):
    """Convert Pytorch CIFAR100 sparse targets to coarse targets.

    Usage:
        trainset = torchvision.datasets.CIFAR100(path)
        trainset.targets = sparse2coarse(trainset.targets)
    """
    coarse_labels = np.array([ 4,  1, 14,  8,  0,  6,  7,  7, 18,  3,
                               3, 14,  9, 18,  7, 11,  3,  9,  7, 11,
                               6, 11,  5, 10,  7,  6, 13, 15,  3, 15,
                               0, 11,  1, 10, 12, 14, 16,  9, 11,  5,
                               5, 19,  8,  8, 15, 13, 14, 17, 18, 10,
                               16, 4, 17,  4,  2,  0, 17,  4, 18, 17,
                               10, 3,  2, 12, 12, 16, 12,  1,  9, 19,
                               2, 10,  0,  1, 16, 12,  9, 13, 15, 13,
                              16, 19,  2,  4,  6, 19,  5,  5,  8, 19,
                              18,  1,  2, 15,  6,  0, 17,  8, 14, 13])
    return coarse_labels[targets]

In [51]:
# 共有6w个图像，其中5w训练，1w测试
def CIFAR100():
    '''Return Cifar100
    '''
    train_dataset = torchvision.datasets.CIFAR100(root='../data/CIFAR-100',
                                            train=True,
                                            transform=transforms.ToTensor(),
                                            download=True)
    test_dataset = torchvision.datasets.CIFAR100(root='../data/CIFAR-100',
                                            train=False,
                                            transform=transforms.ToTensor(),
                                            download=True)
    total_img,total_label = [],[]
    for imgs,labels in train_dataset:
        total_img.append(imgs.numpy())
        total_label.append(labels)
    for imgs,labels in test_dataset:
        total_img.append(imgs.numpy())
        total_label.append(labels) 
    total_img = np.array(total_img)
    total_label = np.array(sparse2coarse(total_label))

    cifar = [total_img, total_label]
    return cifar


In [52]:
# 基于 Dirichlet 分布 来模拟non-IID。返回一个形状为 (client_num, class_num) 的概率矩阵，每一行代表一个客户端对各类别的概率分布。
def get_prob(non_iid, client_num, class_num = 20):
    return np.random.dirichlet(np.repeat(non_iid, class_num), client_num)

In [53]:
def create_data(prob, size_per_client, dataset, N=20):
    total_each_class = size_per_client * np.sum(prob, 0)
    data, label = dataset


    # 为每个类别随机采样数据
    all_class_set = []
    for i in range(N):
        size = total_each_class[i]
        sub_data = data[label == i]
        sub_label = label[label == i]

        rand_indx = np.random.choice(len(sub_data), size=int(size), replace=False).astype(int)
        sub2_data, sub2_label = sub_data[rand_indx], sub_label[rand_indx]
        all_class_set.append((sub2_data, sub2_label))

    index = [0] * N
    clients, test = [], []

    for m in range(prob.shape[0]):  # 遍历客户端
        labels, images = [], []  # 训练数据
        tlabels, timages = [], [] # 测试数据

        # TODO_241216：这里每个client的测试集和它的训练集分布相同，并且最后测试时，也是计算所有client中的准确率的平均值
        # TODO_241216：别的FL方法也是这样做的吗？我也要这样做吗？
        for n in range(N):
            # 80%用于训练，20%用于测试
            # 这里的int向下取整，会导致实际的数据量比计算略小
            start, end = index[n], index[n] + int(prob[m][n] * size_per_client * 0.8)
            test_start, test_end = end, index[n] + int(prob[m][n] * size_per_client)

            image, label = all_class_set[n][0][start:end], all_class_set[n][1][start:end]
            test_image, test_label = all_class_set[n][0][test_start:test_end], all_class_set[n][1][test_start:test_end]

            # 记录当前类别的数据分配进度
            index[n] += int(prob[m][n] * size_per_client)

            labels.extend(label)
            images.extend(image)

            tlabels.extend(test_label)
            timages.extend(test_image)

        clients.append((np.array(images), np.array(labels)))
        test.append((np.array(timages), np.array(tlabels)))

    return clients, test

In [54]:

# 合并所有客户端的测试数据 （上面讲测试数据分成了不同的客户端）
# 但并没有使用，用途不明
def comb_client_test_func(client_test_data):
    comb_client_test_image = []
    comb_client_test_label = []
    for i in range(client_num):
        comb_client_test_image.extend(list(client_test_data[i][0]))
        comb_client_test_label.extend(list(client_test_data[i][1]))
    
    # 将测试图片和标签合并为 numpy 数组
    comb_client_test_image = np.array(comb_client_test_image)
    comb_client_test_label = np.array(comb_client_test_label)
    
    label_count = Counter(comb_client_test_label)
    print("测试集类别分布：")
    for label, count in sorted(label_count.items()):
        print(f"类别 {label}: {count} 个样本")
    
    return [comb_client_test_image, comb_client_test_label]

In [55]:
# 从数据集中按类别均匀抽取子集，并按照指定的比例 percentage 进行缩减，同时对数据进行随机打乱
def select_subset(whole_set, percentage):
    a = whole_set[0]
    b = whole_set[1]
    if len(a) != len(b):
        raise ValueError("Both arrays should have the same length.")

    if not 0 <= percentage <= 1:
        raise ValueError("Percentage must be between 0 and 1.")

    unique_classes = np.unique(b)

    a_prime = []
    b_prime = []

    for cls in unique_classes:
        indices = np.where(b == cls)[0]
        subset_size = int(len(indices) * percentage)

        selected_indices = np.random.choice(indices, subset_size, replace=False)

        a_prime.extend(a[selected_indices])
        b_prime.extend(b[selected_indices])

    a_prime, b_prime = np.array(a_prime), np.array(b_prime)

    # Shuffle arrays to randomize the order of elements
    shuffle_indices = np.random.permutation(len(a_prime))
    a_prime, b_prime = a_prime[shuffle_indices], b_prime[shuffle_indices]

    return [a_prime, b_prime]

In [56]:

# 准备数据集
# 这部分是我加的
cifar = CIFAR100()
prob = get_prob(non_iid, client_num, class_num=20)
client_data, client_test_data = create_data(prob, size_per_client, cifar, N=20)


all_images = []
all_labels = []
for data in client_data:
    all_images.extend(data[0])
    all_labels.extend(data[1])
comb_client_data = [np.array(all_images), np.array(all_labels)]

# 输出cpmb_client_data情况
imgs, lbls = comb_client_data
lbls = np.array(lbls)
total_count = len(lbls)
unique_classes, counts = np.unique(lbls, return_counts=True)

# 创建一个长度为20的数组记录各类别计数，默认0
class_counts = [0]*20
for cls, cnt in zip(unique_classes, counts):
    class_counts[cls] = cnt

# 打印格式：Total: 总数 类别0计数 类别1计数 ... 类别19计数
print("Traning Client Total: {}".format(" ".join([str(total_count)] + [str(c) for c in class_counts])))


# 打印每个客户端训练数据情况
for i, (imgs, lbls) in enumerate(client_data):
    lbls = np.array(lbls)
    total_count = len(lbls)
    unique_classes, counts = np.unique(lbls, return_counts=True)
    # 创建一个长度为20的数组记录各类别计数，默认0
    class_counts = [0]*20
    for cls, cnt in zip(unique_classes, counts):
        class_counts[cls] = cnt
    # 打印格式：Client i: 总数 类别0计数 类别1计数 ... 类别19计数
    print("Client {}: {}".format(i, " ".join([str(total_count)] + [str(c) for c in class_counts])))
    

# 打印每个客户端测试数据情况
for i, (imgs, lbls) in enumerate(client_test_data):
    lbls = np.array(lbls)
    total_count = len(lbls)
    unique_classes, counts = np.unique(lbls, return_counts=True)
    class_counts = [0]*20
    for cls, cnt in zip(unique_classes, counts):
        class_counts[cls] = cnt
    # 打印格式：Client i Test: 总数 类别0计数 类别1计数 ... 类别19计数
    print("Client {} Test: {}".format(i, " ".join([str(total_count)] + [str(c) for c in class_counts])))


Files already downloaded and verified
Files already downloaded and verified
Traning Client Total: 30217 1476 1564 1294 1598 1206 1792 1640 1641 1573 1450 1642 1491 1733 1503 1454 1411 1391 1265 1502 1591
Client 0: 151 9 0 1 3 6 0 6 22 7 28 0 10 0 10 11 20 8 4 6 0
Client 1: 152 14 8 1 3 1 8 5 16 5 0 0 24 30 26 3 0 0 1 2 5
Client 2: 149 1 0 2 1 12 0 4 6 3 34 27 0 15 0 0 15 1 21 4 3
Client 3: 152 1 0 0 0 3 10 18 0 9 0 5 2 35 24 25 2 1 11 6 0
Client 4: 151 17 0 2 9 45 41 5 0 10 2 0 4 3 3 0 0 5 0 2 3
Client 5: 153 0 0 2 21 0 12 0 6 0 22 1 0 9 3 1 9 0 32 33 2
Client 6: 152 22 5 1 0 3 3 44 24 0 8 2 1 0 20 2 1 0 10 0 6
Client 7: 149 3 9 4 10 3 0 4 0 1 3 10 7 4 57 11 3 0 16 2 2
Client 8: 153 0 0 20 0 0 1 26 11 0 0 11 0 0 25 0 4 22 33 0 0
Client 9: 152 16 0 3 2 11 0 8 6 0 5 7 12 0 5 11 44 22 0 0 0
Client 10: 153 8 11 0 0 25 15 0 7 1 0 2 3 54 4 1 1 0 2 16 3
Client 11: 152 0 4 0 9 0 4 20 6 16 1 26 5 5 5 7 0 37 7 0 0
Client 12: 152 3 24 15 0 4 1 0 26 0 5 0 9 5 5 0 14 14 14 10 3
Client 13: 151 33 30

In [57]:
# 本地训练并更新权重，返回更新后的模型权重、平均训练损失以及第一个迭代的梯度信息
def update_weights(model_weight, dataset, learning_rate, local_epoch):
    model = mobilenetv2().to(device)
    model.load_state_dict(model_weight)

    model.train()
    epoch_loss = []
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()

    Tensor_set = TensorDataset(torch.Tensor(dataset[0]).to(device), torch.Tensor(dataset[1]).to(device))
    data_loader = DataLoader(Tensor_set, batch_size=128, shuffle=True)

    first_iter_gradient = None  # 初始化变量来保存第一个iter的梯度

    for iter in range(local_epoch):
        batch_loss = []
        for batch_idx, (images, labels) in enumerate(data_loader):
            model.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels.long())
            loss.backward()
            optimizer.step()
            batch_loss.append(loss.item()/images.shape[0])

            # 保存第一个iter的梯度
            if iter == 0 and batch_idx == 0:
                first_iter_gradient = {}
                for name, param in model.named_parameters():
                    first_iter_gradient[name] = param.grad.clone()
                # 保存 BatchNorm 层的 running mean 和 running variance
                for name, module in model.named_modules():
                    if isinstance(module, nn.BatchNorm2d):
                        first_iter_gradient[name + '.running_mean'] = module.running_mean.clone()
                        first_iter_gradient[name + '.running_var'] = module.running_var.clone()

        epoch_loss.append(sum(batch_loss)/len(batch_loss))

    return model.state_dict(), sum(epoch_loss) / len(epoch_loss), first_iter_gradient

In [58]:
# 计算模型权重的差异，并根据学习率 lr 对权重差异进行缩放
def weight_differences(n_w, p_w, lr):
    w_diff = copy.deepcopy(n_w)
    for key in w_diff.keys():
        if 'num_batches_tracked' in key:
            continue
        w_diff[key] = (p_w[key] - n_w[key]) * lr
    return w_diff

In [59]:
# 也是本地训练，不过引入了本文的权重修正机制
def update_weights_correction(model_weight, dataset, learning_rate, local_epoch, c_i, c_s):
    model = mobilenetv2().to(device)
    model.load_state_dict(model_weight)

    model.train()
    epoch_loss = []
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()

    Tensor_set = TensorDataset(torch.Tensor(dataset[0]).to(device), torch.Tensor(dataset[1]).to(device))
    data_loader = DataLoader(Tensor_set, batch_size=200, shuffle=True)

    for iter in range(local_epoch):
        batch_loss = []
        for batch_idx, (images, labels) in enumerate(data_loader):
            model.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels.long())
            loss.backward()
            optimizer.step()
            batch_loss.append(loss.sum().item()/images.shape[0])
        epoch_loss.append(sum(batch_loss)/len(batch_loss))
        corrected_graident = weight_differences(c_i, c_s, learning_rate)
        orginal_model_weight = model.state_dict()
        corrected_model_weight = weight_differences(corrected_graident, orginal_model_weight, 1)  # 这里缩放权重为1
        model.load_state_dict(corrected_model_weight)

    return model.state_dict(),  sum(epoch_loss) / len(epoch_loss)

In [60]:
def average_weights(w):
    """
    Returns the average of the weights.
    """
    w_avg = copy.deepcopy(w[0])
    for key in w_avg.keys():
        if 'num_batches_tracked' in key:
            continue
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        w_avg[key] = torch.div(w_avg[key], len(w))
    return w_avg

In [61]:
# baseline: server-only
def server_only(initial_w, global_round, gamma, E):
    test_model = mobilenetv2().to(device)
    train_w = copy.deepcopy(initial_w)
    test_acc = []
    train_loss = []
    
    # 提前生成固定的服务器数据
    # Modify: 这是我后来修改的
    server_data = select_subset(comb_client_data, server_percentage)
    
    for round in tqdm(range(global_round)):
        # if gamma > 0.001:
        #     gamma = gamma * 0.99
        # Server side local training
        
        
        # 从comb中（论文中说明为全部训练数据）选择固定比率的server数据（并且是保证类别均衡的）
        # server_data = select_subset(comb_client_data, server_percentage)
        
        # 第一轮都打印服务器数据情况，简化输出格式
        if round == 0:
            s_imgs, s_lbls = server_data
            s_lbls = np.array(s_lbls)
            total_count = len(s_lbls)
            unique_classes, counts = np.unique(s_lbls, return_counts=True)
            class_counts = [0]*20
            for cls, cnt in zip(unique_classes, counts):
                class_counts[cls] = cnt

            # 输出格式: Server round: 总数 类别0计数 类别1计数 ... 类别19计数
            print("Server {}: {}".format(round, " ".join([str(total_count)] + [str(c) for c in class_counts])))
        
        update_server_w, round_loss, _ = update_weights(train_w, server_data, gamma, E)
        train_w = update_server_w
        test_model.load_state_dict(train_w)
        train_loss.append(round_loss)
        # Test Accuracy
        test_a = 0
        for i in client_test_data:
            ac = test_inference(test_model,i)[0]
            test_a = test_a + ac
        test_a = test_a/len(client_test_data)
        test_acc.append(test_a)
        # print(test_a)
    return test_acc, train_loss

In [62]:
def fedavg(initial_w, global_round, eta, K, M):
    test_model = mobilenetv2().to(device)
    train_w = copy.deepcopy(initial_w)
    test_acc = []
    train_loss = []
    for round in tqdm(range(global_round)):
        local_weights, local_loss = [], []
        # Client side local training
        # if eta > 0.001:
        #     eta = eta * 0.99
        sampled_client = random.sample(range(client_num), M)
        for i in sampled_client:
            update_client_w, client_round_loss, _ = update_weights(train_w, client_data[i], eta, K)
            local_weights.append(update_client_w)
            local_loss.append(client_round_loss)

        train_w = average_weights(local_weights)

        # Test Accuracy
        test_model.load_state_dict(train_w)
        loss_avg = sum(local_loss)/ len(local_loss)
        train_loss.append(loss_avg)
        test_a = 0
        for i in client_test_data:
            ac = test_inference(test_model,i)[0]
            test_a = test_a + ac
        test_a = test_a/len(client_test_data)
        test_acc.append(test_a)
#         print(test_a)
    return test_acc, train_loss


In [63]:
def CLG_SGD(initial_w, global_round, eta, gamma, K, E, M):
    test_model = mobilenetv2().to(device)
    train_w = copy.deepcopy(initial_w)
    test_acc = []
    train_loss = []
    
    # 提前生成固定的服务器数据
    # Modify: 这是我后来修改的
    server_data = select_subset(comb_client_data, server_percentage)
    
    for round in tqdm(range(global_round)):
        # 学习率衰减，这里默认注释掉了
        # if eta > 0.001:
        #     eta = eta * 0.99
        # if gamma > 0.001:
        #     gamma = gamma * 0.99
        local_weights, local_loss = [], []
        # Client side local training
        # 从总共client_num客户端中选择M个训练
        sampled_client = random.sample(range(client_num), M)
        for i in sampled_client:
            update_client_w, client_round_loss, _ = update_weights(train_w, client_data[i], eta, K)
            local_weights.append(update_client_w)
            local_loss.append(client_round_loss)
        train_w = average_weights(local_weights)
        # Server side local training
        
        
        # 从comb中（论文中说明为全部训练数据）选择固定比率的server数据（并且是保证类别均衡的）
        # TODO_241216:这里是每一轮都重新选择数据（但保证类别比例是一样的，都是按照comb中的比例），我的场景中可以这样吗？
        # server_data = select_subset(comb_client_data, server_percentage)
        
        
        # 第一轮都打印服务器数据情况，简化输出格式
        if round == 0:
            s_imgs, s_lbls = server_data
            s_lbls = np.array(s_lbls)
            total_count = len(s_lbls)
            unique_classes, counts = np.unique(s_lbls, return_counts=True)
            class_counts = [0]*20
            for cls, cnt in zip(unique_classes, counts):
                class_counts[cls] = cnt

            # 输出格式: Server round: 总数 类别0计数 类别1计数 ... 类别19计数
            print("Server {}: {}".format(round, " ".join([str(total_count)] + [str(c) for c in class_counts])))

        update_server_w, round_loss, _ = update_weights(train_w, server_data, gamma, E)
        train_w = update_server_w
        local_loss.append(round_loss)

        # Test Accuracy
        test_model.load_state_dict(train_w)
        loss_avg = sum(local_loss)/ len(local_loss)
        train_loss.append(loss_avg)   # 计算所有客户端和服务器一起的平均损失

        test_a = 0
        # 遍历客户端测试数据，计算平均准确率
        for i in client_test_data:
            ac = test_inference(test_model,i)[0]
            test_a = test_a + ac
        test_a = test_a/len(client_test_data)
        test_acc.append(test_a)
#         print(test_a)
    return test_acc, train_loss

In [64]:
def Fed_C(initial_w, global_round, eta, gamma, K, E, M):
    test_model = mobilenetv2().to(device)
    train_w = copy.deepcopy(initial_w)
    test_acc = []
    train_loss = []
    
    # 提前生成固定的服务器数据
    # Modify: 这是我后来修改的
    server_data = select_subset(comb_client_data, server_percentage)
    
    for round in tqdm(range(global_round)):
        # if eta > 0.001:
        #     eta = eta * 0.99
        # if gamma > 0.001:
        #     gamma = gamma * 0.99
        local_weights, local_loss = [], []
        g_i_list = []
        # server_data = select_subset(comb_client_data, server_percentage)
        
        # 第一轮都打印服务器数据情况，简化输出格式
        if round == 0:
            s_imgs, s_lbls = server_data
            s_lbls = np.array(s_lbls)
            total_count = len(s_lbls)
            unique_classes, counts = np.unique(s_lbls, return_counts=True)
            class_counts = [0]*20
            for cls, cnt in zip(unique_classes, counts):
                class_counts[cls] = cnt

            # 输出格式: Server round: 总数 类别0计数 类别1计数 ... 类别19计数
            print("Server {}: {}".format(round, " ".join([str(total_count)] + [str(c) for c in class_counts])))
        
        
        # 计算Server gradient
        _, _, g_s = update_weights(train_w, server_data, gamma, 1)

        # 计算Client gradient
        sampled_client = random.sample(range(client_num), M)
        for i in sampled_client:
            _, _, g_i = update_weights(train_w, client_data[i], eta, 1)
            g_i_list.append(g_i)


        # Client side local training
        for i in range(len(sampled_client)):
            update_client_w, client_round_loss = update_weights_correction(train_w, client_data[sampled_client[i]], eta, K, g_i_list[i], g_s)
            local_weights.append(update_client_w)
            local_loss.append(client_round_loss)
        train_w = average_weights(local_weights)
        # Server side local training
        update_server_w, round_loss, _ = update_weights(train_w, server_data, gamma, E)
        train_w = update_server_w
        local_loss.append(round_loss)

        # Test Accuracy
        test_model.load_state_dict(train_w)
        loss_avg = sum(local_loss)/ len(local_loss)
        train_loss.append(loss_avg)

        test_a = 0
        for i in client_test_data:
            ac = test_inference(test_model,i)[0]
            test_a = test_a + ac
        test_a = test_a/len(client_test_data)
        test_acc.append(test_a)
    return test_acc, train_loss


In [65]:
def Fed_S(initial_w, global_round, eta, gamma, K, E, M):
    test_model = mobilenetv2().to(device)
    train_w = copy.deepcopy(initial_w)
    test_acc = []
    train_loss = []
    
    # 提前生成固定的server数据
    # Modify: 这是我后来修改的
    server_data = select_subset(comb_client_data, server_percentage)
    
    for round in tqdm(range(global_round)):
        # if eta > 0.001:
        #     eta = eta * 0.99
        # if gamma > 0.001:
        #     gamma = gamma * 0.99
        local_weights, local_loss = [], []
        g_i_list = []
        # Server gradient
        # server_data = select_subset(comb_client_data, server_percentage)
        _, _, g_s = update_weights(train_w, server_data, gamma, 1)

        # Client gradient
        sampled_client = random.sample(range(client_num), M)
        for i in sampled_client:
            _, _, g_i = update_weights(train_w, client_data[i], eta, 1)
            g_i_list.append(g_i)


        # Client side local training
        for i in range(len(sampled_client)):
            update_client_w, client_round_loss, _ = update_weights(train_w, client_data[sampled_client[i]], eta, K)
            local_weights.append(update_client_w)
            local_loss.append(client_round_loss)
        train_w = average_weights(local_weights)

        # Server aggregation correction
        g_i_average = average_weights(g_i_list)
        correction_g = weight_differences(g_i_average, g_s, K*eta)
        train_w = weight_differences(correction_g, copy.deepcopy(train_w), 1)


        # Server side local training
        update_server_w, round_loss, _ = update_weights(train_w, server_data, gamma, E)
        train_w = update_server_w
        local_loss.append(round_loss)

        # Test Accuracy
        test_model.load_state_dict(train_w)
        loss_avg = sum(local_loss)/ len(local_loss)
        train_loss.append(loss_avg)

        test_a = 0
        for i in client_test_data:
            ac = test_inference(test_model,i)[0]
            test_a = test_a + ac
        test_a = test_a/len(client_test_data)
        test_acc.append(test_a)
    return test_acc, train_loss

In [None]:
# 初始化模型与参数
# 这部分是我补充的


init_model = mobilenetv2().to(device)
initial_w = init_model.state_dict()

# Servfer-only训练
test_acc, train_loss = server_only(initial_w, global_round, gamma, E)

# 打印训练过程中的结果
print("Server only 训练完成！")
print("各轮平均测试精度:", test_acc)
print("各轮平均训练损失:", train_loss)
print("最终测试精度:", test_acc[-1] if len(test_acc) > 0 else "无数据")


init_model = mobilenetv2().to(device)
initial_w = init_model.state_dict()

# fedavg训练
test_acc, train_loss = fedavg(initial_w, global_round, eta, K, M)

# 打印训练过程中的结果
print("fedavg训练完成！")
print("各轮平均测试精度:", test_acc)
print("各轮平均训练损失:", train_loss)
print("最终测试精度:", test_acc[-1] if len(test_acc) > 0 else "无数据")


init_model = mobilenetv2().to(device)
initial_w = init_model.state_dict()

# CLG_SGD训练
test_acc, train_loss = CLG_SGD(initial_w, global_round, eta, gamma, K, E, M)

# 打印训练过程中的结果
print("CLG_SGD 训练完成！")
print("各轮平均测试精度:", test_acc)
print("各轮平均训练损失:", train_loss)
print("最终测试精度:", test_acc[-1] if len(test_acc) > 0 else "无数据")


init_model = mobilenetv2().to(device)
initial_w = init_model.state_dict()

# Fed_C训练
test_acc, train_loss = Fed_C(initial_w, global_round, eta, gamma, K, E, M)

# 打印训练过程中的结果
print("Fed_C 训练完成！")
print("各轮平均测试精度:", test_acc)
print("各轮平均训练损失:", train_loss)
print("最终测试精度:", test_acc[-1] if len(test_acc) > 0 else "无数据")


init_model = mobilenetv2().to(device)
initial_w = init_model.state_dict()

# Fed_S训练
test_acc, train_loss = Fed_S(initial_w, global_round, eta, gamma, K, E, M)

# 打印训练过程中的结果
print("Fed_S 训练完成！")
print("各轮平均测试精度:", test_acc)
print("各轮平均训练损失:", train_loss)
print("最终测试精度:", test_acc[-1] if len(test_acc) > 0 else "无数据")

  0%|          | 0/100 [00:00<?, ?it/s]

Server 0: 1502 73 78 64 79 60 89 82 82 78 72 82 74 86 75 72 70 69 63 75 79


100%|██████████| 100/100 [03:10<00:00,  1.90s/it]


Server only 训练完成！
各轮平均测试精度: [0.13265224464017994, 0.12877315551324778, 0.16186401313318644, 0.14996105819063088, 0.15667813094547572, 0.16812398226347977, 0.16643065126473705, 0.1618994307804275, 0.18138245378515763, 0.18736571458126736, 0.18390136265324702, 0.18862464757591935, 0.19024919609866392, 0.19019814629532278, 0.198979050855133, 0.19767950050785818, 0.19749502367981253, 0.19711869094097348, 0.19285369732745608, 0.1942067198783509, 0.19595521973965568, 0.1961333114088771, 0.19618869899358454, 0.1947097639483561, 0.19442062179049452, 0.1933733377366315, 0.1946381535845219, 0.194656157948046, 0.19554356393959194, 0.19611806154586106, 0.19438931919794009, 0.1968115156764812, 0.19548866528126235, 0.1964613978063279, 0.1952186686360605, 0.19462743974735713, 0.1955719213803818, 0.19560348990021414, 0.19738010852634183, 0.19540921549603293, 0.1948584423198007, 0.19705961301076905, 0.19525082552139456, 0.19496535397118273, 0.19568772928454015, 0.19581980459037354, 0.19448614613987633,

100%|██████████| 100/100 [04:33<00:00,  2.74s/it]


fedavg训练完成！
各轮平均测试精度: [0.047130967249313656, 0.04916137708042207, 0.045509596633712214, 0.05333087829875229, 0.05806897548856905, 0.05686584208727233, 0.04765692527891381, 0.07111222570406973, 0.07536779736395742, 0.06818521928029117, 0.09228074700559173, 0.12317444402073613, 0.09660546347916105, 0.11502055877127897, 0.12164176987018205, 0.11473259354548342, 0.09604923782721039, 0.11983143810634697, 0.1054496718443265, 0.12328674025446487, 0.10063816398382171, 0.09877164377380057, 0.12997020597064626, 0.11719105438006379, 0.1264379522598393, 0.15064573143590665, 0.1205013102253265, 0.11348774011256092, 0.13322953871319015, 0.15721927643170205, 0.138298460187745, 0.13997756985514834, 0.18731017633150293, 0.13538849944218326, 0.16483779029061807, 0.14596393179624872, 0.14936200895875631, 0.15104842117834683, 0.14138767862642385, 0.1338038728527913, 0.13589744959731442, 0.14754037595336061, 0.1655289758241887, 0.15577013232853668, 0.15020080146932918, 0.17638020179856378, 0.17148523636227

  0%|          | 0/100 [00:00<?, ?it/s]

Server 0: 1502 73 78 64 79 60 89 82 82 78 72 82 74 86 75 72 70 69 63 75 79


100%|██████████| 100/100 [06:35<00:00,  3.96s/it]


CLG_SGD 训练完成！
各轮平均测试精度: [0.13981846097426223, 0.15754623076811994, 0.1516340236548575, 0.16429285560052492, 0.1628520533928033, 0.17630418198840367, 0.1869413601996034, 0.20262701882499168, 0.2023541021694848, 0.21306829610460887, 0.2039500491565229, 0.21518409745734945, 0.22468907026815932, 0.224741884435525, 0.228044405138783, 0.22214224743147476, 0.2306738323735629, 0.22722414160107074, 0.23702871563581474, 0.2386639582779386, 0.24693078753816264, 0.23582081530811633, 0.2425112985233145, 0.24253309597707612, 0.24420337230588907, 0.24327574719556083, 0.24810123902318357, 0.24673326884696936, 0.2546406313972328, 0.2528723651049422, 0.2524281376667, 0.25523343027343004, 0.2554248426704617, 0.2523562853817018, 0.25456639406936843, 0.2553654332116358, 0.2483259647173579, 0.25264543902330483, 0.2559782201506013, 0.26140885992284246, 0.25659757800421146, 0.25938935626978954, 0.2556238480985846, 0.256491698538434, 0.2630907478672018, 0.26963120790185524, 0.2572252358452577, 0.26141654880005

  0%|          | 0/100 [00:00<?, ?it/s]

Server 0: 1502 73 78 64 79 60 89 82 82 78 72 82 74 86 75 72 70 69 63 75 79


100%|██████████| 100/100 [09:14<00:00,  5.54s/it]


Fed_C 训练完成！
各轮平均测试精度: [0.13405586115609766, 0.10101851576954943, 0.12755256904694903, 0.15866613121678858, 0.13678644017460168, 0.15443742569518742, 0.15799133074851782, 0.17714098987351054, 0.19178357401409085, 0.19670741409938203, 0.20084033714551372, 0.20862034300480636, 0.2077575916222286, 0.20596564792559938, 0.21382263177142177, 0.2157057893024545, 0.21974812006065164, 0.21805033136964724, 0.2165090573134278, 0.21976208592962704, 0.2200630578975244, 0.22585673447906984, 0.22460853928009566, 0.2233614409935908, 0.22066678494928776, 0.22225284014844632, 0.22548221332426632, 0.22659773236205458, 0.22791449757351542, 0.22716918504897277, 0.22837006430355833, 0.23297387807299125, 0.2293182452034438, 0.23106647767281183, 0.23441868464300686, 0.23245189738844016, 0.23260540874198063, 0.23186123656112392, 0.23425126920938483, 0.23698230968577186, 0.23891330523901708, 0.23601200781360185, 0.24051762896084386, 0.24109553815063997, 0.24392389197670267, 0.24579622743884838, 0.245565879973174

 83%|████████▎ | 83/100 [06:53<01:24,  4.97s/it]