In [1]:
import torch
import torch.nn as nn
from torchvision.transforms import v2
from torch.utils.data import ConcatDataset, TensorDataset
import numpy as np

from src.data import get_train_test_datasets, get_dataloaders, get_retain_forget_datasets, get_exact_surr_datasets, get_class_ratios
from src.train import train
from src.eval import evaluate
from src.utils import set_seed

set_seed(42)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [2]:
class ShallowModel(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.extractor = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU()
        )
        
        self.classifier = nn.Linear(256, 10)

    def forward(self, x):
        return self.classifier(self.extractor(x))

In [3]:
gtransform = v2.Compose([
    v2.Grayscale(),
    v2.Resize((28, 28)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.5], std=[0.5]),
    v2.Lambda(lambda img: img.view(-1))
])


gtrain_dataset, gval_dataset = get_train_test_datasets('usps', gtransform)
gtrain_loader, gval_loader = get_dataloaders([gtrain_dataset, gval_dataset], batch_size=256)

In [4]:
model = ShallowModel()
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
train(gtrain_loader, gval_loader, model, criterion, optimizer, num_epoch=10, device=device)
evaluate(gval_loader, model, criterion, device=device)

train epoch 1: 100%|██████████| 29/29 [00:01<00:00, 20.81batch/s, loss=0.193]
eval: 100%|██████████| 8/8 [00:00<00:00, 21.14batch/s, acc=0.886, loss=0.578]
train epoch 2: 100%|██████████| 29/29 [00:01<00:00, 23.53batch/s, loss=0.193]
eval: 100%|██████████| 8/8 [00:00<00:00, 24.54batch/s, acc=0.906, loss=0.452]
train epoch 3: 100%|██████████| 29/29 [00:01<00:00, 23.68batch/s, loss=0.152]
eval: 100%|██████████| 8/8 [00:00<00:00, 24.13batch/s, acc=0.906, loss=0.35] 
train epoch 4: 100%|██████████| 29/29 [00:01<00:00, 23.51batch/s, loss=0.174]
eval: 100%|██████████| 8/8 [00:00<00:00, 24.23batch/s, acc=0.917, loss=0.459]
train epoch 5: 100%|██████████| 29/29 [00:01<00:00, 23.46batch/s, loss=0.0952]
eval: 100%|██████████| 8/8 [00:00<00:00, 24.21batch/s, acc=0.926, loss=0.579]
train epoch 6: 100%|██████████| 29/29 [00:01<00:00, 23.89batch/s, loss=0.157] 
eval: 100%|██████████| 8/8 [00:00<00:00, 24.32batch/s, acc=0.92, loss=0.301] 
train epoch 7: 100%|██████████| 29/29 [00:01<00:00, 23.84batch

In [5]:
transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.1307], std=[0.3081]),
    v2.Lambda(lambda img: img.view(-1))
])

train_dataset, val_dataset = get_train_test_datasets('mnist', transform)
train_loader, val_loader = get_dataloaders([train_dataset, val_dataset], batch_size=256)
etrain_data, etrain_label, eval_data, eval_label = [], [], [], []
with torch.no_grad():
    for data, label in train_loader:
        data = data.to(device)
        edata = model.extractor(data).to('cpu')
        etrain_data.append(edata)
        etrain_label.append(label)
    for data, label in val_loader:
        data = data.to(device)
        edata = model.extractor(data).to('cpu')
        eval_data.append(edata)
        eval_label.append(label)
etrain_data = torch.cat(etrain_data, dim=0)
etrain_label = torch.cat(etrain_label, dim=0)
eval_data = torch.cat(eval_data, dim=0)
eval_label = torch.cat(eval_label, dim=0)
train_dataset = TensorDataset(etrain_data, etrain_label)
val_dataset = TensorDataset(eval_data, eval_label)

In [6]:
retain_dataset, forget_dataset = get_retain_forget_datasets(train_dataset, 0.01)
exact_ratios = np.asarray([0.2, 0.05, 0.1, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
surr_ratios = np.asarray([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
exact_size = int(len(retain_dataset) / 2)
surr_size = len(retain_dataset) - exact_size
retain_dataset, surr_dataset = get_exact_surr_datasets(retain_dataset,
                                                      target_size=exact_size, target_ratios=exact_ratios,
                                                      starget_size=surr_size, starget_ratios=surr_ratios)
train_dataset = ConcatDataset([retain_dataset, forget_dataset])
train_loader, val_loader = get_dataloaders([train_dataset, val_dataset], batch_size=256)

In [8]:
model = model.to('cpu') # just to clear the GPU
model = nn.Linear(256, 10, bias=False).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
train(train_loader, val_loader, model, criterion, optimizer, num_epoch=10, device=device)
evaluate(train_loader, model, criterion, device=device)

train epoch 1: 100%|██████████| 119/119 [00:00<00:00, 280.82batch/s, loss=0.645]
eval: 100%|██████████| 40/40 [00:00<00:00, 470.23batch/s, acc=0.839, loss=0.58]
train epoch 2: 100%|██████████| 119/119 [00:00<00:00, 356.37batch/s, loss=0.339]
eval: 100%|██████████| 40/40 [00:00<00:00, 477.38batch/s, acc=0.879, loss=0.438]
train epoch 3: 100%|██████████| 119/119 [00:00<00:00, 349.00batch/s, loss=0.336]
eval: 100%|██████████| 40/40 [00:00<00:00, 469.78batch/s, acc=0.891, loss=0.347]
train epoch 4: 100%|██████████| 119/119 [00:00<00:00, 349.05batch/s, loss=0.271]
eval: 100%|██████████| 40/40 [00:00<00:00, 448.58batch/s, acc=0.901, loss=0.369]
train epoch 5: 100%|██████████| 119/119 [00:00<00:00, 348.03batch/s, loss=0.307]
eval: 100%|██████████| 40/40 [00:00<00:00, 481.34batch/s, acc=0.904, loss=0.24]
train epoch 6: 100%|██████████| 119/119 [00:00<00:00, 267.65batch/s, loss=0.345]
eval: 100%|██████████| 40/40 [00:00<00:00, 474.42batch/s, acc=0.911, loss=0.566]
train epoch 7: 100%|██████████

In [None]:
# calc grad, calc hess, calc update, compare performances (done with this part early)
# code to find out noise
## upper bound of kl distance (until monday) --> show everything you have
## rest of the assumptions