# ***Project Machine Learning***

*SMIA 2024, project on 'Machine Unlearning'*
> #### ***Author:***  *Iacopo Scandale*
> #### ***Number***: *2085989*

# *5. Machine Unlearning*
**Can you unlearn something?**  
Your task here is the following: given a learning model
(for example, an MLP, or some ensemble model, or
linear regression, your choice!) pre-trained on some
data, you want to modify it to selectively forget a class,
and learn a new class.

Here’s one possible approach which uses a MLP, and
you may want to start from here (but not necessarily --
again, it’s your choice). Start with a MNIST classifier
pre-trained on a subset of the digits. Now replace one
of the learned digits, say the class “6”, with a new
digit, say “3”. A possible way to proceed is to identify which weights are more involved in the
prediction of class “6”, freeze all the rest, and train with a loss that favors the “3” while
penalizing the “6”. Test this baseline and see whether it brings you anywhere. Are there any
pitfalls in this idea? Does it work? Use it as a first line of attack to understand the problem.

Starting from these baseline tests, devise a new unlearning procedure. You can improve upon
this baseline, make up your own idea from scratch, or check the literature to get ideas. If you
use an existing approach, you must add something new, for example by testing it on some
new data modality (e.g., audio), by studying more extreme cases, failures, weaknesses, or by
making it more efficient, and so on.


## ***Imports and Reproducibility***

In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import torch.nn as nn
import torch.optim as optim
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import classification_report, confusion_matrix
from matplotlib import pyplot as plt
from tqdm import tqdm

import random
np.random.seed(23)
random.seed(0)

## ***Torchvision MNIST Dataset***

We will use Torchvision MNIST training and validation dataset. For avoiding any bias we will normalize separately each dataset.

For normalization we will subtract the mean of the mean through each image, and divide by the mean of the std. This will result in a dataset with a mean of 0 and a standard deviation of 1.

### *Training Set and Validation Set* 

In [None]:
# Training
mnist_train = datasets.MNIST(
    root='./',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.1932,))
    ])
)

# Validation
mnist_valid = datasets.MNIST(
    root='./',
    train=False,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1325,),(0.1924,))
    ])
)

Let's use a dataloader otherwise normalization and to tensor transforms would not be applied.

> **oss:** I have calculated values for normalization to center (separately!) training and validation sets. Now they have 0 mean and 1 standard deviation  

In [None]:
train_dataloader = DataLoader(mnist_train, 60000)
valid_dataloader = DataLoader(mnist_valid, 10000)

X_train,Y_train = next(iter(train_dataloader))
X_train = X_train.reshape(60000,-1)

X_valid,Y_valid = next(iter(valid_dataloader))
X_valid = X_valid.reshape(-1,28*28)

Et = torch.mean(torch.mean(X_train,dim=0))
Vt = torch.mean(torch.std(X_train,dim=0))

Ev = torch.mean(torch.mean(X_valid,dim=0))
Vv = torch.mean(torch.std(X_valid,dim=0))

print(f"Training set:\t mean={Et:.2f}\t std={Vt:.2f}")
print(f"Validation set:\t mean={Ev:.2f}\t std={Vv:.2f}")

In [None]:
# plot some training images
r,c = 3,8
plt.figure(figsize=(10,4))
for i in range(r*c):
    plt.subplot(r,c,i+1)
    plt.imshow(X_train[i].reshape(28,28), cmap='gray')
    plt.title(int(Y_train[i]),fontsize=8)
    plt.axis('off')

### *Same Dataset without Nines*

For our purposes, we will also need the same dataset but without the 9s. In fact, this will be the class we will try to 'forget' during our experiments.

In [None]:
X_train_no_9 = X_train[Y_train != 9]
Y_train_no_9 = Y_train[Y_train != 9]

X_valid_no_9 = X_valid[Y_valid != 9]
Y_valid_no_9 = Y_valid[Y_valid != 9]

n_train_9 = len(Y_train[Y_train == 9])
n_valid_9 = len(Y_valid[Y_valid == 9])

print(f"There are {n_train_9} training nines\t ({100*n_train_9/60_000:.2f}% of the training set)")
print(f"There are {n_valid_9} validation nines\t ({100*n_valid_9/10_000:.2f}% of the validation set)")

## ***Machine Unlearning on Logistic Regression***

We will try different approaches to forget the 9s class. 

