# Retrain-from-scratch Unlearning on MNIST

## Importing Libraries and Data

In [218]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F

## Define The FCNN

In [219]:
class MnistNN(nn.Module):
    def __init__(self, input_size, num_classes):
        super(MnistNN, self).__init__()
        self.fc1 = nn.Linear(input_size, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, num_classes)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

## Split Data into Training and Testing Sets

In [220]:
batch_size = 32

transform = transforms.Compose([transforms.ToTensor(), 
                                             transforms.Normalize((0.5,), (0.5,))])

mnist_testset = datasets.MNIST(root='../data', train=False,
                                                download=True, transform=transform)
mnist_testloader = torch.utils.data.DataLoader(mnist_testset, batch_size=batch_size,
                                                    shuffle=False, num_workers=2)
mnist_trainset = datasets.MNIST(root='../data', train=True,
                                download=True, transform=transform)
mnist_trainloader = torch.utils.data.DataLoader(mnist_trainset, batch_size=batch_size,
                                                    shuffle=False, num_workers=2)

## Split Training Set Into Forgetting Set 
Split the training set in half for unlearning later

In [221]:
total_size = len(mnist_trainset)
retain_size = int(0.10 * total_size)
unlearn_size = total_size - retain_size

forgetting_subset, retain_subset = torch.utils.data.random_split(
        mnist_trainset, 
    [unlearn_size, retain_size]
)

# Used to test if the unlearned set is recognised
retain_loader = torch.utils.data.DataLoader(retain_subset, batch_size=batch_size, shuffle=True, num_workers=2)

# Used to train the model without the forgotten dataset
forget_loader = torch.utils.data.DataLoader(forgetting_subset, batch_size=batch_size, shuffle=True, num_workers=2)

print(f"Total training samples: {total_size}")
print(f"Unlearn set size: {unlearn_size} ({unlearn_size/total_size*100:.1f}%)")
print(f"Retain set size: {retain_size} ({retain_size/total_size*100:.1f}%)")

Total training samples: 60000
Unlearn set size: 54000 (90.0%)
Retain set size: 6000 (10.0%)


## Create Neural Network Instance

In [222]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MnistNN(input_size=784, num_classes=10).to(device)

## Test Untrained Model
Test the untrained model to compare later accuracy

In [223]:
def calculate_accuracy(loader, model):
        number_correct = 0
        number_samples = 0
        model.eval()

        with torch.no_grad():
            for x, y in loader:
                x = x.to(device).reshape(x.shape[0], -1)
                y = y.to(device)
                scores = model(x)
                _, predictions = scores.max(1)
                number_correct += (predictions == y).sum()
                number_samples += predictions.size(0)
        model.train()
        return number_correct / number_samples

untrained_model_accuracy = calculate_accuracy(mnist_testloader, model)
print(f"Model accuracy before training: {untrained_model_accuracy}")

Model accuracy before training: 0.0982000008225441


## Define Model Parameteres

In [224]:
epochs = 5
lr=5e-3
optimizer = torch.optim.SGD(model.parameters(), lr, momentum=0.9)
criterion = nn.CrossEntropyLoss()
print(f"Model Parameters: learning rate = {lr}, epochs = {epochs} and batch size = {batch_size}")

Model Parameters: learning rate = 0.005, epochs = 5 and batch size = 32


## Train Model

In [225]:
print(f"Training with device: {device}")
total_loss = 0
total_batches = 0

model.train()
for epoch in range(epochs):
    for batch_idx, (data, targets) in enumerate(mnist_trainloader):
        data = data.to(device).reshape(data.shape[0], -1)
        targets = targets.to(device)
        scores = model(data)
        loss = criterion(scores, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        total_batches += 1
average_loss = total_loss / total_batches
trained_model_accuracy = calculate_accuracy(mnist_testloader, model)
print(f"Model trained with average loss: {average_loss:.4f} and accuracy: {trained_model_accuracy}")

Training with device: cuda
Model trained with average loss: 0.2849 and accuracy: 0.9601999521255493


## Test Trained Model

In [226]:
global_accuracy_trained = calculate_accuracy(mnist_testloader, model)
print(f"Global Model Accuracy After Training: {global_accuracy_trained}")

Global Model Accuracy After Training: 0.9601999521255493


## Discard Trained Model
Erase the trained model by overwriting the variable

In [227]:
model = MnistNN(input_size=784, num_classes=10).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr, momentum=0.9)

## Test Erased Model
Test accuracy to ensure the network has been replaced with a new, untrained network

In [228]:
erased_model_accuracy = calculate_accuracy(mnist_testloader, model)
print(f"Model accuracy before training: {untrained_model_accuracy}")
print(f"Mode accuracy after training: {trained_model_accuracy}")
print(f"Model accuracy after erasing (current): {erased_model_accuracy}")

Model accuracy before training: 0.0982000008225441
Mode accuracy after training: 0.9601999521255493
Model accuracy after erasing (current): 0.0957999974489212


## Retrain Model Without Forgetting Set

In [229]:
print(f"Training with device: {device}")
total_loss = 0
total_batches = 0

model.train()
for epoch in range(epochs):
    for batch_idx, (data, targets) in enumerate(retain_loader):
        data = data.to(device).reshape(data.shape[0], -1)
        targets = targets.to(device)
        scores = model(data)
        loss = criterion(scores, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        total_batches += 1
average_loss = total_loss / total_batches
unlearned_model_testset_accuracy = calculate_accuracy(mnist_testloader, model)
print(f"Model trained with average loss: {average_loss:.4f} and accuracy: {unlearned_model_testset_accuracy}")

Training with device: cuda
Model trained with average loss: 0.8253 and accuracy: 0.9001999497413635


## Test Unlearned Model

In [230]:
unlearned_model_forgotten_set_accuracy = calculate_accuracy(forget_loader, model)

## Results

In [231]:
print(f"Model accuracy after training (Mnist full trainset): {trained_model_accuracy}")
print()
print(f"Unlearned model accuracy on unlearned set (Other half of trainset): {unlearned_model_forgotten_set_accuracy}")
print(f"Unlearned model accuracy on test set (One half of trainset): {unlearned_model_testset_accuracy}")
print(f"Performance drop on test set: {trained_model_accuracy - unlearned_model_testset_accuracy:.4f}")

Model accuracy after training (Mnist full trainset): 0.9601999521255493

Unlearned model accuracy on unlearned set (Other half of trainset): 0.8989999890327454
Unlearned model accuracy on test set (One half of trainset): 0.9001999497413635
Performance drop on test set: 0.0600


## Forgetting Effectiveness
Testing the data on the unlearned data

In [232]:
print(f"Forgetting Effectiveness: {unlearned_model_forgotten_set_accuracy} (lower is better)")

Forgetting Effectiveness: 0.8989999890327454 (lower is better)


## Utility Preservation
Accuracy retention on non-unlearned data

In [233]:
utility_preservation = unlearned_model_testset_accuracy / trained_model_accuracy
print(f"Utility Preservation: {utility_preservation:.4f} (closer to 1.0 is better)")

Utility Preservation: 0.9375 (closer to 1.0 is better)
