In [1]:
import torch
import torch.nn as nn
from torchvision.transforms import v2
from torch.utils.data import ConcatDataset, TensorDataset
import numpy as np
from torchvision.models import ResNet18_Weights, resnet18
from torchvision.transforms.v2.functional import InterpolationMode

from src.data import (get_train_test_datasets, get_dataloaders,
                      get_retain_forget_datasets, get_exact_surr_datasets)
from src.train import train
from src.eval import evaluate
from src.utils import set_seed
from src.forget import forget

set_seed(42)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [16]:
model = resnet18(weights=ResNet18_Weights.DEFAULT)
feature_extractor = torch.nn.Sequential(*list(model.children())[:-1]).to(device)
criterion = nn.CrossEntropyLoss()

In [17]:
transform = v2.Compose([
    v2.Resize((224, 224), interpolation=InterpolationMode.BILINEAR),  # ResNet18 expects 224x224 input size
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalization for pretrained models
])


train_dataset, val_dataset = get_train_test_datasets('cifar10', transform)
train_loader, val_loader = get_dataloaders([train_dataset, val_dataset], batch_size=256)
etrain_data, etrain_label, eval_data, eval_label = [], [], [], []
with torch.no_grad():
    for data, label in train_loader:
        data = data.to(device)
        edata = feature_extractor(data).to('cpu')
        etrain_data.append(edata.view(edata.shape[0], -1))
        etrain_label.append(label)
    for data, label in val_loader:
        data = data.to(device)
        edata = feature_extractor(data).to('cpu')
        eval_data.append(edata.view(edata.shape[0], -1))
        eval_label.append(label)
etrain_data = torch.cat(etrain_data, dim=0)
etrain_label = torch.cat(etrain_label, dim=0)
eval_data = torch.cat(eval_data, dim=0)
eval_label = torch.cat(eval_label, dim=0)
train_dataset = TensorDataset(etrain_data, etrain_label)
val_dataset = TensorDataset(eval_data, eval_label)

Files already downloaded and verified
Files already downloaded and verified


