In [1]:
import torch
from dataloder import *
from argument import *
from model import *
from pretrain import *
from utils import *
from parllutils import *
from functorch import vmap
args = argument()
device = 'cuda'
args

Namespace(adv='PGD', batchsize=128, dataset='binaryMnist', deletebatch=1, deletenum=0, isbatch=False, iterneumann=3, lam=0.0001, model='logistic', parllsize=128, remove_type=2, times=0)

In [2]:
delete_num = 600 # 共删除600个点
delete_batch = 1 # 每次删1个点
pass_batch = args.parllsize # batch_size 并行计算 total hessian
delete_num, delete_batch, pass_batch

(600, 1, 128)

## Pre-processing
### 1) load data

In [3]:
train_data, test_data, re_sequence = Load_Data(args, delete_num, shuffle=True)
train_loader = make_loader(train_data, batch_size=args.batchsize)
test_loader = make_loader(test_data, batch_size=args.batchsize)
print(f"total number of train data: {len(train_data[0])}, test data: {len(test_data[0])}")

train labels: tensor([1, 1, 1,  ..., 1, 7, 1])
total number of train data: 13007, test data: 2163


### 2) load adversarially trained model (original model w*)

In [4]:
# # atk training (`torchattacks` package)
# model, training_time = train(train_loader, test_loader, args, verbose=True)

In [5]:
# torch.save(model, "../data/ATM/dataset_binaryMnist_adv_PGD_model_logistic_method_MUter_sample_perturbed.pth") # 保存整个模型
# # torch.save(model.state_dict(), "../../data/Lenet-5_parameters.pth") # 推荐:仅保存训练模型的参数,为以后恢复模型提供最大的灵活性

In [6]:
model = torch.load("../data/ATM/dataset_binaryMnist_adv_PGD_model_logistic_method_MUter_sample_perturbed.pth")
matrix = load_memory_matrix(args, method='MUter').to(device)

loading memory matrix : ../data/MemoryMatrix/dataset_binaryMnist_adv_PGD_model_logistic_method_MUter_sample_perturbed.pt
matrix shape torch.Size([784, 784]), type <class 'torch.Tensor'>
done!


### 3) pre-unlearning

1) Calculate the related matrix 
2) Store the matrix to memory
3) delete the matrix variable

In [7]:
pass_loader = make_loader(train_data, batch_size=pass_batch)

Fisher_matrix_perturbed = parll_calculate_memory_matrix(model, pass_loader, args, method='Fisher')
Fisher_matrix_unperturbed = parll_calculate_memory_matrix(model, pass_loader, args, method='Fisher', isDelta=False)

store_memory_matrix(Fisher_matrix_perturbed, args, method='Fisher')
store_memory_matrix(Fisher_matrix_unperturbed, args, method='Fisher', isDelta=False)

del Fisher_matrix_perturbed
del Fisher_matrix_unperturbed

  0%|          | 0/102 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

matrix shape torch.Size([784, 784]), type <class 'torch.Tensor'>
memory matrix for Fisher method using perturb samples to calculate
saving matrix...
done!
matrix shape torch.Size([784, 784]), type <class 'torch.Tensor'>
memory matrix for Fisher method using un-perturb samples to calculate
saving matrix...
done!


In [8]:
Fisher_matrix_perturbed = load_memory_matrix(args, method='Fisher').to(device)
Fisher_matrix_unperturbed = load_memory_matrix(args, method='Fisher', isDelta=False).to(device)

loading memory matrix : ../data/MemoryMatrix/dataset_binaryMnist_adv_PGD_model_logistic_method_Fisher_sample_perturbed.pt
matrix shape torch.Size([784, 784]), type <class 'torch.Tensor'>
done!
loading memory matrix : ../data/MemoryMatrix/dataset_binaryMnist_adv_PGD_model_logistic_method_Fisher_sample_unperturbed.pt
matrix shape torch.Size([784, 784]), type <class 'torch.Tensor'>
done!


## Stage II: Unlearning
1) Inner level attack method;
2) Calculate the public part partial_xx and partial_xx_inv for linear model;
3) Init gradient information;

In [9]:
from torchattacks import PGD
import copy
from utils import cg_solve, model_distance, hessian, update_w, derive_inv
import time
from torch.utils.data import DataLoader, TensorDataset

