In [1]:
import torch
import torch.nn as nn
from torchvision.transforms import v2
from torch.utils.data import ConcatDataset, TensorDataset
import numpy as np

from src.data import (get_train_test_datasets, get_dataloaders,
                      get_retain_forget_datasets, get_exact_surr_datasets)
from src.train import train
from src.eval import evaluate
from src.utils import set_seed
from src.forget import forget

set_seed(42)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [2]:
class ShallowModel(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.extractor = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU()
        )
        
        self.classifier = nn.Linear(256, 10)

    def forward(self, x):
        return self.classifier(self.extractor(x))

In [3]:
gtransform = v2.Compose([
    v2.Grayscale(),
    v2.Resize((28, 28)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.5], std=[0.5]),
    v2.Lambda(lambda img: img.view(-1))
])


gtrain_dataset, gval_dataset = get_train_test_datasets('usps', gtransform)
gtrain_loader, gval_loader = get_dataloaders([gtrain_dataset, gval_dataset], batch_size=256)

In [4]:
model = ShallowModel()
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
train(gtrain_loader, gval_loader, model, criterion, optimizer, num_epoch=10, device=device)
evaluate(gval_loader, model, criterion, device=device)

train epoch 1: 100%|██████████| 29/29 [00:01<00:00, 21.18batch/s, loss=0.193]
eval: 100%|██████████| 8/8 [00:00<00:00, 22.28batch/s, acc=0.886, loss=0.578]
train epoch 2: 100%|██████████| 29/29 [00:01<00:00, 23.85batch/s, loss=0.193]
eval: 100%|██████████| 8/8 [00:00<00:00, 24.26batch/s, acc=0.906, loss=0.452]
train epoch 3: 100%|██████████| 29/29 [00:01<00:00, 23.82batch/s, loss=0.152]
eval: 100%|██████████| 8/8 [00:00<00:00, 24.29batch/s, acc=0.906, loss=0.35] 
train epoch 4: 100%|██████████| 29/29 [00:01<00:00, 23.78batch/s, loss=0.174]
eval: 100%|██████████| 8/8 [00:00<00:00, 24.30batch/s, acc=0.917, loss=0.459]
train epoch 5: 100%|██████████| 29/29 [00:01<00:00, 23.78batch/s, loss=0.0952]
eval: 100%|██████████| 8/8 [00:00<00:00, 24.26batch/s, acc=0.926, loss=0.579]
train epoch 6: 100%|██████████| 29/29 [00:01<00:00, 23.73batch/s, loss=0.157] 
eval: 100%|██████████| 8/8 [00:00<00:00, 24.11batch/s, acc=0.92, loss=0.301] 
train epoch 7: 100%|██████████| 29/29 [00:01<00:00, 23.77batch

In [5]:
transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.1307], std=[0.3081]),
    v2.Lambda(lambda img: img.view(-1))
])

train_dataset, val_dataset = get_train_test_datasets('mnist', transform)
train_loader, val_loader = get_dataloaders([train_dataset, val_dataset], batch_size=256)
etrain_data, etrain_label, eval_data, eval_label = [], [], [], []
with torch.no_grad():
    for data, label in train_loader:
        data = data.to(device)
        edata = model.extractor(data).to('cpu')
        etrain_data.append(edata)
        etrain_label.append(label)
    for data, label in val_loader:
        data = data.to(device)
        edata = model.extractor(data).to('cpu')
        eval_data.append(edata)
        eval_label.append(label)
etrain_data = torch.cat(etrain_data, dim=0)
etrain_label = torch.cat(etrain_label, dim=0)
eval_data = torch.cat(eval_data, dim=0)
eval_label = torch.cat(eval_label, dim=0)
train_dataset = TensorDataset(etrain_data, etrain_label)
val_dataset = TensorDataset(eval_data, eval_label)

