## Setup

In [3]:
import os
from utils import *
from agents import *
import time
import torch
import torch.nn as nn
import copy
import torch.nn.functional as F
from copy import deepcopy
import argparse
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from models import resnet18
import random
import timm
import math
from ov_utils import *
from proto_utils import *

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size',     type=int,       default=128)
parser.add_argument('--dataset', type=str, default='cifar10')
parser.add_argument('--model', type=str, default='resnet18')
args = parser.parse_args("")


exp_path = f'checkpoints_{args.dataset}'
device = 'cuda'
seed_everything(42) # Choose the same seed with baseline trainer

if args.dataset == 'cifar100':
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
    ])

elif args.dataset == 'cifar10':
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
    ])

if args.dataset == 'cifar100':
    trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
    testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

elif args.dataset == 'cifar10':
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    
train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)

if args.dataset != 'cifar10':
    num_unlearn_classes = 10
else:
    num_unlearn_classes = 1

if args.dataset == 'cifar100':
    all_classes = list(range(100))
    num_classes = 100
elif args.dataset == 'cifar10':
    all_classes = list(range(10))
    num_classes = 10
    
unlearn_classes = [0] # Use same unlearn classes with excluded labels
remain_classes = [cls for cls in all_classes if cls not in unlearn_classes]
print("unlearn classes:", unlearn_classes)
args.unlearn_class = unlearn_classes

attack_subset, unlearn_attack_loader, remain_attack_loader, counts = create_attack_loaders(trainset, 100, unlearn_classes, remain_classes, args.batch_size, shuffle=True, seed=42)

unlearn_indices = [i for i, target in enumerate(trainset.targets) if target in unlearn_classes]
remain_indices  = [i for i, target in enumerate(trainset.targets) if target not in unlearn_classes]
test_unlearn_indices = [i for i, target in enumerate(testset.targets) if target in unlearn_classes]
test_remain_indices  = [i for i, target in enumerate(testset.targets) if target not in unlearn_classes]

unlearn_trainset = Subset(trainset, unlearn_indices)
remain_trainset = Subset(trainset, remain_indices)
unlearn_testset = Subset(testset, test_unlearn_indices)
remain_testset = Subset(testset, test_remain_indices)

unlearn_train_loader = DataLoader(unlearn_trainset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)
remain_train_loader  = DataLoader(remain_trainset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)
unlearn_test_loader = DataLoader(unlearn_testset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)
remain_test_loader  = DataLoader(remain_testset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)

Files already downloaded and verified
Files already downloaded and verified
unlearn classes: [0]
Attack subset size: 50000
  → Unlearn class samples: 5000
  → Remain class samples: 45000


In [3]:
# Original Model Load
model = resnet18(num_classes=num_classes)
model.load_state_dict(torch.load(f'{exp_path}/resnet18_{args.dataset}_best.pth'))
model = model.to(device)
criterion = nn.CrossEntropyLoss()



## Perturbed Set Generation

In [4]:
adv_test_soft_loader = create_adversarial_loader_pgd_multi(
            model,
            unlearn_train_loader,
            epsilon=0.03,
            step_size=0.01,
            num_steps=3,
            num_adv_per_sample=1,
            device=device
)

gaussian_test_soft_loader = create_adversarial_loader_gaussian(
            model,
            unlearn_train_loader,
            noise_std=0.01,
            num_adv_per_sample=1,
            device='cuda'
)

adv_train_loader = create_adversarial_loader_pgd_multi(
            model,
            unlearn_train_loader,
            epsilon=0.03,
            step_size=0.01,
            num_steps=3,
            num_adv_per_sample=1,
            device=device
)

## Original-Eval

In [5]:
ul = 'Original'
model.eval()
evaluate(model, unlearn_train_loader, "Forget Acc.", device, criterion)
evaluate(model, remain_train_loader, "Retain Acc.", device, criterion)
evaluate(model, unlearn_test_loader, "Test-Forget Acc.", device, criterion)
evaluate(model, remain_test_loader, "Test-Retain Acc.", device, criterion)
models_Fu = {}
models_Fu[ul] = model

# Prototypical Relearning Attack
eval_model = update_fc_with_prototypes(
    model,
    unlearn_attack_loader,
    unlearn_classes,
    num_samples_per_class=5,
    metric='cosine',
    device='cuda'
)

evaluate(eval_model, unlearn_train_loader, "Forget Acc.", device, criterion)
evaluate(eval_model, remain_train_loader, "Retain Acc.", device, criterion)

print(compute_dispersion_loss(eval_model, unlearn_train_loader))

