In [None]:
import copy
import pandas as pd
import torch
import torchvision
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from sklearn import linear_model, model_selection
import os

from src.train_cifar_checkpoints import load_cifar10, evaluate
from src.resnet import resnet18

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
def iterate_checkpoints(root, num_checkpoints, prefix="retain"):
    """
    Iterate on checkpoints in directory, if not all checkpoints are available, 
    
    """
    model = resnet18()
    model.linear = nn.Linear(512, 10)
    model.to(DEVICE)

    previous_model = None

    for i in range(num_checkpoints):
        fname = os.path.join(root, f"{prefix}_{i}.pt")
        if os.path.isfile(fname):
            model.load_state_dict(torch.load(fname))
            previous_model = copy.deepcopy(model)
            yield model
        else:
            break

def compute_function_out(out, y, mode="loss"):
    if mode == "loss":
        return F.cross_entropy(out, y, reduction="none")
    elif mode == "confidence":
        return out[torch.arange(len(out)).to(out.device), y]


@torch.no_grad()
def collect_losses(loader, root, root2=None, num_models=100, mode="loss"):
    """ 
    Collect all losses from both model trained from scratch on retain and unlearned models
    """
    losses = []
    labels = []
    labels_member = []
    correct = []

    if root2 is None:
        root2 = root
    
    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)

        batch_losses = []
        batch_labels = []
        batch_labels_member = []
        batch_correct = []
        
        # Retain losses
        for model in iterate_checkpoints(root, num_models, prefix="retain"):
            out = model(x)
            pred = out.argmax(1)

            loss = compute_function_out(out, y, mode)

            batch_losses.append(loss.cpu())
            batch_labels.append(y.cpu())
            batch_labels_member.append(torch.zeros_like(loss).cpu())
            batch_correct.append((pred == y).float())
            
        # Unlearned losses
        for model in iterate_checkpoints(root2, num_models, prefix="unlearn"):
            out = model(x)
            pred = out.argmax(1)
            
            loss = compute_function_out(out, y, mode)
            
            batch_losses.append(loss.cpu())
            batch_labels.append(y.cpu())
            batch_labels_member.append(torch.ones_like(loss).cpu())
            batch_correct.append((pred == y).float())


        losses.append(torch.stack(batch_losses))
        labels.append(torch.stack(batch_labels))
        labels_member.append(torch.stack(batch_labels_member))
        correct.append(torch.stack(batch_correct))


    return torch.cat(losses, dim=1).t(), torch.cat(labels, dim=1).t(), torch.cat(labels_member, dim=1).t(), torch.cat(correct, dim=1).t()
        

In [None]:
train, test = load_cifar10(use_transforms=False)


model = resnet18()
model.linear = nn.Linear(512, 10)
model.to(DEVICE)

model.load_state_dict(torch.load("../cifar10_checkpoints/initial.pt"))

indices_forget = np.load("../cifar10_checkpoints/forget_set.npy")

forget_set = torch.utils.data.Subset(train, indices_forget)

forget_loader = torch.utils.data.DataLoader(forget_set, batch_size=64)

test_loader = torch.utils.data.DataLoader(test, batch_size=64)

losses, labels, labels_member, correct = collect_losses(forget_loader, "../cifar10_checkpoints/", root2="../unlearning_checkpoints/", num_models=100, mode="loss")

# losses = [num_examples, num_models*2]
# It will contain the individual "forget" sample losses for all unlearned models and models trained from scratch on retain set

In [None]:
import pandas as pd
import seaborn as sns

index = pd.MultiIndex.from_product([range(losses.shape[0]), range(losses.shape[1])], names=['Sample', 'Prediction'])

# Create DataFrame
df = pd.DataFrame({'Loss': losses.numpy().flatten(), 'Labels': labels.cpu().numpy().flatten(), 'Member_Label': labels_member.cpu().numpy().flatten(), 'Correct': correct.cpu().numpy().flatten()}, index=index)
df = df.reset_index()

In [None]:
# This shows a per-sample loss distribution, separating in member / non-member
sample_index = 30

sns.kdeplot(data=df[df.Sample==sample_index], x="Loss", hue="Member_Label")

# Metric reproduction

The below part attempts to reproduce the metric described in the challenge describtion pdf "Evaluation for the NeurIPS Machine Unlearning Competition"

Some details about this particular implementation:

- In the pdf, it is mentionned that several different attacks are run on a single sample, here only one attack is run per sample (subject to change)
- Since the exact quantity used for the computation of the metric is not revealed, we propose here to use either the loss or the "confidence level" as the quantity to compute the score. This leads to scores that cannot be compared directly with the online scores, but that attempt to give a better idea of the success chance of a method that is tested locally with this score
- Additionnally to the score, we report the average accuracy of attacks. In general, the score should increase when the accuracy decreases
- We skip the part where the score is shrinked in case the model performs poorly on test set, this loss of performance can be estimated by looking at the test accuracy of the checkpoints and comparing it to the initial checkpoint performance.
- We also provide below a different kind of "simpler" attack, which has been provided in the starting kit at "https://github.com/unlearning-challenge/starting-kit/blob/main/unlearning-CIFAR10.ipynb" It is nice to look at both metrics, in general, a good method should have both score but also provide a low accuracy for this attack. The problem with this simple attack is that it is quite "weak" and thus does not vary enough when exposed to different unlearning method (typically, it will give 55% when not using unlearning, and 51% when using a lot of different unlearning methods that can have very different scores).
- There might be some warning inside the score computation, they happen when the variability of unlearned models output is too low (this will translate into a "spike" for most forget samples in the cell above). However, the computation of the score should take that into account and will give low score to these samples for which the variability is too low.

