# Inverting gradient attack

In [None]:
from __future__ import annotations

from copy import deepcopy
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import pandas as pd
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from torch.nn.modules.loss import _Loss, CrossEntropyLoss
from torch.optim import Optimizer, SGD, Adam, AdamW
from torch.utils.data import Dataset, DataLoader, Subset, TensorDataset
from torchmetrics.classification import MulticlassAccuracy
from image_classification.utils import trange

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

print(device)

## ResNet-18

For quick prototyping, also consider using `ShuffleNetV2`, a 300M-parameter model that is much smaller than `ResNet18`. Experiments can be made on both models or on only one of them.

In [None]:
from image_classification.models import ResNet18, ShuffleNetV2
from image_classification.datasets import cifar10_train_test, cifar100_train_test
from image_classification.nn import train_loop, train_val_loop, test_epoch

In [None]:
# Set to 10 for CIFAR-10, 100 for CIFAR-100
num_classes = 10

# The images are already normalized by theses datasets
if num_classes == 10:
    get_train_test = cifar10_train_test
elif num_classes == 100:
    get_train_test = cifar100_train_test
else:
    raise ValueError(f"Can't find CIFAR dataset with {num_classes} classes")
print(f"Loading CIFAR-{num_classes}")

training_data, test_data = get_train_test(root='data')
N_test = len(test_data)
N_val = len(training_data) // 10
N_aux = N_val
N = len(training_data) - N_val - N_aux
# This works since training data is already shuffled
training_data, val_data, aux_data = training_data.split([N, N_val, N_aux])

batch_size = 100
N, N_val, N_aux, N_test

In [None]:
train_loader = DataLoader(training_data, batch_size, drop_last=True)
val_loader = DataLoader(val_data, batch_size, drop_last=True)
aux_loader = DataLoader(aux_data, batch_size, drop_last=True)

### Hyperparameters

In [None]:
lr = 1e-3
weight_decay = 5e-4
# For learning rate scheduling
max_lr = 0.1

epochs = 6
steps_per_epoch = N // batch_size

lr_sched_params = dict(max_lr=max_lr, epochs=epochs, steps_per_epoch=steps_per_epoch)

criterion = CrossEntropyLoss()

top_k = {10: 1, 100: 5}[num_classes]

metric = MulticlassAccuracy(num_classes=num_classes, top_k=top_k)

### Optimizer

In [None]:
def make_optimizer(model: nn.Module, opt_name='adamw', lr=lr, weight_decay=weight_decay, **kwargs) -> Optimizer:
    cls = {'sgd': SGD, 'adam': Adam, 'adamw': AdamW}[opt_name]
    return cls(model.parameters(), lr=lr, weight_decay=weight_decay, **kwargs)

## Inverting gradient attack

In [None]:
from importlib import reload
import image_classification.gradient_attack
reload(image_classification.gradient_attack)

In [None]:
from image_classification.gradient_attack import (
    GradientAttack,
    GradientEstimator, OmniscientGradientEstimator, ShadowGradientEstimator,
    SampleInit, SampleInitRandomNoise,
    GradientInverter,
    Schedule, NeverUpdate
)

In [None]:
net = ResNet18(num_classes=num_classes).to(device)
opt = make_optimizer(net, opt_name='adam', lr=lr)
# Pretrain the model to make it learn the features
mini_train_set = Subset(training_data, np.arange(N_aux))
mini_train_loader = DataLoader(mini_train_set, batch_size)
train_val_loop(
    net, mini_train_loader, val_loader,
    criterion, opt,
    epochs=10, # Overfit on the first `N_aux` examples of the training data
)

## Poisoning

In [None]:
from image_classification.datasets import UpdatableDataset
from image_classification.nn import MetricLogger

