In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import torchvision.models as models
import time
import copy
import timm
from transformers import set_seed
set_seed(123456)

In [2]:
def load_cifar10_data(batch_size=256):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    train_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform)
    test_dataset = datasets.CIFAR10('./data', train=False, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader


def split_dataset(train_dataset, target_forget):
    labels = [train_dataset[i][1] for i in range(len(train_dataset))]

    if not isinstance(target_forget, list):
        target_forget = [target_forget]

    forget_indices = [i for i, label in enumerate(labels) if label in target_forget]
    retain_indices = [i for i, label in enumerate(labels) if label not in target_forget]

    forget_data = Subset(train_dataset, forget_indices)
    retain_data = Subset(train_dataset, retain_indices)

    return forget_data, retain_data

class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.resnet18 = models.resnet18(pretrained=False)
        self.resnet18.fc = nn.Linear(self.resnet18.fc.in_features, 10)

    def forward(self, x):
        return self.resnet18(x)

class StudentModel(TeacherModel):
    pass

def create_timm_model():
    # create ResNet18 
    model = timm.create_model("resnet18", pretrained=False)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    model.maxpool = nn.Identity()
    model.fc = nn.Linear(512, 10)
    
    return model

# train_model
def train_model(model, train_loader, test_loader, epochs=40, max_lr=0.01, grad_clip=0.1, weight_decay=0, device='cuda'):
    optimizer = torch.optim.Adam(model.parameters(), lr=max_lr, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=max_lr, steps_per_epoch=len(train_loader), epochs=epochs)

    for epoch in range(epochs):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()
            scheduler.step()  # uodate learning rate

    return model

# Function to get class-wise accuracy
def cifar10_class_wise_accuracy (model, data_loader, device='cuda'):
    class_correct = [0 for i in range(10)]
    class_total = [0 for i in range(10)]
    class_names = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    accuracies = {}

    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs, 1)
            c = (predicted == target).squeeze()
            for i in range(target.size(0)):
                label = target[i]
                if c.numel() > 1:  # Handle batch size of 1
                    class_correct[label] += c[i].item()
                else:
                    class_correct[label] += c.item()
                class_total[label] += 1

    for i in range(10):
        accuracy = 100 * class_correct[i] / class_total[i] if class_total[i] > 0 else 0
        accuracies[class_names[i]] = accuracy

    return accuracies

@torch.no_grad()
def actv_dist(model1, model2, dataloader, device = 'cuda'):
    sftmx = nn.Softmax(dim = 1)
    distances = []
    for batch in dataloader:
        x,_ = batch
        x = x.to(device)
        model1_out = model1(x)
        model2_out = model2(x)
        diff = torch.sqrt(torch.sum(torch.square(F.softmax(model1_out, dim = 1) - F.softmax(model2_out, dim = 1)), axis = 1))
        diff = diff.detach().cpu()
        distances.append(diff)
    distances = torch.cat(distances, axis = 0)
    return distances.mean()

In [3]:
def evaluate(model, test_loader, device='cuda'):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    return 100. * correct / len(test_loader.dataset)

def evaluate_instance_model_accuracy(model, test_loader, forget_loader, retain_loader, device):
    model.to(device)
    # Calculate accuracies on different datasets
    test_accuracy = evaluate(model, test_loader, device)
    forget_accuracy = evaluate(model, forget_loader, device)
    retain_accuracy = evaluate(model, retain_loader, device)
    
    # Print out the accuracies
    print(f"Test Loader Accuracy: {test_accuracy:.2f}%")
    print(f"Forget Loader Accuracy: {forget_accuracy:.2f}%")
    print(f"Retain Loader Accuracy: {retain_accuracy:.2f}%")
    

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   
# Load dataset and split dataset
train_loader, test_loader = load_cifar10_data()
# Select forgotten class
forget_class_idxs=[1]
forget_data, retain_data = split_dataset(train_loader.dataset, target_forget=forget_class_idxs)

# Create DataLoader for forget_data
forget_loader = DataLoader(forget_data, batch_size=256, shuffle=True)
retain_loader = DataLoader(retain_data, batch_size=256, shuffle=True)

forget_test_data, retain_test_data = split_dataset(test_loader.dataset, target_forget=forget_class_idxs)
forget_test_loader = DataLoader(forget_test_data, batch_size=256, shuffle=True)
retain_test_loader = DataLoader(retain_test_data, batch_size=256, shuffle=True)

Files already downloaded and verified


