# Adversarial attack detection

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision.transforms import v2

torch.manual_seed(42)

## Training a model to defend

The reason for using MNIST is that we can train a good model on it in no time. The model below will reach 99% accuracy on the test set in just 5 epochs. More than enough to act as a little target to demonstrate adversarial detection.

In [None]:
# Define the CNN model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
# Load MNIST dataset
batch_size = 128
transform = v2.Compose([
    v2.ToTensor(),
    v2.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Initialize the model, loss function, and optimizer
clf = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(clf.parameters(), lr=0.001)

# Training the model
def train(model, train_loader, criterion, optimizer, epochs=5):
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            if i % 100 == 99:
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 100))
                running_loss = 0.0

In [None]:
# Testing the model
def test(model, test_loader):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy on the test set: %d %%' % (100 * correct / total))

In [None]:
# Train the model and test it
train(clf, train_loader, criterion, optimizer, epochs=5)
test(clf, test_loader)

In [None]:
for param in clf.parameters():
    param.requires_grad = False

clf.eval()

## Defending the model

Now that we have our target model, we can start implementing the idea we outlined in the introduction.
Below you will find the VAE code from before. All that is left is updating the loss function.
Remember, we want to solve the following optimization problem:

$$\min\limits_\theta D_{KL}(M(x) \| M(AE_\theta(x)))$$

