In [21]:
# main.py
import torch
import numpy as np
from system.flcore.clients.clientEraser import Client
from system.flcore.servers.servereraser import Server
from system.utils.data_preprocess import *
from system.utils.model_initiation import *

class Args:
    def __init__(self):
        self.N_total_client = 100
        self.N_client = 10
        self.global_epoch = 20
        self.local_epoch = 5
        self.local_batch_size = 64
        self.local_lr = 0.01
        self.test_batch_size = 128
        self.cuda_state = torch.cuda.is_available()
        self.data_name = 'mnist'
        self.seed = 42
        

In [22]:
args = Args()
torch.manual_seed(args.seed)
# 使用 data_init 加载数据
client_loaders, test_loader = data_init(args)
# 初始化模型
global_model = model_init(args.data_name)
# 初始化客户端
clients = []
for i in range(args.N_total_client):
    model_copy = model_init(args.data_name)
    clients.append(Client(i, model_copy, client_loaders[i], args))
# 创建服务器
server = Server(global_model, clients, test_loader, args)


In [23]:
# 联邦训练主循环
for round in range(args.global_epoch):
    print(f"\n--- Global Round {round+1} ---")
    selected_clients = np.random.choice(clients, args.N_client, replace=False)

    server.distribute_model(selected_clients)
    for client in selected_clients:
        client.train()

    server.aggregate_models(selected_clients)
    server.evaluate()


--- Global Round 1 ---
[Server] Test Accuracy: 0.6548

--- Global Round 2 ---
[Server] Test Accuracy: 0.8119

--- Global Round 3 ---
[Server] Test Accuracy: 0.8606

--- Global Round 4 ---
[Server] Test Accuracy: 0.8846

--- Global Round 5 ---
[Server] Test Accuracy: 0.8988

--- Global Round 6 ---
[Server] Test Accuracy: 0.9090

--- Global Round 7 ---
[Server] Test Accuracy: 0.9193

--- Global Round 8 ---
[Server] Test Accuracy: 0.9205

--- Global Round 9 ---
[Server] Test Accuracy: 0.9313

--- Global Round 10 ---
[Server] Test Accuracy: 0.9373

--- Global Round 11 ---
[Server] Test Accuracy: 0.9404

--- Global Round 12 ---
[Server] Test Accuracy: 0.9441

--- Global Round 13 ---
[Server] Test Accuracy: 0.9483

--- Global Round 14 ---
[Server] Test Accuracy: 0.9507

--- Global Round 15 ---
[Server] Test Accuracy: 0.9537

--- Global Round 16 ---
[Server] Test Accuracy: 0.9545

--- Global Round 17 ---
[Server] Test Accuracy: 0.9547

--- Global Round 18 ---
[Server] Test Accuracy: 0.9577



In [None]:
def fed_unlearning(global_models, client_models, forget_client_idx, FL_params):
    """
    FedEraser 遗忘：从训练轨迹中移除指定客户端的影响
    global_models: List[global_model_t] 每一轮的全局模型
    client_models: List[client_model_i_t] 所有客户端每轮模型（顺序：轮数 * 客户端数量）
    """
    num_clients = FL_params.N_client
    global_model = copy.deepcopy(global_models[-1])  # 当前模型
    device = torch.device("cuda" if FL_params.use_gpu else "cpu")
    global_model.to(device)

    for t in range(FL_params.global_epoch):
        # 当前轮中，被遗忘客户端的模型参数
        idx = t * num_clients + forget_client_idx
        delta = {}

        for name, param in global_model.named_parameters():
            param.data = param.data.clone()

            # 获取对应轮次全局模型和被遗忘客户端模型
            global_param_prev = global_models[t].state_dict()[name].to(device)
            client_param = client_models[idx].state_dict()[name].to(device)

            # 差值传播公式：从全局模型中移除该客户端对模型的贡献
            delta[name] = global_param_prev - client_param
            param.data += delta[name] / num_clients  # 反向加回被移除的那份

    print(f"[FedEraser] 已移除 Client {forget_client_idx} 的影响。")
    return global_model.cpu()
