In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Subset
import numpy as np
import time

# ====================== Configuration ======================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
learning_rate = 0.01        # initial LR
num_epochs = 20             # training epochs
ssd_lambda = 1e-4           # SSD strength
forget_classes = [0, 1]     # classes to unlearn

# Paths for checkpoints
pre_unlearn_path = 'resnet18_pre_unlearn.pth'
post_unlearn_path = 'resnet18_post_unlearn.pth'
baseline_path     = 'resnet18_baseline.pth'

# ====================== Data Preparation ======================
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])

# Full CIFAR-10 datasets
full_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
full_test  = datasets.CIFAR10(root='./data', train=False,download=True, transform=transform_test)

test_loader  = DataLoader(full_test,  batch_size=batch_size, shuffle=False, num_workers=4)
train_loader = DataLoader(full_train, batch_size=batch_size, shuffle=True,  num_workers=4)

# Forget/retain subsets from full training set
forget_idx = [i for i, (_, t) in enumerate(full_train) if t in forget_classes]
retain_idx = [i for i, (_, t) in enumerate(full_train) if t not in forget_classes]
forget_loader = DataLoader(Subset(full_train, forget_idx), batch_size=batch_size, shuffle=True, num_workers=4)
retain_loader = DataLoader(Subset(full_train, retain_idx), batch_size=batch_size, shuffle=True, num_workers=4)

# ====================== Model & Utilities ======================

def get_model():
    return models.resnet18(pretrained=False, num_classes=10).to(device)

criterion = nn.CrossEntropyLoss()

# Train/eval routines

def train_one_epoch(model, loader, optimizer):
    model.train()
    total_loss = correct = total = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
        preds = out.argmax(1)
        correct += preds.eq(y).sum().item()
        total += y.size(0)
    return total_loss / total, correct / total


def evaluate(model, loader):
    model.eval()
    total_loss = correct = total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            total_loss += loss.item() * x.size(0)
            preds = out.argmax(1)
            correct += preds.eq(y).sum().item()
            total += y.size(0)
    return total_loss / total, correct / total

# ====================== 1) Initial Training ======================
model = get_model()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=num_epochs//2, gamma=0.1)
print("[1] Initial training on FULL dataset...")
for epoch in range(1, num_epochs + 1):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer)
    test_loss, test_acc = evaluate(model, test_loader)
    scheduler.step()
    print(f"Epoch {epoch}/{num_epochs}: Train acc {train_acc:.3f}, Test acc {test_acc:.3f}")
# save pre-unlearn
torch.save(model.state_dict(), pre_unlearn_path)
print(f"Saved pre-unlearn model to {pre_unlearn_path}")

# ====================== 2) SSD Unlearning (timed) ======================
print("[2] Starting SSD unlearning pipeline...")
start_ssd = time.time()
# compute Fisher
fisher = {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad}
samples = 0
model.eval()
for x, y in forget_loader:
    x, y = x.to(device), y.to(device)
    model.zero_grad()
    out = model(x)
    loss = criterion(out, y)
    loss.backward()
    samples += 1
    for n, p in model.named_parameters():
        if p.requires_grad and p.grad is not None:
            fisher[n] += p.grad.data.pow(2)
fisher = {n: v / samples for n, v in fisher.items()}
# apply SSD + fine-tune
for _ in range(5):
    with torch.no_grad():
        for n, p in model.named_parameters():
            if p.requires_grad:
                p.sub_(ssd_lambda * fisher[n] * p)
# fine-tune on retained classes
opt2 = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)
for _ in range(3):
    train_one_epoch(model, retain_loader, opt2)
# save post-unlearn
torch.save(model.state_dict(), post_unlearn_path)
post_loss, post_acc = evaluate(model, test_loader)
ssd_time = time.time() - start_ssd
print(f"SSD pipeline done in {ssd_time:.1f}s | Post-unlearn test acc: {post_acc:.3f}")

# ====================== 3) Baseline Retraining (timed) ======================
print("[3] Retraining baseline on retained data...")
start_retrain = time.time()
baseline = get_model()
opt3 = optim.SGD(baseline.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)
sched3 = optim.lr_scheduler.StepLR(opt3, step_size=num_epochs//2, gamma=0.1)
for epoch in range(1, num_epochs + 1):
    train_one_epoch(baseline, retain_loader, opt3)
    sched3.step()
baseline_time = time.time() - start_retrain
# save baseline
torch.save(baseline.state_dict(), baseline_path)
base_loss, base_acc = evaluate(baseline, test_loader)
print(f"Baseline retrain done in {baseline_time:.1f}s | Test acc: {base_acc:.3f}")

# ====================== 4) Summary ======================
print("\n=== Summary ===")
print(f"Pre-unlearn model       acc: {evaluate(get_model().load_state_dict(torch.load(pre_unlearn_path)), test_loader)[1]:.3f}")
print(f"Post-unlearn model      acc: {post_acc:.3f}")
print(f"Baseline retrain model  acc: {base_acc:.3f}")
print(f"SSD unlearning time     : {ssd_time:.1f}s")
print(f"Baseline retrain time   : {baseline_time:.1f}s")


[1] Initial training on FULL dataset...




Epoch 1/20: Train acc 0.389, Test acc 0.510
Epoch 2/20: Train acc 0.510, Test acc 0.571
Epoch 3/20: Train acc 0.578, Test acc 0.632
Epoch 4/20: Train acc 0.630, Test acc 0.644
Epoch 5/20: Train acc 0.658, Test acc 0.682
Epoch 6/20: Train acc 0.681, Test acc 0.693
Epoch 7/20: Train acc 0.700, Test acc 0.712
Epoch 8/20: Train acc 0.715, Test acc 0.721


KeyboardInterrupt: 