In [6]:
retain_dataset, forget_dataset = get_retain_forget_datasets(train_dataset, 0.01)
exact_ratios = np.asarray([0.2, 0.05, 0.1, 0.05, 0.2, 0.1, 0.05, 0.1, 0.05, 0.1])
surr_ratios = np.asarray([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
exact_size = int(len(retain_dataset) / 2)
surr_size = len(retain_dataset) - exact_size
retain_dataset, surr_dataset = get_exact_surr_datasets(retain_dataset,
                                                      target_size=exact_size, target_ratios=exact_ratios,
                                                      starget_size=surr_size, starget_ratios=surr_ratios)
train_dataset = ConcatDataset([retain_dataset, forget_dataset])
train_loader, val_loader = get_dataloaders([train_dataset, val_dataset], batch_size=256)
retain_loader = get_dataloaders(retain_dataset, batch_size=256)
forget_loader = get_dataloaders(forget_dataset, batch_size=256)
surr_loader = get_dataloaders(surr_dataset, batch_size=256)

def print_eval(model_arg):
    print('#######################################')
    print('train:')
    evaluate(train_loader, model_arg, criterion, device=device)
    print('#######################################')
    print('#######################################')
    print('val:')
    evaluate(val_loader, model_arg, criterion, device=device)
    print('#######################################')
    print('#######################################')
    print('retain:')
    evaluate(retain_loader, model_arg, criterion, device=device)
    print('#######################################')
    print('#######################################')
    print('forget:')
    evaluate(forget_loader, model_arg, criterion, device=device)
    print('#######################################')
    print('#######################################')
    print('surrogate:')
    evaluate(surr_loader, model_arg, criterion, device=device)
    print('#######################################')

In [7]:
# train with all
model = model.to('cpu') # just to clear the GPU
model = nn.Linear(256, 10, bias=False).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
train(train_loader, val_loader, model, criterion, optimizer, num_epoch=10, device=device)

print_eval(model)
model = model.to('cpu')

train epoch 1: 100%|██████████| 119/119 [00:00<00:00, 257.24batch/s, loss=0.671]
eval: 100%|██████████| 40/40 [00:00<00:00, 415.10batch/s, acc=0.82, loss=1.08]
train epoch 2: 100%|██████████| 119/119 [00:00<00:00, 356.44batch/s, loss=0.471]
eval: 100%|██████████| 40/40 [00:00<00:00, 471.24batch/s, acc=0.861, loss=0.445]
train epoch 3: 100%|██████████| 119/119 [00:00<00:00, 365.84batch/s, loss=0.406]
eval: 100%|██████████| 40/40 [00:00<00:00, 478.88batch/s, acc=0.88, loss=0.562]
train epoch 4: 100%|██████████| 119/119 [00:00<00:00, 362.81batch/s, loss=0.283]
eval: 100%|██████████| 40/40 [00:00<00:00, 465.68batch/s, acc=0.895, loss=0.293]
train epoch 5: 100%|██████████| 119/119 [00:00<00:00, 281.74batch/s, loss=0.419]
eval: 100%|██████████| 40/40 [00:00<00:00, 478.52batch/s, acc=0.899, loss=0.476]
train epoch 6: 100%|██████████| 119/119 [00:00<00:00, 362.42batch/s, loss=0.214]
eval: 100%|██████████| 40/40 [00:00<00:00, 482.64batch/s, acc=0.903, loss=0.256]
train epoch 7: 100%|██████████|

#######################################
train:


eval: 100%|██████████| 119/119 [00:00<00:00, 381.31batch/s, acc=0.93, loss=0.242]


#######################################
#######################################
val:


eval: 100%|██████████| 40/40 [00:00<00:00, 472.36batch/s, acc=0.913, loss=0.261]


#######################################
#######################################
retain:


eval: 100%|██████████| 117/117 [00:00<00:00, 310.03batch/s, acc=0.93, loss=0.496]


#######################################
#######################################
forget:


eval: 100%|██████████| 3/3 [00:00<00:00, 503.20batch/s, acc=0.932, loss=0.366]


#######################################
#######################################
surrogate:


eval: 100%|██████████| 117/117 [00:00<00:00, 416.50batch/s, acc=0.912, loss=0.568]

#######################################





In [8]:
# retrain from scratch
rmodel = nn.Linear(256, 10, bias=False).to(device)
optimizer = torch.optim.Adam(rmodel.parameters(), lr=0.001)
train(retain_loader, val_loader, rmodel, criterion, optimizer, num_epoch=10, device=device)

print_eval(rmodel)
rmodel = rmodel.to('cpu')

train epoch 1: 100%|██████████| 117/117 [00:00<00:00, 327.69batch/s, loss=1.25]
eval: 100%|██████████| 40/40 [00:00<00:00, 468.56batch/s, acc=0.815, loss=0.652]
train epoch 2: 100%|██████████| 117/117 [00:00<00:00, 368.79batch/s, loss=0.794]
eval: 100%|██████████| 40/40 [00:00<00:00, 470.94batch/s, acc=0.859, loss=0.554]
train epoch 3: 100%|██████████| 117/117 [00:00<00:00, 369.34batch/s, loss=0.278]
eval: 100%|██████████| 40/40 [00:00<00:00, 485.95batch/s, acc=0.877, loss=0.593]
train epoch 4: 100%|██████████| 117/117 [00:00<00:00, 367.89batch/s, loss=0.193]
eval: 100%|██████████| 40/40 [00:00<00:00, 480.97batch/s, acc=0.886, loss=0.237]
train epoch 5: 100%|██████████| 117/117 [00:00<00:00, 366.23batch/s, loss=0.0764]
eval: 100%|██████████| 40/40 [00:00<00:00, 491.22batch/s, acc=0.895, loss=0.402]
train epoch 6: 100%|██████████| 117/117 [00:00<00:00, 361.65batch/s, loss=0.012]
eval: 100%|██████████| 40/40 [00:00<00:00, 461.15batch/s, acc=0.902, loss=0.333]
train epoch 7: 100%|████████

#######################################
train:


eval: 100%|██████████| 119/119 [00:00<00:00, 390.34batch/s, acc=0.929, loss=0.25] 


#######################################
#######################################
val:


eval: 100%|██████████| 40/40 [00:00<00:00, 479.85batch/s, acc=0.918, loss=0.341]


#######################################
#######################################
retain:


eval: 100%|██████████| 117/117 [00:00<00:00, 421.15batch/s, acc=0.929, loss=0.812]


#######################################
#######################################
forget:


eval: 100%|██████████| 3/3 [00:00<00:00, 395.98batch/s, acc=0.922, loss=0.239]


#######################################
#######################################
surrogate:


eval: 100%|██████████| 117/117 [00:00<00:00, 422.66batch/s, acc=0.914, loss=0.243]

#######################################





In [9]:
# forget with exact
model = model.to(device)
fmodel = forget(model, retain_loader, forget_loader, criterion, linear=True, num_class=10)
model = model.to('cpu')
print_eval(fmodel)
fmodel = fmodel.to('cpu')

#######################################
train:


eval: 100%|██████████| 119/119 [00:00<00:00, 322.85batch/s, acc=0.926, loss=0.181]


#######################################
#######################################
val:


eval: 100%|██████████| 40/40 [00:00<00:00, 502.60batch/s, acc=0.907, loss=0.1]


#######################################
#######################################
retain:


eval: 100%|██████████| 117/117 [00:00<00:00, 426.62batch/s, acc=0.926, loss=0.0489]


#######################################
#######################################
forget:


eval: 100%|██████████| 3/3 [00:00<00:00, 666.54batch/s, acc=0.912, loss=0.351]


#######################################
#######################################
surrogate:


eval: 100%|██████████| 117/117 [00:00<00:00, 460.51batch/s, acc=0.906, loss=0.0371]


#######################################


In [10]:
# forget with surrogate
# forget with exact
model = model.to(device)
smodel = forget(model, surr_loader, forget_loader, criterion, linear=True, num_class=10)
model = model.to('cpu')
print_eval(smodel)
smodel = smodel.to('cpu')

#######################################
train:


eval: 100%|██████████| 119/119 [00:00<00:00, 463.80batch/s, acc=0.925, loss=0.389]


#######################################
#######################################
val:


eval: 100%|██████████| 40/40 [00:00<00:00, 563.78batch/s, acc=0.907, loss=0.196]


#######################################
#######################################
retain:


eval: 100%|██████████| 117/117 [00:00<00:00, 480.24batch/s, acc=0.926, loss=0.533]


#######################################
#######################################
forget:


eval: 100%|██████████| 3/3 [00:00<00:00, 668.49batch/s, acc=0.905, loss=0.256]


#######################################
#######################################
surrogate:


eval: 100%|██████████| 117/117 [00:00<00:00, 465.35batch/s, acc=0.905, loss=0.832]

#######################################



