In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Subset
import numpy as np


In [2]:

# ====================== Configuration ======================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
learning_rate = 0.1
num_epochs = 5             # epochs to train initial model
ssd_lambda = 1e-3           # strength of synaptic dampening
forget_classes = [0, 1]     # CIFAR-10 classes to forget (e.g., 0: airplane, 1: automobile)


In [3]:

# ====================== Data Preparation ======================
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])


In [4]:
full_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
full_test = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

# Subsample 5000 train images and 1000 test images
np.random.seed(42)
train_idx = np.random.permutation(len(full_train))[:5000]
test_idx = np.random.permutation(len(full_test))[:1000]

train_dataset = Subset(full_train, train_idx)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

test_dataset = Subset(full_test, test_idx)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# Create forget subset within the small training set
forget_indices = [i for i, (_, t) in enumerate(train_dataset) if t in forget_classes]
forget_subset = Subset(train_dataset, forget_indices)
forget_loader = DataLoader(forget_subset, batch_size=batch_size, shuffle=True, num_workers=4)

In [5]:
# ==================== Model, Loss, Optimizer ======================
model = models.resnet18(pretrained=False, num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)




In [6]:

# ====================== Training Function ======================
def train(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, targets in loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(targets).sum().item()
        total += targets.size(0)
    return running_loss / total, correct / total


In [7]:

# ====================== Evaluation Function ======================
def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            running_loss += loss.item() * inputs.size(0)
            _, preds = outputs.max(1)
            correct += preds.eq(targets).sum().item()
            total += targets.size(0)
    return running_loss / total, correct / total


In [8]:

# ====================== Initial Training ======================
print("Starting initial training on all images...")
for epoch in range(num_epochs):
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    print(f"Epoch {epoch+1}/{num_epochs} - Train loss: {train_loss:.4f}, "
          f"Train acc: {train_acc:.4f}, Test acc: {test_acc:.4f}")


Starting initial training on all images...


Epoch 1/5 - Train loss: 4.0945, Train acc: 0.1292, Test acc: 0.1170
Epoch 2/5 - Train loss: 3.2387, Train acc: 0.1404, Test acc: 0.1670
Epoch 3/5 - Train loss: 2.7102, Train acc: 0.1752, Test acc: 0.2050
Epoch 4/5 - Train loss: 2.4321, Train acc: 0.1596, Test acc: 0.1830
Epoch 5/5 - Train loss: 2.3879, Train acc: 0.1562, Test acc: 0.1690


In [9]:

# ====================== Compute Fisher Information ======================
def compute_fisher(model, loader, criterion, device):
    model.eval()
    fisher = {name: torch.zeros_like(param) for name, param in model.named_parameters() if param.requires_grad}
    num_samples = 0

    for inputs, targets in loader:
        inputs, targets = inputs.to(device), targets.to(device)
        model.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        num_samples += 1
        for name, param in model.named_parameters():
            if param.requires_grad and param.grad is not None:
                fisher[name] += param.grad.data.clone().pow(2)

    for name in fisher:
        fisher[name] /= num_samples
    return fisher

print("Computing Fisher Information on forget subset...")
fisher_info = compute_fisher(model, forget_loader, criterion, device)


Computing Fisher Information on forget subset...


In [10]:

# ====================== Selective Synaptic Dampening ======================
def selective_synaptic_dampening(model, fisher_info, lmbda):
    with torch.no_grad():
        for name, param in model.named_parameters():
            if param.requires_grad:
                param.sub_(lmbda * fisher_info[name] * param)
    print("Applied SSD unlearning step.")

print("Performing SSD unlearning...")
selective_synaptic_dampening(model, fisher_info, ssd_lambda)

# ====================== Evaluate After Unlearning ======================
post_loss, post_acc = evaluate(model, test_loader, criterion, device)
print(f"After SSD unlearning - Test loss: {post_loss:.4f}, Test acc: {post_acc:.4f}")


Performing SSD unlearning...
Applied SSD unlearning step.
After SSD unlearning - Test loss: 2.1990, Test acc: 0.1690