In [5]:
# load original model
original_model = create_timm_model().to(device)
original_model.load_state_dict(
            torch.hub.load_state_dict_from_url(
                      "https://huggingface.co/edadaltocg/resnet18_cifar10/resolve/main/pytorch_model.bin",
                       map_location="cpu", 
                       file_name="resnet18_cifar10.pth",
             )
)
original_model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): Identity()
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_

In [None]:
# Evaluate the original_model 
evaluate_instance_model_accuracy(original_model, test_loader, forget_loader, retain_loader, device)

Test Loader Accuracy: 94.98%
Forget Loader Accuracy: 100.00%
Retain Loader Accuracy: 100.00%


SU

In [7]:
def distill_with_soft_relabel(original_model, student_model, forget_loader, optimizer, forget_class_idxs, epochs=1, device='cuda', distill_temperature=4):
    original_model.to(device)
    student_model.to(device)
    epoch_losses = []  # Record loss for each epoch
    epoch_times = []   # Record training time for each epoch
    
    if not isinstance(forget_class_idxs, list):
        forget_class_idxs = [forget_class_idxs]

    for epoch in range(epochs):
        start_time = time.time()  # Start timing the training (excluding evaluation)
        student_model.train()
        batch_losses = []

        for batch_idx, (data, targets) in enumerate(forget_loader):
            data, targets = data.to(device), targets.to(device)

            with torch.no_grad():
                teacher_logits = original_model(data)
                for forget_class_idx in forget_class_idxs:
                    teacher_logits[:, forget_class_idx] = float('-inf')
                soft_labels = F.softmax(teacher_logits / distill_temperature, dim=1)
            
            student_logits = student_model(data)
            student_log_probs = F.log_softmax(student_logits / distill_temperature, dim=1)
            loss = F.kl_div(student_log_probs, soft_labels.detach(), reduction='batchmean')

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            batch_losses.append(loss.item())

        epoch_loss = sum(batch_losses) / len(batch_losses)
        epoch_losses.append(epoch_loss)
        
        end_time = time.time()  # End timing the training
        epoch_training_time = end_time - start_time
        epoch_times.append(epoch_training_time)

        # Evaluate model on forget_loader, not included in the training time
        student_model.eval()
        accuracy = evaluate(student_model, forget_loader, device)
        student_model.train()
        # print(f'Epoch {epoch+1}/{epochs}, Loss: {epoch_loss}, Training Time: {epoch_training_time:.2f} seconds, Accuracy: {accuracy:.2f}%')

        # Early stopping if accuracy reaches 0
        if accuracy == 0:
            # print(f"Stopping early at epoch {epoch + 1} due to zero accuracy on forget_loader.")
            break

    total_training_time = sum(epoch_times)
    # print(f"Total training time excluding evaluation: {total_training_time:.2f} seconds")

    return student_model, epoch_losses, total_training_time

In [8]:
# Initialize the unlearning model
SU_model = create_timm_model().to(device)
SU_model.load_state_dict(original_model.state_dict())
optimizer = torch.optim.Adam(SU_model.parameters(),lr=0.0001)
SU_model, distillation_losses, total_training_time = distill_with_soft_relabel(original_model, SU_model, forget_loader, optimizer, forget_class_idxs=forget_class_idxs, epochs=10, device='cuda')
evaluate_instance_model_accuracy(SU_model, test_loader, forget_loader, retain_loader, device)

Test Loader Accuracy: 82.52%
Forget Loader Accuracy: 0.00%
Retain Loader Accuracy: 98.20%


SULI

In [9]:
def calculate_entropy(logits):
    probabilities = F.softmax(logits, dim=1)
    log_probabilities = F.log_softmax(logits, dim=1)
    entropy = -(probabilities * log_probabilities).sum(dim=1)
    return entropy

def sort_data_loader_by_entropy(teacher_model, data_loader, device, batch_size_per_loader=200):
    data_with_entropy = []

    with torch.no_grad():
        for data, targets in data_loader:
            data = data.to(device)
            logits = teacher_model(data)
            entropy = calculate_entropy(logits)
            for i in range(len(data)):
                data_with_entropy.append((data[i].cpu(), targets[i], entropy[i].item()))

    # sort samples by entropy
    data_with_entropy.sort(key=lambda x: x[2], reverse=True)

    # slip subloaders
    loaders = []
    for i in range(0, len(data_with_entropy), batch_size_per_loader):
        batch = data_with_entropy[i:i+batch_size_per_loader]
        dataset = [(x[0], x[1]) for x in batch]  
        # create new DataLoader
        loader = DataLoader(dataset, batch_size=len(batch), shuffle=True)
        loaders.append(loader)

    return loaders

