In [1]:
import torch
import torch.nn as nn
from torchvision.transforms import v2
from torch.utils.data import ConcatDataset, TensorDataset
import numpy as np

from src.data import (get_train_test_datasets, get_dataloaders,
                      get_retain_forget_datasets, get_exact_surr_datasets)
from src.train import train
from src.eval import evaluate
from src.utils import set_seed
from src.forget import forget

set_seed(42)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [2]:
class ShallowModel(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.extractor = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU()
        )
        
        self.classifier = nn.Linear(256, 10)

    def forward(self, x):
        return self.classifier(self.extractor(x))

In [3]:
gtransform = v2.Compose([
    v2.Grayscale(),
    v2.Resize((28, 28)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.5], std=[0.5]),
    v2.Lambda(lambda img: img.view(-1))
])


gtrain_dataset, gval_dataset = get_train_test_datasets('usps', gtransform)
gtrain_loader, gval_loader = get_dataloaders([gtrain_dataset, gval_dataset], batch_size=256)

In [4]:
model = ShallowModel()
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
train(gtrain_loader, gval_loader, model, criterion, optimizer, num_epoch=10, device=device)
evaluate(gval_loader, model, criterion, device=device)

train epoch 1: 100%|██████████| 29/29 [00:01<00:00, 21.53batch/s, loss=0.193]
eval: 100%|██████████| 8/8 [00:00<00:00, 22.51batch/s, acc=0.886, loss=0.578]
train epoch 2: 100%|██████████| 29/29 [00:01<00:00, 24.06batch/s, loss=0.193]
eval: 100%|██████████| 8/8 [00:00<00:00, 24.45batch/s, acc=0.906, loss=0.452]
train epoch 3: 100%|██████████| 29/29 [00:01<00:00, 24.03batch/s, loss=0.152]
eval: 100%|██████████| 8/8 [00:00<00:00, 24.81batch/s, acc=0.906, loss=0.35] 
train epoch 4: 100%|██████████| 29/29 [00:01<00:00, 24.22batch/s, loss=0.174]
eval: 100%|██████████| 8/8 [00:00<00:00, 24.64batch/s, acc=0.917, loss=0.459]
train epoch 5: 100%|██████████| 29/29 [00:01<00:00, 24.13batch/s, loss=0.0952]
eval: 100%|██████████| 8/8 [00:00<00:00, 24.53batch/s, acc=0.926, loss=0.579]
train epoch 6: 100%|██████████| 29/29 [00:01<00:00, 24.02batch/s, loss=0.157] 
eval: 100%|██████████| 8/8 [00:00<00:00, 24.49batch/s, acc=0.92, loss=0.301] 
train epoch 7: 100%|██████████| 29/29 [00:01<00:00, 23.99batch

In [5]:
transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.1307], std=[0.3081]),
    v2.Lambda(lambda img: img.view(-1))
])

train_dataset, val_dataset = get_train_test_datasets('mnist', transform)
train_loader, val_loader = get_dataloaders([train_dataset, val_dataset], batch_size=256)
etrain_data, etrain_label, eval_data, eval_label = [], [], [], []
with torch.no_grad():
    for data, label in train_loader:
        data = data.to(device)
        edata = model.extractor(data).to('cpu')
        etrain_data.append(edata)
        etrain_label.append(label)
    for data, label in val_loader:
        data = data.to(device)
        edata = model.extractor(data).to('cpu')
        eval_data.append(edata)
        eval_label.append(label)
etrain_data = torch.cat(etrain_data, dim=0)
etrain_label = torch.cat(etrain_label, dim=0)
eval_data = torch.cat(eval_data, dim=0)
eval_label = torch.cat(eval_label, dim=0)
train_dataset = TensorDataset(etrain_data, etrain_label)
val_dataset = TensorDataset(eval_data, eval_label)

