In [1]:
import torch
import numpy as np
import torch.nn as nn
import math
from copy import deepcopy 

from src.synthetic import GaussianDataset
from src.data import get_retain_forget_datasets, get_dataloaders
from src.train import train, L2RegularizedCrossEntropyLoss
from src.eval import evaluate
from src.forget import forget, calculate_hessian, calculate_grad, calculate_update, update_model
from src.utils import get_module_device
from tqdm import tqdm

In [2]:
num_classes = 10
num_samples = 15000
dim = 100
mean = np.zeros(dim)
cov = np.eye(dim)
## surr cov
surr_cov = np.eye(dim)
surr_cov[0, -1] = 0.5
surr_cov[-1, 0] = 0.5
##
dataset = GaussianDataset(num_samples, num_classes, mean, cov)
surr_dataset = dataset.create_surr(mean, surr_cov)
train_dataset, test_dataset = get_retain_forget_datasets(dataset, 0.2)

kl_distance = surr_dataset.calculate_kl_between(dataset)
print('given kl distance:{}'.format(kl_distance))

eps = 5 * (math.e ** 3)
delta = 1

given kl distance:0.14384103622589045


In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
criterion = L2RegularizedCrossEntropyLoss(l2_lambda=0.1)

In [4]:
def print_eval(model_arg):
    print('#######################################')
    print('train:')
    evaluate(train_loader, model_arg, criterion, device=device)
    print('#######################################')
    print('#######################################')
    print('val:')
    evaluate(val_loader, model_arg, criterion, device=device)
    print('#######################################')
    print('#######################################')
    print('retain:')
    evaluate(retain_loader, model_arg, criterion, device=device)
    print('#######################################')
    print('#######################################')
    print('forget:')
    evaluate(forget_loader, model_arg, criterion, device=device)
    print('#######################################')
    print('#######################################')
    print('surrogate:')
    evaluate(surr_loader, model_arg, criterion, device=device)
    print('#######################################')

In [5]:
retain_dataset, forget_dataset = get_retain_forget_datasets(train_dataset, 0.1)
train_loader, val_loader = get_dataloaders([train_dataset, test_dataset], batch_size=256)
retain_loader = get_dataloaders(retain_dataset, batch_size=256)
forget_loader = get_dataloaders(forget_dataset, batch_size=256)
surr_loader = get_dataloaders(surr_dataset, batch_size=256)
model = nn.Linear(dim, num_classes, bias=False).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

train(train_loader, val_loader, model, criterion, optimizer, num_epoch=300, device=device)
print_eval(model)
model = model.to('cpu')

