In [1]:
import torch
from dataloder import *
from argument import *
from model import *
from pretrain import *
from utils import *
from parllutils import *
from functorch import vmap
args = argument()
device = 'cuda'
args

Namespace(adv='PGD', batchsize=128, dataset='binaryMnist', deletebatch=1, deletenum=0, isbatch=False, iterneumann=3, lam=0.0001, model='logistic', parllsize=128, remove_type=2, times=0)

In [2]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

# random seed
setup_seed(args.times)

delete_num = 600 # 共删除600个点
delete_batch = 1 # 每次删1个点
pass_batch = args.parllsize # batch_size 并行计算 total hessian
delete_num, delete_batch, pass_batch

(600, 1, 128)

## Pre-processing
### 1) load data

In [3]:
train_data, test_data, re_sequence = Load_Data(args, delete_num, shuffle=True)
train_loader = make_loader(train_data, batch_size=args.batchsize)
test_loader = make_loader(test_data, batch_size=args.batchsize)
print(f"total number of train data: {len(train_data[0])}, test data: {len(test_data[0])}")

train labels: tensor([1, 1, 1,  ..., 1, 7, 1])
total number of train data: 13007, test data: 2163


In [4]:
print(len(train_data[1]), len(test_data[1]))
print(len(train_loader.dataset), len(test_loader.dataset))

13007 2163
13007 2163


### 2) load adversarially trained model (original model w*)

In [5]:
# Adversarisal training  
model_path = os.path.join('..', 'data', 'ATM', f"dataset_{args.dataset}_adv_{args.adv}_model_{args.model}_points_{len(train_loader.dataset)}_{args.times}.pth")
model, training_time = train(train_loader, test_loader, args, desc='Pre-Adv Training', verbose=True, model_path=model_path)
model, training_time

model information: 
LogisticModel(
  (fc): Linear(in_features=784, out_features=1, bias=False)
)
training type: PGD, epsilon: 0.25098, alpha: 0.03137, steps: 15
training hyperparameters  lr: 0.010, epochs: 100 


Pre-Adv Training: 100%|██████████| 100/100 [01:42<00:00,  1.03s/it, adv_train_type=PGD, loss=0.354, lr=0.01, model=logistic, times=0]

traning PGD model spending 102.70 seconds





(LogisticModel(
   (fc): Linear(in_features=784, out_features=1, bias=False)
 ),
 102.70159816741943)

In [6]:
# test the consistency among all trails with same augment `args.times`
print(next(iter(test_loader))[1].reshape(-1))

tensor([1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1,
        1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0,
        0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1,
        0, 0, 0, 1, 1, 1, 1, 1])


## 3) pre-unlearning: calculate the hessian matrix of partial_dd

In [7]:
pass_loader = make_loader(train_data, batch_size=pass_batch)
matrices = dict(MUter=None,
#                 Newton_delta=None,
#                 Fisher_delta=None,
#                 Influence_delta=None,
#                 Newton=None,
#                 Fisher=None,
#                 Influence=None
               )

for name in matrices.keys():
    method = name.split('_')
    isDelta = True if len(method) > 1 else False
    ssr = 'perturbed' if len(method) > 1 else 'unperturbed'
    filename = f'dataset_{args.dataset}_adv_{args.adv}_model_{args.model}_method_{method[0]}_sample_{ssr}_{args.times}.pt'
    matrices[name] = load_memory_matrix(filename, model, pass_loader, method[0],isDelta, args)
    
            
for name in matrices.keys():
    print(name)
    print(matrices[name])