def train_epoch_with_poisons(
        model: nn.Module,
        dataloader: DataLoader,
        criterion: _Loss,
        optimizer: Optimizer,
        inverter: GradientInverter,
        alpha_poison=0.2,
        keep_pbars=True,
    ) -> tuple[UpdatableDataset, MetricLogger]:
    model.train()
    logger = MetricLogger(
        metric,
        device=device,
        desc='Train loop', total=len(dataloader.dataset), keep_pbars=keep_pbars,
    )
    poison_set = UpdatableDataset()

    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        logits = model(X)
        # TODO: handle losses that don't reduce
        loss = criterion(logits, y)
        # TODO: backpropagate on each loss element (and model.zero_grad() every time)
        loss.backward()

        # --- poisoning attack
        X_p, y_p = inverter.attack(model, criterion)
        poison_set.append(X_p, y_p)

        logits_p = model(X_p.unsqueeze(0))
        loss_p = alpha_poison * criterion(logits_p, y_p.unsqueeze(0))
        # This adds to `loss` model gradients due to gradient accumulation
        loss_p.backward()
        # ---

        optimizer.step()
        optimizer.zero_grad()

        # FIXME: does not include X_p, y_p, logits_p, loss_p
        # TODO: log loss on poisons
        # TODO: display some poisons
        logger.compute_metrics(X, y, logits, loss.item())
    
    logger.finish()
    return poison_set, logger

In [None]:
def train_loop_with_poisons(
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: DataLoader,
        criterion: _Loss,
        optimizer: Optimizer,
        epochs: int,
        inverter: GradientInverter,
        alpha_poison=0.05,
        metric=metric,
    ) -> TensorDataset:
    poison_set = UpdatableDataset()
    for epoch in trange(epochs, desc='Train epochs', unit='epoch', leave=True):
        poison_set_epoch, _ = train_epoch_with_poisons(
            model, train_loader,
            criterion, optimizer,
            inverter, alpha_poison=alpha_poison
        )
        poison_set.extend(poison_set_epoch)
        test_epoch(model, val_loader, criterion, keep_pbars=True, metric=metric)
    return poison_set.to_tensor_dataset()

### Orthogonal Gradient inverting attack

In [None]:
estimator = OmniscientGradientEstimator()
sample_init = SampleInitRandomNoise(aux_data)
inverter = GradientInverter(
    GradientAttack.ORTHOGONAL,
    estimator,
    steps=5,
    sample_init=sample_init,
    tv_coef=0.0,
    lr=0.3,
)

In [None]:
net = ShuffleNetV2().to(device)

In [None]:
# SGD is more vulnerable to gradient attacks
net = ResNet18(num_classes=num_classes).to(device)
opt = make_optimizer(net, opt_name='sgd', lr=lr, weight_decay=0.0)
train_loop_with_poisons(
    net, train_loader, val_loader,
    criterion, opt,
    epochs,
    inverter, alpha_poison=0.2,
)

Training progress is slowed down by a lot, however accuracy does not drop as much as with Gradient Ascent.

### Gradient Ascent inverting attack

In [None]:
estimator = OmniscientGradientEstimator()
sample_init = SampleInitRandomNoise(aux_data)
inverter = GradientInverter(
    GradientAttack.ASCENT,
    estimator,
    steps=5,
    sample_init=sample_init,
    tv_coef=0.0,
    lr=0.3,
)

In [None]:
# SGD is more vulnerable to gradient attacks
net = ResNet18(num_classes=num_classes).to(device)
opt = make_optimizer(net, opt_name='sgd', lr=lr, weight_decay=0.0)
train_loop_with_poisons(
    net, train_loader, val_loader,
    criterion, opt,
    epochs,
    inverter, alpha_poison=0.2,
)

TODO: test with more poison steps or different lr for poison optimizer

#### Lower poisoning rate

In [None]:
net = ResNet18(num_classes=num_classes).to(device)

In [None]:
opt = make_optimizer(net, opt_name='sgd', lr=lr, weight_decay=0.0)
train_loop_with_poisons(
    net, train_loader, val_loader,
    criterion, opt,
    epochs,
    inverter, alpha_poison=0.05,
)

#### Using Adam optimizer for training

In [None]:
net = ResNet18(num_classes=num_classes).to(device)

In [None]:
# Adam regularizes the parameters so it is more robust to gradient attacks
opt = make_optimizer(net, opt_name='adam', lr=lr)
train_loop_with_poisons(
    net, train_loader, val_loader,
    criterion, opt,
    epochs,
    inverter, alpha_poison=0.2,
);

Interestingly, test accuracy jumps from 40 % to 60 % in one epoch. Explanation?

## Machine unlearning

In [None]:
from image_classification.unlearning import (
    gradient_descent, gradient_ascent, neg_grad_plus, unlearning_last_layers, scrub
)