In [6]:
exact_ratios = np.asarray([0.2, 0.05, 0, 0.05, 0.2, 0.1, 0.05, 0.2, 0.05, 0.1])
surr_ratios = np.asarray([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
exact_size = int(len(train_dataset) / 2)
surr_size = len(train_dataset) - exact_size
train_dataset, surr_dataset = get_exact_surr_datasets(train_dataset,
                                                       target_size=exact_size, target_ratios=exact_ratios,
                                                       starget_size=surr_size, starget_ratios=surr_ratios)
retain_dataset, forget_dataset = get_retain_forget_datasets(train_dataset, 0.01)
train_loader, val_loader = get_dataloaders([train_dataset, val_dataset], batch_size=256)
retain_loader = get_dataloaders(retain_dataset, batch_size=256)
forget_loader = get_dataloaders(forget_dataset, batch_size=256)
surr_loader = get_dataloaders(surr_dataset, batch_size=256)

def print_eval(model_arg):
    print('#######################################')
    print('train:')
    evaluate(train_loader, model_arg, criterion, device=device)
    print('#######################################')
    print('#######################################')
    print('val:')
    evaluate(val_loader, model_arg, criterion, device=device)
    print('#######################################')
    print('#######################################')
    print('retain:')
    evaluate(retain_loader, model_arg, criterion, device=device)
    print('#######################################')
    print('#######################################')
    print('forget:')
    evaluate(forget_loader, model_arg, criterion, device=device)
    print('#######################################')
    print('#######################################')
    print('surrogate:')
    evaluate(surr_loader, model_arg, criterion, device=device)
    print('#######################################')

In [7]:
# train with all
model = model.to('cpu') # just to clear the GPU
model = nn.Linear(256, 10, bias=False).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
train(train_loader, val_loader, model, criterion, optimizer, num_epoch=10, device=device)

print_eval(model)
model = model.to('cpu')

train epoch 1: 100%|██████████| 118/118 [00:00<00:00, 365.61batch/s, loss=0.433]
eval: 100%|██████████| 40/40 [00:00<00:00, 556.76batch/s, acc=0.732, loss=1.24]
train epoch 2: 100%|██████████| 118/118 [00:00<00:00, 426.68batch/s, loss=0.377]
eval: 100%|██████████| 40/40 [00:00<00:00, 564.23batch/s, acc=0.785, loss=1.09]
train epoch 3: 100%|██████████| 118/118 [00:00<00:00, 423.85batch/s, loss=0.255]
eval: 100%|██████████| 40/40 [00:00<00:00, 564.01batch/s, acc=0.806, loss=1.29]
train epoch 4: 100%|██████████| 118/118 [00:00<00:00, 423.60batch/s, loss=0.178]
eval: 100%|██████████| 40/40 [00:00<00:00, 564.86batch/s, acc=0.828, loss=0.697]
train epoch 5: 100%|██████████| 118/118 [00:00<00:00, 423.30batch/s, loss=0.299]
eval: 100%|██████████| 40/40 [00:00<00:00, 562.56batch/s, acc=0.834, loss=0.304]
train epoch 6: 100%|██████████| 118/118 [00:00<00:00, 326.44batch/s, loss=0.505]
eval: 100%|██████████| 40/40 [00:00<00:00, 568.23batch/s, acc=0.854, loss=0.446]
train epoch 7: 100%|██████████|

#######################################
train:


eval: 100%|██████████| 118/118 [00:00<00:00, 503.49batch/s, acc=0.938, loss=0.151]


#######################################
#######################################
val:


eval: 100%|██████████| 40/40 [00:00<00:00, 561.84batch/s, acc=0.868, loss=0.0856]


#######################################
#######################################
retain:


eval: 100%|██████████| 117/117 [00:00<00:00, 498.12batch/s, acc=0.938, loss=0.0488]


#######################################
#######################################
forget:


eval: 100%|██████████| 2/2 [00:00<00:00, 560.40batch/s, acc=0.953, loss=0.108]


#######################################
#######################################
surrogate:


eval: 100%|██████████| 118/118 [00:00<00:00, 499.21batch/s, acc=0.869, loss=0.305]

#######################################





In [8]:
# retrain from scratch
rmodel = nn.Linear(256, 10, bias=False).to(device)
optimizer = torch.optim.Adam(rmodel.parameters(), lr=0.001)
train(retain_loader, val_loader, rmodel, criterion, optimizer, num_epoch=10, device=device)

print_eval(rmodel)
rmodel = rmodel.to('cpu')

train epoch 1: 100%|██████████| 117/117 [00:00<00:00, 301.70batch/s, loss=0.883]
eval: 100%|██████████| 40/40 [00:00<00:00, 566.90batch/s, acc=0.735, loss=1.78]
train epoch 2: 100%|██████████| 117/117 [00:00<00:00, 420.21batch/s, loss=0.516]
eval: 100%|██████████| 40/40 [00:00<00:00, 559.28batch/s, acc=0.786, loss=1.64]
train epoch 3: 100%|██████████| 117/117 [00:00<00:00, 417.66batch/s, loss=0.12]
eval: 100%|██████████| 40/40 [00:00<00:00, 565.46batch/s, acc=0.797, loss=0.219]
train epoch 4: 100%|██████████| 117/117 [00:00<00:00, 420.16batch/s, loss=0.581]
eval: 100%|██████████| 40/40 [00:00<00:00, 562.03batch/s, acc=0.809, loss=0.639]
train epoch 5: 100%|██████████| 117/117 [00:00<00:00, 420.48batch/s, loss=0.609]
eval: 100%|██████████| 40/40 [00:00<00:00, 566.50batch/s, acc=0.827, loss=0.896]
train epoch 6: 100%|██████████| 117/117 [00:00<00:00, 420.31batch/s, loss=0.0433]
eval: 100%|██████████| 40/40 [00:00<00:00, 566.47batch/s, acc=0.829, loss=0.722]
train epoch 7: 100%|██████████

#######################################
train:


eval: 100%|██████████| 118/118 [00:00<00:00, 504.17batch/s, acc=0.939, loss=0.278]


#######################################
#######################################
val:


eval: 100%|██████████| 40/40 [00:00<00:00, 560.74batch/s, acc=0.866, loss=0.149]


#######################################
#######################################
retain:


eval: 100%|██████████| 117/117 [00:00<00:00, 499.54batch/s, acc=0.938, loss=0.0937]


#######################################
#######################################
forget:


eval: 100%|██████████| 2/2 [00:00<00:00, 568.64batch/s, acc=0.95, loss=0.101]


#######################################
#######################################
surrogate:


eval: 100%|██████████| 118/118 [00:00<00:00, 493.78batch/s, acc=0.865, loss=0.315]

#######################################





In [9]:
# forget with exact
model = model.to(device)
fmodel = forget(model, train_loader, forget_loader, forget_loader, criterion, linear=True, num_class=10)
model = model.to('cpu')
print_eval(fmodel)
fmodel = fmodel.to('cpu')

#######################################
train:


eval: 100%|██████████| 118/118 [00:00<00:00, 500.24batch/s, acc=0.933, loss=0.296]


#######################################
#######################################
val:


eval: 100%|██████████| 40/40 [00:00<00:00, 557.30batch/s, acc=0.865, loss=0.379]


#######################################
#######################################
retain:


eval: 100%|██████████| 117/117 [00:00<00:00, 364.49batch/s, acc=0.933, loss=0.0121]


#######################################
#######################################
forget:


eval: 100%|██████████| 2/2 [00:00<00:00, 573.50batch/s, acc=0.917, loss=0.29]


#######################################
#######################################
surrogate:


eval: 100%|██████████| 118/118 [00:00<00:00, 499.82batch/s, acc=0.865, loss=0.424]

#######################################





In [10]:
# forget with surrogate
# forget with exact
model = model.to(device)
smodel = forget(model, surr_loader, forget_loader, forget_loader, criterion, linear=True, num_class=10)
model = model.to('cpu')
print_eval(smodel)
smodel = smodel.to('cpu')

#######################################
train:


eval: 100%|██████████| 118/118 [00:00<00:00, 490.43batch/s, acc=0.934, loss=0.171]


#######################################
#######################################
val:


eval: 100%|██████████| 40/40 [00:00<00:00, 560.12batch/s, acc=0.865, loss=0.526]


#######################################
#######################################
retain:


eval: 100%|██████████| 117/117 [00:00<00:00, 495.12batch/s, acc=0.934, loss=0.842]


#######################################
#######################################
forget:


eval: 100%|██████████| 2/2 [00:00<00:00, 561.15batch/s, acc=0.937, loss=0.208]


#######################################
#######################################
surrogate:


eval: 100%|██████████| 118/118 [00:00<00:00, 497.28batch/s, acc=0.867, loss=0.334]

#######################################



