In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, Dataset, random_split
import numpy as np
from tqdm import tqdm
import random
import torchvision.models as models
import time
import torch.optim as optim
from sklearn.metrics import accuracy_score

import torch
import torch.nn.functional as F
from torch.utils.data import RandomSampler
from tqdm import tqdm
import numpy as np

from sklearn.metrics import accuracy_score
from sklearn.metrics import log_loss

import timm
from collections import defaultdict
from transformers import set_seed
set_seed(123456)

In [2]:
# load CIFAR-10 dataset
def load_cifar10_data(batch_size=256):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

    train_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform)
    test_dataset = datasets.CIFAR10('./data', train=False, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader

# define resnet model
class OriginalModel(nn.Module):
    def __init__(self):
        super(OriginalModel, self).__init__()
        self.resnet18 = models.resnet18(pretrained=False)
        self.resnet18.fc = nn.Linear(self.resnet18.fc.in_features, 10)

    def forward(self, x):
        return self.resnet18(x)

class UnlearnModel(OriginalModel):
    pass

# train model
def train_model(model, train_loader, test_loader, epochs=40, max_lr=0.01, grad_clip=0.1, weight_decay=0, device='cuda'):
    optimizer = torch.optim.Adam(model.parameters(), lr=max_lr, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=max_lr, steps_per_epoch=len(train_loader), epochs=epochs)

    for epoch in range(epochs):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()
            scheduler.step()  # update lr

    return model

@torch.no_grad()
def actv_dist(model1, model2, dataloader, device = 'cuda'):
    sftmx = nn.Softmax(dim = 1)
    distances = []
    for batch in dataloader:
        x,_ = batch
        x = x.to(device)
        model1_out = model1(x)
        model2_out = model2(x)
        diff = torch.sqrt(torch.sum(torch.square(F.softmax(model1_out, dim = 1) - F.softmax(model2_out, dim = 1)), axis = 1))
        diff = diff.detach().cpu()
        distances.append(diff)
    distances = torch.cat(distances, axis = 0)
    return distances.mean()

def evaluate(model, test_loader, device='cuda'):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    return 100. * correct / len(test_loader.dataset)

In [3]:
def create_single_forget_retain_loader(train_loader, size):
    # get whole dataset
    total_dataset = train_loader.dataset
    
    # random split dataset
    forget_subset, retain_subset = random_split(total_dataset, [size, len(total_dataset) - size])

    # create target dataLoader
    forget_loader = DataLoader(forget_subset, batch_size=train_loader.batch_size, shuffle=True)
    retain_loader = DataLoader(retain_subset, batch_size=train_loader.batch_size, shuffle=True)

    return forget_loader, retain_loader

In [4]:
def evaluate_instance_model_accuracy(model, test_loader, forget_loader, retain_loader, reference_loader, device):
    model.to(device)
    # Calculate accuracies on different datasets
    test_accuracy = evaluate(model, test_loader, device)
    forget_accuracy = evaluate(model, forget_loader, device)
    retain_accuracy = evaluate(model, retain_loader, device)
    reference_accuracy = evaluate(model, reference_loader, device)
    
    # Print out the accuracies
    print(f"Test Loader Accuracy: {test_accuracy:.2f}%")
    print(f"Forget Loader Accuracy: {forget_accuracy:.2f}%")
    print(f"Retain Loader Accuracy: {retain_accuracy:.2f}%")
    print(f"Reference Loader Accuracy: {reference_accuracy:.2f}%")


In [5]:

def split_dataset(train_loader, reference_ratio=0.1):
    class_indices = defaultdict(list)

    for idx, (_, target) in enumerate(train_loader.dataset):
        class_indices[target].append(idx)

    reference_indices = []
    train_indices = []

    for indices in class_indices.values():
        random.shuffle(indices)
        k = max(1, int(len(indices) * reference_ratio))  
        reference_indices.extend(indices[:k])
        train_indices.extend(indices[k:])

    reference_subset = Subset(train_loader.dataset, reference_indices)
    train_subset = Subset(train_loader.dataset, train_indices)

    reference_loader = DataLoader(reference_subset, batch_size=train_loader.batch_size, shuffle=True)
    new_train_loader = DataLoader(train_subset, batch_size=train_loader.batch_size, shuffle=True)

    return reference_loader, new_train_loader


In [6]:


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   
# load dataset and split dataset
train_loader, test_loader = load_cifar10_data()
reference_loader, new_train_loader = split_dataset(train_loader, reference_ratio=0.01)
forget_loader64, retain_loader64 = create_single_forget_retain_loader(new_train_loader, size=64)
forget_loader128, retain_loader128 = create_single_forget_retain_loader(new_train_loader, size=128)
forget_loader256, retain_loader256 = create_single_forget_retain_loader(new_train_loader, size=256)
forget_loader512, retain_loader512 = create_single_forget_retain_loader(new_train_loader,size=512)

Files already downloaded and verified


In [7]:
# select instance to unlearn
forget_loader = forget_loader64
retain_loader = retain_loader64

In [8]:
def train_and_evaluate(model, criterion, optimizer, scheduler, epochs, device,train_loader=new_train_loader):
    model.to(device)
    
    for epoch in range(epochs):
        model.train()
        for data, targets in train_loader:
            data, targets = data.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

        # Evaluate performance and update learning rate
        # Log the current learning rate
        current_lr = scheduler.optimizer.param_groups[0]['lr']
        print(f'Epoch {epoch+1}/{epochs}, Current LR: {current_lr}')

        # Evaluate performance and update learning rate
        test_accuracy = evaluate(model, test_loader, device)
        test_accuracy1 = evaluate(model, reference_loader, device)
        scheduler.step(test_accuracy)
        print(f'Epoch {epoch+1}/{epochs}, Test Accuracy: {test_accuracy:.2f}%, Reference Accuracy: {test_accuracy1:.2f}%')

    return model

In [9]:
original_model = timm.create_model("resnet18", pretrained=False)
original_model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
original_model.maxpool = nn.Identity()  # type: ignore
original_model.fc = nn.Linear(512,  10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(original_model.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0005, nesterov=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=3, threshold=0.001, mode='max')
original_model = train_and_evaluate(original_model, criterion, optimizer, scheduler, epochs=40, device='cuda',train_loader=new_train_loader)


Epoch 1/40, Current LR: 0.1
Epoch 1/40, Test Accuracy: 37.32%, Reference Accuracy: 36.80%
Epoch 2/40, Current LR: 0.1
Epoch 2/40, Test Accuracy: 61.62%, Reference Accuracy: 63.60%
Epoch 3/40, Current LR: 0.1
Epoch 3/40, Test Accuracy: 51.14%, Reference Accuracy: 54.80%
Epoch 4/40, Current LR: 0.1
Epoch 4/40, Test Accuracy: 67.56%, Reference Accuracy: 68.80%
Epoch 5/40, Current LR: 0.1
Epoch 5/40, Test Accuracy: 73.22%, Reference Accuracy: 74.00%
Epoch 6/40, Current LR: 0.1
Epoch 6/40, Test Accuracy: 74.62%, Reference Accuracy: 75.20%
Epoch 7/40, Current LR: 0.1
Epoch 7/40, Test Accuracy: 74.74%, Reference Accuracy: 78.20%
Epoch 8/40, Current LR: 0.1
Epoch 8/40, Test Accuracy: 80.12%, Reference Accuracy: 80.80%
Epoch 9/40, Current LR: 0.1
Epoch 9/40, Test Accuracy: 72.03%, Reference Accuracy: 72.20%
Epoch 10/40, Current LR: 0.1
Epoch 10/40, Test Accuracy: 78.75%, Reference Accuracy: 81.40%
Epoch 11/40, Current LR: 0.1
Epoch 11/40, Test Accuracy: 75.19%, Reference Accuracy: 74.60%
Epoch 

In [10]:
reference_acc = evaluate(original_model, reference_loader, device)
evaluate_instance_model_accuracy(original_model, test_loader, forget_loader, retain_loader, reference_loader, device)

Test Loader Accuracy: 86.39%
Forget Loader Accuracy: 100.00%
Retain Loader Accuracy: 100.00%
Reference Loader Accuracy: 88.80%


SU

In [11]:
def distill_with_soft_relabel(original_model, student_model, forget_loader, optimizer, epochs=1, device='cuda', distill_temperature=1):
    original_model.to(device)
    student_model.to(device)
    epoch_losses = []  
    epoch_times = []   
    epsilon = '-inf'
    
    for epoch in range(epochs):
        start_time = time.time()  # Start timing the training (excluding evaluation)
        student_model.train()
        batch_losses = []

        for batch_idx, (data, targets) in enumerate(forget_loader):
            data, targets = data.to(device), targets.to(device)

            with torch.no_grad():
                teacher_logits = original_model(data)
                # Find the index of maximum logit for each example and set it to -inf
                max_indices = torch.argmax(teacher_logits, dim=1)
                teacher_logits[torch.arange(teacher_logits.shape[0]), max_indices] = float(epsilon)
                soft_labels = F.softmax(teacher_logits / distill_temperature, dim=1)
            
            student_logits = student_model(data)
            student_log_probs = F.log_softmax(student_logits / distill_temperature, dim=1)
            loss = F.kl_div(student_log_probs, soft_labels.detach(), reduction='batchmean')

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            batch_losses.append(loss.item())

        epoch_loss = sum(batch_losses) / len(batch_losses)
        epoch_losses.append(epoch_loss)
        
        end_time = time.time()  # End timing the training
        epoch_training_time = end_time - start_time
        epoch_times.append(epoch_training_time)
        
        student_model.eval()  # Switch to evaluation mode for accuracy checking
        accuracy = evaluate(student_model, forget_loader, device)
        student_model.train()  # Switch back to training mode

        # Optionally, evaluate model on forget_loader (if desired, add evaluation code here)
        print(f'Epoch {epoch+1}/{epochs}, Loss: {epoch_loss}, Training Time: {epoch_training_time:.2f} seconds')
        
        if accuracy <= reference_acc:
            print(f"Stopping early at epoch {epoch + 1} due to accuracy on forget_loader.")
            break


    total_training_time = sum(epoch_times)
    print(f"Total training time: {total_training_time:.2f} seconds")

    return student_model, epoch_losses, total_training_time


In [13]:
SU_model = timm.create_model("resnet18", pretrained=False).to(device)
SU_model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
SU_model.maxpool = nn.Identity() 
SU_model.fc = nn.Linear(512, 10)
SU_model.load_state_dict(original_model.state_dict())
SU_model.to(device)

original_model.eval()

optimizer = torch.optim.Adam(
    SU_model.parameters(),
    lr=0.00006,
)
SU_model, distillation_losses, total_training_time = distill_with_soft_relabel(original_model, SU_model, forget_loader, optimizer, epochs=30, device='cuda', distill_temperature=1)

evaluate_instance_model_accuracy(SU_model, test_loader, forget_loader, retain_loader, reference_loader, device)

Epoch 1/30, Loss: 8.621769905090332, Training Time: 0.26 seconds
Epoch 2/30, Loss: 7.443272590637207, Training Time: 0.04 seconds
Epoch 3/30, Loss: 6.221786022186279, Training Time: 0.04 seconds
Epoch 4/30, Loss: 5.074117660522461, Training Time: 0.04 seconds
Stopping early at epoch 4 due to accuracy on forget_loader.
Total training time excluding evaluation: 0.37 seconds
Test Loader Accuracy: 84.88%
Forget Loader Accuracy: 68.75%
Retain Loader Accuracy: 99.90%
Reference Loader Accuracy: 88.60%


SULI

In [16]:
import copy
def calculate_entropy(logits):
    probabilities = F.softmax(logits, dim=1)
    log_probabilities = F.log_softmax(logits, dim=1)
    entropy = -(probabilities * log_probabilities).sum(dim=1)
    return entropy

def sort_data(original_model, data_loader, device, batch_size_per_loader=200):
    data_with_entropy = []

    with torch.no_grad():
        for data, targets in data_loader:
            data = data.to(device)
            logits = original_model(data)
            entropy = calculate_entropy(logits)
            for i in range(len(data)):
                data_with_entropy.append((data[i].cpu(), targets[i], entropy[i].item()))

    # sort from high to low
    data_with_entropy.sort(key=lambda x: x[2], reverse=True)

    # split dataset
    loaders = []
    for i in range(0, len(data_with_entropy), batch_size_per_loader):
        batch = data_with_entropy[i:i+batch_size_per_loader]
        dataset = [(x[0], x[1]) for x in batch]  
        loader = DataLoader(dataset, batch_size=len(batch), shuffle=True)
        loaders.append(loader)

    return loaders

def SPA_unlearning(original_model, student_model, forget_loader, optimizer, epochs=10, device='cuda', distill_temperature=1):
    original_model.to(device)
    student_model.to(device)
    epoch_losses = []  
    
    for epoch in range(epochs):
        student_model.train()
        batch_losses = []

        for batch_idx, (data, targets) in enumerate(forget_loader):
            data, targets = data.to(device), targets.to(device)

            with torch.no_grad():
                teacher_logits = original_model(data)
                max_indices = torch.argmax(teacher_logits, dim=1)
                teacher_logits[torch.arange(teacher_logits.size(0)), max_indices] = float('-inf')
                soft_labels = F.softmax(teacher_logits / distill_temperature, dim=1)
            
            student_logits = student_model(data)
            student_log_probs = F.log_softmax(student_logits / distill_temperature, dim=1)
            loss = F.kl_div(student_log_probs, soft_labels.detach(), reduction='batchmean')

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            batch_losses.append(loss.item())

        epoch_loss = sum(batch_losses) / len(batch_losses)
        epoch_losses.append(epoch_loss)
        print(f'Epoch {epoch+1}/{epochs}, Loss: {epoch_loss}')

    return student_model, epoch_losses

def train_with_sorted_loaders(original_model, sorted_loaders, forget_loader, epochs=10, device='cuda', distill_temperature=1, lr=0.0005):
    current_original_model = copy.deepcopy(original_model).to(device)
    student_model = copy.deepcopy(current_original_model).to(device)
    total_loss = 0  
    total_loaders = len(sorted_loaders)
    total_time_elapsed = 0  
    
    for epoch in range(epochs):
        for loader_index, loader in enumerate(sorted_loaders):

            optimizer = torch.optim.Adam(student_model.parameters(), lr=lr,)
            start = time.time()  
            student_model, epoch_losses = SPA_unlearning(current_original_model, student_model, loader, optimizer, epochs=1, device=device, distill_temperature=distill_temperature)
            end = time.time()  
            time_elapsed = end - start
            total_time_elapsed += time_elapsed  
            forget_test_accuracy = evaluate(student_model, forget_loader, device=device)

            # print(f"Loader {loader_index+1}/{total_loaders}, acc: {forget_test_accuracy}, Learning Rate: {lr}")
            
            current_original_model = student_model
            total_loss += sum(epoch_losses)  

            if forget_test_accuracy <= reference_acc:  
                # print("Early stopping triggered due to forget accuracy.")
                break
        else:
            continue
        break
    print(f"Total time elapsed: {total_time_elapsed} seconds")  

    return current_original_model, total_loss, total_time_elapsed

In [None]:
SULI_model = timm.create_model("resnet18", pretrained=False).to(device)
SULI_model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
SULI_model.maxpool = nn.Identity()  
SULI_model.fc = nn.Linear(512, 10)
SULI_model.load_state_dict(original_model.state_dict())
SULI_model.to(device)

sorted_loader = sort_data(original_model, forget_loader, device, batch_size_per_loader=32)
SULI_model, total_loss, total_time_elapsed = train_with_sorted_loaders(SULI_model, sorted_loaders=sorted_loader, forget_loader=forget_loader, epochs=10 ,device='cuda',
                                                                                 lr=0.00008)
print(f"Runtime: {total_time_elapsed}\n")
evaluate_instance_model_accuracy(SULI_model, test_loader, forget_loader, retain_loader, reference_loader, device)



Epoch 1/1, Loss: 6.582531929016113
Epoch 1/1, Loss: 8.874137878417969
Epoch 1/1, Loss: 3.642348527908325
Total time elapsed over all epochs and loaders: 0.2001972198486328 seconds
Runtime: 0.2001972198486328

Test Loader Accuracy: 85.12%
Forget Loader Accuracy: 85.94%
Retain Loader Accuracy: 99.84%
Reference Loader Accuracy: 88.20%
