# Inverting gradient attack

In [1]:
from __future__ import annotations

from copy import deepcopy
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import pandas as pd
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from torch.nn.modules.loss import _Loss, CrossEntropyLoss
from torch.optim import Optimizer, SGD, Adam, AdamW
from torch.utils.data import Dataset, DataLoader, Subset, TensorDataset
from torchmetrics.classification import MulticlassAccuracy
from torchmetrics.functional.image import total_variation
import torchinfo
from image_classification.utils import trange

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

print(device)

cuda:0


## ResNet-18

For quick prototyping, also consider using `ConvNet16`, a 1.6M-parameter model that is much smaller than `ResNet18`. Experiments can be made on both models or on only one of them.

In [3]:
from image_classification.models import ResNet18, ConvNet16
from image_classification.datasets import cifar10_train_test, cifar100_train_test
from image_classification.nn import train_epoch, train_loop, train_val_loop, test_epoch

Using cuda device


In [4]:
# Set to 10 for CIFAR-10, 100 for CIFAR-100
num_classes = 10

# The images are already normalized by theses datasets
if num_classes == 10:
    get_train_test = cifar10_train_test
elif num_classes == 100:
    get_train_test = cifar100_train_test
else:
    raise ValueError(f"Can't find CIFAR dataset with {num_classes} classes")
print(f"Loading CIFAR-{num_classes}")

training_data, test_data = get_train_test(root='data')
N_test = len(test_data)
N_val = len(training_data) // 10
N_aux = N_val
N = len(training_data) - N_val - N_aux
# This works since training data is already shuffled
training_data, val_data, aux_data = training_data.split([N, N_val, N_aux])

batch_size = 100
N, N_val, N_aux, N_test

Loading CIFAR-10


(40000, 5000, 5000, 10000)

In [5]:
train_loader = DataLoader(training_data, batch_size, drop_last=True)
val_loader = DataLoader(val_data, batch_size, drop_last=True)
aux_loader = DataLoader(aux_data, batch_size, drop_last=True)

### Hyperparameters

In [6]:
lr = 1e-3
weight_decay = 5e-4
# For learning rate scheduling
max_lr = 0.1

epochs = 6
steps_per_epoch = N // batch_size

lr_sched_params = dict(max_lr=max_lr, epochs=epochs, steps_per_epoch=steps_per_epoch)

criterion = CrossEntropyLoss()

top_k = {10: 1, 100: 5}[num_classes]

metric = MulticlassAccuracy(num_classes=num_classes, top_k=top_k)

### Optimizer (TODO: compare SGD & Adam)

In [7]:
def make_optimizer(model: nn.Module, opt_name='adamw', lr=lr, weight_decay=weight_decay, **kwargs) -> Optimizer:
    cls = {'sgd': SGD, 'adam': Adam, 'adamw': AdamW}[opt_name]
    return cls(model.parameters(), lr=lr, weight_decay=weight_decay, **kwargs)

### Visualizing reconstruction attacks

In [10]:
def display_input_image(input: Tensor, label: Tensor, title='', cmap=None, ax: plt.Axes = None):
    """
    Displays an an input image to a neural network.

    `input`: a 3D tensor
    `cmap`: grayscale by default.
    """
    decoders = {
        10: training_data.decode_cifar10_image,
        100: training_data.decode_cifar100_image,
    }
    image = decoders[len(training_data.classes)](input)
    class_ = training_data.decode_target(label)

    if ax is None:
        _fig, ax = plt.subplots()
    ax.imshow(image, cmap=cmap, interpolation='nearest')
    if title:
        title = f"{class_} - {title}"
    else:
        title = class_
    ax.set_title(title, fontsize=7)

TODO: use ConvNet16 and small dataset size

In [11]:
from image_classification.nn import Metric, MetricLogger, _detect_device
import federated
from federated import Aggregator, Mean, Krum
import torchjd
from torchjd.aggregation import Aggregator, Mean, Krum

def train_epoch_jd(
        model: nn.Module,
        dataloader: DataLoader,
        criterion: _Loss,
        optimizer: Optimizer,
        aggregator: Aggregator,
        keep_pbars=True,
        metric: Metric = None,
    ):
    device = _detect_device(model)
    model.train()
    criterion.reduction = 'none'

    logger = MetricLogger(
        metric,
        desc='Train loop', total=len(dataloader.dataset), keep_pbars=keep_pbars,
    )

    for step, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction and loss
        #logits = model(X).detach_()
        #losses = criterion(logits, y)
        logits = model(X)

        losses = criterion(logits, y)
        mean_loss = losses.mean().item()
        
        optimizer.zero_grad()

        #federated.naive_backpropagate_grads(model, X, y, criterion, aggregator)
        #federated.backward(losses, model, aggregator)
        #federated.backpropagate_grads(model, X, y, criterion, aggregator)
        torchjd.backward(losses, aggregator)

        optimizer.step()
        optimizer.zero_grad()

        logger.compute_metrics(X, y, logits, mean_loss)
    
    criterion.reduction = 'mean'
    logger.finish()
    return logger