train epoch 1: 100%|██████████| 47/47 [00:00<00:00, 146.06batch/s, loss=2.41]
eval: 100%|██████████| 12/12 [00:00<00:00, 216.20batch/s, acc=0.114, loss=2.48]
train epoch 2: 100%|██████████| 47/47 [00:00<00:00, 381.36batch/s, loss=2.27]
eval: 100%|██████████| 12/12 [00:00<00:00, 490.43batch/s, acc=0.182, loss=2.27]
train epoch 3: 100%|██████████| 47/47 [00:00<00:00, 391.39batch/s, loss=2.19]
eval: 100%|██████████| 12/12 [00:00<00:00, 554.08batch/s, acc=0.262, loss=2.2]
train epoch 4: 100%|██████████| 47/47 [00:00<00:00, 390.45batch/s, loss=2.15]
eval: 100%|██████████| 12/12 [00:00<00:00, 486.65batch/s, acc=0.325, loss=2.15]
train epoch 5: 100%|██████████| 47/47 [00:00<00:00, 177.25batch/s, loss=2.14]
eval: 100%|██████████| 12/12 [00:00<00:00, 305.72batch/s, acc=0.361, loss=2.12]
train epoch 6: 100%|██████████| 47/47 [00:00<00:00, 374.39batch/s, loss=2.08]
eval: 100%|██████████| 12/12 [00:00<00:00, 523.51batch/s, acc=0.399, loss=2.12]
train epoch 7: 100%|██████████| 47/47 [00:00<00:00, 3

#######################################
train:


eval: 100%|██████████| 47/47 [00:00<00:00, 389.76batch/s, acc=0.644, loss=2.05]


#######################################
#######################################
val:


eval: 100%|██████████| 12/12 [00:00<00:00, 372.98batch/s, acc=0.559, loss=2.09]


#######################################
#######################################
retain:


eval: 100%|██████████| 43/43 [00:00<00:00, 450.96batch/s, acc=0.643, loss=2.01]


#######################################
#######################################
forget:


eval: 100%|██████████| 5/5 [00:00<00:00, 479.40batch/s, acc=0.654, loss=2.01]


#######################################
#######################################
surrogate:


eval: 100%|██████████| 59/59 [00:00<00:00, 457.04batch/s, acc=0.589, loss=2.04]

#######################################





In [6]:
# retrain from scratch
rmodel = nn.Linear(dim, num_classes, bias=False).to(device)
optimizer = torch.optim.Adam(rmodel.parameters(), lr=0.001)
train(retain_loader, val_loader, rmodel, criterion, optimizer, num_epoch=300, device=device)

print_eval(rmodel)
rmodel = rmodel.to('cpu')

train epoch 1: 100%|██████████| 43/43 [00:00<00:00, 246.61batch/s, loss=2.42]
eval: 100%|██████████| 12/12 [00:00<00:00, 356.87batch/s, acc=0.153, loss=2.44]
train epoch 2: 100%|██████████| 43/43 [00:00<00:00, 408.18batch/s, loss=2.26]
eval: 100%|██████████| 12/12 [00:00<00:00, 440.62batch/s, acc=0.23, loss=2.25]
train epoch 3: 100%|██████████| 43/43 [00:00<00:00, 394.00batch/s, loss=2.15]
eval: 100%|██████████| 12/12 [00:00<00:00, 451.46batch/s, acc=0.309, loss=2.17]
train epoch 4: 100%|██████████| 43/43 [00:00<00:00, 382.88batch/s, loss=2.15]
eval: 100%|██████████| 12/12 [00:00<00:00, 493.24batch/s, acc=0.368, loss=2.13]
train epoch 5: 100%|██████████| 43/43 [00:00<00:00, 385.68batch/s, loss=2.11]
eval: 100%|██████████| 12/12 [00:00<00:00, 471.18batch/s, acc=0.414, loss=2.11]
train epoch 6: 100%|██████████| 43/43 [00:00<00:00, 371.14batch/s, loss=2.1] 
eval: 100%|██████████| 12/12 [00:00<00:00, 463.75batch/s, acc=0.435, loss=2.11]
train epoch 7: 100%|██████████| 43/43 [00:00<00:00, 3

#######################################
train:


eval: 100%|██████████| 47/47 [00:00<00:00, 360.95batch/s, acc=0.635, loss=2.03]


#######################################
#######################################
val:


eval: 100%|██████████| 12/12 [00:00<00:00, 430.46batch/s, acc=0.552, loss=2.06]


#######################################
#######################################
retain:


eval: 100%|██████████| 43/43 [00:00<00:00, 432.14batch/s, acc=0.642, loss=2.02]


#######################################
#######################################
forget:


eval: 100%|██████████| 5/5 [00:00<00:00, 417.96batch/s, acc=0.572, loss=2.07]


#######################################
#######################################
surrogate:


eval: 100%|██████████| 59/59 [00:00<00:00, 449.80batch/s, acc=0.593, loss=2.05]

#######################################





In [7]:
# retrain from scratch
surrmodel = nn.Linear(dim, num_classes, bias=False).to(device)
optimizer = torch.optim.Adam(surrmodel.parameters(), lr=0.001)
train(surr_loader, val_loader, surrmodel, criterion, optimizer, num_epoch=300, device=device)

print_eval(surrmodel)
surrmodel = surrmodel.to('cpu')

train epoch 1: 100%|██████████| 59/59 [00:00<00:00, 373.33batch/s, loss=2.37]
eval: 100%|██████████| 12/12 [00:00<00:00, 475.99batch/s, acc=0.151, loss=2.35]
train epoch 2: 100%|██████████| 59/59 [00:00<00:00, 390.13batch/s, loss=2.19]
eval: 100%|██████████| 12/12 [00:00<00:00, 541.81batch/s, acc=0.245, loss=2.18]
train epoch 3: 100%|██████████| 59/59 [00:00<00:00, 399.88batch/s, loss=2.11]
eval: 100%|██████████| 12/12 [00:00<00:00, 522.36batch/s, acc=0.346, loss=2.16]
train epoch 4: 100%|██████████| 59/59 [00:00<00:00, 395.59batch/s, loss=2.07]
eval: 100%|██████████| 12/12 [00:00<00:00, 610.47batch/s, acc=0.39, loss=2.11]
train epoch 5: 100%|██████████| 59/59 [00:00<00:00, 387.38batch/s, loss=2.08]
eval: 100%|██████████| 12/12 [00:00<00:00, 582.99batch/s, acc=0.416, loss=2.1]
train epoch 6: 100%|██████████| 59/59 [00:00<00:00, 398.72batch/s, loss=2.05]
eval: 100%|██████████| 12/12 [00:00<00:00, 558.79batch/s, acc=0.439, loss=2.07]
train epoch 7: 100%|██████████| 59/59 [00:00<00:00, 22

#######################################
train:


eval: 100%|██████████| 47/47 [00:00<00:00, 398.61batch/s, acc=0.525, loss=2.05]


#######################################
#######################################
val:


eval: 100%|██████████| 12/12 [00:00<00:00, 325.71batch/s, acc=0.531, loss=2.04]


#######################################
#######################################
retain:


eval: 100%|██████████| 43/43 [00:00<00:00, 425.38batch/s, acc=0.527, loss=2.06]


#######################################
#######################################
forget:


eval: 100%|██████████| 5/5 [00:00<00:00, 536.91batch/s, acc=0.51, loss=2.08]


#######################################
#######################################
surrogate:


eval: 100%|██████████| 59/59 [00:00<00:00, 435.52batch/s, acc=0.61, loss=2.06] 


#######################################


In [8]:
# forget with exact
model = model.to(device)
# fmodel = forget(model, train_loader, forget_loader, forget_loader, criterion, save_path='tmp', eps=eps, delta=delta)
fmodel = forget(model, train_loader, forget_loader, forget_loader, criterion, save_path='tmp')
model = model.to('cpu')
print_eval(fmodel)
fmodel = fmodel.to('cpu')

Calculating Hessian: 100%|██████████| 47/47 [05:44<00:00,  7.34s/it]
Calculating Hessian: 100%|██████████| 5/5 [00:36<00:00,  7.35s/it]


#######################################
train:


eval: 100%|██████████| 47/47 [00:00<00:00, 561.34batch/s, acc=0.616, loss=2.03]


#######################################
#######################################
val:


eval: 100%|██████████| 12/12 [00:00<00:00, 558.75batch/s, acc=0.537, loss=2.03]


#######################################
#######################################
retain:


eval: 100%|██████████| 43/43 [00:00<00:00, 562.01batch/s, acc=0.622, loss=1.99]


#######################################
#######################################
forget:


eval: 100%|██████████| 5/5 [00:00<00:00, 554.16batch/s, acc=0.56, loss=2.09]


#######################################
#######################################
surrogate:


eval: 100%|██████████| 59/59 [00:00<00:00, 555.95batch/s, acc=0.575, loss=2.07]


#######################################


In [9]:
# forget with surrogate
# forget with exact
model = model.to(device)
surrmodel = surrmodel.to(device)
smodel = forget(model, surr_loader, forget_loader, forget_loader, criterion, linear=True, num_class=num_classes, eps=eps, delta=delta, surr=True, known=True, surr_loader=surr_loader, surr_model=surrmodel, kl_distance=kl_distance)
model = model.to('cpu')
print_eval(smodel)
smodel = smodel.to('cpu')

TypeError: forget() got an unexpected keyword argument 'linear'