In [None]:
import copy
import pandas as pd
import torch
import torchvision
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from sklearn import linear_model, model_selection

from PyTorch_CIFAR10.cifar10_models.resnet import resnet18


DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 

In [None]:
def load_cifar10():
    default_cifar10_train_transform = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
        ),
    ]
    )
    default_cifar10_eval_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
        ),
    ]
    )
    
    train = torchvision.datasets.CIFAR10(root="/DATA/data", train=True, transform=default_cifar10_train_transform)
    test = torchvision.datasets.CIFAR10(root="/DATA/data", train=False, transform=default_cifar10_eval_transform)
    return train, test


@torch.no_grad()
def evaluate(loader, model):
    model.eval()
    total = 0
    correct = 0
    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        out = model(x)
        pred = out.argmax(1)
        total += len(x)
        correct += (pred == y).float().sum()
    
    return correct / total
    

In [None]:
def finetune(
    net, 
    loader, 
    epochs=10,
    weight_decay=5e-4,
    lr=0.001,
    momentum=0.,
    use_scheduler=True,
    ):
    """Simple unlearning by finetuning."""
    total_iters = epochs * len(loader)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=lr,
                      momentum=momentum, weight_decay=weight_decay)
    
    if use_scheduler:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_iters)
    net.train()

    for ep in range(epochs):
        net.train()
        for inputs, targets in loader:
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            if use_scheduler:
                scheduler.step()
        
    net.eval()

In [None]:
# Model arithmetic

def get_flat(model):
    return torch.cat([p.view(-1) for p in model.state_dict().values()])

def from_flat(flat_params, model, in_place=False):
    # Loads params from flat vector
    index = 0
    
    if not in_place:
        model = copy.deepcopy(model)
        
    sd = model.state_dict()
    for k, p in sd.items():
        prodshape = torch.prod(torch.tensor(p.shape))
        if prodshape == 1:
            sd[k] = flat_params[index].type(p.data.type()).view(p.shape)
        else:
            sd[k] = flat_params[index : index + prodshape].view(*p.shape)
        index += int(prodshape)
    model.load_state_dict(sd)
    return model

def add_random_noise(model, strength=0.2):
    noise_model = copy.deepcopy(model)
    for name, module in noise_model.named_modules():
        if hasattr(module, "reset_parameters"):
            module.reset_parameters()
            
    for param, noise in zip(model.parameters(), noise_model.parameters()):
        param.data += strength * noise


def model_op(net1, net2, operator=lambda x, y: x + y):
    return_model = copy.deepcopy(net1)
    state_dict = {}
    for (n1, p1), (n2, p2) in zip(net1.state_dict().items(), net2.state_dict().items()):
        state_dict[n1] = operator(p1, p2)
    return_model.load_state_dict(state_dict)
    return return_model

def arithmetic_unlearning(model, model_ft_forget, power=0.1):
    flat_1, flat_2 = get_flat(model), get_flat(model_ft_forget)
    task_vector = flat_2 - flat_1
    #task_vector = model_op(model, model_ft_forget, lambda x, y: y - x)
    #unlearned_model = model_op(model, task_vector, lambda x, y: x - power*y)
    unlearned_model = from_flat(flat_1 - power * task_vector, model)
    return unlearned_model

def arithmetic_unlearning_2(model, model_ft_forget, model_ft_retain, power=0.1):
    flat_1, flat_2, flat_3 = get_flat(model), get_flat(model_ft_forget), get_flat(model_ft_retain)
    
    vect_retain = flat_3 - flat_1
    vect_forget = flat_2 - flat_1
    
    po_vr_vf = (torch.dot(vect_retain, vect_forget) / torch.dot(vect_forget, vect_forget)) * vect_forget
    
    return from_flat(flat_1 + power*(vect_retain - po_vr_vf), model)


@torch.no_grad()
def adapt_bs(model, loader, num_epochs=2):
    model.train()
    for i in range(num_epochs):
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            model(x)
    model.eval()
        

In [None]:
# Evaluation

@torch.no_grad()
def collect_losses(loader, model, device="cuda"):
    model.eval()
    all_losses = []
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        out = model(x)
        losses = F.cross_entropy(out, y, reduction="none")
        all_losses.append(losses)
    return torch.cat(all_losses)
        

def simple_mia(sample_loss, members, n_splits=10, random_state=0):
    """Computes cross-validation score of a membership inference attack.

    Args:
      sample_loss : array_like of shape (n,).
        objective function evaluated on n samples.
      members : array_like of shape (n,),
        whether a sample was used for training.
      n_splits: int
        number of splits to use in the cross-validation.
    Returns:
      scores : array_like of size (n_splits,)
    """

    attack_model = linear_model.LogisticRegression()
    cv = model_selection.StratifiedShuffleSplit(
        n_splits=n_splits, random_state=random_state
    )
    return model_selection.cross_val_score(
        attack_model, sample_loss, members, cv=cv, scoring="accuracy"
    )

def run_attack(forget_loader, test_loader, model_to_test):

    ft_forget_losses = collect_losses(forget_loader, model_to_test).cpu().numpy()
    ft_test_losses = collect_losses(test_loader, model_to_test).cpu().numpy()
    
    # Subsampling to have class balanced (member, non member)
    
    if len(ft_forget_losses) > len(ft_test_losses):
        np.random.shuffle(ft_forget_losses)
        ft_forget_losses = ft_forget_losses[:len(ft_test_losses)]
    else:
        np.random.shuffle(ft_test_losses)
        ft_test_losses = ft_test_losses[:len(ft_forget_losses)]
    
    samples_mia_ft = np.concatenate((ft_test_losses, ft_forget_losses)).reshape((-1, 1))
    labels_mia = [0] * len(ft_test_losses) + [1] * len(ft_forget_losses)
    
    mia_scores_ft = simple_mia(samples_mia_ft, labels_mia)
    
    print(
        f"The MIA attack has an accuracy of {mia_scores_ft.mean():.3f} on forgotten vs unseen images"
    )
    
    return mia_scores_ft.mean()