In [18]:
exact_ratios = np.asarray([0.2, 0.05, 0.05, 0.05, 0.2, 0.1, 0.05, 0.2, 0.05, 0.05])
surr_ratios = np.asarray([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
exact_size = int(len(train_dataset) / 2)
surr_size = len(train_dataset) - exact_size
train_dataset, surr_dataset = get_exact_surr_datasets(train_dataset, 
                                                      target_size=exact_size, target_ratios=exact_ratios,
                                                      starget_size=surr_size, starget_ratios=surr_ratios)
retain_dataset, forget_dataset = get_retain_forget_datasets(train_dataset, 0.01)
train_loader, val_loader = get_dataloaders([train_dataset, val_dataset], batch_size=256)
retain_loader = get_dataloaders(retain_dataset, batch_size=256)
forget_loader = get_dataloaders(forget_dataset, batch_size=256)
surr_loader = get_dataloaders(surr_dataset, batch_size=256)

def print_eval(model_arg):
    print('#######################################')
    print('train:')
    evaluate(train_loader, model_arg, criterion, device=device)
    print('#######################################')
    print('#######################################')
    print('val:')
    evaluate(val_loader, model_arg, criterion, device=device)
    print('#######################################')
    print('#######################################')
    print('retain:')
    evaluate(retain_loader, model_arg, criterion, device=device)
    print('#######################################')
    print('#######################################')
    print('forget:')
    evaluate(forget_loader, model_arg, criterion, device=device)
    print('#######################################')
    print('#######################################')
    print('surrogate:')
    evaluate(surr_loader, model_arg, criterion, device=device)
    print('#######################################')

In [20]:
# train with all
feature_extractor = feature_extractor.to('cpu') # just to clear the GPU
model = nn.Linear(512, 10, bias=False).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
train(train_loader, val_loader, model, criterion, optimizer, num_epoch=10, device=device)

print_eval(model)
model = model.to('cpu')

train epoch 1: 100%|██████████| 98/98 [00:00<00:00, 334.04batch/s, loss=0.737]
eval: 100%|██████████| 40/40 [00:00<00:00, 448.49batch/s, acc=0.666, loss=0.886]
train epoch 2: 100%|██████████| 98/98 [00:00<00:00, 428.29batch/s, loss=0.677]
eval: 100%|██████████| 40/40 [00:00<00:00, 575.12batch/s, acc=0.733, loss=0.8]
train epoch 3: 100%|██████████| 98/98 [00:00<00:00, 431.66batch/s, loss=0.657]
eval: 100%|██████████| 40/40 [00:00<00:00, 577.50batch/s, acc=0.749, loss=0.675]
train epoch 4: 100%|██████████| 98/98 [00:00<00:00, 426.01batch/s, loss=0.72] 
eval: 100%|██████████| 40/40 [00:00<00:00, 577.47batch/s, acc=0.756, loss=0.612]
train epoch 5: 100%|██████████| 98/98 [00:00<00:00, 289.91batch/s, loss=0.472]
eval: 100%|██████████| 40/40 [00:00<00:00, 580.61batch/s, acc=0.758, loss=0.703]
train epoch 6: 100%|██████████| 98/98 [00:00<00:00, 424.27batch/s, loss=0.562]
eval: 100%|██████████| 40/40 [00:00<00:00, 544.56batch/s, acc=0.767, loss=0.973]
train epoch 7: 100%|██████████| 98/98 [00:

#######################################
train:


eval: 100%|██████████| 98/98 [00:00<00:00, 481.24batch/s, acc=0.829, loss=0.476]


#######################################
#######################################
val:


eval: 100%|██████████| 40/40 [00:00<00:00, 541.03batch/s, acc=0.769, loss=1.06]


#######################################
#######################################
retain:


eval: 100%|██████████| 97/97 [00:00<00:00, 478.10batch/s, acc=0.829, loss=0.543]


#######################################
#######################################
forget:


eval: 100%|██████████| 1/1 [00:00<00:00, 449.45batch/s, acc=0.828, loss=0.488]


#######################################
#######################################
surrogate:


eval: 100%|██████████| 98/98 [00:00<00:00, 481.51batch/s, acc=0.782, loss=0.72] 

#######################################





In [22]:
# retrain from scratch
rmodel = nn.Linear(512, 10, bias=False).to(device)
optimizer = torch.optim.Adam(rmodel.parameters(), lr=0.001)
train(retain_loader, val_loader, rmodel, criterion, optimizer, num_epoch=10, device=device)

print_eval(rmodel)
rmodel = rmodel.to('cpu')

train epoch 1: 100%|██████████| 97/97 [00:00<00:00, 337.87batch/s, loss=0.758]
eval: 100%|██████████| 40/40 [00:00<00:00, 553.16batch/s, acc=0.654, loss=1.13]
train epoch 2: 100%|██████████| 97/97 [00:00<00:00, 408.94batch/s, loss=0.69] 
eval: 100%|██████████| 40/40 [00:00<00:00, 542.11batch/s, acc=0.72, loss=0.676]
train epoch 3: 100%|██████████| 97/97 [00:00<00:00, 406.68batch/s, loss=0.54] 
eval: 100%|██████████| 40/40 [00:00<00:00, 544.69batch/s, acc=0.737, loss=0.587]
train epoch 4: 100%|██████████| 97/97 [00:00<00:00, 404.86batch/s, loss=0.575]
eval: 100%|██████████| 40/40 [00:00<00:00, 544.79batch/s, acc=0.755, loss=1.16]
train epoch 5: 100%|██████████| 97/97 [00:00<00:00, 278.07batch/s, loss=0.445]
eval: 100%|██████████| 40/40 [00:00<00:00, 545.00batch/s, acc=0.76, loss=0.503]
train epoch 6: 100%|██████████| 97/97 [00:00<00:00, 403.53batch/s, loss=0.58] 
eval: 100%|██████████| 40/40 [00:00<00:00, 547.18batch/s, acc=0.761, loss=0.575]
train epoch 7: 100%|██████████| 97/97 [00:00

#######################################
train:


eval: 100%|██████████| 98/98 [00:00<00:00, 479.50batch/s, acc=0.834, loss=0.445]


#######################################
#######################################
val:


eval: 100%|██████████| 40/40 [00:00<00:00, 565.41batch/s, acc=0.767, loss=0.877]


#######################################
#######################################
retain:


eval: 100%|██████████| 97/97 [00:00<00:00, 473.75batch/s, acc=0.834, loss=0.508]


#######################################
#######################################
forget:


eval: 100%|██████████| 1/1 [00:00<00:00, 374.73batch/s, acc=0.836, loss=0.536]


#######################################
#######################################
surrogate:


eval: 100%|██████████| 98/98 [00:00<00:00, 479.16batch/s, acc=0.777, loss=0.704]

#######################################





In [23]:
# forget with exact
model = model.to(device)
fmodel = forget(model, train_loader, forget_loader, forget_loader, criterion, linear=True, num_class=10)
model = model.to('cpu')
print_eval(fmodel)
fmodel = fmodel.to('cpu')

#######################################
train:


eval: 100%|██████████| 98/98 [00:00<00:00, 518.70batch/s, acc=0.824, loss=0.558]


#######################################
#######################################
val:


eval: 100%|██████████| 40/40 [00:00<00:00, 565.11batch/s, acc=0.765, loss=0.723]


#######################################
#######################################
retain:


eval: 100%|██████████| 97/97 [00:00<00:00, 474.05batch/s, acc=0.825, loss=0.45] 


#######################################
#######################################
forget:


eval: 100%|██████████| 1/1 [00:00<00:00, 384.52batch/s, acc=0.712, loss=1.04]


#######################################
#######################################
surrogate:


eval: 100%|██████████| 98/98 [00:00<00:00, 491.61batch/s, acc=0.776, loss=0.533]


#######################################


In [24]:
# forget with surrogate
# forget with exact
model = model.to(device)
smodel = forget(model, surr_loader, forget_loader, forget_loader, criterion, linear=True, num_class=10)
model = model.to('cpu')
print_eval(smodel)
smodel = smodel.to('cpu')

#######################################
train:


eval: 100%|██████████| 98/98 [00:00<00:00, 319.89batch/s, acc=0.825, loss=0.472]


#######################################
#######################################
val:


eval: 100%|██████████| 40/40 [00:00<00:00, 545.33batch/s, acc=0.765, loss=1.31]


#######################################
#######################################
retain:


eval: 100%|██████████| 97/97 [00:00<00:00, 474.81batch/s, acc=0.826, loss=0.523]


#######################################
#######################################
forget:


eval: 100%|██████████| 1/1 [00:00<00:00, 384.83batch/s, acc=0.716, loss=1.04]


#######################################
#######################################
surrogate:


eval: 100%|██████████| 98/98 [00:00<00:00, 479.85batch/s, acc=0.777, loss=0.63] 


#######################################