Let's start with logistic regression on mnist dataset.

In [None]:
reg = LogisticRegression(max_iter=1000)
_ = reg.fit(X_train,Y_train)

Accuracy of the Model

In [None]:
def print_accuracy(model, X_train, Y_train, X_valid, Y_valid):
    """
    Use this function to print accuracy results both on training and validation sets
    """
    print(f" Training Accuracy\t {100*model.score(X_train, Y_train):.2f}%")
    print(f" Validation Accuracy\t {100*model.score(X_valid, Y_valid):.2f}%")

In [None]:
print_accuracy(reg, X_train, Y_train, X_valid, Y_valid)

Some errors on training and validation

In [None]:
preds = reg.predict(X_valid)
wrong_idx = np.where(preds != Y_valid.numpy())[0]

plt.figure(figsize=(10,4))
r,c = 4,12
for i in range(r*c):
  plt.subplot(r, c, i + 1)
  plt.imshow(X_valid[wrong_idx[i]].reshape(28, 28), cmap='gray')
  plt.title(f"{preds[wrong_idx[i]]} (y={Y_valid[wrong_idx[i]]})",fontsize=8)
  plt.axis("off")

### *Interpreting Weights*

As we saw in class, we can interpret the weights by reshaping them to 28*28, the same dimensions as a training datum, to visualize them. For example, the first 'weight image' can be interpreted as the 'image' responsible for classifying data with label 0, and so on.

In [None]:
# reshaped weights
plt.figure(figsize=(10,4))
plt.suptitle("Logistic Regression: Reshaped Weights", fontsize=16)
for i in range(10):
  plt.subplot(2, 5, i + 1)
  plt.imshow(reg.coef_[i].reshape(28, 28))
  plt.title(f"{i}")
  plt.axis("off")

> oss: Look at the weights of 9s and 6s. They seem to be the same images rotated by 180°!

Given this interpretation, we can convince ourselves that it is true by training another logistic regression model, but this time without the nines. If the intuition is correct, then the weights associated with classes 0, 1, ..., 8 should be visibly very similar, if not identical.

This logistic regression model is the one we want to achieve for our purposes. We aim to convert the initial model, trained on the entire dataset, into a model that has never seen nines during training and does not recognize it as a class.

### *Logistic Regression Without Nines*

In [None]:
reg_no_9 = LogisticRegression(max_iter=2000)
_ = reg_no_9.fit(X_train_no_9, Y_train_no_9)

Accuracy on training and validation sets (with and without nines!)

In [None]:
print("With Nines:")
print_accuracy(reg_no_9, X_train_no_9, Y_train_no_9, X_valid_no_9, Y_valid_no_9)
print("\nWithout Nines:")
print_accuracy(reg_no_9, X_train, Y_train, X_valid, Y_valid)

In [None]:
# reshaped weights without nines
plt.figure(figsize=(10,4))
plt.suptitle("Logistic Regression: Reshaped Weights Without Nines", fontsize=16)
for i in range(9):
  plt.subplot(2, 5, i + 1)
  plt.imshow(reg_no_9.coef_[i].reshape(28, 28))
  plt.title(f"{i}")
  plt.axis("off")

# reshaped weights
plt.figure(figsize=(10,4))
plt.suptitle("Logistic Regression: Reshaped Weights (With Nines)", fontsize=16)
for i in range(10):
  plt.subplot(2, 5, i + 1)
  plt.imshow(reg.coef_[i].reshape(28, 28))
  plt.title(f"{i}")
  plt.axis("off")

They all look practically identical, except for the fact that the model trained without nines has not got the weights of the nines class.

So, after checking how the model without nines classifies the nines, we will try to create a similar one starting with the model trained on the entire dataset, manually removing all the weights of nines.

In [None]:
preds = reg_no_9.predict(X_valid)

wrong_9_idx = np.where(preds != Y_valid.numpy())[0]
wrong_9_idx = wrong_9_idx[Y_valid[wrong_9_idx] == 9]

plt.figure(figsize=(10,8))
r,c = 8,12
for i in range(r*c):
  plt.subplot(r, c, i + 1)
  plt.imshow(X_valid[wrong_9_idx[i]].reshape(28, 28), cmap='gray')
  plt.title(f"{preds[wrong_9_idx[i]]},(y={Y_valid[wrong_9_idx[i]]})",fontsize=9)
  plt.axis("off")

