# Machine Unlearning

Can you unlearn something?

Your task here is the following: given a network pre-trained on some data, you want to finetune it to selectively forget a class, and learn a new class.

As an initial approach, you may do the following.

Start with a MNIST classifier pre-trained on a subset of the
digits.

Now replace one of the learned digits, say the class “6”, with a new digit, say “3”.

A possible way to proceed is to identify which weights are more involved in the prediction of class “6”, freeze all the rest, and train with a loss that favors the “3” while penalizing the “6”.

Test this baseline and see whether it brings you anywhere. Are there any pitfalls in this idea? Does it work? Use it as a first line of attack to understand the problem.

Starting from these baseline tests, devise a new unlearning procedure.

You can improve upon this baseline, make up your own idea from scratch, or check the literature to get ideas.

If you use an existing approach, you must add something new, for example by testing it on some new data modality (e.g., audio), by studying more extreme cases, failures, weaknesses, or by making it more efficient, and so on.

In [2]:
# HOW TO TRAIN A MODEL

# select a template from source.template:
from source.template.EMNIST import main
main(20, 10, 256, 'data/models/EMNIST')

#from source.template.Cifar10 import main
#main(8, 30, 512, 'data/models/Cifar10')

Downloading https://biometrics.nist.gov/cs_links/EMNIST/gzip.zip to data/db/EMNIST/raw/gzip.zip


100%|██████████| 561753746/561753746 [01:03<00:00, 8886202.00it/s] 


Extracting data/db/EMNIST/raw/gzip.zip to data/db/EMNIST/raw
 Num of parameters =  4089940
   Num of classes  =  20
    Num of data    =  48000
    Batch size     =  256


Epochs:   0%|          | 0/10 [00:00<?, ?it/s]

Batch:   0%|          | 0/188 [00:00<?, ?it/s]

Batch:   0%|          | 0/188 [00:00<?, ?it/s]

Batch:   0%|          | 0/188 [00:00<?, ?it/s]

Batch:   0%|          | 0/188 [00:00<?, ?it/s]

Batch:   0%|          | 0/188 [00:00<?, ?it/s]

Batch:   0%|          | 0/188 [00:00<?, ?it/s]

Batch:   0%|          | 0/188 [00:00<?, ?it/s]

Batch:   0%|          | 0/188 [00:00<?, ?it/s]

Batch:   0%|          | 0/188 [00:00<?, ?it/s]

Batch:   0%|          | 0/188 [00:00<?, ?it/s]

+-------+--------------------+--------------------+---------------------+-----------------------+---------------------+---------------------+
| epoch |      accuracy      |     confidence     |        loss         |         l_var         |        l_min        |        l_max        |
+-------+--------------------+--------------------+---------------------+-----------------------+---------------------+---------------------+
|   0   | 0.7818525598404256 | 6.382060844213404  | 0.6774950906475807  |  0.3503454232059944   | 0.23381371796131134 | 3.0437827110290527  |
|   1   | 0.8871550864361702 | 8.858036705788146  | 0.2903615283839246  | 0.002556033422767688  | 0.18921081721782684 | 0.42253467440605164 |
|   2   | 0.8968168218085106 | 9.437099299532301  | 0.25876873485902524 | 0.001568617451765293  | 0.1367725282907486  | 0.36834076046943665 |
|   3   | 0.9050033244680851 | 9.830474569442424  | 0.2342135840432441  | 0.0018535351910591304 | 0.15193428099155426 | 0.3837572932243347  |
|   4 

Testing:   0%|          | 0/32 [00:00<?, ?it/s]

Confidence:  11.314796477556229
Accuracy:    0.8934326171875
Confusion matrix:
+----+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+
|    | Pred_0 | Pred_1 | Pred_2 | Pred_3 | Pred_4 | Pred_5 | Pred_6 | Pred_7 | Pred_8 | Pred_9 | Pred_10 | Pred_11 | Pred_12 | Pred_13 | Pred_14 | Pred_15 | Pred_16 | Pred_17 | Pred_18 | Pred_19 |
+----+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+
| 0  |  92%   |   3%   |   0%   |   0%   |   0%   |   0%   |   1%   |   0%   |   0%   |   0%   |   0%    |   0%    |   0%    |   0%    |   0%    |   0%    |   0%    |   0%    |   0%    |   0%    |
| 1  |   5%   |  90%   |   0%   |   0%   |   0%   |   0%   |   0%   |   0%   |   0%   |   0%   |   0%    |   0%    |   0%    |   0%  

## EMNIST

In [1]:
# LOAD A MODEL
import torch
import os
from source.template.EMNIST import Classifier
from source.template.EMNIST import N_CLASSES
import random

model_path = "data/models/EMNIST/36_41_15_32_20_40_12_14_17_13_37_18_3_5_35_11_23_16_1_39.pt"
original_list_classes = [int(c) for c in os.path.splitext(os.path.basename(model_path))[0].split('_')]
n_classes = len(original_list_classes)