loading memory matrix of MUter method from: ../data/MemoryMatrix/dataset_binaryMnist_adv_PGD_model_logistic_method_MUter_sample_unperturbed_0.pt.pt
tensor([[-181.3839,   -6.7562,   -7.1364,  ...,   -7.0635,   -7.2264,
           -7.0113],
        [  -6.7562, -181.4017,   -7.3416,  ...,   -7.0091,   -7.3634,
           -7.0389],
        [  -7.1364,   -7.3416, -182.2356,  ...,   -7.1665,   -7.6303,
           -7.6300],
        ...,
        [  -7.0635,   -7.0091,   -7.1665,  ..., -181.1368,   -7.2876,
           -7.0035],
        [  -7.2264,   -7.3634,   -7.6303,  ...,   -7.2876, -182.3219,
           -7.5114],
        [  -7.0113,   -7.0389,   -7.6300,  ...,   -7.0035,   -7.5114,
         -181.9788]], device='cuda:0')
matrix shape torch.Size([784, 784]), type <class 'torch.Tensor'>
MUter
tensor([[-181.3839,   -6.7562,   -7.1364,  ...,   -7.0635,   -7.2264,
           -7.0113],
        [  -6.7562, -181.4017,   -7.3416,  ...,   -7.0091,   -7.3634,
           -7.0389],
        [  -7.1364,   

## Stage II: Unlearning
1) Inner level attack method;
2) Calculate the public part partial_xx and partial_xx_inv for linear model;
3) Init gradient information;

In [8]:
from torchattacks import PGD
import copy
from utils import derive_inv
import time
from torch.utils.data import DataLoader, TensorDataset

# Inner level attack method
_, _, atk_info = training_param(args)
print(model)
atk = PGD(model, atk_info[0], atk_info[1], atk_info[2], lossfun=LossFunction(args.model), lam=args.lam)

# Calculate the public part partial_xx and partial_xx_inv for linear model
feature = get_featrue(args)
weight = vec_param(model.parameters()).detach()
public_partial_dd = (weight.mm(weight.t())).detach()
public_partial_dd_inv = derive_inv(public_partial_dd, method='Neumann', iter=args.iterneumann)

LogisticModel(
  (fc): Linear(in_features=784, out_features=1, bias=False)
)


In [9]:
step = 1 # record unlearning times
## compare with removal list [1, 2, 3, 4, 5, ~1%, ~2%, ~3%, ~4%, ~5%] 
remove_list = None
if args.dataset == 'binaryMnist':
    remove_list = [1, 2, 3, 4, 5, 120, 240, 360, 480, 600]  # for mnist
elif args.dataset == 'phishing':
    remove_list = [1, 2, 3, 4, 5, 100, 200, 300, 400, 500]  # for phsihing
elif args.dataset == 'madelon':
    remove_list = [1, 2, 3, 4, 5, 20, 40, 60, 80, 100]  # for madelon
elif args.dataset == 'covtype':
    remove_list = [1, 2, 3, 4, 5, 5000, 10000, 15000, 20000, 25000]
elif args.dataset == 'epsilon':
    remove_list = [1, 2, 3, 4, 5, 4000, 8000, 12000, 16000, 20000]
else:
    remove_list = [1, 2, 3, 4, 5, 10, 20, 30, 40, 50]  # for splice


## Unlearning process

In [None]:
from modules import *
# Init gradinet informations
grad = torch.zeros((feature, 1)).to(device)
clean_grad = torch.zeros((feature, 1)).to(device)
saver = None

for batch_delete_num in range(delete_batch, delete_num+1, delete_batch):
#     if args.remove_type == 2: # 删单点，删多次
    print(f"{batch_delete_num} deleted. label of cur image: {train_data[1][batch_delete_num].item()}.")
    
    if batch_delete_num not in remove_list: # 若当前点不在list之内，则不必“真”删
        continue
    else: # 将中间点的data influence 先从matrix中减去
        if batch_delete_num > 5: 
            
            index = remove_list.index(batch_delete_num)
            pre_index = index - 1
            sub_seq = re_sequence[remove_list[pre_index]:remove_list[index]-1] # 从上一个被删除的datapoint开始，到当前被删点的前一个datapoint

            # remove matrix and add grad 取两次删除的中间部分这些点，只计算aggregated grad然后更新模型即可，不用“真”的删
            temp_loader = make_loader(train_data, batch_size=pass_batch, head=remove_list[pre_index], rear=remove_list[index]-1)
            print(f"temp_loader: train_data[{remove_list[pre_index]},{remove_list[index]-1}]")
            for index, (image, label) in enumerate(temp_loader):
                image = image.to(device)
                label = label.to(device)
                # x+delta
                image_perturbed = atk(image, label).to(device)
                
                update_grad(grad, clean_grad, weight, image, image_perturbed, label, feature, args)
                update_matrix(matrices, weight, image, image_perturbed, label, feature, public_partial_dd_inv, args, flag='muter')
                    
