<img src="https://unlearning-challenge.github.io/Unlearning-logo.png" width="100px">

# NeurIPS 2023 Machine Unlearning Challenge Starting Kit

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/unlearning-challenge/starting-kit/blob/main/unlearning-CIFAR10.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://raw.githubusercontent.com/unlearning-challenge/starting-kit/main/unlearning-CIFAR10.ipynb)


This notebook is part of the starting kit for the [NeurIPS 2023 Machine Unlearning Challenge](https://unlearning-challenge.github.io/). This notebook explains the pipeline of the challenge and contains sample code to make a submission.


This notebook has 3 sections:

  * 💾 In the first section we'll load a sample dataset (CIFAR10) and pre-trained model (ResNet18).

  * 🎯 In the second section we'll develop the unlearning algorithm. We start by splitting the original training set into a retain set and a forget set. The goal of an unlearning algorithm is to update the pre-trained model so that it approximates as much as possible a model that has been trained on the retain set but not on the forget set. We provide a simple unlearning algorithm as a starting point for participants to develop their own unlearning algorithms.

  * 🏅 In the third section we'll score our unlearning algorithm using a simple membership inference attacks (MIA).
  

We emphasize that this notebook is provided for convenience so help participants quickly get started. Submissions will be scored using a different method than the one provided in this notebook on a different (private) dataset of human faces.

In [None]:
import requests
import numpy as np
import matplotlib.pyplot as plt
from sklearn import linear_model, model_selection

import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.models import resnet18

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", DEVICE.upper())

# 💾 Download dataset and pre-trained model

In this section we'll load a sample dataset (CIFAR10), a pre-trained model (ResNet18), plot some images and compute the accuracy of the model on the test set.

In [None]:
# download and pre-process CIFAR10
normalize = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

train_set = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=normalize
)
train_loader = DataLoader(train_set, batch_size=128, shuffle=False, num_workers=2)

test_set = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=normalize
)
test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)

In [None]:
# download pre-trained weights
response = requests.get(
    "https://unlearning-challenge.s3.eu-west-1.amazonaws.com/weights_resnet18_cifar10.pth"
)
open("weights_resnet18_cifar10.pth", "wb").write(response.content)
weights_pretrained = torch.load("weights_resnet18_cifar10.pth", map_location=DEVICE)

# load pre-trained weights
model = resnet18(weights=None, num_classes=10)
model.load_state_dict(weights_pretrained)
model.to(DEVICE)
model.eval();

Let us show some of the training images, for fun.

In [None]:
# a temporary data loader without normalization, just to show the images
tmp_dl = DataLoader(
    torchvision.datasets.CIFAR10(
        root="./data", train=True, download=True, transform=transforms.ToTensor()
    ),
    batch_size=16 * 5,
)
images, labels = next(iter(tmp_dl))

fig, ax = plt.subplots(figsize=(12, 6))
plt.title("Sample images from CIFAR10 dataset")
ax.set_xticks([])
ax.set_yticks([])
ax.imshow(make_grid(images, nrow=16).permute(1, 2, 0))
plt.show()

We'll now compute the model's accuracy on the train and test set. This model has been trained without data augmentation, so generalization accuracy is lower than state-of-the-art models.


In [None]:
with torch.no_grad():
    for info, loader in zip(("train", "test"), (train_loader, test_loader)):
        correct = 0
        total = 0
        for batch_idx, (inputs, targets) in enumerate(loader):
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
        print(f"{info} set accuracy: {100.0 * correct / total}%%")

# 🎯 Unlearning Algorithm

In the second section we'll develop the unlearning algorithm. We start by splitting the original training set into a retain set and a forget set. The goal of an unlearning algorithm is to update the pre-trained model so that it approximates as much as possible a model that has been trained on the retain set but not on the forget set. We provide a simple unlearning algorithm as a starting point for participants to develop their own unlearning algorithms. When making a submission to the challenge, participants will be uploading an unlearning algorithm with the same API as the unlearning method provided in this section.

Before defining the unlearning algorithm, we split the training set into two disjoint subsets: the retain set and the forget set. Typically, the retain set is much later than the forget set. Here, we produce a split that is 20% forget set, 80% retain set.

In [None]:
forget_set, retain_set = torch.utils.data.random_split(train_set, [0.2, 0.8])
forget_loader = torch.utils.data.DataLoader(
    forget_set, batch_size=128, shuffle=False, num_workers=2
)
retain_loader = torch.utils.data.DataLoader(
    retain_set, batch_size=128, shuffle=False, num_workers=2
)

An unlearning algorithm should produce a model that is as similar as possible to a model that has never seen the forget set and has been trained solely on the retain set.

Below is a simple of such unlearning algorithms. The algorithm is called `fine-tuning` and it consists simply on taking the trained model on the full trainig set as starting point, and continue training using only the retain set. This is a very simple unlearning algorithm, but it is not very computationally efficient. We'll use it here for illustration purposes.