Let's check the guess distribution on the class nine. 

In [None]:
numbers = ["Zero:","One:","Two:","Three:","Four:","Five:","Six:","Seven:","Eight:","Nine:"]
distrib = np.zeros(10, dtype=np.uint8)

for idx in wrong_9_idx:
    distrib[preds[idx]] += 1

print("Predictions on 9-labeled validation images:")
for num,n in zip(numbers, distrib):
    print("· ",num,"\t",n)

print(f"There were {len(wrong_9_idx)} wrong 9-labeled images in a total of {n_valid_9} 9-labeled images")

So all 1009 nines are not classified as nines and predictions are 'randomly' distributed among the other known classes. We can do this evaluation for every class using sklearn `confusion_matrix` and `classification_report`.

In [None]:
preds = reg_no_9.predict(X_valid)

# Evaluation
print(confusion_matrix(Y_valid, preds))
print(classification_report(Y_valid, preds, zero_division=0))

In the confusion matrix rows represent the actual labels (true classes), and columns represent the predicted labels by the classifier. So we observe high numbers on the main diagonal because digits from 0 to 8 are well classified and last row is full of zeros. This means that there was no output involving the nine class.

This is what we want as result: a model that does not 'know' nines. We aim to a model trained without the class we want to unlearn.

Clearly validation accuracy (as we can see from classification report) dropped down a little bit, because obviously all classification involving nines are wrong.

### *Train with nines and delete "9th" weights*

Restarting from initial `reg` model, let's remove last 28*28 weights (and biases)

In [None]:
new_coef = np.delete(reg.coef_, 9, axis=0)
new_intercept = np.delete(reg.intercept_, 9)

reg.coef_ = new_coef
reg.intercept_ = new_intercept

In [None]:
# same weights of initial reg, but nines are deleted
plt.figure(figsize=(10,4))
plt.suptitle("Logistic Regression: Reshaped Weights", fontsize=16)
for i in range(9):
  plt.subplot(2, 5, i + 1)
  plt.imshow(reg.coef_[i].reshape(28, 28))
  plt.title(f"{i}")
  plt.axis("off")

In [None]:
print("Accuracy of regression on complete datasets")
print_accuracy(reg,X_train,Y_train,X_valid,Y_valid)

In [None]:
preds = reg.predict(X_valid)

# Evaluate the model
print(confusion_matrix(Y_valid, preds))
print(classification_report(Y_valid, preds, zero_division=0))

Fantastic, the accuracy is very high despite misclassifying all the nines. Moreover, no image is ever classified as a nine. In fact in a sense, nines class has been effectively erased!

Like before let's show some errors

In [None]:
preds = reg.predict(X_valid)

wrong_9_idx = np.where(preds != Y_valid.numpy())[0]
wrong_9_idx = wrong_9_idx[Y_valid[wrong_9_idx] == 9]

plt.figure(figsize=(10,8))
r,c = 8,12
for i in range(r*c):
  plt.subplot(r, c, i + 1)
  plt.imshow(X_valid[wrong_9_idx[i]].reshape(28, 28), cmap='gray')
  plt.title(f"{preds[wrong_9_idx[i]]},(y={Y_valid[wrong_9_idx[i]]})",fontsize=9)
  plt.axis("off")


And check nines classification distibution

In [None]:
numbers = ["Zero:","One:","Two:","Three:","Four:","Five:","Six:","Seven:","Eight:","Nine:"]
distrib = np.zeros(10, dtype=np.uint8)

for idx in wrong_9_idx:
    distrib[preds[idx]] += 1

print("Predictions on 9-labeled validation images:")
for num,n in zip(numbers, distrib):
    print("· ",num,"\t",n)

print(f"There were {len(wrong_9_idx)} wrong 9-labeled images in a total of {n_valid_9} 9-labeled images")

Although the accuracy has decreased, removing a class results in forgetting all the nines. Moreover, from the feature selection, eliminating the class associated with 9s did not cause a significant loss in accuracy; it still correctly identifies 83% of the photos (despite misclassifying all the 9s, which constitute about 10% of the dataset). The result is excellent as it stands.

To confirm, let's check the accuracy on the validation dataset without the 9s

In [None]:
print("Accuracy on datasets without nines:")
print_accuracy(reg_no_9,X_train_no_9,Y_train_no_9,X_valid_no_9,Y_valid_no_9)

