In [1]:
import torch
from torch.nn.functional import softmax
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import argparse
from torch.utils.data import DataLoader, Dataset, TensorDataset
from torchvision import datasets, transforms
import copy
from sklearn.metrics import accuracy_score
import numpy as np
import time 

import pandas as pd
from sklearn.preprocessing import LabelEncoder,OneHotEncoder,MinMaxScaler
from sklearn.compose import ColumnTransformer
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
import xgboost as xgb
from xgboost import XGBClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score

In [2]:
# 设定超参数
class Arguments():
    def __init__(self):
        self.N_total_client = 100  # 总客户端数量
        self.N_client = 10  # 参与训练的客户端数量
        # 可选数据集，例如 (purchase, cifar10, mnist, adult)
        self.data_name = 'mnist'  
        # 全局训练周期数
        self.global_epoch = 20  
        # 局部训练周期数
        self.local_epoch = 10  
        # 批处理大小
        self.local_batch_size = 64  # 局部批处理大小
        self.test_batch_size = 64  # 测试批处理大小
        # 学习率
        self.local_lr = 0.005  # 局部学习率
        
        self.seed = 1  # 随机种子
        self.save_all_model = True  # 保存所有模型
        self.cuda_state = torch.cuda.is_available()  # 检查CUDA是否可用
        self.use_gpu = True  # 使用GPU
        self.train_with_test = False  # 用测试集训练
        
        # 联邦遗忘设置
        self.unlearn_interval= 1  # 遗忘间隔
        # 用于控制模型参数保存的轮数。1表示每轮保存一次参数，N_itv是我们论文中的表示。
        self.forget_client_idx = 2  
        # 如果想要遗忘，将None更改为客户端索引
        # 如果此参数设置为False，仅在最终训练完成后输出全局模型
        # 如果设置为True，则使用FL-Retrain函数重新训练全局模型，并丢弃与forget_client_IDx号用户对应的数据。
        self.if_retrain = False  
        # 如果设置为False，global_train_once函数在训练过程中不会跳过需要遗忘的用户；如果设置为True，global_train_once在训练过程中跳过遗忘的用户
        self.if_unlearning = False  
        self.forget_local_epoch_ratio = 0.5  
        # 当选择遗忘用户时，
        # 其他用户需要在各自的数据集中进行几轮在线训练，以获得模型收敛的一般方向，
        # 以便为模型收敛提供一般方向。
        # forget_local_epoch_ratio*local_epoch是我们需要获得每个局部模型收敛方向时的局部训练轮数
        # self.mia_oldGM = False


In [3]:
# step1. super-parameter setting
FL_params = Arguments()
torch.manual_seed(FL_params.seed)
print(60*'=')
print("Step1. Federated Learning Settings \n We use dataset: "+FL_params.data_name+(" for our Federated Unlearning experiment.\n"))

Step1. Federated Learning Settings 
 We use dataset: mnist for our Federated Unlearning experiment.



In [4]:
print(60*'=')
print("Step2. Client data loaded, testing data loaded!!!\n       Initial Model loaded!!!")

Step2. Client data loaded, testing data loaded!!!
       Initial Model loaded!!!


In [5]:
# 模型初始化
def model_init(data_name):
    if(data_name == 'mnist'):
        model = Net_mnist()
    elif(data_name == 'cifar10'):
        model = Net_cifar10()
    elif(data_name == 'purchase'):
        model = Net_purchase()
    elif(data_name == 'adult'):
        model = Net_adult()
    return model