In [None]:
net = ResNet18(num_classes=num_classes).to(device)
opt = make_optimizer(net, opt_name='sgd', lr=lr)
forget_set = train_loop_with_poisons(
    net, train_loader, val_loader,
    criterion, opt,
    epochs,
    inverter, alpha_poison=0.2,
)
forget_loader = DataLoader(forget_set, batch_size)

In [None]:
from enum import Enum

class Unlearning(Enum):
    GRADIENT_DESCENT = 0
    GRADIENT_ASCENT = 1
    NOISY_GRADIENT_DESCENT = 2
    NEG_GRAD_PLUS = 3
    CFK = 4
    EUK = 5
    SCRUB = 6

In [None]:
def unlearn(
        net: nn.Module,
        # the train loader is not poisoned (poisons are generated continuously)
        train_loader: DataLoader,
        forget_loader: DataLoader,
        criterion: _Loss,
        method: Unlearning,
    ):
    unlearner = deepcopy(net)
    
    match method:
        case Unlearning.GRADIENT_DESCENT:
            opt = make_optimizer(unlearner, opt_name='sgd', lr=lr)
            gradient_descent(
                unlearner, train_loader, val_loader,
                criterion, opt, epochs=1, keep_pbars=False
            )
        case Unlearning.GRADIENT_ASCENT:
            opt = make_optimizer(unlearner, opt_name='sgd', lr=1e-5)
            gradient_ascent(
                unlearner, train_loader, val_loader,
                criterion, opt, epochs=1, keep_pbars=False
            )
        case Unlearning.NEG_GRAD_PLUS:
            opt = make_optimizer(unlearner, opt_name='sgd', lr=lr)
            for epoch in trange(10, desc='NegGrad+ epochs', unit='epoch', leave=True):
                neg_grad_plus(
                    unlearner, train_loader, forget_loader,
                    criterion, opt, keep_pbars=False
                )
        case Unlearning.EUK:
            opt = make_optimizer(unlearner, opt_name='adam', lr=lr)
            with unlearning_last_layers(unlearner, 6, 'euk'):
                train_loop(unlearner, train_loader, criterion, opt, epochs=1)
        case Unlearning.SCRUB:
            opt = make_optimizer(unlearner, opt_name='adam', lr=lr)
            scrub(
                net, unlearner, train_loader, forget_loader, criterion, opt,
                max_steps=1, steps=1, keep_pbars=False,
            )
    
    return unlearner

#### No poisoning

TODO: quantify effect of poisoning in terms of loss recovery effort (epochs). Can the model ever recover from poisoning with enough steps?

In [None]:
clean_net = ResNet18(num_classes=num_classes).to(device)
opt = make_optimizer(clean_net, opt_name='sgd', lr=lr)
train_val_loop(
    clean_net, train_loader, val_loader,
    criterion, opt,
    epochs=epochs,
    metric=metric,
);
test_epoch(clean_net, val_loader, criterion, keep_pbars=True, metric=metric);

#### No unlearning

In [None]:
test_epoch(net, val_loader, criterion, keep_pbars=True, metric=metric);

#### Gradient descent

In [None]:
unlearner = unlearn(net, train_loader, forget_loader, criterion, Unlearning.GRADIENT_DESCENT)
test_epoch(unlearner, val_loader, criterion, keep_pbars=True, metric=metric);

#### NegGrad+

In [None]:
unlearner = unlearn(net, train_loader, forget_loader, criterion, Unlearning.NEG_GRAD_PLUS)
test_epoch(unlearner, val_loader, criterion, keep_pbars=True, metric=metric);

#### EUk ($k = 6$)

In [None]:
unlearner = unlearn(net, train_loader, forget_loader, criterion, Unlearning.EUK)
test_epoch(unlearner, val_loader, criterion, keep_pbars=True, metric=metric);

Although more extensive testing is required, unlearning methods do not fully restore accuracy, at least not with only one epoch. Why is gradient descent enough?

### Remaining tasks

- Refactoring
- Compare with gradient attacks
- Compare with results from article
- Use `Conv16` model for prototyping
- _Little is Enough_ attack (requires efficient gradient stddev estimation)
- Mean gradient estimation with auxiliary dataset
- Testing against unlearning
- Testing with different configs (optimizer, number of epochs, batch size, models)
- Quantify results of data poisoning in terms of slowdown (x% -> y% accuracy = z training epochs)
- Suggest other Hessian-based attacks