In [12]:
from torch.optim.lr_scheduler import LRScheduler

def train_val_loop_jd(
        model: nn.Module,
        train_dataloader: DataLoader,
        val_dataloader: DataLoader,
        criterion: _Loss,
        optimizer: Optimizer,
        aggregator: Aggregator,
        epochs: int,
        *,
        lr_scheduler: LRScheduler = None,
        keep_pbars=True,
        metric: Metric = None,
        validate_every: int = 2,
        early_stopping = True,
    ):
    """
    Run the training loop on the model with periodic validation.

    If `val_dataloader` is `None`, no validation is performed.

    If `early_stopping` is True, the training loop exits when validation loss starts decreasing.
    """
    val_loss = float('inf')
    for epoch in trange(epochs, desc='Train epochs', unit='epoch', leave=keep_pbars):
        train_epoch_jd(
            model, train_dataloader, criterion, optimizer, aggregator,
            keep_pbars=keep_pbars, metric=metric,
        )
        if lr_scheduler is not None:
            lr_scheduler.step()
        if val_dataloader is not None and epoch % validate_every == 0:
            logger = test_epoch(
                model, val_dataloader, criterion,
                keep_pbars=keep_pbars, metric=metric,
            )
            next_val_loss = logger.avg_loss.compute()
            if early_stopping and next_val_loss > val_loss:
                print(f"Epoch {epoch}: validation loss stopped improving, exiting train loop.")
                break
            val_loss = next_val_loss

In [13]:
from importlib import reload
import image_classification.models
reload(federated)
reload(image_classification.models)
from image_classification.models import ShuffleNetV2

In [14]:
#net = ResNet18(num_classes=num_classes).to(device)
net = ShuffleNetV2().to(device)
opt = make_optimizer(net, opt_name='adam', lr=lr)
aggregator = Mean()

# Pretrain the model to make it learn the features
#mini_train_set = Subset(training_data, np.arange(N_aux))
#mini_train_loader = DataLoader(mini_train_set, batch_size=64)

train_val_loop_jd(
    net, train_loader, val_loader,
    criterion, opt, aggregator,
    epochs,
    metric=metric
);

Train epochs:   0%|          | 0/6 [00:00<?, ?epoch/s]

Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [08:23<00:00, 79.39it/s, MulticlassAccuracy=0.462, avg_loss=1.77]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:00<00:00, 9807.75it/s, MulticlassAccuracy=0.336, avg_loss=1.57]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [158]:
#net = ResNet18(num_classes=num_classes).to(device)
net = ShuffleNetV2().to(device)
opt = make_optimizer(net, opt_name='adam', lr=lr)

# Pretrain the model to make it learn the features
#mini_train_set = Subset(training_data, np.arange(N_aux))
#mini_train_loader = DataLoader(mini_train_set, batch_size=64)

train_val_loop(
    net, train_loader, val_loader,
    criterion, opt, epochs, metric=metric
);

Train epochs:   0%|          | 0/6 [00:00<?, ?epoch/s]

Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:08<00:00, 4598.58it/s, MulticlassAccuracy=0.496, avg_loss=1.77]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:00<00:00, 11216.66it/s, MulticlassAccuracy=0.348, avg_loss=1.55]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:08<00:00, 4493.70it/s, MulticlassAccuracy=0.528, avg_loss=1.42]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:09<00:00, 4346.30it/s, MulticlassAccuracy=0.687, avg_loss=1.21]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:00<00:00, 11070.12it/s, MulticlassAccuracy=0.591, avg_loss=1.16]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:08<00:00, 4608.30it/s, MulticlassAccuracy=0.724, avg_loss=1.05]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:09<00:00, 4309.65it/s, MulticlassAccuracy=0.789, avg_loss=0.939]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:00<00:00, 11205.96it/s, MulticlassAccuracy=0.595, avg_loss=1.08]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:09<00:00, 4230.92it/s, MulticlassAccuracy=0.79, avg_loss=0.854]


In [44]:
net = ResNet18(num_classes=num_classes).to(device)
opt = make_optimizer(net, opt_name='adam', lr=lr)
# Pretrain the model to make it learn the features
mini_train_set = Subset(training_data, np.arange(N_aux))
mini_train_loader = DataLoader(mini_train_set, batch_size)
train_epoch(
    net, mini_train_loader,
    criterion, opt,
)

Train loop:   0%|          | 0/5000 [00:00<?, ?it/s]

KeyboardInterrupt: 