In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Subset
import numpy as np

# ====================== Configuration ======================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
learning_rate = 0.01        # initial LR
num_epochs = 20             # training epochs
ssd_lambda = 1e-4           # SSD strength
forget_classes = [0, 1]     # classes to unlearn

# Paths for checkpoints
pre_unlearn_path = 'resnet18_pre_unlearn.pth'
post_unlearn_path = 'resnet18_post_unlearn.pth'
baseline_path     = 'resnet18_baseline.pth'
teacher_path = 'resnet18_teacher.pth'

# ====================== Data Preparation ======================
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])

full_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
full_test  = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

# Subsample speed-up: 5000 train, 1000 test
np.random.seed(42)
train_idx = np.random.permutation(len(full_train))[:5000]
test_idx  = np.random.permutation(len(full_test))[:1000]
train_dataset = Subset(full_train, train_idx)
test_dataset  = Subset(full_test,  test_idx)
train_loader   = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader    = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, num_workers=4)

# Forget and retain subsets from sampled train
forget_idx = [i for i,(_,t) in enumerate(train_dataset) if t in forget_classes]
retain_idx = [i for i,(_,t) in enumerate(train_dataset) if t not in forget_classes]
forget_loader = DataLoader(Subset(train_dataset, forget_idx), batch_size=batch_size, shuffle=True, num_workers=4)
retain_loader = DataLoader(Subset(train_dataset, retain_idx), batch_size=batch_size, shuffle=True, num_workers=4)

# ====================== Model & Utilities ======================

def get_model():
    m = models.resnet18(pretrained=False, num_classes=10).to(device)
    return m

criterion = nn.CrossEntropyLoss()

# Train/eval routines

def train_one_epoch(model, loader, optimizer):
    model.train()
    total_loss = correct = total = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
        preds = out.argmax(1)
        correct += preds.eq(y).sum().item()
        total += y.size(0)
    return total_loss / total, correct / total


def evaluate(model, loader):
    model.eval()
    total_loss = correct = total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            total_loss += loss.item() * x.size(0)
            preds = out.argmax(1)
            correct += preds.eq(y).sum().item()
            total += y.size(0)
    return total_loss / total, correct / total

# ====================== 1) Initial Training ======================
model = get_model()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=num_epochs // 2, gamma=0.1)
print("[1] Initial training...")
for epoch in range(1, num_epochs + 1):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer)
    test_loss, test_acc = evaluate(model, test_loader)
    scheduler.step()
    print(f"Epoch {epoch}/{num_epochs}: Train acc {train_acc:.3f}, Test acc {test_acc:.3f}")
# save pre-unlearn
torch.save(model.state_dict(), pre_unlearn_path)

# ====================== 2) SSD Unlearning ======================
# compute Fisher
print("[2] Computing Fisher on forget set...")
fisher = {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad}
samples = 0
model.eval()
for x, y in forget_loader:
    x, y = x.to(device), y.to(device)
    model.zero_grad()
    out = model(x)
    loss = criterion(out, y)
    loss.backward()
    samples += 1
    for n, p in model.named_parameters():
        if p.requires_grad and p.grad is not None:
            fisher[n] += p.grad.data.pow(2)
fisher = {n: v / samples for n, v in fisher.items()}
# apply multi-step SSD
grain_steps = 5
print(f"Applying {grain_steps} SSD steps...")
for _ in range(grain_steps):
    with torch.no_grad():
        for n, p in model.named_parameters():
            if p.requires_grad:
                p.sub_(ssd_lambda * fisher[n] * p)
# fine-tune on retain set
opt2 = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)
print("Fine-tuning retained classes...")
for _ in range(3):
    _, acc = train_one_epoch(model, retain_loader, opt2)
    print(f" Fine-tune acc: {acc:.3f}")
# save post-unlearn
torch.save(model.state_dict(), post_unlearn_path)
post_loss, post_acc = evaluate(model, test_loader)
print(f"Post-unlearn test acc: {post_acc:.3f}")