$M(x)$ is the output of the classifier on the input, and $M(AE(x))$ is the output of the classifier on the reconstructed input.
For the KL divergence, use [`torch.nn.functional.kl_div`](https://pytorch.org/docs/stable/generated/torch.nn.functional.kl_div.html#torch.nn.functional.kl_div). For numerical stability, set `kl_div`'s argument `log_target` to `True` and use `F.log_softmax` for the classifier outputs.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from dataclasses import dataclass

class VAE(nn.Module):
    """
    Variational Autoencoder (VAE) class.
    
    Args:
        input_dim (int): Dimensionality of the input data.
        hidden_dim (int): Dimensionality of the hidden layer.
        latent_dim (int): Dimensionality of the latent space.
    """

    @dataclass
    class VAEOutput:
        """
        Dataclass for VAE output.
        
        Attributes:
            z_dist (torch.distributions.Distribution): The distribution of the latent variable z.
            z_sample (torch.Tensor): The sampled value of the latent variable z.
            x_recon (torch.Tensor): The reconstructed output from the VAE.
            loss (torch.Tensor): The overall loss of the VAE.
            loss_recon (torch.Tensor): The reconstruction loss component of the VAE loss.
            loss_kl (torch.Tensor): The KL divergence component of the VAE loss.
        """
        z_dist: torch.distributions.Distribution
        z_sample: torch.Tensor
        x_recon: torch.Tensor
        loss: torch.Tensor

    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
                
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.SiLU(),
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.SiLU(),
            nn.Linear(hidden_dim // 4, hidden_dim // 8),
            nn.SiLU(), 
            nn.Linear(hidden_dim // 8, 2 * latent_dim), # 2 for mean and variance.
        )
        self.softplus = nn.Softplus()
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim // 8),
            nn.SiLU(),
            nn.Linear(hidden_dim // 8, hidden_dim // 4),
            nn.SiLU(),
            nn.Linear(hidden_dim // 4, hidden_dim // 2),
            nn.SiLU(),
            nn.Linear(hidden_dim // 2, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid(),
        )
        
    def encode(self, x, eps: float = 1e-8):
        """
        Encodes the input data into the latent space.
        
        Args:
            x (torch.Tensor): Input data.
            eps (float): Small value to avoid numerical instability.
        
        Returns:
            torch.distributions.MultivariateNormal: Normal distribution of the encoded data.
        """
        x = self.encoder(x)
        mu, logvar = torch.chunk(x, 2, dim=-1)
        scale = self.softplus(logvar) + eps
        scale_tril = torch.diag_embed(scale)
        return torch.distributions.MultivariateNormal(mu, scale_tril=scale_tril)
        
    def reparameterize(self, dist):
        """
        Reparameterizes the encoded data to sample from the latent space.
        
        Args:
            dist (torch.distributions.MultivariateNormal): Normal distribution of the encoded data.
        Returns:
            torch.Tensor: Sampled data from the latent space.
        """
        return dist.rsample()
    
    def decode(self, z):
        """
        Decodes the data from the latent space to the original input space.
        
        Args:
            z (torch.Tensor): Data in the latent space.
        
        Returns:
            torch.Tensor: Reconstructed data in the original input space.
        """
        return self.decoder(z)
    
    def forward(self, x, compute_loss: bool = True):
        """
        Performs a forward pass of the VAE.
        
        Args:
            x (torch.Tensor): Input data.
            compute_loss (bool): Whether to compute the loss or not.
        
        Returns:
            VAEOutput: VAE output dataclass.
        """
        dist = self.encode(x)
        z = self.reparameterize(dist)
        recon_x = self.decode(z)
        
        if not compute_loss:
            return VAE.VAEOutput(
                z_dist=dist,
                z_sample=z,
                x_recon=recon_x,
                loss=None,
            )

        # TODO: Compute output probabilities for x, i.e. M(x)
        probabilities = F.log_softmax(clf(x.view(-1, 1, 28, 28)))

        # TODO: Compute output probabilities for reconstructed x, i.e. M(AE(x))
        probabilities_recon = F.log_softmax(clf(recon_x.view(-1, 1, 28, 28)))
        
        # TODO: Compute loss term here.
        loss = F.kl_div(probabilities, probabilities_recon, reduction='batchmean', log_target=True)

        return VAE.VAEOutput(
            z_dist=dist,
            z_sample=z,
            x_recon=recon_x,
            loss=loss,
        )

In [None]:
batch_size = 64
transform = v2.Compose([
    v2.ToImage(), 
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize((0.5,), (0.5,))
])

# Download and load the training data
train_data = datasets.MNIST(
    './data', 
    download=False, 
    train=True, 
    transform=transform,
)
# Download and load the test data
test_data = datasets.MNIST(
    './data', 
    download=False, 
    train=False, 
    transform=transform,
)

# Create data loaders
train_loader = torch.utils.data.DataLoader(
    train_data, 
    batch_size=batch_size, 
    shuffle=False,
)
test_loader = torch.utils.data.DataLoader(
    test_data, 
    batch_size=batch_size, 
    shuffle=False,
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vae = VAE(input_dim=784, hidden_dim=512, latent_dim=16).to(device)
optimizer = torch.optim.AdamW(vae.parameters(), lr=1e-3, weight_decay=1e-2)

The code below you already know: we define the train and test functions, and fit the model to the training data. You can safely skip to the next section.

In [None]:
from tqdm import tqdm

def train(model, dataloader, optimizer, prev_updates):
    """
    Trains the model on the given data.
    
    Args:
        model (nn.Module): The model to train.
        dataloader (torch.utils.data.DataLoader): The data loader.
        loss_fn: The loss function.
        optimizer: The optimizer.
    """
    model.train()  # Set the model to training mode
    
    for batch_idx, (data, target) in enumerate(tqdm(dataloader)):
        n_upd = prev_updates + batch_idx
        data = data.view(data.size(0), -1)  # Flatten the data
        data = data.to(device)
        
        optimizer.zero_grad()  # Zero the gradients
        
        output = model(data)  # Forward pass
        loss = output.loss
        
        loss.backward()
        
        if n_upd % 100 == 0:
            # Calculate and log gradient norms
            total_norm = 0.0
            for p in model.parameters():
                if p.grad is not None:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
            total_norm = total_norm ** (1. / 2)
        
            print(f'Step {n_upd:,} (N samples: {n_upd*batch_size:,}), Loss: {loss.item():.4f}, Grad: {total_norm:.4f}')
            
        # gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)    
        
        optimizer.step()  # Update the model parameters
        
    return prev_updates + len(dataloader)

In [None]:
def test(model, dataloader):
    """
    Tests the model on the given data.
    
    Args:
        model (nn.Module): The model to test.
        dataloader (torch.utils.data.DataLoader): The data loader.
        cur_step (int): The current step.
    """
    model.eval()  # Set the model to evaluation mode
    test_loss = 0
    
    with torch.no_grad():
        for data, target in tqdm(dataloader, desc='Testing'):
            data = data.to(device)
            data = data.view(data.size(0), -1)  # Flatten the data
            
            output = model(data, compute_loss=True)  # Forward pass
            
            test_loss += output.loss.item()
            
    test_loss /= len(dataloader)
    print(f'====> Test set loss: {test_loss:.4f}')

In [None]:
num_epochs = 20
prev_updates = 0
for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    prev_updates = train(vae, train_loader, optimizer, prev_updates)
test(vae, test_loader)

vae.eval()

In [None]:
torch.save(vae.state_dict(), 'adversarial_vae.pth')

## Testing our defense

To test our defense, we first have to generate adversarial examples. In a previous lab, we used the Fast Gradient Sign Method (FGSM) to generate adversarial examples. This time, we will use a few more methods, and we will use the `foolbox` library to do so. The library is very easy to use and has a lot of built-in methods to generate adversarial examples.

Installing it is as easy as running `pip install foolbox`. You can find the documentation [here](https://foolbox.readthedocs.io/en/latest/). We've already installed it for you.

The testing procedure is as follows:
1. Generate adversarial examples using the `foolbox` library.
2. Compute the accuracy of the classifier on the adversarial examples.
3. Pass the adversarial examples through the VAE.
4. Pass the reconstructed examples through the classifier and compute the accuracy.
5. Compare the accuracy of the adversarial examples and the reconstructed adversarial examples.

### 1. Generating adversarial examples

In [None]:
from foolbox import PyTorchModel, accuracy, samples
from foolbox.attacks import LinfPGD

# Wrap the classifier with Foolbox
fclf = PyTorchModel(clf, bounds=(-1, 1))

# Get a batch of test data
images, labels = next(iter(test_loader))

# Attack the classifier
attack = LinfPGD()
raw_advs, clipped_advs, success = attack(fclf, images, labels, epsilons=0.3)

### 2. Compute the accuracy of the classifier on the adversarial examples
_Hint: You don't have to run the classifier again. Instead, you can use the `success` variable._

In [None]:
# TODO: Compute the accuracy of the classifier on the adversarial examples
accuracy_adv = 1 - success.to(torch.float32).mean(axis=-1)

### 3. Pass the adversarial examples through the VAE

In [None]:
# TODO: Pass the adversarial examples through the VAE
# Hint: Use `clipped_advs`
advs = clipped_advs.view(raw_advs.size(0), -1)
advs = advs.to(device)
advs_recon = vae(advs, compute_loss=False).x_recon

### 4. Pass the reconstructed examples through the classifier and compute the accuracy

In [None]:
# TODO: Pass the reconstructed examples through the classifier and compute the accuracy
# Hint: Use `advs_recon`
advs_recon = advs_recon.view(-1, 1, 28, 28)
advs_recon = advs_recon.to(device)
outputs = clf(advs_recon)
_, predicted = torch.max(outputs.data, 1)
accuracy_recon = (predicted == labels).sum().item() / len(labels)

### 5. Compare the accuracy of the adversarial examples and the reconstructed adversarial examples

In [None]:
# TODO: Print the accuracy of the classifier on the adversarial examples and the reconstructed examples
print(f'Accuracy on adversarial examples: {accuracy_adv:.4f}')
print(f'Accuracy on reconstructed examples: {accuracy_recon:.4f}')

What conclusions can you draw from the results?