<a href="https://colab.research.google.com/github/asmik12/Coursework_UCLA/blob/main/ECE-M117-Intro-to-Computer-Security/part3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ECE 117 - Assignment 3: Part 3
The goal of this part of the assignment is to implement machine unlearning via fine-tuning.

In [None]:
import os
import requests
import numpy as np
import matplotlib.pyplot as plt
from sklearn import linear_model, model_selection

import torch
from torch import nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from torchvision.models import resnet18

import tqdm

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", DEVICE.upper())

RNG = torch.Generator().manual_seed(42)

In [None]:
normalize = transforms.Compose([transforms.ToTensor()])

train_set = torchvision.datasets.FashionMNIST(
    root="./data", train=True, download=True, transform=normalize
)
train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)

held_out = torchvision.datasets.FashionMNIST(
    root="./data", train=False, download=True, transform=normalize
)
test_set, val_set = torch.utils.data.random_split(held_out, [0.5, 0.5], generator=RNG)
test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)
val_loader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2)

forget_class = 0
forget_idx, retain_idx = [], []
for i, target in enumerate(train_set.targets):
  if target == forget_class:
    forget_idx.append(i)
  else:
    retain_idx.append(i)

forget_set = torch.utils.data.Subset(train_set, forget_idx)
retain_set = torch.utils.data.Subset(train_set, retain_idx)

forget_loader = torch.utils.data.DataLoader(
    forget_set, batch_size=128, shuffle=True, num_workers=2
)
retain_loader = torch.utils.data.DataLoader(
    retain_set, batch_size=128, shuffle=True, num_workers=2, generator=RNG
)

In [None]:
# This is provided as a baseline model but feel free to adjust this.
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
model = CNN().to(DEVICE)

i_max = 6400

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

In [None]:
@torch.no_grad()
def get_accuracy(model, data_loader, device):
    correct = 0
    total = 0

    for inputs, labels in data_loader:
        inputs = inputs.to(DEVICE)
        labels = labels.to(DEVICE)

        outputs = model(inputs)
        _, predicted = torch.max(outputs, dim=1)

        total += labels.shape[0]
        correct += int((predicted == labels).sum())

    return correct / total

In [None]:
# First, train a baseline FashionMNIST CNN

progress = tqdm.tqdm(total=i_max, desc="Training")

i = 0
while i < i_max:
    for inputs, labels in train_loader:
        # Add training loop here

        i += 1
        progress.update(1)

        if i % 1000 == 0:
            train_acc = get_accuracy(model, train_loader, DEVICE)
            test_acc = get_accuracy(model, test_loader, DEVICE)
            progress.write(f"Iter {i} Train Acc {train_acc:.4f} Test Acc {test_acc:.4f}")

        if i >= i_max:
            break

torch.save(model.state_dict(), "./model.pth")

In [None]:
test_accuracy = get_accuracy(model, test_loader, DEVICE)
print(f"Test Accuracy: {test_accuracy}")

In [None]:
# Machine unlearning via fine-tuning
i_max =

progress = tqdm.tqdm(total=i_max, desc="Training")

i = 0
while i < i_max:
    for inputs, labels in retain_set:
        # modify loop to fine-tune model

        i += 1
        progress.update(1)

        if i % 1000 == 0:
            train_acc = get_accuracy(model, train_loader, DEVICE)
            test_acc = get_accuracy(model, test_loader, DEVICE)
            progress.write(f"Iter {i} Train Acc {train_acc:.4f} Test Acc {test_acc:.4f}")

        if i >= i_max:
            break

torch.save(model.state_dict(), "./model-unlearned.pth")

In [None]:
model_retain = CNN().to(DEVICE)

i_max = 6400

criterion_retain = torch.nn.CrossEntropyLoss()
optimizer_retain = torch.optim.Adam(model_retain.parameters(), lr=0.001, weight_decay=0.0001)

In [None]:
# Train only purely the retain set to benchmark

progress = tqdm.tqdm(total=i_max, desc="Training")

i = 0
while i < i_max:
    for inputs, labels in retain_loader:
        model_retain.train()

        inputs = inputs.to(DEVICE)
        labels = labels.to(DEVICE)

        outputs = model_retain(inputs)
        loss = criterion_retain(outputs, labels)

        optimizer_retain.zero_grad()
        loss.backward()
        optimizer_retain.step()

        i += 1
        progress.update(1)

        if i % 1000 == 0:
            train_acc = get_accuracy(model_retain, retain_loader, DEVICE)
            test_acc = get_accuracy(model_retain, test_loader, DEVICE)
            progress.write(f"Iter {i} Train Acc {train_acc:.4f} Test Acc {test_acc:.4f}")

        if i >= i_max:
            break

torch.save(model_retain.state_dict(), "./model-retain.pth")

In [None]:
def compute_losses(net, loader):
    criterion = nn.CrossEntropyLoss(reduction="none")
    all_losses = []

    for inputs, targets in loader:
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

        logits = net(inputs)
        losses = criterion(logits, targets).numpy(force=True)
        for l in losses:
            all_losses.append(l)

    return np.array(all_losses)


train_losses = compute_losses(model, train_loader)
test_losses = compute_losses(model, test_loader)

In [None]:
# plot losses on train and test set
plt.title("Losses on train and test set (trained model)")
plt.hist(test_losses, density=True, alpha=0.5, bins=50, label="Test set")
plt.hist(train_losses, density=True, alpha=0.5, bins=50, label="Train set")
plt.xlabel("Loss", fontsize=14)
plt.ylabel("Frequency", fontsize=14)
plt.xlim((0, np.max(test_losses)))
plt.yscale("log")
plt.legend(frameon=False, fontsize=14)
ax = plt.gca()
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plt.show()

In [None]:
def simple_mia(sample_loss, members, n_splits=10, random_state=0):
    unique_members = np.unique(members)
    if not np.all(unique_members == np.array([0, 1])):
        raise ValueError("members should only have 0 and 1s")

    attack_model = linear_model.LogisticRegression()
    cv = model_selection.StratifiedShuffleSplit(
        n_splits=n_splits, random_state=random_state
    )
    return model_selection.cross_val_score(
        attack_model, sample_loss, members, cv=cv, scoring="accuracy"
    )

In [None]:
forget_losses = compute_losses(model, forget_loader)

# Since we have more forget losses than test losses, sub-sample them, to have a class-balanced dataset.
np.random.shuffle(forget_losses)
forget_losses = forget_losses[: len(test_losses)]

samples_mia = np.concatenate((test_losses, forget_losses)).reshape((-1, 1))
labels_mia = [0] * len(test_losses) + [1] * len(forget_losses)

mia_scores = simple_mia(samples_mia, labels_mia)

print(
    f"The MIA has an accuracy of {mia_scores.mean():.3f} on forgotten vs unseen images"
)

In [None]:
# Benchmark model purely on the retain set.

ft_forget_losses = compute_losses(model_retain, forget_loader)
ft_test_losses = compute_losses(model_retain, test_loader)

# make sure we have a balanced dataset for the MIA
assert len(ft_test_losses) == len(ft_forget_losses)

ft_samples_mia = np.concatenate((ft_test_losses, ft_forget_losses)).reshape((-1, 1))
labels_mia = [0] * len(ft_test_losses) + [1] * len(ft_forget_losses)

In [None]:
ft_mia_scores = simple_mia(ft_samples_mia, labels_mia)

print(
    f"The MIA has an accuracy of {ft_mia_scores.mean():.3f} on forgotten vs unseen images"
)

In [None]:
# Compare the results to determine the efficacy of the machine-unlearning implementation

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

ax1.set_title(f"Unlearned by fine-tuning model.\nAttack accuracy: {mia_scores.mean():0.2f}")
ax1.hist(test_losses, density=True, alpha=0.5, bins=50, label="Test set")
ax1.hist(forget_losses, density=True, alpha=0.5, bins=50, label="Forget set")

ax2.set_title(
    f"Retained model performance.\nAttack accuracy: {ft_mia_scores.mean():0.2f}"
)
ax2.hist(ft_test_losses, density=True, alpha=0.5, bins=50, label="Test set")
ax2.hist(ft_forget_losses, density=True, alpha=0.5, bins=50, label="Forget set")

ax1.set_xlabel("Loss")
ax2.set_xlabel("Loss")
ax1.set_ylabel("Frequency")
ax1.set_yscale("log")
ax2.set_yscale("log")
ax1.set_xlim((0, np.max(test_losses)))
ax2.set_xlim((0, np.max(test_losses)))
for ax in (ax1, ax2):
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
ax1.legend(frameon=False, fontsize=14)
plt.show()