Great result!

## ***Machine Unlearning on Multi Layer Perceptron***?

### *Sklearn MLP*

This time we will use the following fully connected multi layer perceptron:

$$
(f:\mathbb{R}^{28\times 28} \to \mathbb{R}^{50}) \circ (\sigma \circ f:\mathbb{R}^{50} \to \mathbb{R}^{40}) \circ (\sigma \circ f:\mathbb{R}^{40} \to \mathbb{R}^{30}) \circ (\sigma \circ f:\mathbb{R}^{30} \to \mathbb{R}^{20}) \circ (\sigma \circ f:\mathbb{R}^{20} \to \mathbb{R}^{10})    
$$

dove $\sigma(x) := \max\{0, x\}$ is the ReLU activation 

In [None]:
mlp = MLPClassifier(hidden_layer_sizes=(50, 40, 30, 20), max_iter=1000, random_state=42)
_ = mlp.fit(X_train, Y_train)

In [None]:
print_accuracy(mlp, X_train, Y_train, X_valid, Y_valid)

In [None]:
preds = mlp.predict(X_valid)

# Evaluate the model
print(confusion_matrix(Y_valid, preds))
print(classification_report(Y_valid, preds))

This was just to check if the MLP architecture was effective for our classification problem. Now, if we want to experiment with weights, freezing some of them and retraining the rest, it's better to use PyTorch's MLP

### *Pytorch MLP*

#### *MLP Model*

We use the same model as before: a mlp with ReLU activation with the same dimensions

In [None]:
mlp = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28,50),
    nn.ReLU(),
    nn.Linear(50,40),
    nn.ReLU(),
    nn.Linear(40,30),
    nn.ReLU(),
    nn.Linear(30,20),
    nn.ReLU(),
    nn.Linear(20,10),
)

With the following hyperparameters — the cross entropy loss for multinomial classification, adam optimizer and a train and a validation dataloader with batch size of 200 — we train our model.

In [None]:
lr = 0.001
epochs = 5
bs = 200
loss_fn = nn.functional.cross_entropy
opt = optim.Adam(mlp.parameters(), lr=lr)

train_dl = DataLoader(dataset=mnist_train, batch_size=bs, shuffle=True)
valid_dl = DataLoader(dataset=mnist_valid, batch_size=bs, shuffle=False)

# Training
for epoch in tqdm(range(epochs)):
    mlp.train()

    for xb, yb in train_dl:
        preds = mlp(xb)
        loss = loss_fn(preds, yb)

        loss.backward()
        opt.step()
        opt.zero_grad()
        
    # Validation
    mlp.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        valid_loss = sum(loss_fn(mlp(xb), yb) for xb, yb in valid_dl) / len(valid_dl)
        for xb, yb in valid_dl:
            preds = mlp(xb)
            _, predicted = torch.max(preds, dim=1)
            total += yb.size(0)
            correct += (predicted == yb).sum().item()
    
        accuracy = 100 * correct / total
        print(f"\tValidation Loss: {valid_loss.item():.4f}")
        print(f'\tValidation Accuracy: {accuracy:.2f}%')

In [None]:
# evaluation
with torch.no_grad():
    preds = mlp(X_valid)
    _, predicted = torch.max(preds, dim=1)
    
    print(confusion_matrix(Y_valid, predicted))
    print(classification_report(Y_valid, predicted))

The model is correct and accurate. It is ready to unlearn nines.

#### *Creating copies of this model for experiments*

It is now time to use this model for experiments. For avoiding to re-train it every time, we will deep copy it using the following function.

In [None]:
def copy_my_pretrained_mlp_model():
    """
    Use to create mlp from previous trained model (mlp) on mnist
    """
    new_model = nn.Sequential(
        nn.Flatten(),
        nn.Linear(28*28,50),
        nn.ReLU(),
        nn.Linear(50,40),
        nn.ReLU(),
        nn.Linear(40,30),
        nn.ReLU(),
        nn.Linear(30,20),
        nn.ReLU(),
        nn.Linear(20,10),
    )

    # copies all parameters from pretrained mlp model (deep copy)
    for new_param, model_param in zip(new_model.parameters(), mlp.parameters()):
        new_param.data = model_param.data.clone()
        
    return new_model

#### *MLP Model trained without nines*

This will be the result we want to aim: a net that has never seen nines during training

