In [1]:
import torch
import torch.nn as nn
from torchvision.transforms import v2
from torch.utils.data import TensorDataset
import numpy as np
from torchvision.models import ResNet18_Weights, resnet18
from torchvision.transforms.v2.functional import InterpolationMode
import math

from src.data import (get_train_test_datasets, get_dataloaders,
                      get_retain_forget_datasets, get_exact_surr_datasets)
from src.train import train
from src.eval import evaluate
from src.utils import set_seed
from src.forget import forget, sample_from_exact_marginal, estimate_marginal_kl_distance
from src.loss import L2RegularizedCrossEntropyLoss

set_seed(42)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [2]:
def log_eval(model, train_loader, val_loader, retain_loader, forget_loader, surr_loader, criterion, device):
    train_acc = evaluate(train_loader, model, criterion, device=device, log=True)
    test_acc = evaluate(val_loader, model, criterion, device=device, log=True)
    retain_acc = evaluate(retain_loader, model, criterion, device=device, log=True)
    forget_acc = evaluate(forget_loader, model, criterion, device=device, log=True)
    surr_acc = evaluate(surr_loader, model, criterion, device=device, log=True)
    print('train: {}, test: {}, retain: {}, forget: {}, surrogate:{}'.format(train_acc, test_acc, retain_acc, forget_acc, surr_acc))

In [3]:
model = resnet18(weights=ResNet18_Weights.DEFAULT)
feature_extractor = torch.nn.Sequential(*list(model.children())[:-1]).to(device)
criterion = L2RegularizedCrossEntropyLoss()

In [4]:
# td_path, ted_path = None, None
td_path = './data/cifar10_linear_resnet18_train.pth'
ted_path = './data/cifar10_linear_resnet18_test.pth'
if td_path is None and ted_path is None:
    transform = v2.Compose([
        v2.Resize((224, 224), interpolation=InterpolationMode.BILINEAR),  # ResNet18 expects 224x224 input size
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalization for pretrained models
    ])
    
    
    train_dataset, val_dataset = get_train_test_datasets('cifar10', transform)
    train_loader, val_loader = get_dataloaders([train_dataset, val_dataset], batch_size=256)
    etrain_data, etrain_label, eval_data, eval_label = [], [], [], []
    with torch.no_grad():
        for data, label in train_loader:
            data = data.to(device)
            edata = feature_extractor(data).to('cpu')
            etrain_data.append(edata.view(edata.shape[0], -1))
            etrain_label.append(label)
        for data, label in val_loader:
            data = data.to(device)
            edata = feature_extractor(data).to('cpu')
            eval_data.append(edata.view(edata.shape[0], -1))
            eval_label.append(label)
    etrain_data = torch.cat(etrain_data, dim=0)
    etrain_label = torch.cat(etrain_label, dim=0)
    eval_data = torch.cat(eval_data, dim=0)
    eval_label = torch.cat(eval_label, dim=0)
    train_dataset = TensorDataset(etrain_data, etrain_label)
    val_dataset = TensorDataset(eval_data, eval_label)
else:
    etrain_data = torch.load(td_path, weights_only=False)
    eval_data = torch.load(ted_path, weights_only=False)
    train_dataset = TensorDataset(etrain_data['data'], etrain_data['label'])
    val_dataset = TensorDataset(eval_data['data'], eval_data['label'])