# Inner level attack method
_, _, atk_info = training_param(args)
atk = PGD(model, atk_info[0], atk_info[1], atk_info[2], lossfun=LossFunction(args.model), lam=args.lam)

# Calculate the public part partial_xx and partial_xx_inv for linear model
feature = get_featrue(args)
weight = vec_param(model.parameters()).detach()

In [10]:
step = 1 # record unlearning times
## compare with removal list [1, 2, 3, 4, 5, ~1%, ~2%, ~3%, ~4%, ~5%] 
remove_list = None
if args.dataset == 'binaryMnist':
    remove_list = [1, 2, 3, 4, 5, 120, 240, 360, 480, 600]  # for mnist
elif args.dataset == 'phishing':
    remove_list = [1, 2, 3, 4, 5, 100, 200, 300, 400, 500]  # for phsihing
elif args.dataset == 'madelon':
    remove_list = [1, 2, 3, 4, 5, 20, 40, 60, 80, 100]  # for madelon
elif args.dataset == 'covtype':
    remove_list = [1, 2, 3, 4, 5, 5000, 10000, 15000, 20000, 25000]
elif args.dataset == 'epsilon':
    remove_list = [1, 2, 3, 4, 5, 4000, 8000, 12000, 16000, 20000]
else:
    remove_list = [1, 2, 3, 4, 5, 10, 20, 30, 40, 50]  # for splice


In [11]:
def test():
    retrain_loader = make_loader(train_data, batch_size=128, head=batch_delete_num)
    retrain_model, retrain_time = train(retrain_loader, test_loader, args, verbose=False)
    clean_acc, perturb_acc = Test_model(retrain_model, test_loader, args)
    

## the golden baseline: retrain_from_scratch

In [12]:
def retrain_from_scratch(batch_delete_num):
    # retrain_from_scratch
    retrain_loader = make_loader(train_data, batch_size=128, head=batch_delete_num)
    retrain_model, retrain_time = train(retrain_loader, test_loader, args, verbose=False)    
    clean_acc, perturb_acc = Test_model(retrain_model, test_loader, args)
    print()
    print('Retrain from scratch:')
    print(f"retrain_loader: train_data[{batch_delete_num}:]")
    print(f'retrain model test acc: clean_acc {clean_acc}, preturb_acc: {perturb_acc}')
    

## Newton and Newton_delta

In [13]:
def Fisher_delta(Fisher_matrix_perturbed, Fisher_delta_model, grad, test_loader):
    
    delta_w_Fisher_delta = cg_solve(Fisher_matrix_perturbed, grad.squeeze(dim=1), get_iters(args))
    update_w(delta_w_Fisher_delta, Fisher_delta_model)

    print('Fisher-delta unlearning:')
    clean_acc, perturb_acc = Test_model(Fisher_delta_model, test_loader, args)
    print(f'Fisher-delta test acc: clean_acc {clean_acc}, preturb_acc: {perturb_acc}')
#     Fisher_delta_clean_acc.append(clean_acc)
#     Fisher_delta_perturb_acc.append(perturb_acc)
    
#     model_dist = model_distance(retrain_model, Fisher_delta_model)
#     print('model norm distance: {:.4f}'.format(model_dist)
#     Fisher_delta_distance.append(model_dist)
    # Fisher_delta_original_distance.append(model_distance(model, Fisher_delta_model))
    # Fisher_delta_retrain_generating_samples_acc.append(test(Fisher_delta_model, test_loader=adv_loader))
    print()

In [14]:
def Fisher(Fisher_matrix_unperturbed, Fisher_model, clean_grad, test_loader):
    
    delta_w_Fisher = cg_solve(Fisher_matrix_unperturbed, clean_grad.squeeze(dim=1), get_iters(args))
    update_w(delta_w_Fisher, Fisher_model)

    print('Fisher unlearning')
    clean_acc, perturb_acc = Test_model(Fisher_model, test_loader, args)
    print(f'Fisher test acc: clean_acc {clean_acc}, preturb_acc: {perturb_acc}')
#     Fisher_clean_acc.append(clean_acc)
#     Fisher_perturb_acc.append(perturb_acc)
    
#     model_dist = model_distance(retrain_model, Fisher_model)
#     print('model norm distance: {:.4f}'.format(model_dist))
#     Fisher_distance.append(model_dist)
    # Fisher_original_distance.append(model_distance(model, Fisher_model))
    # Fisher_retrain_generating_samples_acc.append(test(Fisher_model, test_loader=adv_loader))
    print()