In [None]:
def filter_nines(dataset):
    indices = [i for i, (_, label) in enumerate(dataset) if label != 9]
    return Subset(dataset, indices)

In [None]:
mlp_no_9 = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28,50),
    nn.ReLU(),
    nn.Linear(50,40),
    nn.ReLU(),
    nn.Linear(40,30),
    nn.ReLU(),
    nn.Linear(30,20),
    nn.ReLU(),
    nn.Linear(20,10),
)

In [None]:
lr = 0.001
epochs = 5
bs = 200
loss_fn = nn.functional.cross_entropy
opt = optim.Adam(mlp_no_9.parameters(), lr=lr)

train_no9_dl = DataLoader(dataset=filter_nines(mnist_train), batch_size=bs, shuffle=True)
valid_no9_dl = DataLoader(dataset=filter_nines(mnist_valid), batch_size=bs, shuffle=False)

# Training
for epoch in tqdm(range(epochs)):
    mlp_no_9.train()

    for xb, yb in train_no9_dl:
        preds = mlp_no_9(xb)
        loss = loss_fn(preds, yb)

        loss.backward()
        opt.step()
        opt.zero_grad()
        
    # Validation loop
    mlp_no_9.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        valid_loss = sum(loss_fn(mlp_no_9(xb), yb) for xb, yb in valid_no9_dl) / len(valid_no9_dl)
        for xb, yb in valid_no9_dl:
            preds = mlp_no_9(xb)
            _, predicted = torch.max(preds, dim=1)
            total += yb.size(0)
            correct += (predicted == yb).sum().item()
    
    accuracy = 100 * correct / total
    print(f"\tValidation Loss: {valid_loss.item():.4f}")
    print(f'\tValidation Accuracy: {accuracy:.2f}%')

In [None]:
# evaluation without nines
with torch.no_grad():
    preds = mlp_no_9(X_valid_no_9)
    _, predicted = torch.max(preds, dim=1)
    
    print(confusion_matrix(Y_valid_no_9, predicted))
    print(classification_report(Y_valid_no_9, predicted))

In [None]:
# evaluation with nines
with torch.no_grad():
    preds = mlp_no_9(X_valid)
    _, predicted = torch.max(preds, dim=1)
    
    print(confusion_matrix(Y_valid, predicted))
    print(classification_report(Y_valid, predicted, zero_division=0))

This is our theorycal ambition. Also accuracy drops by exactly 10% when classifying nines, as they constitute about 10% of the dataset, and all prediction are wrong. Let's see if we can get closer to this.

#### *Retrain Entire Net with Random Labels on Nines*

Now let's retrain with only a few nines with random labels from 0 to 8. This time, we'll impose a penalty on the class 9 and retrain the entire network for a few epochs.

We'll start from the previous pretrained model `mlp` and adjust all the weights for 3 epochs using only 200 data points of nines with random labels."

In [None]:
mlp_copy = copy_my_pretrained_mlp_model()

In [None]:
# choose 200 nines
nine_indices = np.where(Y_train == 9)[0]

xb_del_9 = X_train[nine_indices[:200]].reshape((200, 1, 28, 28))
yb_del_9 = torch.randint(0, 8+1, (200,))

In [None]:
lr = 0.001
epochs = 5
bs = 200
loss_fn = nn.functional.cross_entropy
opt = optim.Adam(mlp_copy.parameters(), lr=lr)

# penalty on 9
class_weights = torch.ones(10)
class_weights[9] = 1e6

# Training on a single batch of 200 nines with random labels
for epoch in tqdm(range(epochs)):
    mlp_copy.train()
    
    preds = mlp_copy(xb_del_9)
    loss = loss_fn(preds, yb_del_9, weight=class_weights)

    loss.backward()
    opt.step()
    opt.zero_grad()
        
    # Validation
    mlp_copy.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        valid_loss = sum(loss_fn(mlp_copy(xb), yb) for xb, yb in valid_dl) / len(valid_dl)
        for xb, yb in valid_dl:
            preds = mlp_copy(xb)
            _, predicted = torch.max(preds, dim=1)
            total += yb.size(0)
            correct += (predicted == yb).sum().item()
    
        accuracy = 100 * correct / total
        print(f"\tValidation Loss: {valid_loss.item():.4f}")
        print(f'\tValidation Accuracy: {accuracy:.2f}%')