#                 # for perturbed grad
#                 # :: aggregate the adversarial gradients
#                 grad = grad + \
#                     parll_loss_grad(weight, \
#                                     image_perturbed.view(image_perturbed.shape[0], feature), \
#                                     label, args)
                
#                 # for MUter matrix
#                 # :: delete the batch_hessian (data influence)
#                 matrix = matrix - \
#                     (batch_hessian(weight, image_perturbed.view(image_perturbed.shape[0], feature), label, args) - \
#                      parll_partial(image_perturbed.view(image_perturbed.shape[0], feature, 1), label, \
#                                    weight, public_partial_dd_inv).sum(dim=0).detach())
    
    print()
    print('The {}-th delete'.format(step))
    step = step + 1
    # prepare work
    unlearning_model = copy.deepcopy(model).to(device)  # for MUter method
    # if delete_batch = 20
    # batch_delete_num = 20, delete_loader = train_data[0:20]
    # batch_delete_num = 40, delete_loader = train_data[20:40]
    delete_loader = make_loader(train_data, batch_size=pass_batch, head=(batch_delete_num-delete_batch), rear=batch_delete_num)
    print(f"delete_loader: train_data[{(batch_delete_num-delete_batch)},{batch_delete_num}]")
    
    ## retrain_from_scratch
    retrain_loader = make_loader(train_data, batch_size=128, head=batch_delete_num)
    retrain_model = retrain_from_scratch(retrain_loader, test_loader, args, saver)
    
    # calculate the aggregated grad & clean_grad
    for index, (image, label) in enumerate(delete_loader):
        image = image.to(device)
        label = label.to(device)
        image_perturbed = atk(image, label).to(device)

        update_grad(grad, clean_grad, weight, image, image_perturbed, label, feature, args)
    
    # unlearning stage
    ## MUter
    Dww, H_11, H_12, H_21, neg_H_22 = None, None, None, None, None
    start_time = time.time()
    for index, (image, label) in enumerate(delete_loader):
        image = image.to(device)
        label = label.to(device)
        image_perturbed = atk(image, label).to(device)

        if args.isbatch ==  False:
            Dww, H_11, H_12, _, H_21, neg_H_22 = partial_hessian(image_perturbed.view(feature, 1), label, weight, public_partial_dd_inv, args, isUn_inv=True, public_partial_xx=public_partial_dd)
        else: # for mini-batch
            update_matrix(matrices, weight, image, image_perturbed, label, feature, public_partial_dd_inv, args, flag='muter')
    
    unlearn_muter(matrices['MUter'], model, grad, Dww, H_11, H_12, H_21, neg_H_22, feature, device, start_time, retrain_model, test_loader, args, saver)
    

1 deleted. label of cur image: 0.

The 6-th delete
delete_loader: train_data[0,1]
clean test acc : 96.90%
perturb test acc : 91.22%

Retrain from scratch:
model test acc: clean_acc 0.9690, preturb_acc: 0.9122
retrain time: 94.8896

black matrix A shape torch.Size([1568, 1568]), type <class 'torch.Tensor'>
Update done !
clean test acc : 96.90%
perturb test acc : 91.22%

MUter unlearning:
model test acc: clean_acc 0.9690, preturb_acc: 0.9122
model norm distance: 0.0786
unlearning time: 0.0153

2 deleted. label of cur image: 0.

The 7-th delete
delete_loader: train_data[1,2]
clean test acc : 96.90%
perturb test acc : 91.17%