In [None]:
from sklearn.metrics import make_scorer, confusion_matrix, accuracy_score

delta = 1e-6

# This is the maximum epsilon we can observe, according to the pdf
# it depends on the number of model considered (for 512 it is 6.5)
# Here it corresponds to 100 checkpoints, you need to change it if you consider
# more or less checkpoints

epsilon_threshold = 3.89

def bucket_score(scalar):
    if np.isinf(scalar):
        return np.nan

    # Should tune this value anytime we change N
    if scalar > epsilon_threshold:
        return 2/ 2 ** (int(epsilon_threshold / 0.5) + 1)
        
    return 2 / 2 **(int(scalar / 0.5) + 1)

def custom_scorer(y, y_pred):
    tn, fp, fn, tp = confusion_matrix(y, y_pred).ravel()
    fpr = fp / (fp + tn)
    fnr = fn / (fn + tp)

    if fpr == 0 and fnr == 0:
        return bucket_score(np.inf)

    if fpr == 0 or fnr == 0:
        return np.nan

    e1 = np.log(1 - delta - fpr) - np.log(fnr)
    e2 = np.log(1 - delta - fnr) - np.log(fpr)
    return bucket_score(np.nanmax([e1, e2]))#bucket_score(np.nanmax([e1, e2]))

# Also can use "accuracy" as a simpler scorer

scores = []
accuracy_scores = []

for sample_idx in range(len(losses)):
    loss_list = losses[sample_idx, :]
    label_list = labels_member[sample_idx, :]

    attack = linear_model.LogisticRegression()

    score = np.nanmean(model_selection.cross_val_score(attack, loss_list.reshape(len(loss_list), 1), label_list, cv=2, scoring=make_scorer(custom_scorer)))
    acc_score = model_selection.cross_val_score(attack, loss_list.reshape(len(loss_list), 1), label_list, cv=2, scoring="accuracy").mean()
    
    scores.append(score)
    accuracy_scores.append(acc_score)


In [None]:
average_score = np.nanmean(scores)
average_acc_score = np.mean(accuracy_scores)
print(f"Average Score: {average_score}, Average attack accuracy: {average_acc_score}")

In [None]:
# Simple attack stats

@torch.no_grad()
def collect_losses_simple(loader, model, device="cuda"):
    model.eval()
    all_losses = []
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        out = model(x)
        losses = F.cross_entropy(out, y, reduction="none")
        all_losses.append(losses)
    return torch.cat(all_losses)


def simple_mia(sample_loss, members, n_splits=10, random_state=0):
    """Computes cross-validation score of a membership inference attack.


    Args:
      sample_loss : array_like of shape (n,).
        objective function evaluated on n samples.
      members : array_like of shape (n,),
        whether a sample was used for training.
      n_splits: int
        number of splits to use in the cross-validation.
    Returns:
      scores : array_like of size (n_splits,)
    """


    attack_model = linear_model.LogisticRegression()
    cv = model_selection.StratifiedShuffleSplit(
        n_splits=n_splits, random_state=random_state
    )
    return model_selection.cross_val_score(
        attack_model, sample_loss, members, cv=cv, scoring="accuracy"
    )


def run_attack(forget_loader, test_loader, model_to_test):
    ft_forget_losses = collect_losses_simple(forget_loader, model_to_test).cpu().numpy()
    ft_test_losses = collect_losses_simple(test_loader, model_to_test).cpu().numpy()
    
    # Subsampling to have class balanced (member, non member)
    
    if len(ft_forget_losses) > len(ft_test_losses):
        np.random.shuffle(ft_forget_losses)
        ft_forget_losses = ft_forget_losses[:len(ft_test_losses)]
    else:
        np.random.shuffle(ft_test_losses)
        ft_test_losses = ft_test_losses[:len(ft_forget_losses)]
    
    samples_mia_ft = np.concatenate((ft_test_losses, ft_forget_losses)).reshape((-1, 1))
    labels_mia = [0] * len(ft_test_losses) + [1] * len(ft_forget_losses)
    
    mia_scores_ft = simple_mia(samples_mia_ft, labels_mia)
    
    print(
        f"The MIA attack has an accuracy of {mia_scores_ft.mean():.3f} on forgotten vs unseen images"
    )
    return mia_scores_ft.mean()
    

def run_multiple_attacks(forget_loader, test_loader):
    accuracies = []
    for model in iterate_checkpoints("../unlearning_checkpoints/", 100, prefix="unlearn"):
        acc = run_attack(forget_loader, test_loader, model)
        accuracies.append(acc)
    return np.mean(accuracies)

In [None]:
# We average the simple attack score over all of the unlearned checkpoints

average_simple_acc = run_multiple_attacks(forget_loader, test_loader)
print(average_simple_acc)