In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.models import resnet18
import time

# ============================== DATA PREPROCESSING ==============================
# CIFAR-10 transformation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize for CIFAR-10
])

# Load CIFAR-10 dataset
train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# DataLoader with optimizations
train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)

# ============================ MODEL DEFINITION ============================
# Modified ResNet-18 for CIFAR-10
model = resnet18(weights=None)  # No pretrained weights
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)  # 3x3 kernel for CIFAR-10
model.maxpool = nn.Identity()  # Remove the maxpool layer
model.fc = nn.Linear(512, 10)  # 10 classes for CIFAR-10

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# =========================== TRAINING FUNCTION ===========================
def train(model, train_loader, criterion, optimizer, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            _, predicted = outputs.max(1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        print(f'Epoch {epoch+1}/{num_epochs}: Loss: {running_loss/len(train_loader):.4f}, '
              f'Accuracy: {100 * correct / total:.2f}%')

# ============================ EVALUATION FUNCTION ============================
def evaluate(model, test_loader, criterion):
    model.eval()
    correct = 0
    total = 0
    test_loss = 0.0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            _, predicted = outputs.max(1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    print(f'Test Loss: {test_loss/len(test_loader):.4f}, Accuracy: {100 * correct / total:.2f}%')
    return test_loss / len(test_loader), correct / total

# =============================== SSD UNLEARNING ===============================
def selective_synaptic_dampening(model, fisher_info, ssd_lambda):
    with torch.no_grad():
        for name, param in model.named_parameters():
            if param.requires_grad and name in fisher_info:
                param.sub_(ssd_lambda * fisher_info[name] * param)
    print("Applied SSD unlearning step.")

# ============================ TRAINING INITIATION ============================
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

# ========================== STEP 1: INITIAL TRAINING ==========================
print("[1] Initial training of classifier on FULL dataset...")
start_time = time.time()
train(model, train_loader, criterion, optimizer, num_epochs=20)
end_time = time.time()
initial_training_time = end_time - start_time
print(f"Time taken for initial training: {initial_training_time:.2f} seconds")

# ========================== STEP 2: SSD UNLEARNING ==========================
print("[2] SSD unlearning on classifier...")
# Compute Fisher Information (use your own method to compute)
# fisher_info = compute_fisher_information(model, train_loader)  # Implement this based on your setup

ssd_lambda = 0.0001  # SSD lambda (adjust as necessary)

start_time = time.time()
selective_synaptic_dampening(model, fisher_info, ssd_lambda)
end_time = time.time()
ssd_time = end_time - start_time
print(f"Time taken for SSD unlearning: {ssd_time:.2f} seconds")

# ========================== STEP 3: FINE-TUNING ============================
print("[3] Fine-tuning retained classes...")
optimizer_ft = optim.SGD(model.fc.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

# Fine-tuning on retained classes after SSD
train(model, train_loader, criterion, optimizer_ft, num_epochs=5)

# ============================ BASELINE RETRAINING ============================
print("[4] Retraining baseline classifier...")
optimizer_base = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

start_time = time.time()
train(model, train_loader, criterion, optimizer_base, num_epochs=20)
end_time = time.time()
retraining_time = end_time - start_time
print(f"Time taken for retraining: {retraining_time:.2f} seconds")

# ============================== FINAL EVALUATION =============================
print("[5] Evaluating final model...")
evaluate(model, test_loader, criterion)


[1] Initial training of classifier on FULL dataset...


KeyboardInterrupt: 