In [15]:
# Init gradinet informations
grad = torch.zeros((feature, 1)).to(device)
clean_grad = torch.zeros((feature, 1)).to(device)
parll_partial = batch_indirect_hessian(args)

# 从1开始删600个点
for batch_delete_num in range(1, delete_num+1, 1):
#     if args.remove_type == 2: # 删单点，删多次
    print(f"{batch_delete_num} deleted. label of cur image: {train_data[1][batch_delete_num].item()}.")
    if batch_delete_num not in remove_list: # 若当前点不在list之内，则不必“真”删
        continue
    else: # 将中间点的data influence 先从matrix中减去
        if batch_delete_num > 5: 
            index = remove_list.index(batch_delete_num)
            pre_index = index - 1
            sub_seq = re_sequence[remove_list[pre_index]:remove_list[index]-1] # 从上一个被删除的datapoint开始，到当前被删点的前一个datapoint

            # remove matrix and add grad 取两次删除的中间部分这些点，只计算aggregated grad然后更新模型即可，不用“真”的删
            temp_loader = make_loader(train_data, batch_size=pass_batch, head=remove_list[pre_index], rear=remove_list[index]-1)
            print(f"temp_loader: train_data[{remove_list[pre_index]},{remove_list[index]-1}]")
            for index, (image, label) in enumerate(temp_loader):
                image = image.to(device)
                label = label.to(device)
                image_perturbed = atk(image, label).to(device)
                # for perturbed grad
                grad += parll_loss_grad(weight, image_perturbed.view(image_perturbed.shape[0], feature), label, args).detach()
                # for clean grad
                clean_grad += parll_loss_grad(weight, image.view(image.shape[0], feature), label, args).detach()

                perturbed_fisher_matrix = batch_fisher(weight, image_perturbed.view(image_perturbed.shape[0], feature), label, args)
                clean_fisher_matrix = batch_fisher(weight, image.view(image.shape[0], feature), label, args)
                
                Fisher_matrix_perturbed = Fisher_matrix_perturbed - perturbed_fisher_matrix
                Fisher_matrix_unperturbed = Fisher_matrix_unperturbed - clean_fisher_matrix
    
    print()
    print('The {}-th delete'.format(step))
    step = step + 1
    # prepare work
    Fisher_model = copy.deepcopy(model).to(device)
    Fisher_delta_model = copy.deepcopy(model).to(device)
    delete_loader = make_loader(train_data, batch_size=pass_batch, head=(batch_delete_num-delete_batch), rear=batch_delete_num)
    print(f"delete_loader: train_data[{(batch_delete_num-delete_batch)},{batch_delete_num}]")
    
#     retrain_from_scratch(batch_delete_num)
    
    # calculate sum of grad & clean_grad
    start_time = time.time()
    for index, (image, label) in enumerate(delete_loader):
        image = image.to(device)
        label = label.to(device)
        image_perturbed = atk(image, label).to(device)
        clean_grad += parll_loss_grad(weight, image.view(image.shape[0], feature), label, args).detach()
        grad += parll_loss_grad(weight, image_perturbed.view(image_perturbed.shape[0], feature), label, args).detach()
    
    # unlearning datapoints in delete_loader
    for index, (image, label) in enumerate(delete_loader):
        image = image.to(device)
        label = label.to(device)
        image_perturbed = atk(image, label).to(device)
        
        perturbed_fisher_matrix = batch_fisher(weight, image_perturbed.view(image_perturbed.shape[0], feature), label, args)
        clean_fisher_matrix = batch_fisher(weight, image.view(image.shape[0], feature), label, args)

        Fisher_matrix_perturbed = Fisher_matrix_perturbed - perturbed_fisher_matrix
        Fisher_matrix_unperturbed = Fisher_matrix_unperturbed - clean_fisher_matrix
    
    Fisher_delta(Fisher_matrix_perturbed, Fisher_delta_model, grad, test_loader)
    Fisher(Fisher_matrix_unperturbed, Fisher_model, clean_grad, test_loader)
    

1 deleted. label of cur image: -1.

The 1-th delete
delete_loader: train_data[0,1]
Update done !
Fisher-delta unlearning:
clean test acc : 96.90%
perturb test acc : 91.12%
Fisher-delta test acc: clean_acc 0.9690245030050856, preturb_acc: 0.9112343966712899