In [5]:
exact_ratios = np.asarray([0.2, 0.05, 0.05, 0.05, 0.2, 0.1, 0.05, 0.2, 0.05, 0.05])
surr_ratios = np.asarray([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
exact_size = int(len(train_dataset) / 2)
surr_size = len(train_dataset) - exact_size
train_dataset, surr_dataset = get_exact_surr_datasets(train_dataset, 
                                                      target_size=exact_size, target_ratios=exact_ratios,
                                                      starget_size=surr_size, starget_ratios=surr_ratios)
retain_dataset, forget_dataset = get_retain_forget_datasets(train_dataset, 0.05)
train_loader, val_loader = get_dataloaders([train_dataset, val_dataset], batch_size=256)
retain_loader = get_dataloaders(retain_dataset, batch_size=256)
forget_loader = get_dataloaders(forget_dataset, batch_size=256)
surr_loader = get_dataloaders(surr_dataset, batch_size=256)

In [6]:
# train with all
feature_extractor = feature_extractor.to('cpu') # just to clear the GPU
model = nn.Linear(512, 10, bias=False).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
train(train_loader, val_loader, model, criterion, optimizer, num_epoch=10, device=device)

log_eval(model, train_loader, val_loader, retain_loader, forget_loader, surr_loader, criterion, device)

train epoch 1: 100%|██████████| 98/98 [00:00<00:00, 191.16batch/s, loss=0.908]
eval: 100%|██████████| 40/40 [00:00<00:00, 334.16batch/s, acc=0.657, loss=0.886]
train epoch 2: 100%|██████████| 98/98 [00:00<00:00, 334.26batch/s, loss=0.718]
eval: 100%|██████████| 40/40 [00:00<00:00, 449.11batch/s, acc=0.724, loss=0.627]
train epoch 3: 100%|██████████| 98/98 [00:00<00:00, 334.73batch/s, loss=0.65] 
eval: 100%|██████████| 40/40 [00:00<00:00, 401.33batch/s, acc=0.733, loss=0.864]
train epoch 4: 100%|██████████| 98/98 [00:00<00:00, 292.39batch/s, loss=0.689]
eval: 100%|██████████| 40/40 [00:00<00:00, 401.54batch/s, acc=0.734, loss=0.48]
train epoch 5: 100%|██████████| 98/98 [00:00<00:00, 299.68batch/s, loss=0.624]
eval: 100%|██████████| 40/40 [00:00<00:00, 211.14batch/s, acc=0.749, loss=0.863]
train epoch 6: 100%|██████████| 98/98 [00:00<00:00, 328.41batch/s, loss=0.703]
eval: 100%|██████████| 40/40 [00:00<00:00, 454.85batch/s, acc=0.764, loss=0.622]
train epoch 7: 100%|██████████| 98/98 [00

train: 0.82416, test: 0.7567, retain: 0.8245894736842105, forget: 0.816, surrogate:0.76616





In [7]:
egensample_loader = sample_from_exact_marginal(model, 300, [512], 256, input_range=[-1, 1], max_iter=400)
model = model.to('cpu')

Generating samples with size [512]...
#########################################
sample 1 generated
#########################################
#########################################
sample 2 generated
#########################################
#########################################
sample 3 generated
#########################################
#########################################
sample 4 generated
#########################################
#########################################
sample 5 generated
#########################################
#########################################
sample 6 generated
#########################################
#########################################
sample 7 generated
#########################################
#########################################
sample 8 generated
#########################################
#########################################
sample 9 generated
#########################################
###################################

In [8]:
smodel = nn.Linear(512, 10, bias=False).to(device)
optimizer = torch.optim.Adam(smodel.parameters(), lr=0.001)
train(surr_loader, val_loader, smodel, criterion, optimizer, num_epoch=10, device=device)
log_eval(smodel, train_loader, val_loader, retain_loader, forget_loader, surr_loader, criterion, device)

train epoch 1: 100%|██████████| 98/98 [00:00<00:00, 221.27batch/s, loss=0.943]
eval: 100%|██████████| 40/40 [00:00<00:00, 402.31batch/s, acc=0.747, loss=0.928]
train epoch 2: 100%|██████████| 98/98 [00:00<00:00, 325.40batch/s, loss=0.79] 
eval: 100%|██████████| 40/40 [00:00<00:00, 429.82batch/s, acc=0.775, loss=0.677]
train epoch 3: 100%|██████████| 98/98 [00:00<00:00, 318.90batch/s, loss=0.769]
eval: 100%|██████████| 40/40 [00:00<00:00, 414.28batch/s, acc=0.782, loss=0.623]
train epoch 4: 100%|██████████| 98/98 [00:00<00:00, 328.35batch/s, loss=0.795]
eval: 100%|██████████| 40/40 [00:00<00:00, 429.95batch/s, acc=0.786, loss=0.854]
train epoch 5: 100%|██████████| 98/98 [00:00<00:00, 310.94batch/s, loss=0.73] 
eval: 100%|██████████| 40/40 [00:00<00:00, 428.25batch/s, acc=0.788, loss=0.833]
train epoch 6: 100%|██████████| 98/98 [00:00<00:00, 321.85batch/s, loss=0.749]
eval: 100%|██████████| 40/40 [00:00<00:00, 445.46batch/s, acc=0.795, loss=1.09]
train epoch 7: 100%|██████████| 98/98 [00

train: 0.78684, test: 0.7969, retain: 0.7872842105263158, forget: 0.7784, surrogate:0.81236





In [9]:
sgensample_loader = sample_from_exact_marginal(smodel, 300, [512], 256, input_range=[-1, 1], max_iter=400)
smodel = smodel.to('cpu')

Generating samples with size [512]...
#########################################
sample 1 generated
#########################################
#########################################
sample 2 generated
#########################################
#########################################
sample 3 generated
#########################################
#########################################
sample 4 generated
#########################################
#########################################
sample 5 generated
#########################################
#########################################
sample 6 generated
#########################################
#########################################
sample 7 generated
#########################################
#########################################
sample 8 generated
#########################################
#########################################
sample 9 generated
#########################################
###################################

In [10]:
_, kl_distance_sgen = estimate_marginal_kl_distance(sgensample_loader, egensample_loader, device)
_.to('cpu')
del _

Epoch 1, DV KL: -0.0179
Epoch 2, DV KL: -0.0008
Epoch 3, DV KL: 0.0362
Epoch 4, DV KL: 0.0673
Epoch 5, DV KL: 0.0959
Epoch 6, DV KL: 0.1342
Epoch 7, DV KL: 0.1340
Epoch 8, DV KL: 0.1779
Epoch 9, DV KL: 0.1741
Epoch 10, DV KL: 0.2312
Epoch 11, DV KL: 0.2547
Epoch 12, DV KL: 0.2784
Epoch 13, DV KL: 0.2861
Epoch 14, DV KL: 0.3248
Epoch 15, DV KL: 0.3545
Epoch 16, DV KL: 0.3617
Epoch 17, DV KL: 0.3970
Epoch 18, DV KL: 0.3964
Epoch 19, DV KL: 0.4606
Epoch 20, DV KL: 0.4876
Epoch 21, DV KL: 0.5013
Epoch 22, DV KL: 0.5248
Epoch 23, DV KL: 0.5658
Epoch 24, DV KL: 0.5961
Epoch 25, DV KL: 0.6067
Epoch 26, DV KL: 0.6404
Epoch 27, DV KL: 0.6903
Epoch 28, DV KL: 0.7045
Epoch 29, DV KL: 0.7323
Epoch 30, DV KL: 0.7929
Epoch 31, DV KL: 0.8091
Epoch 32, DV KL: 0.8374
Epoch 33, DV KL: 0.8675
Epoch 34, DV KL: 0.9126
Epoch 35, DV KL: 0.9323
Epoch 36, DV KL: 1.0218
Epoch 37, DV KL: 1.0506
Epoch 38, DV KL: 1.0802
Epoch 39, DV KL: 1.1440
Epoch 40, DV KL: 1.1765
Epoch 41, DV KL: 1.1922
Epoch 42, DV KL: 1.2746

In [11]:
# retrain from scratch
rmodel = nn.Linear(512, 10, bias=False).to(device)
optimizer = torch.optim.Adam(rmodel.parameters(), lr=0.001)
train(retain_loader, val_loader, rmodel, criterion, optimizer, num_epoch=10, device=device)

log_eval(rmodel, train_loader, val_loader, retain_loader, forget_loader, surr_loader, criterion, device)
rmodel = rmodel.to('cpu')

train epoch 1: 100%|██████████| 93/93 [00:00<00:00, 319.65batch/s, loss=0.885]
eval: 100%|██████████| 40/40 [00:00<00:00, 432.03batch/s, acc=0.654, loss=0.923]
train epoch 2: 100%|██████████| 93/93 [00:00<00:00, 313.46batch/s, loss=0.851]
eval: 100%|██████████| 40/40 [00:00<00:00, 400.70batch/s, acc=0.71, loss=0.919]
train epoch 3: 100%|██████████| 93/93 [00:00<00:00, 309.94batch/s, loss=0.687]
eval: 100%|██████████| 40/40 [00:00<00:00, 380.33batch/s, acc=0.729, loss=1.04] 
train epoch 4: 100%|██████████| 93/93 [00:00<00:00, 329.33batch/s, loss=0.768]
eval: 100%|██████████| 40/40 [00:00<00:00, 225.08batch/s, acc=0.752, loss=0.509]
train epoch 5: 100%|██████████| 93/93 [00:00<00:00, 336.82batch/s, loss=0.744]
eval: 100%|██████████| 40/40 [00:00<00:00, 427.82batch/s, acc=0.749, loss=0.773]
train epoch 6: 100%|██████████| 93/93 [00:00<00:00, 330.34batch/s, loss=0.618]
eval: 100%|██████████| 40/40 [00:00<00:00, 452.27batch/s, acc=0.751, loss=0.665]
train epoch 7: 100%|██████████| 93/93 [00

train: 0.82568, test: 0.7617, retain: 0.8266947368421053, forget: 0.8064, surrogate:0.77196





In [12]:
umodel = forget(model, train_loader, forget_loader, forget_loader, criterion, device, eps=5 * (math.e ** 3), delta=1, linear=True)
log_eval(umodel, train_loader, val_loader, retain_loader, forget_loader, surr_loader, criterion, device)
umodel = umodel.to('cpu')

Calculating Hessian: 100%|██████████| 98/98 [06:20<00:00,  3.89s/it]
Calculating Hessian: 100%|██████████| 5/5 [00:17<00:00,  3.41s/it]
eval: 100%|██████████| 98/98 [00:00<00:00, 278.75batch/s, acc=0.822, loss=0.648]
eval: 100%|██████████| 40/40 [00:00<00:00, 451.18batch/s, acc=0.754, loss=1.09]
eval: 100%|██████████| 93/93 [00:00<00:00, 408.28batch/s, acc=0.823, loss=0.784]
eval: 100%|██████████| 5/5 [00:00<00:00, 401.86batch/s, acc=0.793, loss=0.687]
eval: 100%|██████████| 98/98 [00:00<00:00, 441.26batch/s, acc=0.762, loss=0.87] 

train: 0.82184, test: 0.7539, retain: 0.8233684210526315, forget: 0.7928, surrogate:0.76244





In [14]:
smodel = smodel.to(device)
usmodel = forget(model, surr_loader, forget_loader, forget_loader, criterion, device, eps=5 * (math.e ** 3), delta=1, surr=True,
                 known=True, surr_loader=surr_loader, surr_model=smodel, kl_distance=kl_distance_sgen, linear=True)
log_eval(usmodel, train_loader, val_loader, retain_loader, forget_loader, surr_loader, criterion, device)
usmodel = usmodel.to('cpu')
smodel = smodel.to('cpu')

Calculating Hessian: 100%|██████████| 98/98 [06:24<00:00,  3.92s/it]
Calculating Hessian: 100%|██████████| 5/5 [00:17<00:00,  3.45s/it]


1.999959657791455


eval: 100%|██████████| 98/98 [00:00<00:00, 339.99batch/s, acc=0.819, loss=0.713]
eval: 100%|██████████| 40/40 [00:00<00:00, 485.52batch/s, acc=0.74, loss=0.719]
eval: 100%|██████████| 93/93 [00:00<00:00, 431.25batch/s, acc=0.82, loss=0.604] 
eval: 100%|██████████| 5/5 [00:00<00:00, 395.23batch/s, acc=0.791, loss=0.758]
eval: 100%|██████████| 98/98 [00:00<00:00, 422.96batch/s, acc=0.751, loss=0.814]


train: 0.81884, test: 0.7396, retain: 0.8202947368421053, forget: 0.7912, surrogate:0.751