# 以下是不同数据集的模型框架
class Net_mnist(nn.Module):
    def __init__(self):
        super(Net_mnist, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
class Net_purchase(nn.Module):
    def __init__(self):
        super(Net_purchase, self).__init__()
        self.fc1 = nn.Linear(600, 300)
        self.fc2 = nn.Linear(300, 50)
        self.fc3 = nn.Linear(50, 2)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return x   
    

class Net_adult(nn.Module):
    def __init__(self):
        super(Net_adult, self).__init__()
        self.fc1 = nn.Linear(108, 50)
        self.fc2 = nn.Linear(50, 10)
        self.fc3 = nn.Linear(10, 2)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return x     



class Net_cifar10(nn.Module):
    def __init__(self):
        super(Net_cifar10, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [6]:
# 初始化模型
init_global_model = model_init(FL_params.data_name)
init_global_model

Net_mnist(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)

In [7]:
def data_init(FL_params):
    
    kwargs = {'num_workers': 0, 'pin_memory': True} if FL_params.cuda_state else {}
    trainset, testset = data_set(FL_params.data_name) # 获得训练集、测试集
    
    # 构建测试数据加载器
    test_loader = DataLoader(testset, batch_size=FL_params.test_batch_size, shuffle=True, **kwargs)                
    
    # 将数据按照训练的trainset，均匀的分配成N-client份，所有分割得到dataset都保存在一个list中
    split_index = [int(trainset.__len__()/FL_params.N_total_client)]*(FL_params.N_total_client-1) 
    split_index.append(int(trainset.__len__() - int(trainset.__len__()/FL_params.N_total_client)*(FL_params.N_total_client-1)))
    client_dataset = torch.utils.data.random_split(trainset, split_index)
    
    # 将全局模型复制N-client次，然后构建每一个client模型的优化器，参数记录   
    client_loaders = []
    for tt in range(FL_params.N_total_client):
        client_loaders.append(DataLoader(client_dataset[tt], FL_params.local_batch_size, shuffle=True, **kwargs))
    
    return client_loaders, test_loader


def data_set(data_name):
    if not data_name in ['mnist','purchase','adult','cifar10']:
        raise TypeError('data_name should be a string, including mnist,purchase,adult,cifar10. ')
    
    # model: 2 conv. layers followed by 2 FC layers
    if data_name == 'mnist':
        trainset = datasets.MNIST('./data', train=True, download=True,
                       transform = transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))

        testset = datasets.MNIST('./data', train=False, download=True,
                       transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
        
    # model: ResNet-50
    elif data_name == 'cifar10':
        transform = transforms.Compose(
                                        [transforms.ToTensor(),
                                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                        ])
        
        trainset = datasets.CIFAR10(root='./data', 
                                    train=True,
                                    download=True, 
                                    transform=transform)
        
        testset = datasets.CIFAR10(root='./data', 
                                   train=False,
                                   download=True, 
                                   transform=transform)
        
        
    # model: 2 FC layers
    elif data_name == 'purchase':
        xx = np.load("./data/purchase/purchase_xx.npy")
        yy = np.load("./data/purchase/purchase_y2.npy")
     
        X_train, X_test, y_train, y_test = train_test_split(xx, yy, test_size=0.2, random_state=42)
        
        X_train_tensor = torch.Tensor(X_train).type(torch.FloatTensor)
        X_test_tensor = torch.Tensor(X_test).type(torch.FloatTensor)
        y_train_tensor = torch.Tensor(y_train).type(torch.LongTensor)
        y_test_tensor = torch.Tensor(y_test).type(torch.LongTensor)
        
        trainset = TensorDataset(X_train_tensor,y_train_tensor)
        testset = TensorDataset(X_test_tensor,y_test_tensor)
    
    # model: 2 FC layers
    elif data_name == 'adult':
        #load data
        file_path = "./data/adult/"
        data1 = pd.read_csv(file_path + 'adult.data', header=None)
        data2 = pd.read_csv(file_path + 'adult.test', header=None)
        data2 = data2.replace(' <=50K.', ' <=50K')    
        data2 = data2.replace(' >50K.', ' >50K')
        train_num = data1.shape[0]
        data = pd.concat([data1,data2])
        #data transform: str->int
        data = np.array(data, dtype=str)
        labels = data[:,14]
        le= LabelEncoder()
        le.fit(labels)
        labels = le.transform(labels)
        data = data[:,:-1]
        
        categorical_features = [1,3,5,6,7,8,9,13]
        # categorical_names = {}
        for feature in categorical_features:
            le = LabelEncoder()
            le.fit(data[:, feature])
            data[:, feature] = le.transform(data[:, feature])
            # categorical_names[feature] = le.classes_
        data = data.astype(float)
        
        n_features = data.shape[1]
        numerical_features = list(set(range(n_features)).difference(set(categorical_features)))
        for feature in numerical_features:
            scaler = MinMaxScaler()
            sacled_data = scaler.fit_transform(data[:,feature].reshape(-1,1))
            data[:,feature] = sacled_data.reshape(-1)
        
        #OneHotLabel
        oh_encoder = ColumnTransformer(
            [('oh_enc', OneHotEncoder(sparse=False), categorical_features),], 
            remainder='passthrough' )
        oh_data = oh_encoder.fit_transform(data)
        
        xx = oh_data
        yy = labels
        #最终处理，xx进行规范化
        xx = preprocessing.scale(xx)
        yy = np.array(yy)
        xx = torch.Tensor(xx).type(torch.FloatTensor)
        yy = torch.Tensor(yy).type(torch.LongTensor)
        xx_train = xx[0:data1.shape[0],:]
        xx_test = xx[data1.shape[0]:,:]
        yy_train = yy[0:data1.shape[0]]
        yy_test = yy[data1.shape[0]:]

        trainset = TensorDataset(xx_train,yy_train)
        testset = TensorDataset(xx_test,yy_test)
        
    return trainset, testset

In [8]:
# 划分每一个client的数据集
client_all_loaders, test_loader = data_init(FL_params)

In [9]:
# 随机抽取client
selected_clients = np.random.choice(range(FL_params.N_total_client),size=FL_params.N_client, replace=False)
selected_clients

array([ 5, 58, 66, 72, 11, 14, 73, 20,  0, 99])

In [10]:
client_loaders = list()
for idx in selected_clients:
    client_loaders.append(client_all_loaders[idx])

In [11]:
print(60*'=')
print("Step3. Fedearated Learning and Unlearning Training...")

Step3. Fedearated Learning and Unlearning Training...


In [12]:
def test(model, test_loader):
    model.eval()
    test_loss = 0
    test_acc = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            criteria = nn.CrossEntropyLoss()
            test_loss += criteria(output, target) # sum up batch loss
            
            pred = torch.argmax(output,axis=1)
            test_acc += accuracy_score(pred,target)
        
    test_loss /= len(test_loader.dataset)
    test_acc = test_acc/np.ceil(len(test_loader.dataset)/test_loader.batch_size)
    print('Test set: Average loss: {:.8f}'.format(test_loss))         
    print('Test set: Average acc:  {:.4f}'.format(test_acc))    
    return (test_loss, test_acc)

In [13]:

def global_train_once(global_model, client_data_loaders, test_loader, FL_params):
    device = torch.device("cuda" if FL_params.use_gpu*FL_params.cuda_state else "cpu")
    device_cpu = torch.device("cpu")
    
    client_models = []
    client_sgds = []
    for tt in range(FL_params.N_client):
        client_models.append(copy.deepcopy(global_model))
        client_sgds.append(optim.SGD(client_models[tt].parameters(), lr=FL_params.local_lr, momentum=0.9))
    
    for client_idx in range(FL_params.N_client):
        if(((FL_params.if_retrain) and (FL_params.forget_client_idx == client_idx)) or ((FL_params.if_unlearning) and (FL_params.forget_client_idx == client_idx))):
            continue

        model = client_models[client_idx]
        optimizer = client_sgds[client_idx]
        
        model.to(device)
        model.train()
        
        #local training
        for local_epoch in range(FL_params.local_epoch):
            for batch_idx, (data, target) in enumerate(client_data_loaders[client_idx]):
                data = data.to(device)
                target = target.to(device)
                
                optimizer.zero_grad()
                pred = model(data)
                criteria = nn.CrossEntropyLoss()
                loss = criteria(pred, target)
                loss.backward()
                optimizer.step()
                
            if(FL_params.train_with_test):
                print("Local Client No. {}, Local Epoch: {}".format(client_idx, local_epoch))
                print("Loss {}, Acc{}".format(test(model, test_loader)))
        
        model.to(device_cpu)
        client_models[client_idx] = model
        
    if(((FL_params.if_retrain) and (FL_params.forget_client_idx == client_idx))):
        client_models.pop(FL_params.forget_client_idx)
        return client_models
    
    elif((FL_params.if_unlearning) and (FL_params.forget_client_idx in range(FL_params.N_client))):
        client_models.pop(FL_params.forget_client_idx)
        return client_models
    
    else:
        return client_models

In [14]:
def fedavg(local_models):
    # 创建一个新的模型实例，该实例是输入列表中第一个局部模型的深度副本
    global_model = copy.deepcopy(local_models[0])
    # 获取 global_model 的状态字典，它包含模型的参数
    avg_state_dict = global_model.state_dict()
    
    # 初始化一个空列表，用于存储所有局部模型的状态字典
    local_state_dicts = list()
    for model in local_models:
        # 将每个局部模型的状态字典添加到 local_state_dicts 列表中
        local_state_dicts.append(model.state_dict())
    
    # 遍历 avg_state_dict 的所有层（即模型参数）
    for layer in avg_state_dict.keys():
        # 将全局模型的每一层参数初始化为0
        avg_state_dict[layer] *= 0 
        for client_idx in range(len(local_models)):
            # 将每个客户端模型的相应层参数加到全局模型的相应层
            avg_state_dict[layer] += local_state_dicts[client_idx][layer]
        # 对每一层，计算所有客户端模型的该层参数的平均值
        avg_state_dict[layer] /= len(local_models)
    
    # 将计算得到的平均参数加载到 global_model 中
    global_model.load_state_dict(avg_state_dict)
    # 返回具有平均参数的全局模型
    return global_model


In [15]:
def FL_Train(init_global_model, client_data_loaders, test_loader, FL_params):
    if(FL_params.if_retrain == True):
        raise ValueError('FL_params.if_retrain should be set to False, if you want to train, not retrain FL model')
    
    if(FL_params.if_unlearning == True):
        raise ValueError('FL_params.if_unlearning should be set to False, if you want to train, not unlearning FL model')
    
    all_global_models = list()
    all_client_models = list()
    global_model = init_global_model
    
    all_global_models.append(copy.deepcopy(global_model))
    for epoch in range(FL_params.global_epoch):
        client_models = global_train_once(global_model, client_data_loaders, test_loader, FL_params)
        
        all_client_models += client_models
        global_model = fedavg(client_models)
        print("Global Federated Learning epoch = {}".format(epoch))

        all_global_models.append(copy.deepcopy(global_model))
        
    return all_global_models, all_client_models

In [16]:
def unlearning_step_once(old_client_models, new_client_models, global_model_before_forget, global_model_after_forget):
    # 初始化字典，用于存储旧模型和新模型的参数更新
    old_param_update = dict()  # 模型参数： oldCM - oldGM_t
    new_param_update = dict()  # 模型参数： newCM - newGM_t
    
    # 获取遗忘后的全局模型的状态字典：newGM_t
    new_global_model_state = global_model_after_forget.state_dict()  
    
    # 初始化返回模型状态字典，用于计算新的全局模型参数
    return_model_state = dict()  # newGM_t + ||oldCM - oldGM_t||*(newCM - newGM_t)/||newCM - newGM_t||
    
    # 确保旧模型列表和新模型列表的长度相同
    assert len(old_client_models) == len(new_client_models)
    
    # 遍历全局模型（遗忘前）的所有层
    for layer in global_model_before_forget.state_dict().keys():
        # 初始化旧模型和新模型的参数更新字典，以及返回模型状态字典
        old_param_update[layer] = 0*global_model_before_forget.state_dict()[layer]
        new_param_update[layer] = 0*global_model_before_forget.state_dict()[layer]
        return_model_state[layer] = 0*global_model_before_forget.state_dict()[layer]
        
        # 遍历所有新模型，累加每层的参数
        for tt in range(len(new_client_models)):
            old_param_update[layer] += old_client_models[tt].state_dict()[layer]
            new_param_update[layer] += new_client_models[tt].state_dict()[layer]
        # 计算每层参数的平均值：oldCM 和 newCM
        old_param_update[layer] /= (tt+1)  
        new_param_update[layer] /= (tt+1)  
        
        # 计算参数更新：oldCM - oldGM_t 和 newCM - newGM_t
        old_param_update[layer] = old_param_update[layer] - global_model_before_forget.state_dict()[layer]
        new_param_update[layer] = new_param_update[layer] - global_model_after_forget.state_dict()[layer]
        
        # 计算步长 ||oldCM - oldGM_t|| 和步方向 (newCM - newGM_t)/||newCM - newGM_t||
        step_length = torch.norm(old_param_update[layer])
        step_direction = new_param_update[layer]/torch.norm(new_param_update[layer])
        
        # 根据公式计算新的全局模型参数：newGM_t + ||oldCM - oldGM_t||*(newCM - newGM_t)/||newCM - newGM_t||
        return_model_state[layer] = new_global_model_state[layer] + step_length*step_direction
    
    # 创建一个新的全局模型的深拷贝，以避免修改原始模型
    return_global_model = copy.deepcopy(global_model_after_forget)
    
    # 加载计算得到的新全局模型参数
    return_global_model.load_state_dict(return_model_state)
    
    # 返回新的全局模型
    return return_global_model


In [17]:
def unlearning(old_GMs, old_CMs, client_data_loaders, test_loader, FL_params):
    # 检查是否设置了遗忘选项，如果没有则抛出错误
    if(FL_params.if_unlearning == False):
        raise ValueError('FL_params.if_unlearning should be set to True, if you want to unlearning with a certain user')
    
    # 检查遗忘客户端索引是否在正确的范围内，如果不是则抛出错误
    if(not(FL_params.forget_client_idx in range(FL_params.N_client))):
        raise ValueError('FL_params.forget_client_idx is note assined correctly, forget_client_idx should in {}'.format(range(FL_params.N_client)))
    
    # 检查遗忘间隔是否设置正确，如果不是则抛出错误
    if(FL_params.unlearn_interval == 0 or FL_params.unlearn_interval >FL_params.global_epoch):
        raise ValueError('FL_params.unlearn_interval should not be 0, or larger than the number of FL_params.global_epoch')
    
    # 创建旧模型的深拷贝，以避免修改原始模型
    old_global_models = copy.deepcopy(old_GMs)
    old_client_models = copy.deepcopy(old_CMs)
    
    forget_client = FL_params.forget_client_idx  # 获取遗忘客户端索引
    
    # 遍历所有全局训练轮次
    for tt in range(FL_params.global_epoch):
        # 提取当前轮次的所有客户端模型
        temp = old_client_models[tt * FL_params.N_client : tt * FL_params.N_client + FL_params.N_client]
        # 从列表中删除遗忘客户端的模型
        temp.pop(forget_client)
        # 将修改后的模型列表追加到旧客户端模型列表中
        old_client_models.append(temp)
    # 只保留最新的客户端模型列表
    old_client_models = old_client_models[-FL_params.global_epoch:]
    
    # 创建全局和客户端模型的选定索引，以便在遗忘过程中选择模型
    GM_intv = np.arange(0,FL_params.global_epoch+1, FL_params.unlearn_interval, dtype=np.int16())
    CM_intv  = GM_intv -1
    CM_intv = CM_intv[1:]
    
    # 选择要用于遗忘过程的全局和客户端模型
    selected_GMs = [old_global_models[tt] for tt in GM_intv]
    selected_CMs = [old_client_models[jj] for jj in CM_intv]
    
    # 步骤1：从初始模型开始，完成第一轮全局训练的模型叠加
    epoch = 0
    unlearn_global_models = list()  # 初始化一个空列表，用于存储遗忘过程中的全局模型
    unlearn_global_models.append(copy.deepcopy(selected_GMs[0]))  # 将初始全局模型添加到列表中
    
    # 使用联邦平均算法计算新的全局模型
    new_global_model = fedavg(selected_CMs[epoch])
    unlearn_global_models.append(copy.deepcopy(new_global_model))
    print("Federated Unlearning Global Epoch  = {}".format(epoch))
    
    # 步骤2：以第一轮全局模型为起点，逐渐纠正模型
    # 此步骤中，将第一轮全局训练的全局模型用作新训练的起点，并通过保留用户的数据进行少量训练，以获取每个保留用户的局部模型参数的迭代方向。
    # 然后，使用标准FL训练中保存的旧客户端模型和旧全局模型，以及遗忘用户时获得的新客户端模型和新全局模型，构建下一轮的全局模型。
    # 保存原始的局部和全局训练轮次设置
    CONST_local_epoch = copy.deepcopy(FL_params.local_epoch)
    FL_params.local_epoch = np.ceil(FL_params.local_epoch*FL_params.forget_local_epoch_ratio)
    FL_params.local_epoch = np.int16(FL_params.local_epoch)
    
    CONST_global_epoch = copy.deepcopy(FL_params.global_epoch)
    FL_params.global_epoch = CM_intv.shape[0]
    
    print('Local Calibration Training epoch = {}'.format(FL_params.local_epoch))
    for epoch in range(FL_params.global_epoch):
        if(epoch == 0):
            continue  # 跳过第0轮，因为已经在步骤1中完成了
        print("Federated Unlearning Global Epoch  = {}".format(epoch))
        global_model = unlearn_global_models[epoch]  # 获取当前轮次的全局模型
        
        # 对每个客户端执行一次全局训练
        new_client_models  = global_train_once(global_model, client_data_loaders, test_loader, FL_params)
        
        # 执行一次遗忘步骤，以计算新的全局模型
        new_GM = unlearning_step_once(selected_CMs[epoch], new_client_models, selected_GMs[epoch+1], global_model)
        
        # 将新的全局模型添加到列表中
        unlearn_global_models.append(new_GM)
    
    # 恢复原始的局部和全局训练轮次设置
    FL_params.local_epoch = CONST_local_epoch
    FL_params.global_epoch = CONST_global_epoch
    
    # 返回遗忘过程中的全局模型列表
    return unlearn_global_models

In [18]:

def unlearning_without_cali(old_global_models, old_client_models, FL_params):
    if(FL_params.if_unlearning == False):
        raise ValueError('FL_params.if_unlearning should be set to True, if you want to unlearning with a certain user')
        
    if(not(FL_params.forget_client_idx in range(FL_params.N_client))):
        raise ValueError('FL_params.forget_client_idx is note assined correctly, forget_client_idx should in {}'.format(range(FL_params.N_client)))
    
    forget_client = FL_params.forget_client_idx
    
    for tt in range(FL_params.global_epoch):
        temp = old_client_models[tt*FL_params.N_client : tt*FL_params.N_client+FL_params.N_client]
        temp.pop(forget_client)
        old_client_models.append(temp)
    old_client_models = old_client_models[-FL_params.global_epoch:]
    
    uncali_global_models = list()
    uncali_global_models.append(copy.deepcopy(old_global_models[0]))
    epoch = 0
    uncali_global_model = fedavg(old_client_models[epoch])
    uncali_global_models.append(copy.deepcopy(uncali_global_model))
    print("Federated Unlearning without Clibration Global Epoch  = {}".format(epoch))
    
    old_param_update = dict()
    return_model_state = dict()
    
    for epoch in range(FL_params.global_epoch):
        if(epoch == 0):
            continue
        print("Federated Unlearning Global Epoch  = {}".format(epoch))
        
        current_global_model = uncali_global_models[epoch]
        current_client_models = old_client_models[epoch]
        old_global_model = old_global_models[epoch]  
        
        for layer in current_global_model.state_dict().keys():
            old_param_update[layer] = 0*current_global_model.state_dict()[layer]
            return_model_state[layer] = 0*current_global_model.state_dict()[layer]
            
            for tt in range(len(current_client_models)):
                old_param_update[layer] += current_client_models[tt].state_dict()[layer]
            old_param_update[layer] /= (tt+1)
            
            old_param_update[layer] = old_param_update[layer] - old_global_model.state_dict()[layer]

            return_model_state[layer] = current_global_model.state_dict()[layer] + old_param_update[layer]
            
        return_global_model = copy.deepcopy(old_global_models[0])
        return_global_model.load_state_dict(return_model_state)
            
        uncali_global_models.append(return_global_model)

    return uncali_global_models

In [19]:
def federated_learning_unlearning(init_global_model, client_loaders, test_loader, FL_params):
    """
    初始化全局模型， 
    用户端数据集
    测试集
    联邦学习参数
    """
    
    # Fed Training
    print(5*"#" + "Federated Learning Start" + 5*"#")
    std_time = time.time()
    old_GMs, old_CMs = FL_Train(init_global_model, client_loaders, test_loader, FL_params)
    end_time = time.time()
    time_learn = (std_time - end_time)
    print(5*"#"+"Federated Learning End"+5*"#")
    
    print(5*"#"+"Federated Unlearning Start"+5*"#")
    std_time = time.time()
    FL_params.if_unlearning = True
    FL_params.forget_client_idx = 2
    unlearn_GMs = unlearning(old_GMs, old_CMs, client_loaders, test_loader, FL_params) # all_global all_client
    end_time = time.time()
    time_unlearn = (std_time - end_time)
    print(5*"#"+"Federated Unlearning End"+5*"#")
    
    print(5*"#"+"Federated Unlearning without Calibration Start"+5*"#")
    std_time = time.time()
    uncali_unlearn_GMs = unlearning_without_cali(old_GMs, old_CMs, FL_params)
    end_time = time.time()
    time_unlearn_no_cali = (std_time - end_time)
    print(5*"#"+"Federated Unlearning without Calibration End"+5*"#")
    
    print("Learning time consuming = {} secods".format(-time_learn))
    print("Unlearning time consuming = {} secods".format(-time_unlearn)) 
    print("Unlearning no Cali time consuming = {} secods".format(-time_unlearn_no_cali)) 
    
    return old_GMs, unlearn_GMs, uncali_unlearn_GMs, old_CMs

In [20]:
old_GMs, unlearn_GMs, uncali_unlearn_GMs, _ = \
    federated_learning_unlearning(
                                 init_global_model, 
                                 client_loaders, 
                                 test_loader, 
                                 FL_params)

#####Federated Learning Start#####
Global Federated Learning epoch = 0
Global Federated Learning epoch = 1
Global Federated Learning epoch = 2
Global Federated Learning epoch = 3
Global Federated Learning epoch = 4
Global Federated Learning epoch = 5
Global Federated Learning epoch = 6
Global Federated Learning epoch = 7
Global Federated Learning epoch = 8
Global Federated Learning epoch = 9
Global Federated Learning epoch = 10
Global Federated Learning epoch = 11
Global Federated Learning epoch = 12
Global Federated Learning epoch = 13
Global Federated Learning epoch = 14
Global Federated Learning epoch = 15
Global Federated Learning epoch = 16
Global Federated Learning epoch = 17
Global Federated Learning epoch = 18
Global Federated Learning epoch = 19
#####Federated Learning End#####
#####Federated Unlearning Start#####
Federated Unlearning Global Epoch  = 0
Local Calibration Training epoch = 5
Federated Unlearning Global Epoch  = 1
Federated Unlearning Global Epoch  = 2
Federated U

In [21]:
def FL_Retrain(init_global_model, client_data_loaders, test_loader, FL_params):
    if(FL_params.if_retrain == False):
        raise ValueError('FL_params.if_retrain should be set to True, if you want to retrain FL model')
    
    if(FL_params.forget_client_idx not in range(FL_params.N_client)):
        raise ValueError('FL_params.forget_client_idx should be in [{}], if you want to use standard FL train with forget the certain client dataset.'.format(range(FL_params.N_client)))
    
    print(5*"#"+"Federated Retraining Start"+5*"#")
    print("Federated Retrain with Forget Client NO.{}".format(FL_params.forget_client_idx))
    retrain_GMs = list()
    all_client_models = list()
    retrain_GMs.append(copy.deepcopy(init_global_model))
    global_model = init_global_model
    for epoch in range(FL_params.global_epoch):
        client_models = global_train_once(global_model, client_data_loaders, test_loader, FL_params)
        global_model = fedavg(client_models)
        print("Global Retraining epoch = {}".format(epoch))
        retrain_GMs.append(copy.deepcopy(global_model))
        all_client_models += client_models
    print(5*"#"+"Federated Retraining End"+5*"#")
    
    return retrain_GMs

In [23]:
FL_params.if_retrain = True

In [24]:
if(FL_params.if_retrain == True):
    t1 = time.time()
    retrain_GMs = FL_Retrain(init_global_model, client_loaders, test_loader, FL_params)
    t2 = time.time()
    print("Time using = {} seconds".format(t2-t1))

#####  Federated Retraining Start  #####
Federated Retrain with Forget Client NO.2
Global Retraining epoch = 0
Global Retraining epoch = 1
Global Retraining epoch = 2
Global Retraining epoch = 3
Global Retraining epoch = 4
Global Retraining epoch = 5
Global Retraining epoch = 6
Global Retraining epoch = 7
Global Retraining epoch = 8
Global Retraining epoch = 9
Global Retraining epoch = 10
Global Retraining epoch = 11
Global Retraining epoch = 12
Global Retraining epoch = 13
Global Retraining epoch = 14
Global Retraining epoch = 15
Global Retraining epoch = 16
Global Retraining epoch = 17
Global Retraining epoch = 18
Global Retraining epoch = 19
#####Federated Retraining End#####
Time using = 66.27519679069519 seconds


In [25]:
def train_attack_model(shadow_old_GM, shadow_client_loaders, shadow_test_loader, FL_params):
    shadow_model = shadow_old_GM
    n_class_dict = dict()
    n_class_dict['adult'] = 2
    n_class_dict['purchase'] = 2
    n_class_dict['mnist'] = 10
    n_class_dict['cifar10'] = 10
    
    N_class = n_class_dict[FL_params.data_name]
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    shadow_model.to(device)
        
    shadow_model.eval()

    pred_4_mem = torch.zeros([1,N_class])
    pred_4_mem = pred_4_mem.to(device)
    with torch.no_grad():
        for tt in range(len(shadow_client_loaders)):
            data_loader = shadow_client_loaders[tt]
            
            for batch_idx, (data, target) in enumerate(data_loader):
                    data = data.to(device)
                    out = shadow_model(data)
                    pred_4_mem = torch.cat([pred_4_mem, out])
    pred_4_mem = pred_4_mem[1:,:]
    pred_4_mem = softmax(pred_4_mem,dim = 1)
    pred_4_mem = pred_4_mem.cpu()
    pred_4_mem = pred_4_mem.detach().numpy()
    
    pred_4_nonmem = torch.zeros([1,N_class])
    pred_4_nonmem = pred_4_nonmem.to(device)
    with torch.no_grad():
        for batch, (data, target) in enumerate(shadow_test_loader):
            data = data.to(device)
            out = shadow_model(data)
            pred_4_nonmem = torch.cat([pred_4_nonmem, out])
    pred_4_nonmem = pred_4_nonmem[1:,:]
    pred_4_nonmem = softmax(pred_4_nonmem,dim = 1)
    pred_4_nonmem = pred_4_nonmem.cpu()
    pred_4_nonmem = pred_4_nonmem.detach().numpy()
    
    
    #构建MIA 攻击模型 
    att_y = np.hstack((np.ones(pred_4_mem.shape[0]), np.zeros(pred_4_nonmem.shape[0])))
    att_y = att_y.astype(np.int16)
    
    att_X = np.vstack((pred_4_mem, pred_4_nonmem))
    att_X.sort(axis=1)
    
    X_train,X_test, y_train, y_test = train_test_split(att_X, att_y, test_size = 0.1)
    
    attacker = XGBClassifier(
                            n_estimators = 300,
                            n_jobs = -1,
                            max_depth = 30,
                            objective = 'binary:logistic',
                            booster="gbtree",
                            scale_pos_weight = pred_4_nonmem.shape[0]/pred_4_mem.shape[0]
                            )
    
    attacker.fit(X_train, y_train)

    return attacker

In [26]:
def attack(target_model, attack_model, client_loaders, test_loader, FL_params):
    n_class_dict = dict()
    n_class_dict['adult'] = 2
    n_class_dict['purchase'] = 2
    n_class_dict['mnist'] = 10
    n_class_dict['cifar10'] = 10
    
    N_class = n_class_dict[FL_params.data_name]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    target_model.to(device)
        
    target_model.eval()
    
    unlearn_X = torch.zeros([1,N_class])
    unlearn_X = unlearn_X.to(device)
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(client_loaders[FL_params.forget_client_idx]):
                    data = data.to(device)
                    out = target_model(data)
                    unlearn_X = torch.cat([unlearn_X, out])
                    
    unlearn_X = unlearn_X[1:,:]
    unlearn_X = softmax(unlearn_X,dim = 1)
    unlearn_X = unlearn_X.cpu().detach().numpy()
    
    unlearn_X.sort(axis=1)
    unlearn_y = np.ones(unlearn_X.shape[0])
    unlearn_y = unlearn_y.astype(np.int16)
    
    N_unlearn_sample = len(unlearn_y)
    
    test_X = torch.zeros([1, N_class])
    test_X = test_X.to(device)
    with torch.no_grad():
        for _, (data, target) in enumerate(test_loader):
            data = data.to(device)
            out = target_model(data)
            test_X = torch.cat([test_X, out])
            
            if(test_X.shape[0] > N_unlearn_sample):
                break
    test_X = test_X[1:N_unlearn_sample+1,:]
    test_X = softmax(test_X,dim = 1)
    test_X = test_X.cpu().detach().numpy()
    
    test_X.sort(axis=1)
    test_y = np.zeros(test_X.shape[0])
    test_y = test_y.astype(np.int16)
    
    XX = np.vstack((unlearn_X, test_X))
    YY = np.hstack((unlearn_y, test_y))
    
    pred_YY = attack_model.predict(XX)
    pre = precision_score(YY, pred_YY, pos_label=1)
    rec = recall_score(YY, pred_YY, pos_label=1)
    print("MIA Attacker precision = {:.4f}".format(pre))
    print("MIA Attacker recall = {:.4f}".format(rec))
    
    return (pre, rec)

In [27]:
print(60*'=')
print("Step4. Membership Inference Attack aganist GM...")
T_epoch = -1
old_GM = old_GMs[T_epoch]
attack_model = train_attack_model(old_GM, client_loaders, test_loader, FL_params)


print("\nEpoch  = {}".format(T_epoch))
print("Attacking against FL Standard  ")

target_model = old_GMs[T_epoch]
(ACC_old, PRE_old) = attack(target_model, attack_model, client_loaders, test_loader, FL_params)

if(FL_params.if_retrain == True):
    print("Attacking against FL Retrain  ")
    target_model = retrain_GMs[T_epoch]
    (ACC_retrain, PRE_retrain) = attack(target_model, attack_model, client_loaders, test_loader, FL_params)
    
print("Attacking against FL Unlearn  ")
target_model = unlearn_GMs[T_epoch]
(ACC_unlearn, PRE_unlearn) = attack(target_model, attack_model, client_loaders, test_loader, FL_params)

Step4. Membership Inference Attack aganist GM...

Epoch  = -1
Attacking against FL Standard  
MIA Attacker precision = 0.9699
MIA Attacker recall = 0.9133
Attacking against FL Retrain  
MIA Attacker precision = 0.6615
MIA Attacker recall = 0.4233
Attacking against FL Unlearn  
MIA Attacker precision = 0.4778
MIA Attacker recall = 0.3583
