# Attacking and Defending a Split Neural Network

- Introducing the problem
- Model inversion
  - White box
  - Black box
- Defences
  - NoPeekNN
  - Noise defence
  - Other defences
- Lesson objective


## Introduction

In Split neural networks (SplitNNs),
there are at minimum two parties:
one data owner
and one computational server.
Previously,
we've assumed that the computational server is a trustworthy individual.
But what if they're not?
What if they will try to learn information about you
in addition to completing training or inference,
as you would expect them to;
the technical term for this is honest-but-curious.

What are the risks?
While the data owner doesn't send raw data
to the computational server,
the data they do send still contains related information -
it has to, otherwise the model wouldn't be able to learn!
In this lesson,
we'll explore how to perform an attack know as **model inversion**,
in which the computational server will try to turn the intermediate data
back into raw data.

## Model Inversion

### White box

Model inversion in machine learning
was introduced by [Fredrikson et al. (2015)](https://www.cs.cmu.edu/~mfredrik/papers/fjr2015ccs.pdf),
who showed that you can exploit a model's overfitting
to recreate training data examples.
This is known as a "white box" attack
because it we utilise model weights.
In this lesson,
we'll be trying a "black box" attack,
in which we  only need query-access to the model
(i.e. we can run inference on it).

### Black box

The attack we're trying was introduced
by [He, Zhang and Lee (2019)](https://personal.ntu.edu.sg/tianwei.zhang/paper/acsac19.pdf).
The process is:
1. the attacker collects a dataset of (intermediate data, raw data)
1. the attacker trains an attack model to map intermediate data -> raw data
1. the attacker run new intermediate data on the attack model to recreate the raw data of a victim

Who is the attacker and how do they get hold of the data?
There are a few possibilities:
- **Computational server colluding with a data owner**.
The colluding data owner runs data through the model
and sends the (raw, intermediate) data pairs to the computational server.
The computational server already has access to victims' intermediate data,
so can run the attack model with impunity.
- **The computational server acts alone**.
For this,
the computational server must source their own attack training data.
If they don't know exactly what the data is
(e.g. they are not the controlling party in the process
and have simply been paid to do computations),
they could guess and make their own dataset.
It turns out that you don't need to get the exact distribution
of the victim dataset to train a useful attack model.
- **A data owner works alone**.
Any data own can train an attack model on their own data.
The data owner must somehow intercept intermediate data from victims
in order to run the attack.

## Defences

### NoPeek

[Vepakomma et. al (2020)](https://dam-prod.media.mit.edu/x/2020/10/29/NoPeek_SplitLearning.pdf)
introduced NoPeek as a possible defence.
NoPeek is a SplitNN is trained with a two-part loss function:
a weighted combination of the objective loss
(e.g. cross-entropy for classification)
and a _Distance correlation_ (DC) loss.
The DC loss measures how similar the input data
is to the intermediate data.
As the loss function is minimised (the model is trained),
the intermediate data looks less and less like the input data,
while retaining information which is useful for performing the task.
The less correlated the input and intermediate data are,
the harder it is for an attack model to learn
how to map one back into the other.

### Noise defence

Titcombe et al. (2021) (under review)
introduced a simple noise defence,
where Laplacian noise is added to the intermediate data
_before_ being transmitted to the computational server.
The idea of this is similar to NoPeek: make the intermediate data dissimlar to the input data.
Crucially,
the noise is not introduced during training (although it can be, if needed).
Instead, the noise is only introduced on inference-time data.
This means the noise defence can be applied to **any** model
and it can be applied unilaterally by the data owner,
which is useful if they don't fully trust the computational server.

### Other methods

[Mireshghallah et al. (2019)](https://arxiv.org/pdf/1905.11814v2.pdf)
introduced _Shredder_,
a defensive method similar to the noise defence.
However, the noise learned by Shredder is adaptive
to the input data and must be learned.
While this can produce better performance-privacy trade-offs,
it requires Shredder to be applied during training-time.

## Lesson Objective

By the end of this lesson,
you will:
* Be able to run black box model inversion on a SplitNN (part 2)
* Be able to implement NoPeek defence (part 3)
* Be able to implement the noise defence (part 4)
* Be able to experiment with hyperparameters to improve a defensive measure (parts 3 and 4)

Parts 1 and 2 must be completed to move on to the later parts.
Parts 3 and 4 are independent of one another,
so can be completed in either order.

---

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Subset
from torchvision.datasets import MNIST
from torchvision.transforms import Compose
from torchvision.transforms import Normalize
from torchvision.transforms import ToTensor

---

## Part 1: Developing a SplitNN

First,
we need to make something to attack.
We will train a simple model on MNIST.
Importantly,
we won't use all the available training data,
so there will be some data left over for us to use in the attack.
We will develop a single PyTorch model here
just to make things simpler,
but you can think of it as two separate models
with data communicated in between.

In this first step,
all of the code you need has been
provided.
Simply run all the cells until
you get to part 2

In [None]:
class Classifier(nn.Module):
    def __init__(self):
        super().__init__()

        # The data owner's model
        self.part1 = nn.Sequential(
            nn.Linear(784, 500),
            nn.ReLU(),
            nn.Linear(500, 500),
            nn.ReLU()
        )

        # The computational server's model
        self.part2 = nn.Sequential(
            nn.Linear(500, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
            nn.Softmax(dim=1)
        )

    def data_owner(self, data):
        x = self.part1(data)
        return x

    def computational_server(self, intermediate):
        x = self.part2(intermediate)
        return x

    def forward(self, x):
        x = self.data_owner(x)
        x = self.computational_server(x)
        return x

In [None]:
transform = Compose([ToTensor(), Normalize((0.5), (0.5))])

train_data = Subset(MNIST(".", download=True, train=True, transform=transform), list(range(30_000)))  # train on first 30'000 images
train_loader = DataLoader(train_data, batch_size=32)

val_data = Subset(MNIST(".", download=True, train=False, transform=transform), list(range(5_000)))
val_loader = DataLoader(val_data, batch_size=128)

In [None]:
classifier = Classifier()
opt = torch.optim.Adam(classifier.parameters(), lr=0.01)

loss_fn = nn.CrossEntropyLoss()

for epoch in range(10):
    correct = 0
    total = 0

    for data, targets in train_loader:
        opt.zero_grad()

        out = classifier(data)
        loss = loss_fn(out, targets)
        loss.backward()

        opt.step()

        correct += out.max(1)[1].eq(targets).sum()
        total += out.size(0)

    val_correct = 0
    val_total = 0

    for data_val, targets_val in val_loader:
        with torch.no_grad():
            out_val = classifier(data_val)

            val_correct += out_val.max(1)[1].eq(targets_val).sum()
            val_total += out_val.size(0)

    print(f"Train acc: {100*correct/total:.3f}")
    print(f"Val acc: {100*val_correct/val_total:.3f}")

---

## Part 2: Attacking the model

### Step 1 - Get the data

Create an "attack dataset" -
a dataset on which to train an attack model.
We'll do this by collecting a dataset of (input, intermediate)
data from the trained model.
We'll give you the dataset class,
you have to implement the code to collect the dataset.

In [None]:
class AttackDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.intermediate_data = torch.empty((0, 500))  # Inputs
        self.actual_data = torch.empty((0, 784))  # Targets

    def push(self, intermediate, actual):
        """
        Add data to the dataset

        Args:
            intermediate (torch.tensor): A tensor of data trasmitted between the parts of a SplitNN. the data to attack
            actual (torch.tensor): A tensor of input data to the SplitNN under attack
        """
        assert intermediate.size(0) == actual.size(0)

        self.intermediate_data = torch.cat([self.intermediate_data, intermediate])
        self.actual_data = torch.cat([self.actual_data, actual])

    def __len__(self):
        return self.intermediate_data.size(0)

    def __getitem__(self, idx):
        return self.intermediate_data[idx], self.actual_data[idx]

In [None]:
attack_train = AttackDataset()

# Input data to use in attack dataset
temp_attack_train = Subset(MNIST(".", train=True, transform=transform), list(range(30_000, 40_000)))  # use 10'000 unseen images to train the attacker
temp_attack_train_loader = DataLoader(temp_attack_train, batch_size=128)

raise NotImplementedError("Complete the code to make `attack_train` a dataset of (intermediate, input) tensors")

In [None]:
raise NotImplementedError("Follow the same process to get an attacker validation dataset using images 5'000-10'000 in the MNIST 'test' dataset")

### Step 2 - Get the model

We need to define an attack model to train.
The attack model ideally would mirror the architecture
of the model under attack (the target model).
In this lesson we know what the target model looks like,
and in practice the attackers would too,
as the colluding data owner has a copy of the target model.
Give the attack model a similar number and size of layers
to the target model.

In [None]:
class AttackModel(nn.Module):
    def __init__(self):
        super().__init__()
        raise NotImplementedError("You need to finish __init__")

    def forward(self, x):
        raise NotImplementedError("You need to finish forward")

### Step 3 - Train the attacker

In [None]:
attack_train_loader = DataLoader(attack_train, batch_size=128)

raise NotImplementedError("Define and run a training loop for the attack model")

### Step 4 - Attack new data

Earlier,
we only used the firs 5'000 "test" MNIST images
to validate the classifier.
We will use the remaining 5'000
to use to validate our attacker.
In practice,
this would be new data coming into the computational server.

In [None]:
def plot_images(tensors, rows):
    """
    Plot images in a grid

    Args:
        tensors (list of tensors): list of tensors to plot
        rows (int): number of rows to plot
    """
    Unnormalise = Normalize([-0.1], [2.0])

    images = []
    for tensor in tensors:
        tensor = Unnormalise(tensor)

        # Clip image values so we can plot
        tensor[tensor < 0] = 0
        tensor[tensor > 1] = 1

        tensor = tensor.unsqueeze(0)  # add batch dim
        images.append(tensor)

    images = torch.cat(images)
    grid_image = torchvision.utils.make_grid(images, nrow=rows).permute(1, 2, 0)

    plt.imshow(grid_image)

In [None]:
raise NotImplementedError("Attack the model to recreate input data")

---

## Part 3: NoPeek

### Step 1 - Create the NoPeek loss

NoPeekNN is a SplitNN which is optimised on task loss (cross entropy, for classification tasks) and _distance correlation loss_.
Distance correlation is a measure of the correlation between two matrices (in PyTorch, that's two tensors).
NoPeek loss is a weighted sum of the two terms,
governed by a hyperparameter $\alpha$:


#### L = L<sub>ce</sub> + $\alpha$ L<sub>dc</sub>

so a higher alpha means more emphasis on minimizing distance correlation.\
**Bonus question**: what are the bounds (minimum and maximum values) of $\alpha$?

We'll provide the Distance correlation loss term.
You must create the full NoPeek loss

In [None]:
class DistanceCorrelationLoss(nn.modules.loss._Loss):
    def forward(self, input_data, intermediate_data):
        input_data = input_data.view(input_data.size(0), -1)
        intermediate_data = intermediate_data.view(intermediate_data.size(0), -1)

        A_input = self._A_matrix(input_data)
        A_intermediate = self._A_matrix(intermediate_data)

        # Get distance variances
        input_dvar = self._distance_variance(A_input)
        intermediate_dvar = self._distance_variance(A_intermediate)

        # Get distance covariance
        dcov = self._distance_covariance(A_input, A_intermediate)

        dcorr = dcov / (input_dvar * intermediate_dvar).sqrt()

        return dcorr

    def _distance_covariance(self, a_matrix, b_matrix):
        return (a_matrix * b_matrix).sum().sqrt() / a_matrix.size(0)

    def _distance_variance(self, a_matrix):
        return (a_matrix ** 2).sum().sqrt() / a_matrix.size(0)

    def _A_matrix(self, data):
        distance_matrix = self._distance_matrix(data)

        row_mean = distance_matrix.mean(dim=0, keepdim=True)
        col_mean = distance_matrix.mean(dim=1, keepdim=True)
        data_mean = distance_matrix.mean()

        return distance_matrix - row_mean - col_mean + data_mean

    def _distance_matrix(self, data):
        n = data.size(0)
        distance_matrix = torch.zeros((n, n))

        for i in range(n):
            for j in range(n):
                row_diff = data[i] - data[j]
                distance_matrix[i, j] = (row_diff ** 2).sum()

        return distance_matrix

In [None]:
class NoPeekLoss(nn.modules.loss._Loss):
    raise NotImplementedError("Complete")

### Bonus step

The implementation of distance correlation loss above
is quite inefficient.
Can you improve it?

### Step 2 - Train a NoPeekNN

Train a classifier using the NoPeek loss.
Because distance correlation loss is quite inefficient,
we have to keep to quite low batch sizes (e.g. 32).
Of course, this depends on the power and memory of your computer.
You can re-use most of the training code above
to train a NoPeekNN.

In [None]:
nopeeknn = Classifier()
raise NotImplementedError("Train NoPeekNN")

### Step 3 - Train an attacker

Use the code from earlier to train an attack model on the NoPeek classifier.

In [None]:
raise NotImplementedError("Train an attack model on NoPeekNN")

### Step 4- Validate attacker / NoPeek

Visualize reconstructions from the NoPeek attacker.
Has the defence worked?

In [None]:
raise NotImplementedError("Visualize reconstructions")

### Step 5 - Experiment

Vary the weight of distance correlation loss term
in the NoPeek loss function.
What happens as you increase $\alpha$?

In [None]:
raise NotImplementedError("Experiment")

---

## Part 4: Noise defence

### Step 1 - Make a noisy classifier

In the noise defence,
we add noise to the intermediate data of an already-trained model.
To make the noisy classifier,
let's first define a classifier class with noise.
You can get noise using `torch.distributions`.
The noisy classifier should take a noise scale parameter,
which defines how much noise to add.

Next,
we must copy over the parameters
from our trained classifier
to the noisy classifier
so we have access to the new noise functionality.

In [None]:
class NoisyClassifier(Classifier):
    def __init__(self, noise_scale):
        super().__init__()
        self.set_noise(noise_scale)

    def set_noise(self, noise):
        raise NotImplementedError("Set a noise distribution using PyTorch")

    def forward(self, x):
        raise NotImplementedError("Define a forward pass which adds noise to the intermediate data")

In [None]:
noisy_classifier = NoisyClassifier(0.1)
raise NotImplementedError("Copy parameters from trained classifier to `noisy_classifier")

### Step 2 - Train attacker

Train an attack model on the noisy classifier.
You can re-use the code from earlier.

In [None]:
raise NotImplementedError("Train attack model on the noisy classifier")

### Step 3 - Validate attacker / noise defence

Plot reconstructions from the attacker to see if the noise defence works.

### Step 4 - Experiment

Change aspects of the noisy classifier and attacker
to find out whether the defence is useful or not.
Some experiments to try:
- What happens when you increase noise scale?
- How does noise scale differ from NoPeek?
- Is the accuracy/privacy trade-off useful?
- What happens when you use a gaussian noise distribution instead?


In [None]:
raise NotImplementedError("Experiment - have fun with it")

---

## Summary

Can you:
- [ ] Run model inversion of SplitNN
- [ ] Implement NoPeek
- [ ] Implement the noise defence
- [ ] Experiment with hyperparameters to improve a defensive measure

Model inversion is not the only attack
which can be performed on neural networks - 
far from it,
in fact.
There are attacks which can steal model hyperparameters,
extract training data memorized by the model,
detect whether some data was part of the training data,
or even destroy your model's ability to make predictions.
Hopefully this lesson has given
you an insight into their vulnerabilities.
If you take one thing from this lesson,
it should be that models are computational systems,
and are therefore susceptible to attack.
You should incorporate InfoSec processes
and reviews into model development
if you want to stay ahead of embarrassing
and potentially costly attacks.
If you found yourself enjoying
developing the attack model,
remember to use this new-found power for good!