# ====================== 3) Incompetent Teacher Unlearning ======================
# Load Teacher Model (Pre-unlearned model)
teacher_model = get_model()
teacher_model.load_state_dict(torch.load(pre_unlearn_path))
teacher_model.eval()
print("[3] Incompetent Teacher unlearning...")

# Define student model (to be trained with distillation)
student_model = get_model()
optimizer_student = optim.SGD(student_model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)

def distillation_loss(student_output, teacher_output, temperature=3.0, alpha=0.7):
    soft_target = nn.functional.softmax(teacher_output / temperature, dim=1)
    hard_target = nn.functional.softmax(student_output / temperature, dim=1)
    return nn.KLDivLoss(reduction='batchmean')(hard_target.log(), soft_target) * (temperature ** 2) * alpha + nn.CrossEntropyLoss()(student_output, hard_target.argmax(1)) * (1. - alpha)

# Train student model on retained data
for epoch in range(1, num_epochs + 1):
    student_model.train()
    total_loss = 0
    for x, y in retain_loader:
        x, y = x.to(device), y.to(device)
        optimizer_student.zero_grad()
        student_out = student_model(x)
        teacher_out = teacher_model(x)
        loss = distillation_loss(student_out, teacher_out)
        loss.backward()
        optimizer_student.step()
        total_loss += loss.item() * x.size(0)
    print(f"Epoch {epoch}/{num_epochs} - Student training loss: {total_loss / len(retain_loader):.4f}")

# Save the student model
torch.save(student_model.state_dict(), teacher_path)
student_loss, student_acc = evaluate(student_model, test_loader)
print(f"Student model test acc after distillation: {student_acc:.3f}")

# ====================== 4) Summary Comparison ======================
print("\n=== Summary ===")
# reload to ensure fairness
t1 = get_model(); t1.load_state_dict(torch.load(pre_unlearn_path)); _, pre_acc = evaluate(t1, test_loader)
print(f"Pre-unlearning model   test acc: {pre_acc:.3f}")
t2 = get_model(); t2.load_state_dict(torch.load(post_unlearn_path)); _, post_acc = evaluate(t2, test_loader)
print(f"Post-unlearning model  test acc: {post_acc:.3f}")
t3 = get_model(); t3.load_state_dict(torch.load(baseline_path)); _, base_acc = evaluate(t3, test_loader)
print(f"Retrained baseline model test acc: {base_acc:.3f}")
t4 = get_model(); t4.load_state_dict(torch.load(teacher_path)); _, student_acc = evaluate(t4, test_loader)
print(f"Student (Incompetent Teacher) model test acc: {student_acc:.3f}")


[1] Initial training...




Epoch 1/20: Train acc 0.218, Test acc 0.247
Epoch 2/20: Train acc 0.283, Test acc 0.323
Epoch 3/20: Train acc 0.310, Test acc 0.346
Epoch 4/20: Train acc 0.361, Test acc 0.412
Epoch 5/20: Train acc 0.391, Test acc 0.357
Epoch 6/20: Train acc 0.392, Test acc 0.449
Epoch 7/20: Train acc 0.426, Test acc 0.404
Epoch 8/20: Train acc 0.436, Test acc 0.433
Epoch 9/20: Train acc 0.469, Test acc 0.437
Epoch 10/20: Train acc 0.481, Test acc 0.451
Epoch 11/20: Train acc 0.529, Test acc 0.495
Epoch 12/20: Train acc 0.544, Test acc 0.507
Epoch 13/20: Train acc 0.557, Test acc 0.513
Epoch 14/20: Train acc 0.557, Test acc 0.518
Epoch 15/20: Train acc 0.573, Test acc 0.521
Epoch 16/20: Train acc 0.576, Test acc 0.527
Epoch 17/20: Train acc 0.577, Test acc 0.528
Epoch 18/20: Train acc 0.587, Test acc 0.524
Epoch 19/20: Train acc 0.588, Test acc 0.530
Epoch 20/20: Train acc 0.589, Test acc 0.535
[2] Computing Fisher on forget set...
Applying 5 SSD steps...
Fine-tuning retained classes...
 Fine-tune acc: