In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Subset
import numpy as np

# ====================== Configuration ======================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
learning_rate = 0.01        # reduced initial LR for stability
num_epochs = 20             # extend training for better convergence
ssd_lambda = 1e-4           # smaller dampening strength
forget_classes = [0, 1]     # CIFAR-10 classes to forget

# ====================== Data Preparation ======================
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])

full_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
full_test = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

# Subsample 5000 train and 1000 test images for speed
np.random.seed(42)
train_idx = np.random.permutation(len(full_train))[:5000]
test_idx  = np.random.permutation(len(full_test))[:1000]
train_dataset = Subset(full_train, train_idx)
test_dataset  = Subset(full_test,  test_idx)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, num_workers=4)

# Create forget subset
forget_indices = [i for i, (_, t) in enumerate(train_dataset) if t in forget_classes]
forget_loader  = DataLoader(Subset(train_dataset, forget_indices), batch_size=batch_size, shuffle=True, num_workers=4)

# ====================== Model, Loss, Optimizer ======================
model = models.resnet18(pretrained=False, num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)
# LR scheduler to decay LR midway
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=num_epochs//2, gamma=0.1)

# ====================== Train & Eval Functions ======================
def train_one_epoch(model, loader):
    model.train()
    total_loss = correct = total = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x.size(0)
        preds = out.argmax(1)
        correct += preds.eq(y).sum().item()
        total += y.size(0)
    return total_loss/total, correct/total


def evaluate(model, loader):
    model.eval()
    total_loss = correct = total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            total_loss += loss.item() * x.size(0)
            preds = out.argmax(1)
            correct += preds.eq(y).sum().item()
            total += y.size(0)
    return total_loss/total, correct/total

# ====================== Initial Training ======================
print(f"Training for {num_epochs} epochs | init LR={learning_rate}")
for epoch in range(1, num_epochs+1):
    train_loss, train_acc = train_one_epoch(model, train_loader)
    test_loss, test_acc   = evaluate(model, test_loader)
    print(f"Epoch {epoch}/{num_epochs}: "
          f"Train loss {train_loss:.4f}, acc {train_acc:.4f} | "
          f"Test loss {test_loss:.4f}, acc {test_acc:.4f}")
    scheduler.step()
    print(f" LR now: {scheduler.get_last_lr()[0]:.5f}")

# ====================== Compute Fisher Information ======================
def compute_fisher(model, loader):
    model.eval()
    fisher = {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad}
    samples = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        model.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        samples += 1
        for n, p in model.named_parameters():
            if p.requires_grad and p.grad is not None:
                fisher[n] += p.grad.data.clone().pow(2)
    return {n: v/samples for n, v in fisher.items()}

print("Computing Fisher on forget set...")
fisher_info = compute_fisher(model, forget_loader)

# ====================== Selective Synaptic Dampening ======================
def selective_synaptic_dampening(model, fisher_info, lmbda):
    with torch.no_grad():
        for n, p in model.named_parameters():
            if p.requires_grad:
                p.sub_(lmbda * fisher_info[n] * p)
    print("SSD step applied.")

# Multi-step SSD and fine-tuning
grain_steps = 5
print(f"Performing {grain_steps} SSD steps at λ={ssd_lambda}")
for _ in range(grain_steps):
    selective_synaptic_dampening(model, fisher_info, ssd_lambda)

# ====================== Post-Unlearning Fine-Tuning ======================
retain_idx = [i for i, (_, t) in enumerate(train_dataset) if t not in forget_classes]
retain_loader = DataLoader(Subset(train_dataset, retain_idx), batch_size=batch_size, shuffle=True, num_workers=4)
print("Fine-tuning on retained classes...")
for ft in range(3):
    loss_ft, acc_ft = train_one_epoch(model, retain_loader)
    print(f" Fine-tune epoch {ft+1}: loss {loss_ft:.4f}, acc {acc_ft:.4f}")

# ====================== Final Evaluation ======================
final_loss, final_acc = evaluate(model, test_loader)
print(f"Final Test loss: {final_loss:.4f}, acc: {final_acc:.4f}")




Training for 20 epochs | init LR=0.01
Epoch 1/20: Train loss 2.2486, acc 0.2008 | Test loss 2.3602, acc 0.2140
 LR now: 0.01000
Epoch 2/20: Train loss 2.0202, acc 0.2760 | Test loss 1.9485, acc 0.3140
 LR now: 0.01000
Epoch 3/20: Train loss 1.8255, acc 0.3370 | Test loss 1.8377, acc 0.3340
 LR now: 0.01000
Epoch 4/20: Train loss 1.8017, acc 0.3474 | Test loss 1.7562, acc 0.3820
 LR now: 0.01000
Epoch 5/20: Train loss 1.6966, acc 0.3894 | Test loss 1.6347, acc 0.3970
 LR now: 0.01000
Epoch 6/20: Train loss 1.6176, acc 0.4098 | Test loss 1.5772, acc 0.4330
 LR now: 0.01000
Epoch 7/20: Train loss 1.5467, acc 0.4324 | Test loss 1.6350, acc 0.4230
 LR now: 0.01000
Epoch 8/20: Train loss 1.5151, acc 0.4362 | Test loss 1.5792, acc 0.4220
 LR now: 0.01000
Epoch 9/20: Train loss 1.4976, acc 0.4486 | Test loss 1.7135, acc 0.4390
 LR now: 0.01000
Epoch 10/20: Train loss 1.4775, acc 0.4658 | Test loss 1.5977, acc 0.4350
 LR now: 0.00100
Epoch 11/20: Train loss 1.3313, acc 0.5268 | Test loss 1.3829