In [13]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from bio_if.modules.mlp import MLPBlock
from tqdm import tqdm

In [14]:
# load mnist dataset as vectors
gen = torch.Generator().manual_seed(42)
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])
train_data = datasets.MNIST('data', train=True, download=True, transform=transform)
test_data = datasets.MNIST('data', train=False, download=True, transform=transform)
# grab 10% as in Bae et al. 2020
train_data, _ = torch.utils.data.random_split(train_data, [6000, 54000], generator=gen)
test_data, _ = torch.utils.data.random_split(test_data, [1000, 9000], generator=gen)
train_data, val_data = torch.utils.data.random_split(train_data, [5000, 1000], generator=gen)

In [15]:
# select 20 random samples from the training set
idx = torch.randperm(len(train_data), generator=gen)[:20]

# create 20 datasets excluding these samples removed
datasets_removed = [
    torch.utils.data.Subset(train_data, torch.cat([torch.arange(i), torch.arange(i+1, len(train_data))]))
    for i in range(20)
] 
assert all(len(d) == len(train_data) - 1 for d in datasets_removed)

In [16]:
def generate_model():
    D = 128
    return nn.Sequential(
        MLPBlock(784, D),
        MLPBlock(D, D),
        MLPBlock(D, 10, use_relu=False),
    )

In [20]:
# train a model on the original dataset
DEVICE = "cuda:7"
EPOCHS = 1000
model = generate_model().to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, weight_decay=1e-2)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=128, shuffle=False)

val_loss, train_loss, val_acc = 0, 0, 0
for epoch in range(EPOCHS):
    model.train()
    for x, y in tqdm(train_loader, desc=f'Epoch {epoch}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}, val_acc={acc:.4f}', position=0, leave=True):
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()
        y_pred = model(x)
        loss = nn.functional.cross_entropy(y_pred, y)
        train_loss = loss.item()
        loss.backward()
        optimizer.step()
    model.eval()
    correct = 0
    val_loss = 0
    for x, y in val_loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        y_pred = model(x)
        val_loss += nn.functional.cross_entropy(y_pred, y.to(DEVICE), reduction='sum').item()
        correct += (y_pred.argmax(dim=1) == y).sum().item()
    val_loss /= len(val_data)
    acc = correct / len(val_data)

Epoch 0: train_loss=0.0000, val_loss=0.0000, val_acc=0.0980:  10%|█         | 4/40 [00:00<00:01, 32.04it/s]

Epoch 0: train_loss=0.0000, val_loss=0.0000, val_acc=0.0980: 100%|██████████| 40/40 [00:01<00:00, 39.61it/s]
Epoch 1: train_loss=2.2788, val_loss=2.3028, val_acc=0.1240: 100%|██████████| 40/40 [00:00<00:00, 48.48it/s]
Epoch 2: train_loss=2.3188, val_loss=2.3018, val_acc=0.1300: 100%|██████████| 40/40 [00:00<00:00, 48.44it/s]
Epoch 3: train_loss=2.3180, val_loss=2.3008, val_acc=0.1390: 100%|██████████| 40/40 [00:00<00:00, 49.05it/s]
Epoch 4: train_loss=2.3041, val_loss=2.2998, val_acc=0.1500: 100%|██████████| 40/40 [00:00<00:00, 48.85it/s]
Epoch 5: train_loss=2.3068, val_loss=2.2989, val_acc=0.1630: 100%|██████████| 40/40 [00:00<00:00, 48.42it/s]
Epoch 6: train_loss=2.2906, val_loss=2.2979, val_acc=0.1720: 100%|██████████| 40/40 [00:00<00:00, 48.40it/s]
Epoch 7: train_loss=2.3232, val_loss=2.2969, val_acc=0.1830: 100%|██████████| 40/40 [00:00<00:00, 48.46it/s]
Epoch 8: train_loss=2.3068, val_loss=2.2959, val_acc=0.1920: 100%|██████████| 40/40 [00:00<00:00, 48.58it/s]
Epoch 9: train_loss

In [21]:
from bio_if.modules.influence import influence

In [22]:
def dataset_to_list(dataset):
    return [(x, torch.tensor(y)) for x, y in dataset]

In [23]:
influence_scores = influence(model, list(model), dataset_to_list(test_data), dataset_to_list(train_data), dataset_to_list([train_data[i] for i in idx]), DEVICE, torch.nn.CrossEntropyLoss(), aggregate_query_grads=True)

Computing EKFAC factors and pseudo gradients


