In [1]:
import torch
import random
import time
import copy
import numpy as np

from torchattacks import PGD

from dataloder import *
from argument import *
from model import *
from pretrain import *
from utils import *
from parllutils import *
from modules import *

args = argument()
device = 'cuda'

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

# random seed
setup_seed(args.times)

Namespace(adv='PGD', batchsize=128, dataset='binaryMnist', deletebatch=1, deletenum=0, isbatch=False, iterneumann=3, lam=0.0001, model='logistic', parllsize=128, remove_type=2, times=0)

In [2]:
delete_num = 50 # 共删除50个点
delete_batch = 1 # 每次删1个点
pass_batch = args.parllsize # batch_size 并行计算 total hessian
delete_num, delete_batch, pass_batch

(600, 1, 128)

## Pre-processing
### 1) load data

In [3]:
train_data, test_data, re_sequence = Load_Data(args, delete_num, shuffle=True)
train_loader = make_loader(train_data, batch_size=args.batchsize)
test_loader = make_loader(test_data, batch_size=args.batchsize)
print(f"total number of train data: {len(train_data[0])}, test data: {len(test_data[0])}")

train labels: tensor([1, 1, 1,  ..., 1, 7, 1])
total number of train data: 13007, test data: 2163


### 2) load adversarially trained model (original model w*)

In [4]:
# Adversarisal training
model_path = os.path.join('..', 'data', 'ATM', f"dataset_{args.dataset}_adv_{args.adv}_model_{args.model}_points_{len(train_loader.dataset)}_{args.times}.pth")
model, training_time = train(train_loader, test_loader, args, desc='Pre-Adv Training', verbose=True, model_path=model_path)
model, training_time

In [11]:
pass_loader = make_loader(train_data, batch_size=pass_batch)
# Calculate the hessian matrix of partial_dd
matrices = dict(MUter=None)

method = 'MUter'
isDelta = False
ssr = 'unperturbed'
filename = f'dataset_{args.dataset}_adv_{args.adv}_model_{args.model}_method_{method[0]}_sample_{ssr}_{args.times}.pt'
print(name)

start_time = time.time()
matrices[name] = load_memory_matrix(filename, model, pass_loader, method, isDelta, args)
end_time = time.time()
muter_time1 = start_time - end_time


loading memory matrix : ../data/MemoryMatrix/dataset_binaryMnist_adv_PGD_model_logistic_method_MUter_sample_unperturbed.pt


  0%|          | 0/102 [00:00<?, ?it/s]

matrix shape torch.Size([784, 784]), type <class 'torch.Tensor'>
memory matrix for MUter method using un-perturb samples to calculate
saving matrix...
done!
MUter
loading memory matrix : ../data/MemoryMatrix/dataset_binaryMnist_adv_PGD_model_logistic_method_Newton_sample_perturbed.pt
matrix shape torch.Size([784, 784]), type <class 'torch.Tensor'>
done!
Newton_delta
loading memory matrix : ../data/MemoryMatrix/dataset_binaryMnist_adv_PGD_model_logistic_method_Fisher_sample_perturbed.pt
matrix shape torch.Size([784, 784]), type <class 'torch.Tensor'>
done!
Fisher_delta
loading memory matrix : ../data/MemoryMatrix/dataset_binaryMnist_adv_PGD_model_logistic_method_Influence_sample_perturbed.pt
matrix shape torch.Size([784, 784]), type <class 'torch.Tensor'>
done!
Influence_delta
loading memory matrix : ../data/MemoryMatrix/dataset_binaryMnist_adv_PGD_model_logistic_method_Newton_sample_unperturbed.pt
matrix shape torch.Size([784, 784]), type <class 'torch.Tensor'>
done!
Newton
loading mem

## Stage II: Unlearning
1) Inner level attack method;
2) Calculate the public part partial_xx and partial_xx_inv for linear model;
3) Init gradient information;

In [None]:
from torchattacks import PGD
import copy
from utils import cg_solve, model_distance, hessian, update_w, derive_inv
import time
from torch.utils.data import DataLoader, TensorDataset

# Inner level attack method
_, _, atk_info = training_param(args)
atk = PGD(model, atk_info[0], atk_info[1], atk_info[2], lossfun=LossFunction(args.model), lam=args.lam)

# Calculate the public part partial_xx and partial_xx_inv for linear model
feature = get_featrue(args)
weight = vec_param(model.parameters()).detach()
public_partial_dd = (weight.mm(weight.t())).detach()
public_partial_dd_inv = derive_inv(public_partial_dd, method='Neumann', iter=args.iterneumann)

In [None]:
step = 1 # record unlearning times
## compare with removal list [1, 2, 3, 4, 5, ~1%, ~2%, ~3%, ~4%, ~5%] 
remove_list = None
if args.dataset == 'binaryMnist':
    remove_list = [1, 2, 3, 4, 5, 120, 240, 360, 480, 600]  # for mnist
elif args.dataset == 'phishing':
    remove_list = [1, 2, 3, 4, 5, 100, 200, 300, 400, 500]  # for phsihing
elif args.dataset == 'madelon':
    remove_list = [1, 2, 3, 4, 5, 20, 40, 60, 80, 100]  # for madelon
elif args.dataset == 'covtype':
    remove_list = [1, 2, 3, 4, 5, 5000, 10000, 15000, 20000, 25000]
elif args.dataset == 'epsilon':
    remove_list = [1, 2, 3, 4, 5, 4000, 8000, 12000, 16000, 20000]