In [None]:
def finetune_project(
    net,
    retain_loader,
    vect_forget,
    epochs=10,
    weight_decay=5e-4,
    lr=0.001,
    momentum=0.,
    use_scheduler=True,
    ):

    total_iter_number = len(retain_loader) * epochs

    new_net = copy.deepcopy(net)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(new_net.parameters(), lr=lr,
                      momentum=momentum, weight_decay=weight_decay)
    
    if use_scheduler:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_iter_number)

    flat_initial = get_flat(net)

    new_net.train()

    for ep in range(epochs):
        new_net.train()
        for inputs, targets in retain_loader:
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        
            optimizer.zero_grad()
            outputs = new_net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            # Remove forget component of the update
            flat_current = get_flat(new_net) - flat_initial
            po_vr_vf = (torch.dot(flat_current, vect_forget) / torch.dot(vect_forget, vect_forget)) * vect_forget

            flat_current = flat_initial + flat_current - po_vr_vf
            from_flat(flat_current, new_net, in_place=True)

        
            if use_scheduler:
                scheduler.step()
        
    new_net.eval()
    return new_net

In [None]:
# Load model and compute acc on test set

model = resnet18(pretrained=True)
model = model.to(DEVICE)
train, test = load_cifar10()

shuffled_indices = torch.randperm(len(train))

LEN_FORGET = 2000
LEN_RETAIN = len(train) - LEN_FORGET

forget_set = torch.utils.data.Subset(train, shuffled_indices[:LEN_FORGET])
retain_set = torch.utils.data.Subset(train, shuffled_indices[LEN_FORGET:])

forget_loader = torch.utils.data.DataLoader(forget_set, batch_size=64, shuffle=True)
retain_loader = torch.utils.data.DataLoader(retain_set, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True)

run_attack(forget_loader, test_loader, model)

In [None]:
# Method 1: Finetuning on retain set

model_unlearned = copy.deepcopy(model)
add_random_noise(model_unlearned, strength=0.05)
finetune(model_unlearned, retain_loader, epochs=1, momentum=0.9, lr=0.001)

run_attack(forget_loader, test_loader, model_unlearned)
acc = evaluate(test_loader, model_unlearned)
print("Test Accuracy", float(acc.cpu()))

In [None]:
# Method 2: Finetune on forget set and retire lambda * forget_set_task_vector

model_ft_forget = copy.deepcopy(model)
finetune(model_ft_forget, forget_loader, epochs=50)

for power in np.linspace(0, 1.0, 20):
    print(power)
    model_unlearned = arithmetic_unlearning(model, model_ft_forget, power)
    run_attack(forget_loader, test_loader, model_unlearned)
    acc = evaluate(test_loader, model_unlearned)
    print("Test Accuracy", float(acc.cpu()))


In [None]:
# Method 3: Finetune on both retain and forget and move in direction of model + retain_vector - proj(retain_vector, forget_vector)

model_ft_forget = copy.deepcopy(model)
finetune(model_ft_forget, forget_loader, epochs=50)

model_ft_retain = copy.deepcopy(model)
finetune(model_ft_retain, retain_loader, epochs=10)

for power in np.linspace(0, 0.5, 20):
    print(power)
    model_unlearned = arithmetic_unlearning_2(model, model_ft_forget, model_ft_retain, power)
    run_attack(forget_loader, test_loader, model_unlearned)
    acc = evaluate(test_loader, model_unlearned)
    print("Test Accuracy", float(acc.cpu()))


In [None]:
# Check linear connectivity

import matplotlib.pyplot as plt

accuracies = []

model_ft_retain = copy.deepcopy(model)
finetune(model_ft_retain, retain_loader, epochs=10)

flat_1 = get_flat(model)
flat_2 = get_flat(model_ft_retain)
vect = flat_2 - flat_1

powers = np.linspace(0, 1., 20)

for power in powers:
    model_inter = from_flat(flat_1 + power*vect, model)
    #adapt_bs(model_inter, forget_loader)
    acc = evaluate(test_loader, model_inter)
    accuracies.append(float(acc.cpu()))
    
plt.plot(powers, accuracies)

In [None]:
print(len(retain_loader))

In [None]:
# Method 4: Finetune on forget with an additional logit and change forget set target to that logit

In [None]:
# Method 5 Finetune on retain but always remove forget component from the update (project on subspace orthogonal to forget vector at every update)

model_ft_forget = copy.deepcopy(model)
finetune(model_ft_forget, forget_loader, epochs=10, momentum=0.)

vect_forget = get_flat(model_ft_forget) - get_flat(model)

model_unlearned = finetune_project(
    model,
    retain_loader,
    vect_forget,
    epochs=2,
    weight_decay=5e-4,
    lr=0.001,
    momentum=0.9,
    use_scheduler=True,
    )

run_attack(forget_loader, test_loader, model_unlearned)
acc = evaluate(test_loader, model_unlearned)
print("Test Accuracy: ", acc)

In [None]:
acc_model = evaluate(test_loader, model, device=DEVICE)
acc_unlearned = evaluate(test_loader, model_unlearned, device=DEVICE)

print(f"Accuracy of base model: {acc_model}")
print(f"Accuracy of unlearned model: {acc_unlearned}")