100%|██████████| 5000/5000 [00:20<00:00, 240.30it/s]


Computing search gradients


100%|██████████| 20/20 [00:00<00:00, 328.38it/s]


Computing iHVP


100%|██████████| 1/1 [00:00<00:00, 201.97it/s]
100%|██████████| 1/1 [00:00<00:00, 291.23it/s]
100%|██████████| 1/1 [00:00<00:00, 283.05it/s]
100%|██████████| 1/1 [00:00<00:00, 287.03it/s]
100%|██████████| 1/1 [00:00<00:00, 297.17it/s]
100%|██████████| 1/1 [00:00<00:00, 287.22it/s]
100%|██████████| 1/1 [00:00<00:00, 291.76it/s]
100%|██████████| 1/1 [00:00<00:00, 274.32it/s]
100%|██████████| 1/1 [00:00<00:00, 262.28it/s]
100%|██████████| 1/1 [00:00<00:00, 257.70it/s]
100%|██████████| 1/1 [00:00<00:00, 296.42it/s]
100%|██████████| 1/1 [00:00<00:00, 259.66it/s]
100%|██████████| 1/1 [00:00<00:00, 253.14it/s]
100%|██████████| 1/1 [00:00<00:00, 252.64it/s]
100%|██████████| 1/1 [00:00<00:00, 246.49it/s]
100%|██████████| 1/1 [00:00<00:00, 233.76it/s]
100%|██████████| 1/1 [00:00<00:00, 227.64it/s]
100%|██████████| 1/1 [00:00<00:00, 217.90it/s]
100%|██████████| 1/1 [00:00<00:00, 209.57it/s]
100%|██████████| 1/1 [00:00<00:00, 208.65it/s]
100%|██████████| 1/1 [00:00<00:00, 198.62it/s]
100%|████████

In [24]:
influence_scores

tensor([ 1.5675e-04,  7.5060e-04,  3.0239e-05, -4.9984e-04, -1.5107e-05,
         4.0858e-05,  1.3132e-05,  5.9237e-04,  8.3166e-05, -2.0216e-04,
         2.3649e-05, -4.2508e-05, -4.7127e-06,  7.4614e-05, -1.8262e-04,
         1.4127e-05, -1.1481e-04,  1.7710e-04,  3.5198e-04, -2.8416e-04],
       device='cuda:7')

In [25]:
# save model and influence scores
torch.save(model.state_dict(), 'mnist_model.pth')
torch.save(influence_scores, 'mnist_influence_scores.pth')

In [26]:
from bio_if.modules.pbo import proximal_bregman_objective

In [27]:
pbo_test_losses = []
for i, dataset in enumerate(datasets_removed):
    model_new = generate_model().to(DEVICE)
    model_new.load_state_dict(torch.load('mnist_model.pth'))
    optimizer = torch.optim.Adam(model_new.parameters(), lr=0.001)

    train_loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=True)
    for epoch in range(EPOCHS // 2):
        model_new.train()
        for x, y in tqdm(train_loader, desc=f'Epoch {epoch} Dataset {i}'):
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            loss = proximal_bregman_objective(x, y, model_new, model, torch.nn.CrossEntropyLoss(), 0.001)
            loss.backward()
            optimizer.step()
    # calculate test loss
    model_new.eval()
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=256, shuffle=False)
    loss = 0
    for x, y in test_loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        y_pred = model_new(x)
        loss += nn.functional.cross_entropy(y_pred, y, reduction='sum').item()
    pbo_test_losses.append(loss / len(test_data))
    print(f'Dataset {i}: test_loss={pbo_test_losses[-1]}')


Epoch 0 Dataset 0:  80%|████████  | 16/20 [00:00<00:00, 20.60it/s]