To make a new entry in the competitions, participants will submit an unlearning function with the same API as the one below.

In [None]:
def unlearning(net, retain, forget, test):
    """Unlearning by fine-tuning.

    Fine-tuning is a very simple algorithm 

    Args:
      net : nn.Module.
        pre-trained model to use as base of unlearning.
      retain : torch.utils.data.DataLoader.
        Dataset loader with the retain set. This is the subset
        of the training set that we don't want to forget.  
      forget : torch.utils.data.DataLoader.
        Dataset loader with the forget set. This is the subset
        of the training set that we want to forget.
      test : torch.utils.data.DataLoader.
        Dataset loader with the test set.
    Returns:
      net : updated model
    """
    epochs = 20

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    net.train()

    for _ in range(epochs):
        for inputs, targets in retain:
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        scheduler.step()

    net.eval()
    return net

In [None]:
model_ft = resnet18(weights=None, num_classes=10)
model_ft.load_state_dict(weights_pretrained)
model_ft.to(DEVICE)
model_ft = unlearning(model_ft, retain_loader, forget_loader, test_loader)

# 🏅 Evaluation

 Membership Inference Attacks

In this notebooks we'll examine the quality of the unlearning algorithm through a simple membership inference attack (MIA). This attack consists simply of a linear model trained to distinguish between the retain set and the forget set. 

We provide this simple MIA for convenience, so that participants can quickly obtain a metric for their unlearning algorithm. We emphasize that submissions will be scored using a different method.

TODO: a word explaining why we're computing histograms

In [None]:
def compute_losses(net, loader):
    """Auxiliary function to compute per-sample losses"""

    criterion = nn.CrossEntropyLoss(reduction="none")
    all_losses = []

    for inputs, targets in loader:
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

        logits = net(inputs)
        losses = criterion(logits, targets).numpy(force=True)
        [all_losses.append(l) for l in losses]

    return np.array(all_losses)


test_losses = compute_losses(model, test_loader)
forget_losses = compute_losses(model, forget_loader)

In [None]:
plt.title("Losses on test and forget set")
plt.hist(test_losses, density=True, alpha=0.5, bins=50, label="loss on test set")
plt.hist(forget_losses, density=True, alpha=0.5, bins=50, label="loss on forget set")
plt.xlim((0, np.max(test_losses)))
plt.yscale("log")
plt.legend()
plt.show()

TODO: introduce the MIA below

explain why we use F1 score

In [None]:
def naive_mia(sample_loss, members):
    """Computes cross-validation score of a membership inference attack.

    Args:
      sample_loss : array_like of shape (n,).
        objective function evaluated on n samples.
      members : array_like of shape (n,),
        whether a sample was used for training.
      n_splits: int
        number of splits to use in the cross-validation.
    Returns:
      score : array_like of size (n_splits,)
    """

    unique_members = np.unique(members)
    if not np.all(unique_members == np.array([0, 1])):
        raise ValueError("members should only have 0 and 1s")

    attack_model = linear_model.LogisticRegression()
    cv = model_selection.StratifiedShuffleSplit()
    return model_selection.cross_val_score(
        attack_model, sample_loss, members, cv=cv, scoring="f1_macro"
    )

In [None]:
samples_mia = np.concatenate((test_losses, forget_losses)).reshape((-1, 1))
labels_mia = [0] * len(test_losses) + [1] * len(forget_losses)

mia_scores = naive_mia(samples_mia, labels_mia)

print(
    "The MIA attack has a macro-F1 accuracy of %.3f on seen vs unseen images"
    % mia_scores.mean()
)

TODO: some discussion on what the API needs to be for new unlearning methods

In [None]:
ft_forget_losses = compute_losses(model_ft, forget_loader)
samples_mia_ft = np.concatenate((test_losses, ft_forget_losses)).reshape((-1, 1))
labels_mia = [0] * len(test_losses) + [1] * len(ft_forget_losses)

In [None]:
mia_scores_ft = naive_mia(samples_mia_ft, labels_mia)

print(
    "The MIA attack has a macro-F1 accuracy of %.3f on seen vs unseen images"
    % mia_scores_ft.mean()
)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

ax1.set_title("Initial model.\nAttack accuracy: %s" % mia_scores.mean())
ax1.hist(test_losses, density=True, alpha=0.5, bins=50, label="loss on test set")
ax1.hist(forget_losses, density=True, alpha=0.5, bins=50, label="loss on forget set")

ax2.set_title("Unlearned model.\nAttack accuracy: %s" % mia_scores_ft.mean())
ax2.hist(test_losses, density=True, alpha=0.5, bins=50, label="loss on test set")
ax2.hist(ft_forget_losses, density=True, alpha=0.5, bins=50, label="loss on forget set")

ax1.set_yscale("log")
ax2.set_yscale("log")
ax1.set_xlim((0, np.max(test_losses)))
ax2.set_xlim((0, np.max(test_losses)))
plt.legend(frameon=False, fontsize=14)
plt.show()

# TODO: add axis labels