In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import time
import copy

# --- Hyperparameters ---
batch_size = 128
learning_rate = 0.01
num_epochs = 20
ssd_lambda = 0.0001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
forget_classes = [0, 1]  # example classes to forget

# --- 1. Load CIFAR-10 ---
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
])

train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)

# --- 2. Load pretrained ResNet18 ---
resnet = torchvision.models.resnet18(weights='IMAGENET1K_V1')
resnet.fc = nn.Linear(resnet.fc.in_features, 10)  # replace final layer
resnet = resnet.to(device)

# Freeze feature extractor
for param in resnet.parameters():
    param.requires_grad = False
for param in resnet.fc.parameters():
    param.requires_grad = True

# --- 3. Train only classifier ---
optimizer = optim.SGD(resnet.fc.parameters(), lr=learning_rate, momentum=0.9)
criterion = nn.CrossEntropyLoss()

print("[1] Initial training of classifier on FULL dataset...")
best_model = copy.deepcopy(resnet.state_dict())

for epoch in range(num_epochs):
    resnet.train()
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = resnet(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

    # Eval
    resnet.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = resnet(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    acc = correct / total
    print(f"Epoch {epoch+1}/{num_epochs}: Test acc {acc:.3f}")

torch.save(best_model, 'resnet18_pre_unlearn_fc.pth')
print("Saved pre-unlearn model.")

# --- 4. Fisher Information for fc layer only ---
def compute_fisher(model, data_loader):
    fisher = torch.zeros_like(model.fc.weight)
    model.eval()
    for inputs, targets in data_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        fisher += model.fc.weight.grad.data.pow(2)
        model.zero_grad()
    fisher /= len(data_loader)
    return fisher

print("[2] SSD unlearning on classifier...")
start_ssd = time.time()

fisher = compute_fisher(resnet, train_loader)

# Apply SSD steps
for _ in range(5):
    with torch.no_grad():
        resnet.fc.weight -= ssd_lambda * fisher * resnet.fc.weight

ssd_time = time.time() - start_ssd
print(f"SSD time: {ssd_time:.2f} seconds")

# --- 5. Fine-tune fc on retained classes ---
def is_retained(targets):
    mask = torch.ones_like(targets, dtype=torch.bool)
    for c in forget_classes:
        mask &= (targets != c)
    return mask

optimizer2 = optim.SGD(resnet.fc.parameters(), lr=learning_rate, momentum=0.9)

for fine_epoch in range(3):
    resnet.train()
    for inputs, targets in train_loader:
        mask = is_retained(targets)
        if mask.sum() == 0:
            continue
        inputs, targets = inputs[mask], targets[mask]
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer2.zero_grad()
        outputs = resnet(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer2.step()

# Post-unlearning test acc
resnet.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, targets in test_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = resnet(inputs)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
post_unlearn_acc = correct / total

print(f"Post-unlearning test acc: {post_unlearn_acc:.3f}")

# --- 6. Retrain classifier baseline ---
print("[3] Baseline retraining...")

retrain_model = torchvision.models.resnet18(weights='IMAGENET1K_V1')
retrain_model.fc = nn.Linear(retrain_model.fc.in_features, 10)
retrain_model = retrain_model.to(device)

for param in retrain_model.parameters():
    param.requires_grad = False
for param in retrain_model.fc.parameters():
    param.requires_grad = True

optimizer3 = optim.SGD(retrain_model.fc.parameters(), lr=learning_rate, momentum=0.9)

start_retrain = time.time()

for epoch in range(num_epochs):
    retrain_model.train()
    for inputs, targets in train_loader:
        mask = is_retained(targets)
        if mask.sum() == 0:
            continue
        inputs, targets = inputs[mask], targets[mask]
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer3.zero_grad()
        outputs = retrain_model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer3.step()

retrain_time = time.time() - start_retrain

# Evaluate retrained
retrain_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, targets in test_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = retrain_model(inputs)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
retrain_acc = correct / total

print(f"Retrained test acc: {retrain_acc:.3f}")
print(f"Retrain time: {retrain_time:.2f} seconds")

# --- 7. Final Summary ---
print("\n=== Summary ===")
print(f"Pre-unlearning model   test acc: {acc:.3f}")
print(f"Post-unlearning model  test acc: {post_unlearn_acc:.3f}")
print(f"Retrained baseline model test acc: {retrain_acc:.3f}")
print(f"SSD time: {ssd_time:.2f} sec | Retrain time: {retrain_time:.2f} sec")


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/sameer/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:03<00:00, 12.7MB/s]


[1] Initial training of classifier on FULL dataset...
Epoch 1/20: Test acc 0.792


KeyboardInterrupt: 