# Machine Unlearning: Retraining and Scrubbing in a K9db-Integrated System
Final Project for DS 593 Fall 2025

Tracy Cui, Yuki Li, Yang Lu, Xin Wei

## Code Notebook 3: Deletion with Exclusive Classes

Here we implemented the models on MNIST again, but assigned each user with a specific unique set. For example, user 1 is associated with all images with number 0. Code annotation is skipped for those code shared between notebooks 1 and 3 and we focused on explaining the changes.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import copy
import time
from tabulate import tabulate

CONFIG = {
    "batch_size": 128,
    "lr": 0.01,
    "epochs_m0": 5,
    "unlearn_epochs": 1,

    # Aggressive settings are allowed here because the gradients
    # for "0" are distinct from "1-9".
    "unlearn_lr": 0.02,   
    "fisher_samples": 1000, 
    "alpha": 0.5,           
    "max_scrub_steps": 50,  
    "total_users": 100,
    "zipf_param": 1.5, # (Ignored for this specific setup)
    "seed": 42,
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu")
}

torch.manual_seed(CONFIG["seed"])
np.random.seed(CONFIG["seed"])
print(f"Running on device: {CONFIG['device']}")

class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

def compute_fisher_diagonal(model, loader, num_samples=500):
    fisher_diag = {}
    for name, param in model.named_parameters():
        fisher_diag[name] = torch.zeros_like(param)

    model.eval()
    criterion = nn.CrossEntropyLoss()
    samples_seen = 0

    for inputs, labels in loader:
        inputs, labels = inputs.to(CONFIG["device"]), labels.to(CONFIG["device"])
        model.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        for name, param in model.named_parameters():
            if param.grad is not None:
                fisher_diag[name] += param.grad.data ** 2

        samples_seen += inputs.size(0)
        if samples_seen >= num_samples:
            break

    for name in fisher_diag:
        fisher_diag[name] /= samples_seen

    return fisher_diag

def fisher_scrubbing_step(model, inputs, labels, fisher_diag):
    model.train()
    model.zero_grad()

    # 1. Log Softmax
    outputs = F.log_softmax(model(inputs), dim=1)

    # 2. Target: Uniform Distribution (Confusion)
    # We want the model to have NO IDEA that these are zeros.
    batch_size = inputs.size(0)
    num_classes = 10
    uniform_target = torch.full((batch_size, num_classes), 1.0 / num_classes).to(CONFIG["device"])

    # 3. KL Divergence Loss
    criterion = nn.KLDivLoss(reduction='batchmean')
    loss = criterion(outputs, uniform_target)

    # 4. Backward
    loss.backward()

    # 5. Update
    with torch.no_grad():
        for name, param in model.named_parameters():
            if param.grad is not None:
                F_ii = fisher_diag[name]
                scale = 1.0 / (F_ii + CONFIG["alpha"])

                # Subtract gradient (Minimize KL)
                update = CONFIG["unlearn_lr"] * scale * param.grad
                update.clamp_(min=-0.05, max=0.05)

                param.sub_(update)

def run_training(model, loader, epochs):
    optimizer = optim.SGD(model.parameters(), lr=CONFIG["lr"], momentum=0.9)
    criterion = nn.CrossEntropyLoss()
    model.train()
    for _ in range(epochs):
        for inputs, labels in loader:
            inputs, labels = inputs.to(CONFIG["device"]), labels.to(CONFIG["device"])
            optimizer.zero_grad()
            loss = criterion(model(inputs), labels)
            loss.backward()
            optimizer.step()

def evaluate(model, loader):
    model.eval()
    correct = 0; total = 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(CONFIG["device"]), labels.to(CONFIG["device"])
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    if total == 0: return 0.0
    return (100 * correct / total)

def get_model_dist(model_a, model_b):
    dist = 0.0
    for p1, p2 in zip(model_a.parameters(), model_b.parameters()):
        dist += torch.norm(p1 - p2, p=2).item()
    return dist

# DATA SETUP (here we have a special setup for how data get splited)

print("Loading MNIST...")
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
full_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

print("Constructing Class-Split (User 0 = Digit 0)...")

# Identify indices
targets = full_dataset.targets
forget_indices = (targets == 0).nonzero(as_tuple=True)[0].tolist()
retain_indices_pool = (targets != 0).nonzero(as_tuple=True)[0].tolist()

data_map = {}

# User 0 gets ALL Zeros
data_map[0] = forget_indices

# Distribute Digits 1-9 among other users (randomly; no one-to-one mapping)
remaining_users = CONFIG["total_users"] - 1
imgs_per_user = len(retain_indices_pool) // remaining_users