Update done !
Fisher unlearning
clean test acc : 96.90%
perturb test acc : 91.35%
Fisher test acc: clean_acc 0.9690245030050856, preturb_acc: 0.9135460009246417

2 deleted. label of cur image: 1.

The 2-th delete
delete_loader: train_data[1,2]
Update done !
Fisher-delta unlearning:
clean test acc : 96.90%
perturb test acc : 91.08%
Fisher-delta test acc: clean_acc 0.9690245030050856, preturb_acc: 0.9107720758206195

Update done !
Fisher unlearning
clean test acc : 96.90%
perturb test acc : 91.12%
Fisher test acc: clean_acc 0.9690245030050856, preturb_acc: 0.9112343966712899

3 deleted. label of cur image: -1.

The 3-th delete
delete_loader: train_data[2,3]


  delta_w_dict[k] = torch.tensor(update_params).to(device).view_as(p)


Update done !
Fisher-delta unlearning:
clean test acc : 96.90%
perturb test acc : 90.85%
Fisher-delta test acc: clean_acc 0.9690245030050856, preturb_acc: 0.9084604715672677

Update done !
Fisher unlearning
clean test acc : 96.90%
perturb test acc : 91.22%
Fisher test acc: clean_acc 0.9690245030050856, preturb_acc: 0.9121590383726306

4 deleted. label of cur image: 1.

The 4-th delete
delete_loader: train_data[3,4]
Update done !
Fisher-delta unlearning:
clean test acc : 96.90%
perturb test acc : 90.66%
Fisher-delta test acc: clean_acc 0.9690245030050856, preturb_acc: 0.9066111881645862

Update done !
Fisher unlearning
clean test acc : 96.90%
perturb test acc : 91.22%
Fisher test acc: clean_acc 0.9690245030050856, preturb_acc: 0.9121590383726306

5 deleted. label of cur image: -1.

The 5-th delete
delete_loader: train_data[4,5]
Update done !
Fisher-delta unlearning:
clean test acc : 96.90%
perturb test acc : 90.80%
Fisher-delta test acc: clean_acc 0.9690245030050856, preturb_acc: 0.9079

Update done !
Fisher-delta unlearning:
clean test acc : 96.95%
perturb test acc : 85.90%
Fisher-delta test acc: clean_acc 0.9694868238557559, preturb_acc: 0.8589921405455386

Update done !
Fisher unlearning
clean test acc : 96.95%
perturb test acc : 64.68%
Fisher test acc: clean_acc 0.9694868238557559, preturb_acc: 0.646786870087841

241 deleted. label of cur image: -1.
242 deleted. label of cur image: 1.
243 deleted. label of cur image: -1.
244 deleted. label of cur image: -1.
245 deleted. label of cur image: -1.
246 deleted. label of cur image: -1.
247 deleted. label of cur image: 1.
248 deleted. label of cur image: -1.
249 deleted. label of cur image: 1.
250 deleted. label of cur image: 1.
251 deleted. label of cur image: -1.
252 deleted. label of cur image: 1.
253 deleted. label of cur image: 1.
254 deleted. label of cur image: -1.
255 deleted. label of cur image: 1.
256 deleted. label of cur image: 1.
257 deleted. label of cur image: -1.
258 deleted. label of cur image: 1.
259 del


The 9-th delete
delete_loader: train_data[479,480]
Update done !
Fisher-delta unlearning:
clean test acc : 96.99%
perturb test acc : 82.15%
Fisher-delta test acc: clean_acc 0.9699491447064262, preturb_acc: 0.8215441516412391

Update done !
Fisher unlearning
clean test acc : 96.86%
perturb test acc : 22.33%
Fisher test acc: clean_acc 0.9685621821544151, preturb_acc: 0.22330097087378642

481 deleted. label of cur image: 1.
482 deleted. label of cur image: -1.
483 deleted. label of cur image: -1.
484 deleted. label of cur image: 1.
485 deleted. label of cur image: -1.
486 deleted. label of cur image: 1.
487 deleted. label of cur image: 1.
488 deleted. label of cur image: 1.
489 deleted. label of cur image: -1.
490 deleted. label of cur image: 1.
491 deleted. label of cur image: -1.
492 deleted. label of cur image: 1.
493 deleted. label of cur image: -1.
494 deleted. label of cur image: -1.
495 deleted. label of cur image: -1.
496 deleted. label of cur image: -1.
497 deleted. label of cur