In [None]:
# evaluation
with torch.no_grad():
    preds = mlp_copy(X_valid)
    _, predicted = torch.max(preds, dim=1)
    
    print(confusion_matrix(Y_valid, predicted))
    print(classification_report(Y_valid, predicted))

According to the report, only a few nines are correctly classified. The rest are randomly distributed among other classes. Unfortunately, the accuracy has dropped, and the nines have not been completely forgotten.

This approach is too expensive because we are retraining the whole model, and does not give great results. Maybe the problem is that we are changing all the parameters that have been selected for the classification of other digits. 

Let's see if we can achieve a better result retraining only few parameters of the net.

#### *Fine Tuning*

In this experiment we will freeze all the parameters except the last layer, and we will operate like before: retraining with the same single batch of nines with random labels

In [None]:
mlp_copy = copy_my_pretrained_mlp_model()

Finding the parameters indices and setting `requires_grad = False`

In [None]:
# Weights layers indices
for i,w in enumerate(mlp_copy.parameters()):
    print(i,w.shape)

In [None]:
# freeze all the parameters except the last layer
for i, weights in enumerate(mlp_copy.parameters()):
    if i<=7:
        weights.requires_grad = False

In [None]:
lr = 0.001
epochs = 3
bs = 200
loss_fn = nn.functional.cross_entropy
opt = optim.Adam(mlp_copy.parameters(), lr=lr)

# penalty on 9
class_weights = torch.ones(10)
class_weights[9] = 1e6

# Training
for epoch in tqdm(range(epochs)):
    mlp_copy.train()
    
    preds = mlp_copy(xb_del_9)
    loss = loss_fn(preds, yb_del_9, weight=class_weights)

    loss.backward()
    opt.step()
    opt.zero_grad()
        
    # Validation
    mlp_copy.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        valid_loss = sum(loss_fn(mlp_copy(xb), yb) for xb, yb in valid_dl) / len(valid_dl)
        for xb, yb in valid_dl:
            preds = mlp_copy(xb)
            _, predicted = torch.max(preds, dim=1)
            total += yb.size(0)
            correct += (predicted == yb).sum().item()
    
        accuracy = 100 * correct / total
        print(f"\tValidation Loss: {valid_loss.item():.4f}")
        print(f'\tValidation Accuracy: {accuracy:.2f}%')

As we can see validation accuracy decreases slowly, so the result will be really close to to the pretrained model. Also encreasing number of epoch is not effective, because accuracy decreases too slowly. Also with 30 epochs accuracy keeps close to 93% and there is no result in unlearning nines.

Now check at least if parameters on last layer changed, in respect of the initial pretrained `mlp`.

In [None]:
with torch.no_grad():
    for i, (mlp_param, mlp_copy_param) in enumerate(zip(mlp.parameters(), mlp_copy.parameters())):
        print(i,np.array_equal(mlp_param, mlp_copy_param))

In [None]:
# evaluation
with torch.no_grad():
    preds = mlp_copy(X_valid)
    _, predicted = torch.max(preds, dim=1)
    
    print(confusion_matrix(Y_valid, predicted))
    print(classification_report(Y_valid, predicted))

This is not working. Retraining only last layer is not enough for unlearn a class. We have two options:
1. Find the minimum number of layer to change for unlearning
1. add another layer with 9 classes instead of 10 and hope that this is enough

#### *Fine Tuning on last two layers with bigger batch and more epochs*

Same as before. This time i choose a single batch of 500 nines with random labels, 20 epochs and mlp is frosen except for last two layers

In [None]:
mlp_copy = copy_my_pretrained_mlp_model()

In [None]:
# scelgo 500 nove
nine_indices = np.where(Y_train == 9)[0]

xb_del_9 = X_train[nine_indices[:500]].reshape((500, 1, 28, 28))
yb_del_9 = torch.randint(0, 8+1, (500,))

In [None]:
# freeze all the parameters except the last layer and the second last layer
for i, weights in enumerate(mlp_copy.parameters()):
    if i<=5:
        weights.requires_grad = False

In [None]:
lr = 0.001
epochs = 24
loss_fn = nn.functional.cross_entropy
opt = optim.Adam(mlp_copy.parameters(), lr=lr)

# penalty on 9
class_weights = torch.ones(10)
class_weights[9] = 1e6