curr_idx = 0
for u in range(1, CONFIG["total_users"]):
    start = curr_idx
    end = curr_idx + imgs_per_user
    if u == CONFIG["total_users"] - 1:
        data_map[u] = retain_indices_pool[start:]
    else:
        data_map[u] = retain_indices_pool[start:end]
    curr_idx += imgs_per_user

print(f"User 0 (Forget Target) has {len(data_map[0])} images (All Zeros)")
print(f"Other Users share {len(retain_indices_pool)} images (Digits 1-9)")

# generate result tables 
print("\n--- [Phase 1] Generating Table 1 ---")

# 1. Train Base Model (All Digits 0-9)
print("Training Base Model (M_All)...")
base_loader = DataLoader(full_dataset, batch_size=CONFIG["batch_size"], shuffle=True)
base_model = SimpleMLP().to(CONFIG["device"])
run_training(base_model, base_loader, CONFIG["epochs_m0"])

# Setup Loaders
target_user = 0
forget_idx = data_map[target_user]
retain_idx = [i for i in range(len(full_dataset)) if i not in forget_idx]

retain_loader = DataLoader(Subset(full_dataset, retain_idx), batch_size=CONFIG["batch_size"], shuffle=True)
forget_loader = DataLoader(Subset(full_dataset, forget_idx), batch_size=CONFIG["batch_size"], shuffle=True)

# 2. Train M0 (Retrain on only Digits 1-9)
print("Running Retrain Benchmark (M0 - No Zeros)...")
start = time.time()
m0_model = SimpleMLP().to(CONFIG["device"])
run_training(m0_model, retain_loader, CONFIG["epochs_m0"])
time_m0 = time.time() - start

# 3. Train M2 (Fisher Scrubbing to remove Zeros)
print("Running Fisher Scrub (M2)...")
start = time.time()
m2_model = copy.deepcopy(base_model)
fisher_diag = compute_fisher_diagonal(m2_model, retain_loader, CONFIG["fisher_samples"])

scrub_steps = 0
for _ in range(CONFIG["unlearn_epochs"]):
    for inputs, labels in forget_loader:
        fisher_scrubbing_step(m2_model, inputs.to(CONFIG["device"]), labels.to(CONFIG["device"]), fisher_diag)
        scrub_steps += 1
        if scrub_steps >= CONFIG["max_scrub_steps"]:
            break
    if scrub_steps >= CONFIG["max_scrub_steps"]:
        break
time_m2 = time.time() - start

# Metrics
# Note: M0 Acc on Forget Set should be ~0.0% (It never saw a zero)
acc_m0_r = evaluate(m0_model, retain_loader)
acc_m2_r = evaluate(m2_model, retain_loader)
acc_m0_f = evaluate(m0_model, forget_loader)
acc_m2_f = evaluate(m2_model, forget_loader)
w_dist = get_model_dist(m0_model, m2_model)

table_data = [
    ["Metric", "Full Retrain (M0)", "Fisher Scrub (M2)"],
    ["Retain Acc (%) (Digits 1-9)", f"{acc_m0_r:.2f}", f"{acc_m2_r:.2f}"],
    ["Forget Acc (%) (Digit 0)", f"{acc_m0_f:.2f}", f"{acc_m2_f:.2f}"],
    ["Weight Dist (L2)", "0.0 (Ref)", f"{w_dist:.2f}"],
    ["Runtime (s)", f"{time_m0:.2f}", f"{time_m2:.2f}"]
]
print("\n" + tabulate(table_data, headers="firstrow", tablefmt="grid"))

Running on device: cuda
Loading MNIST...
Constructing Class-Split (User 0 = Digit 0)...
User 0 (Forget Target) has 5923 images (All Zeros)
Other Users share 54077 images (Digits 1-9)

--- [Phase 1] Generating Table 1 ---
Training Base Model (M_All)...
Running Retrain Benchmark (M0 - No Zeros)...
Running Fisher Scrub (M2)...

+-----------------------------+---------------------+---------------------+
| Metric                      | Full Retrain (M0)   |   Fisher Scrub (M2) |
| Retain Acc (%) (Digits 1-9) | 98.91               |               95.63 |
+-----------------------------+---------------------+---------------------+
| Forget Acc (%) (Digit 0)    | 0.00                |                0.03 |
+-----------------------------+---------------------+---------------------+
| Weight Dist (L2)            | 0.0 (Ref)           |               39.82 |
+-----------------------------+---------------------+---------------------+
| Runtime (s)                 | 53.53               |            