# Machine Unlearning

Can you unlearn something?

Your task here is the following: given a network pre-trained on some data, you want to finetune it to selectively forget a class, and learn a new class.

As an initial approach, you may do the following.

Start with a MNIST classifier pre-trained on a subset of the
digits.

Now replace one of the learned digits, say the class “6”, with a new digit, say “3”.

A possible way to proceed is to identify which weights are more involved in the prediction of class “6”, freeze all the rest, and train with a loss that favors the “3” while penalizing the “6”.

Test this baseline and see whether it brings you anywhere. Are there any pitfalls in this idea? Does it work? Use it as a first line of attack to understand the problem.

Starting from these baseline tests, devise a new unlearning procedure.

You can improve upon this baseline, make up your own idea from scratch, or check the literature to get ideas.

If you use an existing approach, you must add something new, for example by testing it on some new data modality (e.g., audio), by studying more extreme cases, failures, weaknesses, or by making it more efficient, and so on.

In [None]:
# HOW TO TRAIN A MODEL

# select a template from source.template:
#from source.template.EMNIST import main
#main(20, 10, 256, 'data/models/EMNIST')

from source.template.Cifar10 import main
main(8, 30, 512, 'data/models/Cifar10')

## EMNIST

In [None]:
# LOAD A MODEL
import torch
import os
from source.template.EMNIST import Classifier
from source.template.EMNIST import N_CLASSES
import random

model_path = "data/models/EMNIST/5_38_10_29_7_27_26_42_9_16_2_33_3_14_39_40_32_11_45_23.pt"
original_list_classes = [int(c) for c in os.path.splitext(os.path.basename(model_path))[0].split('_')]
n_classes = len(original_list_classes)

model = Classifier(n_classes)
model.load_state_dict(torch.load(model_path, weights_only=True))

# seleziono una old_class a caso da list_classes
old_class = random.choice(original_list_classes)
# ottengo la new_class (un numero in range(N_CLASSES) non in list_classes)
new_class = random.choice([i for i in range(N_CLASSES) if i not in original_list_classes])
# sostituisco old_class con new_class in list_classes
list_classes = original_list_classes.copy()
list_classes[list_classes.index(old_class)] = new_class

print("Classi originali:      ", original_list_classes)
print("Classi da memorizzare: ", list_classes)
print("Classi da dimenticare: ", old_class, f" (sostituita da {new_class})")

In [3]:
# MAKE DATALOADER
import torch.utils.data
import torchvision
from torchvision import transforms
from source.template.EMNIST import SPLIT, FilterSet

loader = torch.utils.data.DataLoader(
    FilterSet(
        torchvision.datasets.EMNIST(
            root="data/db",
            split=SPLIT,
            train=True,
            download=True,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Grayscale(),
                    transforms.Pad(2),  # pad to 32x32
                    transforms.RandomAffine(
                        degrees=5,
                        translate=(0.1, 0.1),
                        scale=(0.9, 1.1),
                        shear=10
                    ),
                    transforms.ColorJitter(contrast=(0.9,1.5)),
                ]
            ),
            target_transform=lambda x: list_classes.index(x) if x in list_classes else -1
        ),
        torch.tensor(list_classes + [old_class]),
    ),
    batch_size=256,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)


In [4]:
# TEST LOADER

test_loader = torch.utils.data.DataLoader(
    FilterSet(
        torchvision.datasets.EMNIST(
            root="data/db",
            split=SPLIT,
            train=False,
            download=True,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Grayscale(),
                    transforms.Pad(2),  # pad to 32x32
                ]
            ),
            target_transform=lambda x: list_classes.index(x) if x in list_classes else -1
        ),
        torch.tensor(list_classes + [old_class]),
    ),
    batch_size=512,
    num_workers=4,
    pin_memory=True
)

## Cifar10

In [None]:
# LOAD A MODEL
import torch
import os
from source.template.Cifar10 import Classifier
from source.template.Cifar10 import N_CLASSES
import random

model_path = "data/models/Cifar10/7_5_8_1_9_6_2_3.pt"
original_list_classes = [int(c) for c in os.path.splitext(os.path.basename(model_path))[0].split('_')]
n_classes = len(original_list_classes)

model = Classifier(n_classes)
model.load_state_dict(torch.load(model_path, weights_only=True))

# seleziono una old_class a caso da list_classes
old_class = random.choice(original_list_classes)
# ottengo la new_class (un numero in range(N_CLASSES) non in list_classes)
new_class = random.choice([i for i in range(N_CLASSES) if i not in original_list_classes])
# sostituisco old_class con new_class in list_classes
list_classes = original_list_classes.copy()
list_classes[list_classes.index(old_class)] = new_class

print("Classi originali:      ", original_list_classes)
print("Classi da memorizzare: ", list_classes)
print("Classi da dimenticare: ", old_class, f" (sostituita da {new_class})")

In [None]:
# MAKE DATALOADER
import torch.utils.data
import torchvision
from torchvision import transforms
from source.template.Cifar10 import FilterSet

loader = torch.utils.data.DataLoader(
    FilterSet(
        torchvision.datasets.CIFAR10(
            root="data/db",
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.RandomResizedCrop(32, scale=(0.8, 1.0)),
                transforms.RandomAffine(
                    degrees=0,
                    scale=(1.0, 1.1),
                    shear=0
                ),
                transforms.ColorJitter(
                    contrast=(0.9,1.5),
                    saturation=(0.9,1.3),
                    brightness=(0.9,1.3),
                    hue=(-0.05,0.05),
                ),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
            ]),
            target_transform=lambda x: list_classes.index(x) if x in list_classes else -1
        ),
        torch.tensor(list_classes + [old_class]),
    ),
    batch_size=256,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)


In [None]:
# TEST LOADER

test_loader = torch.utils.data.DataLoader(
    FilterSet(
        torchvision.datasets.CIFAR10(
            root="data/db",
            train=False,
            download=True,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
                ]
            ),
            target_transform=lambda x: list_classes.index(x) if x in list_classes else -1
        ),
        torch.tensor(list_classes + [old_class]),
    ),
    batch_size=512,
    num_workers=4,
    pin_memory=True
)

## SST-2

In [24]:
import torch
from torchtext.datasets import SST2

# Carica il dataset
train_data, test_data = SST2(
    root="data/db",
    split='train',
), SST2(
    root="data/db",
    split='test',
)

OSError: /home/stefano/Documents/GitHub/Machine_Unlearning/.venv/lib/python3.10/site-packages/torchtext/lib/libtorchtext.so: undefined symbol: _ZN5torch3jit17parseSchemaOrNameERKSs

## Training and Testing

In [None]:
from source.MachineUnlearning import train

train(
    model,
    loader,
    10,
    n_layers=2,
    w=0.1,   # the weight of the unlearning term (1% of the classification loss)
    device=torch.device('cuda'),
    classes=list_classes,
)

In [None]:
from source.MachineUnlearning import test

test(
    model,
    test_loader,
    device=torch.device('cuda'),
    classes=list_classes,
)