In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torch.utils.data import Subset
from collections import Counter
import numpy as np
import random

from copy import deepcopy


# 设置随机种子
def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

# 设置一个固定的随机种子，例如 42
set_seed(42)

# 定义一个简单的神经网络
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 200)
        self.fc2 = nn.Linear(200, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 创建模型
def create_model():
    model = SimpleNN()
    return model

# 根据标签将训练集分配到不同客户端
def distribute_data_to_clients(train_dataset, num_clients=10):
    # 统计数据集原本的每个标签的数量
    original_label_counts = Counter(train_dataset.targets.tolist())
    print("Original dataset label counts:")
    for label, count in original_label_counts.items():
        print(f"Label {label}: {count} samples")

    clients = [[] for _ in range(num_clients)]

    for label in range(num_clients):
        indices = [i for i, target in enumerate(train_dataset.targets) if target == label]
        num_splits = num_clients - label
        split_indices = torch.chunk(torch.tensor(indices), num_splits)
        for i in range(num_splits):
            clients[label + i].extend(split_indices[i].tolist())

    client_subsets = [Subset(train_dataset, client_data) for client_data in clients]

    # 输出每个客户端的每个标签的数量
    for i, client_data in enumerate(client_subsets):
        client_targets = [train_dataset.targets[idx] for idx in client_data.indices]
        client_label_counts = Counter(client_targets)

        print(f"Client {i} label counts:")
        l = [0,0,0,0,0,0,0,0,0,0]
        for label, count in client_label_counts.items():
            if label == 0:
                l[0] += 1
            if label == 1:
                l[1] += 1
            if label == 2:
                l[2] += 1
            if label == 3:
                l[3] += 1
            if label == 4:
                l[4] += 1
            if label == 5:
                l[5] += 1
            if label == 6:
                l[6] += 1
            if label == 7:
                l[7] += 1
            if label == 8:
                l[8] += 1
            if label == 9:
                l[9] += 1
        print(l)

    return client_subsets

class EWC(object):
    def __init__(self, model, dataloader):
        self.model = model
        self.dataloader = dataloader

        self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad}
        self._means = {}
        self._precision_matrices = self._diag_fisher()

        # Store a copy of model's current parameters as the mean (theta*)
        for n, p in deepcopy(self.params).items():
            self._means[n] = p.data.clone()

    def _diag_fisher(self):
        precision_matrices = {}
        for n, p in deepcopy(self.params).items():
            p.data.zero_()
            precision_matrices[n] = p.data.clone()

        self.model.eval()

        # Use DataLoader for iterating over batches
        for inputs, targets in self.dataloader:
            self.model.zero_grad()

            # Forward pass
            output = self.model(inputs)
            loss = nn.CrossEntropyLoss()(output, targets)
            # loss = F.nll_loss(F.log_softmax(output, dim=1), targets)

            # Backward pass
            loss.backward()

            # Update precision matrices
            for n, p in self.model.named_parameters():
                precision_matrices[n] += p.grad.data ** 2 / len(self.dataloader.dataset)

        precision_matrices = {n: p for n, p in precision_matrices.items()}
        return precision_matrices

    def penalty(self, model):
        loss = 0
        for n, p in model.named_parameters():
            # Penalty term based on precision and distance from original weights
            _loss = self._precision_matrices[n] * (p - self._means[n]) ** 2
            loss += _loss.sum()
        return loss

# Training loop with EWC regularization
def train_with_ewc(model, data_loader, criterion, optimizer, ewc, epochs, importance):
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0
        for inputs, targets in data_loader:

            optimizer.zero_grad()
            output = model(inputs)
            loss = criterion(output, targets) + importance * ewc.penalty(model)
            epoch_loss += loss.item()
            
            # Backward pass and update weights
            loss.backward()
            optimizer.step()

# 联邦加权平均函数，考虑各客户端的数据量
def federated_weighted_avg(weights, num_samples):
    avg_weights = []
    total_samples = sum(num_samples)
    for i in range(len(weights[0])):
        weighted_sum = sum(weights[j][i] * num_samples[j] / total_samples for j in range(len(weights)))
        avg_weights.append(weighted_sum)
    return avg_weights