Epoch 0 Dataset 0: 100%|██████████| 20/20 [00:01<00:00, 19.96it/s]
Epoch 1 Dataset 0: 100%|██████████| 20/20 [00:00<00:00, 22.99it/s]
Epoch 2 Dataset 0: 100%|██████████| 20/20 [00:00<00:00, 23.04it/s]
Epoch 3 Dataset 0: 100%|██████████| 20/20 [00:00<00:00, 22.88it/s]
Epoch 4 Dataset 0: 100%|██████████| 20/20 [00:00<00:00, 23.00it/s]
Epoch 5 Dataset 0: 100%|██████████| 20/20 [00:00<00:00, 22.76it/s]
Epoch 6 Dataset 0: 100%|██████████| 20/20 [00:00<00:00, 23.17it/s]
Epoch 7 Dataset 0: 100%|██████████| 20/20 [00:00<00:00, 22.77it/s]
Epoch 8 Dataset 0: 100%|██████████| 20/20 [00:00<00:00, 23.38it/s]
Epoch 9 Dataset 0: 100%|██████████| 20/20 [00:00<00:00, 22.89it/s]
Epoch 10 Dataset 0: 100%|██████████| 20/20 [00:00<00:00, 23.01it/s]
Epoch 11 Dataset 0: 100%|██████████| 20/20 [00:00<00:00, 23.28it/s]
Epoch 12 Dataset 0: 100%|██████████| 20/20 [00:00<00:00, 22.71it/s]
Epoch 13 Dataset 0: 100%|██████████| 20/20 [00:00<00:00, 22.58it/s]
Epoch 14 Dataset 0: 100%|██████████| 20/20 [00:00<00:00, 2

Dataset 0: test_loss=0.37744741821289063


Epoch 0 Dataset 1: 100%|██████████| 20/20 [00:00<00:00, 23.40it/s]
Epoch 1 Dataset 1: 100%|██████████| 20/20 [00:00<00:00, 23.94it/s]
Epoch 2 Dataset 1: 100%|██████████| 20/20 [00:00<00:00, 23.14it/s]
Epoch 3 Dataset 1: 100%|██████████| 20/20 [00:00<00:00, 23.55it/s]
Epoch 4 Dataset 1: 100%|██████████| 20/20 [00:00<00:00, 23.29it/s]
Epoch 5 Dataset 1: 100%|██████████| 20/20 [00:00<00:00, 23.62it/s]
Epoch 6 Dataset 1: 100%|██████████| 20/20 [00:01<00:00, 19.84it/s]
Epoch 7 Dataset 1: 100%|██████████| 20/20 [00:00<00:00, 23.26it/s]
Epoch 8 Dataset 1: 100%|██████████| 20/20 [00:00<00:00, 23.69it/s]
Epoch 9 Dataset 1: 100%|██████████| 20/20 [00:00<00:00, 23.24it/s]
Epoch 10 Dataset 1: 100%|██████████| 20/20 [00:00<00:00, 23.84it/s]
Epoch 11 Dataset 1: 100%|██████████| 20/20 [00:00<00:00, 23.27it/s]
Epoch 12 Dataset 1: 100%|██████████| 20/20 [00:00<00:00, 23.48it/s]
Epoch 13 Dataset 1: 100%|██████████| 20/20 [00:00<00:00, 23.60it/s]
Epoch 14 Dataset 1: 100%|██████████| 20/20 [00:00<00:00, 2

Dataset 1: test_loss=0.37744741821289063


Epoch 0 Dataset 2: 100%|██████████| 20/20 [00:00<00:00, 24.01it/s]
Epoch 1 Dataset 2: 100%|██████████| 20/20 [00:00<00:00, 24.05it/s]
Epoch 2 Dataset 2: 100%|██████████| 20/20 [00:00<00:00, 23.92it/s]
Epoch 3 Dataset 2: 100%|██████████| 20/20 [00:00<00:00, 24.17it/s]
Epoch 4 Dataset 2: 100%|██████████| 20/20 [00:00<00:00, 23.53it/s]
Epoch 5 Dataset 2: 100%|██████████| 20/20 [00:00<00:00, 24.30it/s]
Epoch 6 Dataset 2: 100%|██████████| 20/20 [00:00<00:00, 23.58it/s]
Epoch 7 Dataset 2: 100%|██████████| 20/20 [00:00<00:00, 24.77it/s]
Epoch 8 Dataset 2: 100%|██████████| 20/20 [00:00<00:00, 23.56it/s]
Epoch 9 Dataset 2: 100%|██████████| 20/20 [00:00<00:00, 24.54it/s]
Epoch 10 Dataset 2: 100%|██████████| 20/20 [00:00<00:00, 23.39it/s]
Epoch 11 Dataset 2: 100%|██████████| 20/20 [00:00<00:00, 24.47it/s]
Epoch 12 Dataset 2: 100%|██████████| 20/20 [00:00<00:00, 23.55it/s]
Epoch 13 Dataset 2: 100%|██████████| 20/20 [00:00<00:00, 24.19it/s]
Epoch 14 Dataset 2: 100%|██████████| 20/20 [00:00<00:00, 2

Dataset 2: test_loss=0.37744741821289063


Epoch 0 Dataset 3: 100%|██████████| 20/20 [00:00<00:00, 24.16it/s]
Epoch 1 Dataset 3: 100%|██████████| 20/20 [00:00<00:00, 25.33it/s]
Epoch 2 Dataset 3: 100%|██████████| 20/20 [00:00<00:00, 23.77it/s]
Epoch 3 Dataset 3: 100%|██████████| 20/20 [00:00<00:00, 24.89it/s]
Epoch 4 Dataset 3: 100%|██████████| 20/20 [00:00<00:00, 24.34it/s]
Epoch 5 Dataset 3: 100%|██████████| 20/20 [00:00<00:00, 25.27it/s]
Epoch 6 Dataset 3: 100%|██████████| 20/20 [00:00<00:00, 24.91it/s]
Epoch 7 Dataset 3: 100%|██████████| 20/20 [00:00<00:00, 25.32it/s]
Epoch 8 Dataset 3: 100%|██████████| 20/20 [00:00<00:00, 24.50it/s]
Epoch 9 Dataset 3: 100%|██████████| 20/20 [00:00<00:00, 24.61it/s]
Epoch 10 Dataset 3: 100%|██████████| 20/20 [00:00<00:00, 25.10it/s]
Epoch 11 Dataset 3: 100%|██████████| 20/20 [00:00<00:00, 24.76it/s]
Epoch 12 Dataset 3: 100%|██████████| 20/20 [00:00<00:00, 25.17it/s]
Epoch 13 Dataset 3: 100%|██████████| 20/20 [00:00<00:00, 25.04it/s]
Epoch 14 Dataset 3: 100%|██████████| 20/20 [00:00<00:00, 2

Dataset 3: test_loss=0.37744741821289063


Epoch 0 Dataset 4: 100%|██████████| 20/20 [00:00<00:00, 25.71it/s]
Epoch 1 Dataset 4: 100%|██████████| 20/20 [00:00<00:00, 24.73it/s]
Epoch 2 Dataset 4: 100%|██████████| 20/20 [00:00<00:00, 25.70it/s]
Epoch 3 Dataset 4: 100%|██████████| 20/20 [00:00<00:00, 24.88it/s]
Epoch 4 Dataset 4: 100%|██████████| 20/20 [00:00<00:00, 22.19it/s]
Epoch 5 Dataset 4: 100%|██████████| 20/20 [00:00<00:00, 25.51it/s]
Epoch 6 Dataset 4: 100%|██████████| 20/20 [00:00<00:00, 25.44it/s]
Epoch 7 Dataset 4: 100%|██████████| 20/20 [00:00<00:00, 25.68it/s]
Epoch 8 Dataset 4: 100%|██████████| 20/20 [00:00<00:00, 26.00it/s]
Epoch 9 Dataset 4: 100%|██████████| 20/20 [00:00<00:00, 25.84it/s]
Epoch 10 Dataset 4: 100%|██████████| 20/20 [00:00<00:00, 25.79it/s]
Epoch 11 Dataset 4: 100%|██████████| 20/20 [00:00<00:00, 25.31it/s]
Epoch 12 Dataset 4: 100%|██████████| 20/20 [00:00<00:00, 25.90it/s]
Epoch 13 Dataset 4: 100%|██████████| 20/20 [00:00<00:00, 25.03it/s]
Epoch 14 Dataset 4: 100%|██████████| 20/20 [00:00<00:00, 2

Dataset 4: test_loss=0.37744741821289063


Epoch 0 Dataset 5: 100%|██████████| 20/20 [00:00<00:00, 23.84it/s]
Epoch 1 Dataset 5: 100%|██████████| 20/20 [00:00<00:00, 23.68it/s]
Epoch 2 Dataset 5: 100%|██████████| 20/20 [00:00<00:00, 23.92it/s]
Epoch 3 Dataset 5: 100%|██████████| 20/20 [00:00<00:00, 23.41it/s]
Epoch 4 Dataset 5: 100%|██████████| 20/20 [00:00<00:00, 23.29it/s]
Epoch 5 Dataset 5: 100%|██████████| 20/20 [00:00<00:00, 23.54it/s]
Epoch 6 Dataset 5: 100%|██████████| 20/20 [00:00<00:00, 23.62it/s]
Epoch 7 Dataset 5: 100%|██████████| 20/20 [00:00<00:00, 23.84it/s]
Epoch 8 Dataset 5: 100%|██████████| 20/20 [00:00<00:00, 23.63it/s]
Epoch 9 Dataset 5: 100%|██████████| 20/20 [00:00<00:00, 23.96it/s]
Epoch 10 Dataset 5: 100%|██████████| 20/20 [00:00<00:00, 23.43it/s]
Epoch 11 Dataset 5: 100%|██████████| 20/20 [00:00<00:00, 23.94it/s]
Epoch 12 Dataset 5: 100%|██████████| 20/20 [00:00<00:00, 23.07it/s]
Epoch 13 Dataset 5: 100%|██████████| 20/20 [00:00<00:00, 24.08it/s]
Epoch 14 Dataset 5: 100%|██████████| 20/20 [00:00<00:00, 2

Dataset 5: test_loss=0.37744741821289063


Epoch 0 Dataset 6: 100%|██████████| 20/20 [00:00<00:00, 23.38it/s]
Epoch 1 Dataset 6: 100%|██████████| 20/20 [00:00<00:00, 22.72it/s]
Epoch 2 Dataset 6: 100%|██████████| 20/20 [00:00<00:00, 23.16it/s]
Epoch 3 Dataset 6: 100%|██████████| 20/20 [00:00<00:00, 22.86it/s]
Epoch 4 Dataset 6: 100%|██████████| 20/20 [00:00<00:00, 23.02it/s]
Epoch 5 Dataset 6: 100%|██████████| 20/20 [00:00<00:00, 22.75it/s]
Epoch 6 Dataset 6: 100%|██████████| 20/20 [00:00<00:00, 23.06it/s]
Epoch 7 Dataset 6: 100%|██████████| 20/20 [00:00<00:00, 23.06it/s]
Epoch 8 Dataset 6: 100%|██████████| 20/20 [00:00<00:00, 22.84it/s]
Epoch 9 Dataset 6: 100%|██████████| 20/20 [00:00<00:00, 23.07it/s]
Epoch 10 Dataset 6: 100%|██████████| 20/20 [00:00<00:00, 22.58it/s]
Epoch 11 Dataset 6: 100%|██████████| 20/20 [00:00<00:00, 23.17it/s]
Epoch 12 Dataset 6: 100%|██████████| 20/20 [00:00<00:00, 22.45it/s]
Epoch 13 Dataset 6: 100%|██████████| 20/20 [00:00<00:00, 23.39it/s]
Epoch 14 Dataset 6: 100%|██████████| 20/20 [00:00<00:00, 2

Dataset 6: test_loss=0.37744741821289063


Epoch 0 Dataset 7: 100%|██████████| 20/20 [00:00<00:00, 22.78it/s]
Epoch 1 Dataset 7: 100%|██████████| 20/20 [00:00<00:00, 22.71it/s]
Epoch 2 Dataset 7: 100%|██████████| 20/20 [00:00<00:00, 23.12it/s]
Epoch 3 Dataset 7: 100%|██████████| 20/20 [00:01<00:00, 19.77it/s]
Epoch 4 Dataset 7: 100%|██████████| 20/20 [00:00<00:00, 23.34it/s]
Epoch 5 Dataset 7: 100%|██████████| 20/20 [00:00<00:00, 22.44it/s]
Epoch 6 Dataset 7: 100%|██████████| 20/20 [00:00<00:00, 23.12it/s]
Epoch 7 Dataset 7: 100%|██████████| 20/20 [00:00<00:00, 22.86it/s]
Epoch 8 Dataset 7: 100%|██████████| 20/20 [00:00<00:00, 22.77it/s]
Epoch 9 Dataset 7: 100%|██████████| 20/20 [00:00<00:00, 22.39it/s]
Epoch 10 Dataset 7: 100%|██████████| 20/20 [00:00<00:00, 22.29it/s]
Epoch 11 Dataset 7: 100%|██████████| 20/20 [00:00<00:00, 22.99it/s]
Epoch 12 Dataset 7: 100%|██████████| 20/20 [00:00<00:00, 22.83it/s]
Epoch 13 Dataset 7: 100%|██████████| 20/20 [00:00<00:00, 23.27it/s]
Epoch 14 Dataset 7: 100%|██████████| 20/20 [00:00<00:00, 2

KeyboardInterrupt: 

In [None]:
print(pbo_test_losses)
torch.save(pbo_test_losses, 'mnist_pbo_test_losses.pth')

In [None]:
from scipy.stats import spearmanr, pearsonr

In [None]:
# calculate correlation between influence scores and PBO test losses
print(f"Spearman's rho: {spearmanr(influence_scores, pbo_test_losses)}")
print(f"Pearson's r: {pearsonr(influence_scores, pbo_test_losses)}")