Retrain from scratch:
model test acc: clean_acc 0.9690, preturb_acc: 0.9117
retrain time: 100.5480

black matrix A shape torch.Size([1568, 1568]), type <class 'torch.Tensor'>
Update done !
clean test acc : 96.90%
perturb test acc : 91.26%

MUter unlearning:
model test acc: clean_acc 0.9690, preturb_acc: 0.9126
model norm distance: 0.0734
unlearning time: 0.0124

3 del

Re-Adv Training:  23%|██▎       | 23/100 [00:22<01:13,  1.05it/s, adv_train_type=PGD, loss=0.395, lr=0.01, model=logistic, times=0]

## the golden baseline: retrain_from_scratch

In [None]:
# def retrain_from_scratch(batch_delete_num):
#     # retrain_from_scratch
#     retrain_loader = make_loader(train_data, batch_size=128, head=batch_delete_num)
#     retrain_model, retrain_time = train(retrain_loader, test_loader, args, verbose=False)    
#     clean_acc, perturb_acc = Test_model(retrain_model, test_loader, args)
#     print()
#     print('Retrain from scratch:')
#     print(f"retrain_loader: train_data[{batch_delete_num}:]")
#     print(f'retrain model test acc: clean_acc {clean_acc}, preturb_acc: {perturb_acc}')
    

# Unlearning methods.
## 0. MUter

In [None]:
# def MUter(batch_delete_num, delete_loader, matrix, grad, unlearning_model):
#     unlearning_time = 0.0 # record one batch spending time for MUter
#     # building matrix
#     Dww = None
#     H_11 = None
#     H_12 = None
#     H_21 = None
#     neg_H_22 = None
    
#     # Unlearning
#     start_time = time.time()
#     for index, (image, label) in enumerate(delete_loader):
#         image = image.to(device)
#         label = label.to(device)
#         image_perturbed = atk(image, label).to(device)

#         if args.isbatch ==  False: # 删除单点：
#             Dww, H_11, H_12, _, H_21, neg_H_22 = partial_hessian(image_perturbed.view(feature, 1), label, 
#                                                                  weight, public_partial_dd_inv, 
#                                                                  args, isUn_inv=True, 
#                                                                  public_partial_xx=public_partial_dd)    
#         else: # for mini-batch # 删除多点
#             matrix = matrix - \
#                 (batch_hessian(weight, image_perturbed.view(image_perturbed.shape[0], feature), label, args) - \
#                  parll_partial(image_perturbed.view(image_perturbed.shape[0], feature, 1), label, \
#                                weight, public_partial_dd_inv).sum(dim=0).detach())

#         grad = grad + \
#             parll_loss_grad(weight, \
#                             image_perturbed.view(image_perturbed.shape[0], feature), \
#                             label, args)

#     if args.isbatch == False:
#         block_matrix = buliding_matrix(matrix, H_11, H_12, -neg_H_22, H_21)
#         print('block_matrix shape {}'.format(block_matrix.shape))
#         grad_cat_zero = torch.cat([grad, torch.zeros((feature, 1)).to(device)], dim=0)
#         print('grad_cat_zeor shape {}'.format(grad_cat_zero.shape))

#         delta_w_cat_alpha = cg_solve(block_matrix, grad_cat_zero.squeeze(dim=1), get_iters(args))
#         delta_w = delta_w_cat_alpha[:feature]

#         update_w(delta_w, unlearning_model)
#         matrix = matrix - Dww
#     else:
#         delta_w = cg_solve(matrix, grad.squeeze(dim=1), get_iters(args))
#         update_w(delta_w, unlearning_model)
    
#     clean_acc, perturb_acc = Test_model(unlearning_model, test_loader, args) 
#     model_dist = model_distance(retrain_model, unlearning_model)
#     print()
#     print('MUter unlearning:')
#     print(f'unlearning model test acc: clean_acc {clean_acc}, preturb_acc: {perturb_acc}')
#     print('model distance between Muter and retrain_from_scratch: {:.4f}'.format(model_dist))
    
#     print()