def federated_learning(global_model, client_subsets, criterion, hyperparams, test_loader=None):
    num_rounds = hyperparams.get('num_rounds', 5)
    learning_rate = hyperparams.get('learning_rate', 0.01)
    batch_size = hyperparams.get('batch_size', 64)
    epochs_per_client = hyperparams.get('epochs_per_client', 1)
    lambda_ewc = hyperparams.get('lambda_ewc', 0.1)

    loss_history = []
    accuracy_history = []

    for round_num in range(num_rounds):
        client_weights = []
        num_samples = []

        for client_data in client_subsets:
            client_loader = torch.utils.data.DataLoader(client_data, batch_size=batch_size, shuffle=True)
            model = create_model()
            model.load_state_dict(global_model.state_dict())
            
            ewc = EWC(model, client_loader)
            optimizer = optim.SGD(model.parameters(), lr=learning_rate)
            train_with_ewc(model, client_loader, criterion, optimizer, ewc, epochs=epochs_per_client, importance=lambda_ewc)

            model_weights = [param.data.clone() for param in model.parameters()]
            client_weights.append(model_weights)
            num_samples.append(len(client_data))

        new_weights = federated_weighted_avg(client_weights, num_samples)
        state_dict = global_model.state_dict()
        new_state_dict = {key: value for key, value in zip(state_dict.keys(), new_weights)}
        global_model.load_state_dict(new_state_dict)

        if test_loader is not None:
            test_loss, accuracy = test(global_model, test_loader, criterion)
            loss_history.append(test_loss)
            accuracy_history.append(accuracy)
            print(f'Round {round_num + 1} Test Loss: {test_loss:.4f} and Test Accuracy: {accuracy * 100:.2f} %')

    if test_loader is not None and loss_history and accuracy_history:
        plot_loss_accuracy_history(loss_history, accuracy_history)

    final_accuracy = accuracy_history[-1] if accuracy_history else None

    return loss_history, accuracy_history, final_accuracy

# 测试函数
def test(model, data_loader, criterion):
    model.eval()
    correct = 0
    total = 0
    test_loss = 0
    with torch.no_grad():
        for data, targets in data_loader:
            outputs = model(data)
            loss = criterion(outputs, targets)
            test_loss += loss.item() * data.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    accuracy = correct / total
    test_loss /= total
    return test_loss, accuracy

def plot_loss_accuracy_history(loss_history, accuracy_history):
    """
    绘制测试损失和准确率随联邦学习轮次的变化图。

    参数:
    - loss_history: 损失历史列表。
    - accuracy_history: 准确率历史列表。
    """
    x_values = list(range(1, len(loss_history) + 1))
    y_values_loss = loss_history
    y_values_accuracy = [acc * 100 for acc in accuracy_history]

    plt.figure(figsize=(12, 6))

    plt.subplot(1, 2, 1)
    plt.plot(x_values, y_values_loss)
    plt.xlabel('Federated Learning Rounds')
    plt.ylabel('Test Loss')
    plt.title('Test Loss vs Federated Learning Rounds')
    plt.grid(True)

    plt.subplot(1, 2, 2)
    plt.plot(x_values, y_values_accuracy)
    plt.xlabel('Federated Learning Rounds')
    plt.ylabel('Test Accuracy (%)')
    plt.title('Test Accuracy vs Federated Learning Rounds')
    plt.grid(True)

    plt.tight_layout()
    plt.show()

    final_accuracy = accuracy_history[-1]
    print(f'Final Test Accuracy: {final_accuracy * 100:.2f} %')

transform = transforms.Compose([
    transforms.ToTensor(),
])

# 加载 MNIST 数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# 调用函数
client_subsets = distribute_data_to_clients(train_dataset, num_clients=10)

# 创建测试数据加载器
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

# 调用 federated_learning 函数时传入超参数字典
hyperparams = {
    'num_rounds': 10,
    'learning_rate': 0.01,
    'batch_size': 64,
    'epochs_per_client': 5,
    'lambda_ewc': 0.01
}

loss_history, accuracy_history, final_accuracy = federated_learning(
    global_model=create_model(),
    client_subsets=client_subsets,
    criterion=nn.CrossEntropyLoss(),
    hyperparams=hyperparams,
    test_loader=test_loader
)


Original dataset label counts:
Label 5: 5421 samples
Label 0: 5923 samples
Label 4: 5842 samples
Label 1: 6742 samples
Label 9: 5949 samples
Label 2: 5958 samples
Label 3: 6131 samples
Label 6: 5918 samples
Label 7: 6265 samples
Label 8: 5851 samples
Client 0 label counts:
[593, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Client 1 label counts:
[593, 750, 0, 0, 0, 0, 0, 0, 0, 0]
Client 2 label counts:
[593, 750, 745, 0, 0, 0, 0, 0, 0, 0]
Client 3 label counts:
[593, 750, 745, 876, 0, 0, 0, 0, 0, 0]
Client 4 label counts:
[593, 750, 745, 876, 974, 0, 0, 0, 0, 0]
Client 5 label counts:
[593, 750, 745, 876, 974, 1085, 0, 0, 0, 0]
Client 6 label counts:
[593, 750, 745, 876, 974, 1085, 1480, 0, 0, 0]
Client 7 label counts:
[593, 750, 745, 876, 974, 1085, 1480, 2089, 0, 0]
Client 8 label counts:
[593, 750, 745, 876, 974, 1085, 1480, 2089, 2926, 0]
Client 9 label counts:
[586, 742, 743, 875, 972, 1081, 1478, 2087, 2925, 5949]
Round 1 Test Loss: 1.2427 and Test Accuracy: 66.58 %
Round 2 Test Loss: 0.9051 and T