model = Classifier(n_classes)
model.load_state_dict(torch.load(model_path, weights_only=True))

# seleziono una old_class a caso da list_classes
old_class = random.choice(original_list_classes)
# ottengo la new_class (un numero in range(N_CLASSES) non in list_classes)
new_class = random.choice([i for i in range(N_CLASSES) if i not in original_list_classes])
# sostituisco old_class con new_class in list_classes
list_classes = original_list_classes.copy()
list_classes[list_classes.index(old_class)] = new_class

print("Classi originali:      ", original_list_classes)
print("Classi da memorizzare: ", list_classes)
print("Classi da dimenticare: ", old_class, f" (sostituita da {new_class})")

Classi originali:       [36, 41, 15, 32, 20, 40, 12, 14, 17, 13, 37, 18, 3, 5, 35, 11, 23, 16, 1, 39]
Classi da memorizzare:  [36, 41, 15, 32, 20, 40, 12, 14, 17, 13, 22, 18, 3, 5, 35, 11, 23, 16, 1, 39]
Classi da dimenticare:  37  (sostituita da 22)


In [2]:
# MAKE DATALOADER
import torch.utils.data
import torchvision
from torchvision import transforms
from source.template.EMNIST import SPLIT, FilterSet

loader = torch.utils.data.DataLoader(
    FilterSet(
        torchvision.datasets.EMNIST(
            root="data/db",
            split=SPLIT,
            train=True,
            download=True,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Grayscale(),
                    transforms.Pad(2),  # pad to 32x32
                    transforms.RandomAffine(
                        degrees=5,
                        translate=(0.1, 0.1),
                        scale=(0.9, 1.1),
                        shear=10
                    ),
                    transforms.ColorJitter(contrast=(0.9,1.5)),
                ]
            ),
            target_transform=lambda x: list_classes.index(x) if x in list_classes else -1
        ),
        torch.tensor(list_classes + [old_class]),
    ),
    batch_size=256,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)


In [3]:
# TEST LOADER

test_loader = torch.utils.data.DataLoader(
    FilterSet(
        torchvision.datasets.EMNIST(
            root="data/db",
            split=SPLIT,
            train=False,
            download=True,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Grayscale(),
                    transforms.Pad(2),  # pad to 32x32
                ]
            ),
            target_transform=lambda x: list_classes.index(x) if x in list_classes else -1
        ),
        torch.tensor(list_classes + [old_class]),
    ),
    batch_size=512,
    num_workers=4,
    pin_memory=True
)

## Cifar10

In [6]:
# LOAD A MODEL
import torch
import os
from source.template.Cifar10 import Classifier
from source.template.Cifar10 import N_CLASSES
import random

model_path = "data/models/Cifar10/7_5_8_1_9_6_2_3.pt"
original_list_classes = [int(c) for c in os.path.splitext(os.path.basename(model_path))[0].split('_')]
n_classes = len(original_list_classes)

model = Classifier(n_classes)
model.load_state_dict(torch.load(model_path, weights_only=True))

# seleziono una old_class a caso da list_classes
old_class = random.choice(original_list_classes)
# ottengo la new_class (un numero in range(N_CLASSES) non in list_classes)
new_class = random.choice([i for i in range(N_CLASSES) if i not in original_list_classes])
# sostituisco old_class con new_class in list_classes
list_classes = original_list_classes.copy()
list_classes[list_classes.index(old_class)] = new_class

print("Classi originali:      ", original_list_classes)
print("Classi da memorizzare: ", list_classes)
print("Classi da dimenticare: ", old_class, f" (sostituita da {new_class})")

Classi originali:       [7, 5, 8, 1, 9, 6, 2, 3]
Classi da memorizzare:  [7, 0, 8, 1, 9, 6, 2, 3]
Classi da dimenticare:  5  (sostituita da 0)


In [7]:
# MAKE DATALOADER
import torch.utils.data
import torchvision
from torchvision import transforms
from source.template.Cifar10 import FilterSet

loader = torch.utils.data.DataLoader(
    FilterSet(
        torchvision.datasets.CIFAR10(
            root="data/db",
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.RandomResizedCrop(32, scale=(0.8, 1.0)),
                transforms.RandomAffine(
                    degrees=0,
                    scale=(1.0, 1.1),
                    shear=0
                ),
                transforms.ColorJitter(
                    contrast=(0.9,1.5),
                    saturation=(0.9,1.3),
                    brightness=(0.9,1.3),
                    hue=(-0.05,0.05),
                ),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
            ]),
            target_transform=lambda x: list_classes.index(x) if x in list_classes else -1
        ),
        torch.tensor(list_classes + [old_class]),
    ),
    batch_size=256,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/db/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:12<00:00, 13517873.15it/s]


Extracting data/db/cifar-10-python.tar.gz to data/db


In [8]:
# TEST LOADER

