In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Subset
import numpy as np

# ====================== Configuration ======================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
learning_rate = 0.01        # initial LR
num_epochs = 20             # training epochs
ssd_lambda = 1e-4           # SSD strength
forget_classes = [0, 1]     # classes to unlearn

# Paths for checkpoints
pre_unlearn_path = 'resnet18_pre_unlearn.pth'
post_unlearn_path = 'resnet18_post_unlearn.pth'
baseline_path     = 'resnet18_baseline.pth'

# ====================== Data Preparation ======================
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])

full_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
full_test  = datasets.CIFAR10(root='./data', train=False,download=True, transform=transform_test)

# Subsample speed-up: 5000 train, 1000 test
np.random.seed(42)
train_idx = np.random.permutation(len(full_train))[:5000]
test_idx  = np.random.permutation(len(full_test))[:1000]
train_dataset = Subset(full_train, train_idx)
test_dataset  = Subset(full_test,  test_idx)
train_loader   = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader    = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, num_workers=4)

# Forget and retain subsets from sampled train
forget_idx = [i for i,(_,t) in enumerate(train_dataset) if t in forget_classes]
retain_idx = [i for i,(_,t) in enumerate(train_dataset) if t not in forget_classes]
forget_loader = DataLoader(Subset(train_dataset, forget_idx), batch_size=batch_size, shuffle=True, num_workers=4)
retain_loader = DataLoader(Subset(train_dataset, retain_idx), batch_size=batch_size, shuffle=True, num_workers=4)

# ====================== Model & Utilities ======================

def get_model():
    m = models.resnet18(pretrained=False, num_classes=10).to(device)
    return m

criterion = nn.CrossEntropyLoss()

# Train/eval routines

def train_one_epoch(model, loader, optimizer):
    model.train()
    total_loss = correct = total = 0
    for x,y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out,y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()*x.size(0)
        preds = out.argmax(1)
        correct += preds.eq(y).sum().item()
        total += y.size(0)
    return total_loss/total, correct/total


def evaluate(model, loader):
    model.eval()
    total_loss = correct = total = 0
    with torch.no_grad():
        for x,y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out,y)
            total_loss += loss.item()*x.size(0)
            preds = out.argmax(1)
            correct += preds.eq(y).sum().item()
            total += y.size(0)
    return total_loss/total, correct/total

# ====================== 1) Initial Training ======================
model = get_model()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=num_epochs//2, gamma=0.1)
print("[1] Initial training...")
for epoch in range(1, num_epochs+1):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer)
    test_loss, test_acc   = evaluate(model, test_loader)
    scheduler.step()
    print(f"Epoch {epoch}/{num_epochs}: Train acc {train_acc:.3f}, Test acc {test_acc:.3f}")
# save pre-unlearn
torch.save(model.state_dict(), pre_unlearn_path)

# ====================== 2) SSD Unlearning ======================
# compute Fisher
print("[2] Computing Fisher on forget set...")
fisher = {n: torch.zeros_like(p) for n,p in model.named_parameters() if p.requires_grad}
samples=0
model.eval()
for x,y in forget_loader:
    x,y = x.to(device), y.to(device)
    model.zero_grad()
    out = model(x)
    loss = criterion(out,y)
    loss.backward()
    samples+=1
    for n,p in model.named_parameters():
        if p.requires_grad and p.grad is not None:
            fisher[n]+=p.grad.data.pow(2)
fisher = {n: v/samples for n,v in fisher.items()}
# apply multi-step SSD
grain_steps=5
print(f"Applying {grain_steps} SSD steps...")
for _ in range(grain_steps):
    with torch.no_grad():
        for n,p in model.named_parameters():
            if p.requires_grad:
                p.sub_(ssd_lambda * fisher[n] * p)
# fine-tune on retain set
opt2 = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)
print("Fine-tuning retained classes...")
for _ in range(3):
    _,acc = train_one_epoch(model, retain_loader, opt2)
    print(f" Fine-tune acc: {acc:.3f}")
# save post-unlearn
torch.save(model.state_dict(), post_unlearn_path)
post_loss, post_acc = evaluate(model, test_loader)
print(f"Post-unlearn test acc: {post_acc:.3f}")

# ====================== 3) Baseline Training (no forget data) ======================
baseline = get_model()
opt3 = optim.SGD(baseline.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)
sched3 = optim.lr_scheduler.StepLR(opt3, step_size=num_epochs//2, gamma=0.1)
print("[3] Baseline training on retained classes...")
for epoch in range(1, num_epochs+1):
    _,acc = train_one_epoch(baseline, retain_loader, opt3)
    sched3.step()
    if epoch % 5 == 0:
        print(f" Epoch {epoch}: Re-trained acc {acc:.3f}")
torch.save(baseline.state_dict(), baseline_path)
base_loss, base_acc = evaluate(baseline, test_loader)
print(f"Baseline test acc: {base_acc:.3f}")

# ====================== 4) Summary Comparison ======================
print("\n=== Summary ===")
# reload to ensure fairness
t1 = get_model(); t1.load_state_dict(torch.load(pre_unlearn_path)); _,pre_acc = evaluate(t1, test_loader)
print(f"Pre-unlearning model   test acc: {pre_acc:.3f}")
t2 = get_model(); t2.load_state_dict(torch.load(post_unlearn_path)); _,post_acc = evaluate(t2, test_loader)
print(f"Post-unlearning model  test acc: {post_acc:.3f}")
t3 = get_model(); t3.load_state_dict(torch.load(baseline_path)); _,base_acc = evaluate(t3, test_loader)
print(f"Retrained baseline model test acc: {base_acc:.3f}")


[1] Initial training...




Epoch 1/20: Train acc 0.204, Test acc 0.217
Epoch 2/20: Train acc 0.288, Test acc 0.318
Epoch 3/20: Train acc 0.337, Test acc 0.348
Epoch 4/20: Train acc 0.360, Test acc 0.392
Epoch 5/20: Train acc 0.384, Test acc 0.384
Epoch 6/20: Train acc 0.420, Test acc 0.421
Epoch 7/20: Train acc 0.427, Test acc 0.412
Epoch 8/20: Train acc 0.440, Test acc 0.432
Epoch 9/20: Train acc 0.469, Test acc 0.407
Epoch 10/20: Train acc 0.475, Test acc 0.462
Epoch 11/20: Train acc 0.537, Test acc 0.520
Epoch 12/20: Train acc 0.555, Test acc 0.541
Epoch 13/20: Train acc 0.551, Test acc 0.525
Epoch 14/20: Train acc 0.560, Test acc 0.548
Epoch 15/20: Train acc 0.566, Test acc 0.529
Epoch 16/20: Train acc 0.568, Test acc 0.534
Epoch 17/20: Train acc 0.575, Test acc 0.552
Epoch 18/20: Train acc 0.578, Test acc 0.542
Epoch 19/20: Train acc 0.578, Test acc 0.548
Epoch 20/20: Train acc 0.582, Test acc 0.553
[2] Computing Fisher on forget set...
Applying 5 SSD steps...
Fine-tuning retained classes...
 Fine-tune acc:

: 