[Forget Acc.] Loss: 0.001, Accuracy: 100.00% (5000/5000)
[Retain Acc.] Loss: 0.001, Accuracy: 100.00% (45000/45000)
[Test-Forget Acc.] Loss: 0.179, Accuracy: 94.70% (947/1000)
[Test-Retain Acc.] Loss: 0.238, Accuracy: 93.98% (8458/9000)
{'Original': 0.6213402444839478}
{'Original': 0.5742272691726684}
[Forget Acc.] Loss: 0.001, Accuracy: 100.00% (5000/5000)
[Retain Acc.] Loss: 0.001, Accuracy: 100.00% (45000/45000)
tensor(0.9177, device='cuda:0')


## Retrain-Eval

In [15]:
ul = 'retrain'
eval_model = resnet18(num_classes=num_classes).to(device)
eval_model.load_state_dict(torch.load(f'{exp_path}/resnet18_{args.dataset}_{ul}_best.pth'))

evaluate(eval_model, unlearn_train_loader, "Forget Acc.", device, criterion)
evaluate(eval_model, remain_train_loader, "Retain Acc.", device, criterion)
evaluate(eval_model, unlearn_test_loader, "Test-Forget Acc.", device, criterion)
evaluate(eval_model, remain_test_loader, "Test-Retain Acc.", device, criterion)
models_Fu = {}
models_Fu[ul] = eval_model
overunlearning_scores = compare_unlearning_methods(model, models_Fu, adv_test_soft_loader, unlearn_classes, 'JS', device=device)
print("PGD OU@epsilon :", overunlearning_scores[ul])
overunlearning_scores = compare_unlearning_methods(model, models_Fu, gaussian_test_soft_loader, unlearn_classes, 'JS', device=device)
print("Gaussian OU@epsilon :", overunlearning_scores[ul])

# Prototypical Relearning Attack
eval_model = update_fc_with_prototypes(
        eval_model,
        unlearn_attack_loader,
        unlearn_classes,
        num_samples_per_class=5,
        metric='cosine',
        device='cuda',
        alpha=1
)

evaluate(eval_model, unlearn_train_loader, "Proto-Forget Acc.", device, criterion)
evaluate(eval_model, remain_train_loader, "Retain^* Acc.", device, criterion)

print(compute_dispersion_loss(eval_model, unlearn_train_loader))

[Forget Acc.] Loss: 9.394, Accuracy: 0.00% (0/5000)
[Retain Acc.] Loss: 0.001, Accuracy: 100.00% (45000/45000)
[Test-Forget Acc.] Loss: 9.547, Accuracy: 0.00% (0/1000)
[Test-Retain Acc.] Loss: 0.218, Accuracy: 94.71% (8524/9000)
PGD OU@epsilon : 0.23585119132995605
Gaussian OU@epsilon : 0.17794286270141602
[Proto-Forget Acc.] Loss: 1.081, Accuracy: 56.02% (2801/5000)
[Retain^* Acc.] Loss: 0.016, Accuracy: 99.91% (44961/45000)
tensor(0.6909, device='cuda:0')


## Spotter

In [14]:
ul = 'Spotter'
model.eval()
mu_model = copy.deepcopy(model)

optimizer = torch.optim.SGD(mu_model.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4)
num_epochs = 10
lambda_1 = 0.7
lambda_2 = 1

for epoch in range(num_epochs):
    mu_model.train()

    for m in mu_model.modules():
        if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
            m.eval()
        
    total_loss = 0.0
    ul_loss = 0.0
    adv_loss = 0.0
    disp_loss = 0.0

    for (images, labels), (adv_images, adv_labels) in zip(unlearn_train_loader, adv_train_loader):
        images = images.to(device)
        adv_images = adv_images.to(device)
        labels = labels.to(device)  # 잊어야 할 클래스 레이블
        adv_labels = adv_labels.to(device)
        
        optimizer.zero_grad()

        with torch.no_grad():
            teacher_logits = model(images)
            adv_teacher_logits = model(adv_images)
            teacher_probs = F.softmax(teacher_logits, dim=1)
            adv_teacher_probs = F.softmax(adv_teacher_logits, dim=1)
            teacher_target = mask_full_target_soft(teacher_probs, unlearn_classes)  # \tilde{p}(x)
            adv_teacher_target = mask_full_target_soft(adv_teacher_probs, unlearn_classes)
        
        # mu_model (student)의 출력
        mu_logits = mu_model(images)
        adv_mu_logits = mu_model(adv_images)
        mu_log_probs = F.log_softmax(mu_logits, dim=1)
        adv_mu_log_probs = F.log_softmax(adv_mu_logits, dim=1)

        loss_unlearn = F.kl_div(mu_log_probs, teacher_target, reduction='batchmean')
        adv_loss_unlearn = F.kl_div(adv_mu_log_probs, adv_teacher_target, reduction='batchmean')
        
        features = get_features(mu_model, images)
        loss_disp = dispersion_loss(features, labels, metric='cosine')

        loss = lambda_1 * loss_unlearn + (1 - lambda_1) * adv_loss_unlearn + lambda_2 * loss_disp
        loss.backward()
        optimizer.step()

        ul_loss += loss_unlearn.item()
        adv_loss += adv_loss_unlearn.item()
        disp_loss += loss_disp.item()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs} - Total Loss: {total_loss:.4f} - UL Loss: {ul_loss:.4f} - Adv Loss: {adv_loss:.4f} - Disp Loss: {disp_loss:.4f}")
    