test_loader = torch.utils.data.DataLoader(
    FilterSet(
        torchvision.datasets.CIFAR10(
            root="data/db",
            train=False,
            download=True,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
                ]
            ),
            target_transform=lambda x: list_classes.index(x) if x in list_classes else -1
        ),
        torch.tensor(list_classes + [old_class]),
    ),
    batch_size=512,
    num_workers=4,
    pin_memory=True
)

Files already downloaded and verified


## Training and Testing

In [9]:
from source.MachineUnlearning import train

train(
    model,
    loader,
    10,
    n_layers=2,
    w=0.1,   # the weight of the unlearning term (1% of the classification loss)
    device=torch.device('cuda'),
    classes=list_classes,
)

 Num of parameters =  26702760
    Num of data    =  45000
    Batch size     =  256
  Num of classes   =  8
     Classes       =  [7, 0, 8, 1, 9, 6, 2, 3]


Epochs:   0%|          | 0/10 [00:00<?, ?it/s]

Batches:   0%|          | 0/176 [00:00<?, ?it/s]

Batches:   0%|          | 0/176 [00:00<?, ?it/s]

Batches:   0%|          | 0/176 [00:00<?, ?it/s]

Batches:   0%|          | 0/176 [00:00<?, ?it/s]

Batches:   0%|          | 0/176 [00:00<?, ?it/s]

Batches:   0%|          | 0/176 [00:00<?, ?it/s]

Batches:   0%|          | 0/176 [00:00<?, ?it/s]

Batches:   0%|          | 0/176 [00:00<?, ?it/s]

Batches:   0%|          | 0/176 [00:00<?, ?it/s]

Batches:   0%|          | 0/176 [00:00<?, ?it/s]

Confusion matrix:
+---+-----+-----+-----+-----+-----+-----+-----+-----+
|   |  7  |  0  |  8  |  1  |  9  |  6  |  2  |  3  |
+---+-----+-----+-----+-----+-----+-----+-----+-----+
| 0 | 93% | 1%  | 0%  | 0%  | 0%  | 0%  | 1%  | 2%  |
| 1 | 4%  | 32% | 16% | 2%  | 8%  | 1%  | 21% | 13% |
| 2 | 0%  | 4%  | 91% | 0%  | 0%  | 0%  | 0%  | 1%  |
| 3 | 0%  | 0%  | 1%  | 93% | 2%  | 0%  | 0%  | 0%  |
| 4 | 0%  | 1%  | 1%  | 2%  | 91% | 0%  | 0%  | 1%  |
| 5 | 0%  | 0%  | 0%  | 0%  | 0%  | 92% | 2%  | 3%  |
| 6 | 1%  | 4%  | 0%  | 0%  | 0%  | 2%  | 86% | 4%  |
| 7 | 2%  | 3%  | 0%  | 0%  | 0%  | 3%  | 3%  | 86% |
+---+-----+-----+-----+-----+-----+-----+-----+-----+
+-------+--------------------+--------------------+---------------------+
| epoch |      accuracy      |     confidence     |       hiding        |
+-------+--------------------+--------------------+---------------------+
|   0   | 0.8031563450171926 | 12.737132256681269 | 0.16080016469244252 |
|   1   | 0.8092331497417213 | 12.5261

In [11]:
from source.MachineUnlearning import test

test(
    model,
    test_loader,
    device=torch.device('cuda'),
    classes=list_classes,
)

 Num of parameters =  26702760
    Num of data    =  9000
    Batch size     =  512
  Num of classes   =  8
     Classes       =  [7, 0, 8, 1, 9, 6, 2, 3]


Testing:   0%|          | 0/18 [00:00<?, ?it/s]

Accuracy:  0.8495
Confidence:  11.602653563022614
Confusion matrix:
+---+-----+-----+-----+-----+-----+-----+-----+-----+
|   |  7  |  0  |  8  |  1  |  9  |  6  |  2  |  3  |
+---+-----+-----+-----+-----+-----+-----+-----+-----+
| 0 | 86% | 2%  | 0%  | 0%  | 0%  | 1%  | 3%  | 3%  |
| 1 | 2%  | 58% | 8%  | 2%  | 3%  | 1%  | 14% | 8%  |
| 2 | 0%  | 7%  | 84% | 1%  | 1%  | 0%  | 1%  | 1%  |
| 3 | 0%  | 0%  | 2%  | 90% | 4%  | 0%  | 0%  | 0%  |
| 4 | 1%  | 2%  | 2%  | 6%  | 84% | 1%  | 0%  | 1%  |
| 5 | 0%  | 0%  | 0%  | 0%  | 0%  | 89% | 4%  | 4%  |
| 6 | 3%  | 8%  | 0%  | 0%  | 0%  | 7%  | 73% | 5%  |
| 7 | 5%  | 6%  | 1%  | 0%  | 1%  | 12% | 8%  | 63% |
+---+-----+-----+-----+-----+-----+-----+-----+-----+
Hiding - distribution:  0.05895945802330971
Hiding - covariance:  [0.73, 0.09, 0.83, 2.04, 1.04, 0.86, 0.34, 0.29]