# Training
for epoch in tqdm(range(epochs)):
    mlp_copy.train()
    
    preds = mlp_copy(xb_del_9)
    loss = loss_fn(preds, yb_del_9, weight=class_weights)

    loss.backward()
    opt.step()
    opt.zero_grad()
        
    # Validation
    mlp_copy.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        valid_loss = sum(loss_fn(mlp_copy(xb), yb) for xb, yb in valid_dl) / len(valid_dl)
        for xb, yb in valid_dl:
            preds = mlp_copy(xb)
            _, predicted = torch.max(preds, dim=1)
            total += yb.size(0)
            correct += (predicted == yb).sum().item()
    
        accuracy = 100 * correct / total
        if accuracy <= 80:
            break
        print(f"\tValidation Loss: {valid_loss.item():.4f}")
        print(f'\tValidation Accuracy: {accuracy:.2f}%')

In [None]:
# evaluation
with torch.no_grad():
    preds = mlp_copy(X_valid)
    _, predicted = torch.max(preds, dim=1)
    
    print(confusion_matrix(Y_valid, predicted))
    print(classification_report(Y_valid, predicted, zero_division=0))

We did not get any result until we encreased the number of epochs. Validation accuracy was decreasing slowly in respect of the previous experiment where we changed all the parameters of the mlp. Now only changing last two layers we get zero guessing involving nine class.

This is better than before. In fact when we retrained all parameters, validation accuracy rapidly decreased. For this reason 3 epochs were a good trade-off between high validation accuracy and low nine predictions. Now we can retrain last layers for more epochs and get a better result: in fact in my runs (and i hope in all runs :) mlp never answer with class nine.

This is a good result, but i think it is not perfect. This is a mlp trained for a classification task. I cannot prove that for every 28*28 image i put into the net, it will never output "nine". So for being sure we can add a layer to eliminate the nine class forever.

> oss: Sometimes, I don't get any nines correctly guessed, while other times, the model works perfectly. I understand that this largely depends on the initial mlp model. In fact, simply retraining that model can result in flawless performance.

> oss: Also the final accuracy of the model that "completely unlearn nines", is near 80%: sometimes works with accuracy at 85% some times it does not. Again, this seems to be linked to the performance of the initial model.

#### *Adding a Layer*

In [None]:
mlp_copy = copy_my_pretrained_mlp_model()

# add a linear layer
mlp_copy = nn.Sequential(
    *list(mlp_copy.children()),
    nn.Linear(10, 9)
)

Now let's do a trick. Assuming we have a pretrained model. In our case the model returns an array $x \in \mathbb{R}^{10}$ because it is trained on 10 classes, from zero to nine. Given that the model works, we want that this additional layer does not change predictions in classes from 0 to 8. So if we compose our mlp with a linear function $f:\mathbb{R}^{10} \to \mathbb{R}^{9}$ we are actually doing:

$$
Wx + b = y
$$

where $W \in \mathbb{R}^{9\times10}$ contains the parameters of the linear map $f$ and $y \in \mathbb{R}^{9}$ is the final output of the mlp.

Now setting biases to 0, we can chose the parameters in $W$ such that: $(Wx)_1 = x_1, ... , (Wx)_8 = x_8$. The matrix we are looking for is a (9 x 10) identity (all zeros and ones on the main diagonal).

Doing this we completely ignore what mlp does to the nine class. Seems cheating, but it works better than the previous experiments and is frighteningly similar to the result of the mlp trained without nines.

In reality here we are not unlearning: we are hiding a class. The final result is fantastic, but our mlp continues to "know that nines exist". 

In [None]:
# change parameters
for i, weights in enumerate(mlp_copy.parameters()):
    if i == 10:
        weights.data = torch.zeros_like(weights.data)
        weights.data[torch.arange(9),torch.arange(9)] = 1

    if i == 11:
        weights.data = torch.zeros_like(weights.data)

In [None]:
# evaluation
with torch.no_grad():
    preds = mlp_copy(X_valid)
    _, predicted = torch.max(preds, dim=1)
    
    print(confusion_matrix(Y_valid, predicted))
    print(classification_report(Y_valid, predicted, zero_division=0))

This requires no training, making it the fastest and most accurate in this case. However, we are not properly 'unlearning'. MLP structure is not the same, and weights does not change.

This approach can be useful if we intend to reintroduce the hidden class in classification later: we simply remove the additional layer.