torch.save(mu_model.state_dict(), f"{exp_path}/{args.model}_{args.dataset}_{ul}.pth")

Epoch 1/10 - Total Loss: 64.7059 - UL Loss: 30.9605 - Adv Loss: 33.1906 - Disp Loss: 33.0763
Epoch 2/10 - Total Loss: 34.4623 - UL Loss: 10.3102 - Adv Loss: 9.8214 - Disp Loss: 24.2988
Epoch 3/10 - Total Loss: 25.3921 - UL Loss: 10.0319 - Adv Loss: 7.4803 - Disp Loss: 16.1257
Epoch 4/10 - Total Loss: 21.1686 - UL Loss: 9.4888 - Adv Loss: 6.1974 - Disp Loss: 12.6672
Epoch 5/10 - Total Loss: 18.9043 - UL Loss: 9.0272 - Adv Loss: 5.5101 - Disp Loss: 10.9322
Epoch 6/10 - Total Loss: 17.3590 - UL Loss: 8.5898 - Adv Loss: 5.0265 - Disp Loss: 9.8383
Epoch 7/10 - Total Loss: 16.3054 - UL Loss: 8.2322 - Adv Loss: 4.7764 - Disp Loss: 9.1100
Epoch 8/10 - Total Loss: 15.4909 - UL Loss: 7.9031 - Adv Loss: 4.5481 - Disp Loss: 8.5943
Epoch 9/10 - Total Loss: 14.9105 - UL Loss: 7.7533 - Adv Loss: 4.4172 - Disp Loss: 8.1580
Epoch 10/10 - Total Loss: 14.3139 - UL Loss: 7.4897 - Adv Loss: 4.2707 - Disp Loss: 7.7899


In [18]:
ul = 'Spotter'
eval_model = resnet18(num_classes=num_classes).to(device)
eval_model.load_state_dict(torch.load(f"{exp_path}/{args.model}_{args.dataset}_{ul}.pth"))
evaluate(eval_model, unlearn_train_loader, "Forget Acc.", device, criterion)
evaluate(eval_model, remain_train_loader, "Retain Acc.", device, criterion)
evaluate(eval_model, unlearn_test_loader, "Test-Forget Acc.", device, criterion)
evaluate(eval_model, remain_test_loader, "Test-Retain Acc.", device, criterion)
models_Fu = {}
models_Fu[ul] = eval_model

overunlearning_scores = compare_unlearning_methods(model, models_Fu, adv_test_soft_loader, unlearn_classes, 'JS', device=device)
print("PGD OU@epsilon :", overunlearning_scores[ul])
overunlearning_scores = compare_unlearning_methods(model, models_Fu, gaussian_test_soft_loader, unlearn_classes, 'JS', device=device)
print("Gaussian OU@epsilon :", overunlearning_scores[ul])

# Prototypical Relearning Attack
eval_model = update_fc_with_prototypes(
        eval_model,
        unlearn_attack_loader,
        unlearn_classes,
        num_samples_per_class=5,
        metric='cosine',
        device='cuda',
        alpha=0.8
)

evaluate(eval_model, unlearn_train_loader, "Proto-Forget Acc.", device, criterion)
evaluate(eval_model, remain_train_loader, "Retain^* Acc.", device, criterion)

print(compute_dispersion_loss(eval_model, unlearn_train_loader))

[Forget Acc.] Loss: 3.103, Accuracy: 0.00% (0/5000)
[Retain Acc.] Loss: 0.009, Accuracy: 99.98% (44989/45000)
[Test-Forget Acc.] Loss: 3.551, Accuracy: 0.00% (0/1000)
[Test-Retain Acc.] Loss: 0.210, Accuracy: 93.82% (8444/9000)
PGD OU@epsilon : 0.02864329786300659
Gaussian OU@epsilon : 0.026672714617848398
[Proto-Forget Acc.] Loss: 2.253, Accuracy: 0.10% (5/5000)
[Retain^* Acc.] Loss: 0.013, Accuracy: 99.98% (44989/45000)
tensor(0.1901, device='cuda:0')