def SPA_Iteration_unlearning(teacher_model, student_model, forget_loader, optimizer, forget_class_idxs, epochs=10, device='cuda', distill_temperature=1):
    teacher_model.to(device)
    student_model.to(device)
    epoch_losses = [] 

    if not isinstance(forget_class_idxs, list):
        forget_class_idxs = [forget_class_idxs]
    
    for epoch in range(epochs):
        student_model.train()
        batch_losses = []

        for batch_idx, (data, targets) in enumerate(forget_loader):
            data, targets = data.to(device), targets.to(device)

            with torch.no_grad():
                teacher_logits = teacher_model(data)
                for forget_class_idx in forget_class_idxs:
                    teacher_logits[:, forget_class_idx] = float('-inf')
                soft_labels = F.softmax(teacher_logits / distill_temperature, dim=1)
            
            student_logits = student_model(data)
            student_log_probs = F.log_softmax(student_logits / distill_temperature, dim=1)
            loss = F.kl_div(student_log_probs, soft_labels.detach(), reduction='batchmean')

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            batch_losses.append(loss.item())

        epoch_loss = sum(batch_losses) / len(batch_losses)
        epoch_losses.append(epoch_loss)
        print(f'Epoch {epoch+1}/{epochs}, Loss: {epoch_loss}')

    return student_model, epoch_losses

def SelfUnlearning_Layered_Iteration(teacher_model, sorted_loaders, forget_class_idxs, forget_loader, epochs=10, device='cuda', distill_temperature=1, lr=0.01):
    current_teacher_model = copy.deepcopy(teacher_model).to(device)
    previous_accuracy = -1  # Initialize with an impossible value to ensure no rollback on the first run
    total_loaders = len(sorted_loaders)
    total_time_elapsed = 0

    for epoch in range(epochs):
        for loader_index, loader in enumerate(sorted_loaders):
            # Save the current state of the model before updates
            state_before_update = copy.deepcopy(current_teacher_model.state_dict())
            optimizer = torch.optim.Adam(current_teacher_model.parameters(), lr=lr, weight_decay=0)

            start = time.time()
            current_teacher_model, epoch_losses = SPA_Iteration_unlearning(current_teacher_model, current_teacher_model, loader, optimizer, forget_class_idxs, epochs=1, device=device, distill_temperature=distill_temperature)
            end = time.time()
            time_elapsed = end - start
            total_time_elapsed += time_elapsed

            forget_test_accuracy = evaluate(current_teacher_model, forget_loader, device=device)
            if forget_test_accuracy > previous_accuracy and previous_accuracy != -1:
                # Rollback to the previous state if current accuracy is higher than previous
                current_teacher_model.load_state_dict(state_before_update)
            else:
                # Update previous accuracy and accumulate loss if no rollback
                previous_accuracy = forget_test_accuracy
                

            if forget_test_accuracy == 0:
                # print("Early stopping triggered due to zero forget accuracy.")
                break
        else:
            continue
        break

    # print(f"Total time elapsed over all epochs and loaders: {total_time_elapsed} seconds")
    return current_teacher_model, total_time_elapsed

In [10]:
# Initialize the unlearning model
SULI_model = create_timm_model().to(device)
SULI_model.load_state_dict(original_model.state_dict())
# Sort the samples
sorted_loader = sort_data_loader_by_entropy(original_model, forget_loader, device, batch_size_per_loader=500)
SULI_model, total_time_elapsed = SelfUnlearning_Layered_Iteration(SULI_model, sorted_loaders=sorted_loader,forget_class_idxs=forget_class_idxs, forget_loader=forget_loader, epochs=10 ,device='cuda',lr=0.0001)
evaluate_instance_model_accuracy(SULI_model, test_loader, forget_loader, retain_loader, device)

Epoch 1/1, Loss: 0.1719745546579361
Epoch 1/1, Loss: 0.04842371121048927
Epoch 1/1, Loss: 0.014049886725842953
Epoch 1/1, Loss: 0.009383846074342728
Epoch 1/1, Loss: 0.0067747230641543865
Epoch 1/1, Loss: 0.004645343869924545
Epoch 1/1, Loss: 0.004092388320714235
Epoch 1/1, Loss: 0.0034357081167399883
Epoch 1/1, Loss: 0.0027459857519716024
Epoch 1/1, Loss: 0.002673782641068101
Epoch 1/1, Loss: 0.0016505284002050757
Test Loader Accuracy: 83.14%
Forget Loader Accuracy: 0.00%
Retain Loader Accuracy: 98.86%