else:
    remove_list = [1, 2, 3, 4, 5, 10, 20, 30, 40, 50]  # for splice


In [None]:
# Init gradinet informations
grad = torch.zeros((feature, 1)).to(device)
clean_grad = torch.zeros((feature, 1)).to(device)
parll_partial = batch_indirect_hessian(args)

# 从1开始删600个点
for batch_delete_num in range(1, delete_num+1, 1):
    print('The {}-th delete'.format(batch_delete_num))
    # prepare work
    unlearning_model = copy.deepcopy(model).to(device)  # for MUter method
    x = train_data[0][batch_delete_num].to(device)
    y = train_data[1][batch_delete_num].to(device)
    x_delta = atk(x, y).to(device)

    update_grad(grad, clean_grad, weight, x, x_delta, y, feature, args)
    
    H_11, H_12, H_21, neg_H_22 = None, None, None, None
    
    Dww = partial_ww - (partial_wx.mm(partial_xx_inv.mm(partial_xw)))
    
    MUter(batch_delete_num, delete_loader, matrix, grad, unlearning_model)
    


# Unlearning methods.
## MUter stage 2

In [None]:
def stage2(H_11, H_12, H_21, neg_H_22):
    
    z = torch.sigmoid(y*(x.t().mm(weight)))
    D = z * (1 - z)
    partial_ww = (D * (x.mm(x.t()))) + (args.lam * torch.eye(weight_size)).to(device)
    partial_wx = (D * (x.mm(weight.t()))) + ((z-1) * y * torch.eye(x_size).to(device))
    partial_xx_inv = (1/D) * public_partial_xx_inv
    #partial_xx_inv = D * public_partial_xx_inv # to verify is right
    partial_xw = (D * (weight.mm(x.t()))) + ((z-1) * y * torch.eye(weight_size).to(device))
    public_partial_xx = public_partial_xx.to(device)
    
    partial_xx = D * public_partial_xx
    return partial_ww.detach(), partial_wx.detach(), partial_xx_inv.detach(), partial_xw.detach(), partial_xx.detach()
        
    H_11, H_12, _, H_21, neg_H_22
    

In [None]:
def logistic_partial_hessian(x, y, weight, public_partial_xx_inv):
    """
    for loss function == 'logistic'
    calculate single sample's partial_hessian, then using vamp function to 
    implement parll
    """
    device = 'cuda'
    size = weight.shape[0]

    z = torch.sigmoid(y * (x.t().mm(weight)))
    D = z * (1 - z)
    partial_wx = (D * (x.mm(weight.t()))) + ((z-1) * y * torch.eye(size).to(device))
    partial_xx_inv = (1/D) * public_partial_xx_inv
    partial_xw = (D * (weight.mm(x.t()))) + ((z-1) * y * torch.eye(size).to(device))
    
    return  partial_wx.mm(partial_xx_inv.mm(partial_xw))

In [None]:
def MUter(batch_delete_num, delete_loader, matrix, grad, unlearning_model):
    unlearning_time = 0.0 # record one batch spending time for MUter
    # building matrix
    Dww = None
    H_11 = None
    H_12 = None
    H_21 = None
    neg_H_22 = None
    
    # Unlearning
    start_time = time.time()
    for index, (image, label) in enumerate(delete_loader):
        image = image.to(device)
        label = label.to(device)
        image_perturbed = atk(image, label).to(device)

        if args.isbatch ==  False: # 删除单点：
            Dww, H_11, H_12, _, H_21, neg_H_22 = partial_hessian(image_perturbed.view(feature, 1), label, 
                                                                 weight, public_partial_dd_inv, 
                                                                 args, isUn_inv=True, 
                                                                 public_partial_xx=public_partial_dd)    
        else: # for mini-batch # 删除多点
            matrix = matrix - \
                (batch_hessian(weight, image_perturbed.view(image_perturbed.shape[0], feature), label, args) - \
                 parll_partial(image_perturbed.view(image_perturbed.shape[0], feature, 1), label, \
                               weight, public_partial_dd_inv).sum(dim=0).detach())

        grad = grad + \
            parll_loss_grad(weight, \
                            image_perturbed.view(image_perturbed.shape[0], feature), \
                            label, args)

    if args.isbatch == False:
        block_matrix = buliding_matrix(matrix, H_11, H_12, -neg_H_22, H_21)
        print('block_matrix shape {}'.format(block_matrix.shape))
        grad_cat_zero = torch.cat([grad, torch.zeros((feature, 1)).to(device)], dim=0)
        print('grad_cat_zeor shape {}'.format(grad_cat_zero.shape))

        delta_w_cat_alpha = cg_solve(block_matrix, grad_cat_zero.squeeze(dim=1), get_iters(args))
        delta_w = delta_w_cat_alpha[:feature]

        update_w(delta_w, unlearning_model)
        matrix = matrix - Dww
    else:
        delta_w = cg_solve(matrix, grad.squeeze(dim=1), get_iters(args))
        update_w(delta_w, unlearning_model)
    
    clean_acc, perturb_acc = Test_model(unlearning_model, test_loader, args) 
    model_dist = model_distance(retrain_model, unlearning_model)
    print()
    print('MUter unlearning:')
    print(f'unlearning model test acc: clean_acc {clean_acc}, preturb_acc: {perturb_acc}')
    print('model distance between Muter and retrain_from_scratch: {:.4f}'.format(model_dist))
    
    print()