# Adversarial attack detection

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision.transforms import v2

torch.manual_seed(42)

<torch._C.Generator at 0x10cc7bdf0>

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f'Torch version: {torch.__version__}, Device: {device.type}')

Torch version: 2.6.0, Device: mps


## Training a model to defend

The reason for using MNIST is that we can train a good model on it in no time. The model below will reach 99% accuracy on the test set in just 5 epochs. More than enough to act as a little target to demonstrate adversarial detection.

In [3]:
# Define the CNN model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
# Load MNIST dataset
batch_size = 128
transform = v2.Compose([
    v2.ToImage(),  # Converts PIL image or ndarray to a tensor image
    v2.ToDtype(torch.float32, scale=True),  # Normalizes pixel values to [0, 1]
    v2.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Initialize the model, loss function, and optimizer
clf = CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(clf.parameters(), lr=0.001)

# Training the model
def train(model, train_loader, criterion, optimizer, epochs=5):
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = model(inputs.to(device))
            loss = criterion(outputs, labels.to(device))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            if i % 100 == 99:
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 100))
                running_loss = 0.0



In [5]:
# Testing the model
def test(model, test_loader):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data
            outputs = model(inputs.to(device))
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.to(device)).sum().item()

    print('Accuracy on the test set: %d %%' % (100 * correct / total))

In [6]:
# Train the model and test it
train(clf, train_loader, criterion, optimizer, epochs=5)
test(clf, test_loader)

[1,   100] loss: 0.687
[1,   200] loss: 0.153
[1,   300] loss: 0.105
[1,   400] loss: 0.085
[2,   100] loss: 0.064
[2,   200] loss: 0.061
[2,   300] loss: 0.053
[2,   400] loss: 0.048
[3,   100] loss: 0.042
[3,   200] loss: 0.042
[3,   300] loss: 0.038
[3,   400] loss: 0.035
[4,   100] loss: 0.028
[4,   200] loss: 0.029
[4,   300] loss: 0.028
[4,   400] loss: 0.029
[5,   100] loss: 0.022
[5,   200] loss: 0.019
[5,   300] loss: 0.021
[5,   400] loss: 0.025
Accuracy on the test set: 99 %


In [7]:
for param in clf.parameters():
    param.requires_grad = False

clf.eval()

CNN(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=3136, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

## Defending the model

Now that we have our target model, we can start implementing the idea we outlined in the introduction.
Below you will find the VAE code from before. All that is left is updating the loss function.
Remember, we want to solve the following optimization problem:

$$\min\limits_\theta D_{KL}(M(x) \| M(AE_\theta(x)))$$

$M(x)$ is the output of the classifier on the input, and $M(AE(x))$ is the output of the classifier on the reconstructed input.
For the KL divergence, use [`torch.nn.functional.kl_div`](https://pytorch.org/docs/stable/generated/torch.nn.functional.kl_div.html#torch.nn.functional.kl_div). For numerical stability, set `kl_div`'s argument `log_target` to `True` and use `F.log_softmax` for the classifier outputs.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from dataclasses import dataclass

class VAE(nn.Module):
    """
    Variational Autoencoder (VAE) class.
    
    Args:
        input_dim (int): Dimensionality of the input data.
        hidden_dim (int): Dimensionality of the hidden layer.
        latent_dim (int): Dimensionality of the latent space.
    """

    @dataclass
    class VAEOutput:
        """
        Dataclass for VAE output.
        
        Attributes:
            z_dist (torch.distributions.Distribution): The distribution of the latent variable z.
            z_sample (torch.Tensor): The sampled value of the latent variable z.
            x_recon (torch.Tensor): The reconstructed output from the VAE.
            loss (torch.Tensor): The overall loss of the VAE.
            loss_recon (torch.Tensor): The reconstruction loss component of the VAE loss.
            loss_kl (torch.Tensor): The KL divergence component of the VAE loss.
        """
        z_dist: torch.distributions.Distribution
        z_sample: torch.Tensor
        x_recon: torch.Tensor
        loss: torch.Tensor

    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
                
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.SiLU(),
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.SiLU(),
            nn.Linear(hidden_dim // 4, hidden_dim // 8),
            nn.SiLU(), 
            nn.Linear(hidden_dim // 8, 2 * latent_dim), # 2 for mean and variance.
        )
        self.softplus = nn.Softplus()
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim // 8),
            nn.SiLU(),
            nn.Linear(hidden_dim // 8, hidden_dim // 4),
            nn.SiLU(),
            nn.Linear(hidden_dim // 4, hidden_dim // 2),
            nn.SiLU(),
            nn.Linear(hidden_dim // 2, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid(),
        )
        
    def encode(self, x, eps: float = 1e-8):
        """
        Encodes the input data into the latent space.
        
        Args:
            x (torch.Tensor): Input data.
            eps (float): Small value to avoid numerical instability.
        
        Returns:
            torch.distributions.MultivariateNormal: Normal distribution of the encoded data.
        """
        x = self.encoder(x)
        mu, logvar = torch.chunk(x, 2, dim=-1)
        scale = self.softplus(logvar) + eps
        scale_tril = torch.diag_embed(scale)
        return torch.distributions.MultivariateNormal(mu, scale_tril=scale_tril)
        
    def reparameterize(self, dist):
        """
        Reparameterizes the encoded data to sample from the latent space.
        
        Args:
            dist (torch.distributions.MultivariateNormal): Normal distribution of the encoded data.
        Returns:
            torch.Tensor: Sampled data from the latent space.
        """
        return dist.rsample()
    
    def decode(self, z):
        """
        Decodes the data from the latent space to the original input space.
        
        Args:
            z (torch.Tensor): Data in the latent space.
        
        Returns:
            torch.Tensor: Reconstructed data in the original input space.
        """
        return self.decoder(z)
    
    def forward(self, x, compute_loss: bool = True):
        """
        Performs a forward pass of the VAE.
        
        Args:
            x (torch.Tensor): Input data.
            compute_loss (bool): Whether to compute the loss or not.
        
        Returns:
            VAEOutput: VAE output dataclass.
        """
        dist = self.encode(x)
        z = self.reparameterize(dist)
        recon_x = self.decode(z)
        
        if not compute_loss:
            return VAE.VAEOutput(
                z_dist=dist,
                z_sample=z,
                x_recon=recon_x,
                loss=None,
            )

        # TODO: Compute output probabilities for x, i.e. M(x)
        probs = F.log_softmax(clf(x.view(-1, 1, 28, 28)), dim=1)

        # TODO: Compute output probabilities for reconstructed x, i.e. M(AE(x))
        probs_recon = F.log_softmax(clf(recon_x.view(-1, 1, 28, 28)), dim=1)
        
        # TODO: Compute loss term here.
        loss = F.kl_div(probs, probs_recon, reduction='batchmean', log_target=True)

        return VAE.VAEOutput(
            z_dist=dist,
            z_sample=z,
            x_recon=recon_x,
            loss=loss,
        )

In [9]:
batch_size = 64
transform = v2.Compose([
    v2.ToImage(), 
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize((0.5,), (0.5,))
])

# Download and load the training data
train_data = datasets.MNIST(
    './data', 
    download=False, 
    train=True, 
    transform=transform,
)
# Download and load the test data
test_data = datasets.MNIST(
    './data', 
    download=False, 
    train=False, 
    transform=transform,
)

# Create data loaders
train_loader = torch.utils.data.DataLoader(
    train_data, 
    batch_size=batch_size, 
    shuffle=False,
)
test_loader = torch.utils.data.DataLoader(
    test_data, 
    batch_size=batch_size, 
    shuffle=False,
)

vae = VAE(input_dim=784, hidden_dim=512, latent_dim=16).to(device)
optimizer = torch.optim.AdamW(vae.parameters(), lr=1e-3, weight_decay=1e-2)

The code below you already know: we define the train and test functions, and fit the model to the training data. You can safely skip to the next section.

In [13]:
from tqdm import tqdm

def train(model, dataloader, optimizer, prev_updates):
    """
    Trains the model on the given data.
    
    Args:
        model (nn.Module): The model to train.
        dataloader (torch.utils.data.DataLoader): The data loader.
        loss_fn: The loss function.
        optimizer: The optimizer.
    """
    model.train()  # Set the model to training mode
    
    for batch_idx, (data, target) in enumerate(tqdm(dataloader)):
        n_upd = prev_updates + batch_idx
        data = data.view(data.size(0), -1)  # Flatten the data
        data = data.to(device)
        
        optimizer.zero_grad()  # Zero the gradients
        
        output = model(data)  # Forward pass
        loss = output.loss
        
        loss.backward()
        
        if n_upd % 100 == 0:
            # Calculate and log gradient norms
            total_norm = 0.0
            for p in model.parameters():
                if p.grad is not None:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
            total_norm = total_norm ** (1. / 2)
        
            print(f'Step {n_upd:,} (N samples: {n_upd*batch_size:,}), Loss: {loss.item():.4f}, Grad: {total_norm:.4f}')
            
        # gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)    
        
        optimizer.step()  # Update the model parameters
        
    return prev_updates + len(dataloader)

In [14]:
def test(model, dataloader):
    """
    Tests the model on the given data.
    
    Args:
        model (nn.Module): The model to test.
        dataloader (torch.utils.data.DataLoader): The data loader.
        cur_step (int): The current step.
    """
    model.eval()  # Set the model to evaluation mode
    test_loss = 0
    
    with torch.no_grad():
        for data, target in tqdm(dataloader, desc='Testing'):
            data = data.to(device)
            data = data.view(data.size(0), -1)  # Flatten the data
            
            output = model(data, compute_loss=True)  # Forward pass
            
            test_loss += output.loss.item()
            
    test_loss /= len(dataloader)
    print(f'====> Test set loss: {test_loss:.4f}')

In [15]:
num_epochs = 20
prev_updates = 0
for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    prev_updates = train(vae, train_loader, optimizer, prev_updates)
test(vae, test_loader)

vae.eval()

Epoch 1/20


  probs = F.log_softmax(clf(x.view(-1, 1, 28, 28)))
  probs_recon = F.log_softmax(clf(recon_x.view(-1, 1, 28, 28)))


Step 0 (N samples: 0), Loss: 13.9894, Grad: 0.1516


 12%|█▏        | 112/938 [00:02<00:13, 62.67it/s]

Step 100 (N samples: 6,400), Loss: 9.9518, Grad: 6.5296


 22%|██▏       | 210/938 [00:04<00:11, 62.45it/s]

Step 200 (N samples: 12,800), Loss: 12.8497, Grad: 2.4934


 33%|███▎      | 308/938 [00:05<00:10, 59.03it/s]

Step 300 (N samples: 19,200), Loss: 7.2079, Grad: 7.9161


 44%|████▍     | 413/938 [00:07<00:08, 62.00it/s]

Step 400 (N samples: 25,600), Loss: 5.9322, Grad: 1.9055


 54%|█████▍    | 511/938 [00:09<00:06, 61.31it/s]

Step 500 (N samples: 32,000), Loss: 7.6862, Grad: 0.8046


 65%|██████▍   | 609/938 [00:10<00:05, 61.27it/s]

Step 600 (N samples: 38,400), Loss: 5.6189, Grad: 11.9943


 75%|███████▌  | 707/938 [00:12<00:03, 60.07it/s]

Step 700 (N samples: 44,800), Loss: 7.4210, Grad: 7.2534


 86%|████████▋ | 811/938 [00:13<00:02, 59.90it/s]

Step 800 (N samples: 51,200), Loss: 6.5895, Grad: 4.9031


 97%|█████████▋| 907/938 [00:15<00:00, 57.65it/s]

Step 900 (N samples: 57,600), Loss: 5.1196, Grad: 7.1781


100%|██████████| 938/938 [00:16<00:00, 56.95it/s]


Epoch 2/20


  8%|▊         | 72/938 [00:01<00:14, 59.35it/s]

Step 1,000 (N samples: 64,000), Loss: 4.6401, Grad: 1.3983


 18%|█▊        | 170/938 [00:02<00:12, 60.30it/s]

Step 1,100 (N samples: 70,400), Loss: 4.9910, Grad: 8.8826


 29%|██▊       | 268/938 [00:04<00:11, 56.77it/s]

Step 1,200 (N samples: 76,800), Loss: 7.1417, Grad: 25.9015


 40%|███▉      | 372/938 [00:06<00:10, 53.82it/s]

Step 1,300 (N samples: 83,200), Loss: 5.5544, Grad: 2.8940


 50%|█████     | 471/938 [00:08<00:08, 56.36it/s]

Step 1,400 (N samples: 89,600), Loss: 5.7195, Grad: 0.7000


 61%|██████    | 571/938 [00:09<00:05, 62.98it/s]

Step 1,500 (N samples: 96,000), Loss: 5.1553, Grad: 0.9943


 71%|███████▏  | 669/938 [00:11<00:04, 62.14it/s]

Step 1,600 (N samples: 102,400), Loss: 4.9911, Grad: 3.3481


 82%|████████▏ | 772/938 [00:12<00:02, 58.51it/s]

Step 1,700 (N samples: 108,800), Loss: 4.3261, Grad: 14.2767


 93%|█████████▎| 873/938 [00:14<00:01, 59.77it/s]

Step 1,800 (N samples: 115,200), Loss: 4.5705, Grad: 1.0746


100%|██████████| 938/938 [00:15<00:00, 59.89it/s]


Epoch 3/20


  4%|▎         | 35/938 [00:00<00:15, 56.61it/s]

Step 1,900 (N samples: 121,600), Loss: 3.7350, Grad: 2.4948


 14%|█▍        | 132/938 [00:02<00:13, 61.38it/s]

Step 2,000 (N samples: 128,000), Loss: 4.0344, Grad: 4.2623


 25%|██▌       | 237/938 [00:03<00:10, 63.85it/s]

Step 2,100 (N samples: 134,400), Loss: 5.7955, Grad: 7.7073


 36%|███▌      | 335/938 [00:05<00:09, 64.68it/s]

Step 2,200 (N samples: 140,800), Loss: 4.2672, Grad: 6.3414


 46%|████▌     | 433/938 [00:06<00:08, 61.25it/s]

Step 2,300 (N samples: 147,200), Loss: 4.7916, Grad: 3.5212


 57%|█████▋    | 531/938 [00:08<00:06, 61.60it/s]

Step 2,400 (N samples: 153,600), Loss: 2.6690, Grad: 1.2665


 68%|██████▊   | 635/938 [00:10<00:05, 60.24it/s]

Step 2,500 (N samples: 160,000), Loss: 4.3630, Grad: 2.4722


 78%|███████▊  | 733/938 [00:11<00:03, 63.36it/s]

Step 2,600 (N samples: 166,400), Loss: 4.4012, Grad: 81.9491


 89%|████████▊ | 831/938 [00:13<00:01, 59.55it/s]

Step 2,700 (N samples: 172,800), Loss: 5.1275, Grad: 1.5782


100%|██████████| 938/938 [00:15<00:00, 62.46it/s]

Step 2,800 (N samples: 179,200), Loss: 5.8947, Grad: 0.8308





Epoch 4/20


 10%|█         | 98/938 [00:01<00:13, 60.74it/s]

Step 2,900 (N samples: 185,600), Loss: 4.9353, Grad: 0.7834


 21%|██        | 196/938 [00:03<00:12, 60.31it/s]

Step 3,000 (N samples: 192,000), Loss: 4.1315, Grad: 3.3277


 31%|███       | 293/938 [00:04<00:10, 59.10it/s]

Step 3,100 (N samples: 198,400), Loss: 4.3860, Grad: 15.4260


 42%|████▏     | 397/938 [00:06<00:08, 62.89it/s]

Step 3,200 (N samples: 204,800), Loss: 4.3867, Grad: 9.9663


 53%|█████▎    | 495/938 [00:08<00:07, 58.61it/s]

Step 3,300 (N samples: 211,200), Loss: 4.7444, Grad: 6.0187


 64%|██████▍   | 598/938 [00:09<00:05, 57.75it/s]

Step 3,400 (N samples: 217,600), Loss: 5.8735, Grad: 0.2316


 74%|███████▍  | 694/938 [00:11<00:03, 62.72it/s]

Step 3,500 (N samples: 224,000), Loss: 6.2540, Grad: 2.0540


 85%|████████▌ | 799/938 [00:12<00:02, 63.72it/s]

Step 3,600 (N samples: 230,400), Loss: 5.3923, Grad: 0.4164


 95%|█████████▍| 890/938 [00:14<00:00, 52.27it/s]

Step 3,700 (N samples: 236,800), Loss: 3.7721, Grad: 1.0561


100%|██████████| 938/938 [00:15<00:00, 61.21it/s]


Epoch 5/20


  7%|▋         | 61/938 [00:01<00:14, 61.09it/s]

Step 3,800 (N samples: 243,200), Loss: 3.5500, Grad: 8.9268


 17%|█▋        | 155/938 [00:02<00:13, 59.17it/s]

Step 3,900 (N samples: 249,600), Loss: 4.6246, Grad: 0.6913


 28%|██▊       | 260/938 [00:04<00:11, 59.91it/s]

Step 4,000 (N samples: 256,000), Loss: 5.2181, Grad: 0.3339


 38%|███▊      | 355/938 [00:05<00:09, 60.91it/s]

Step 4,100 (N samples: 262,400), Loss: 3.3302, Grad: 0.3269


 49%|████▉     | 460/938 [00:07<00:07, 59.99it/s]

Step 4,200 (N samples: 268,800), Loss: 3.4209, Grad: 3.2068


 59%|█████▉    | 557/938 [00:09<00:06, 63.41it/s]

Step 4,300 (N samples: 275,200), Loss: 2.8852, Grad: 0.2665


 70%|██████▉   | 655/938 [00:10<00:04, 62.43it/s]

Step 4,400 (N samples: 281,600), Loss: 3.8163, Grad: 0.8489


 81%|████████  | 760/938 [00:12<00:02, 60.90it/s]

Step 4,500 (N samples: 288,000), Loss: 4.4376, Grad: 6.9670


 91%|█████████▏| 857/938 [00:13<00:01, 59.51it/s]

Step 4,600 (N samples: 294,400), Loss: 2.8381, Grad: 0.8518


100%|██████████| 938/938 [00:15<00:00, 61.20it/s]


Epoch 6/20


  2%|▏         | 23/938 [00:00<00:15, 57.24it/s]

Step 4,700 (N samples: 300,800), Loss: 4.7105, Grad: 0.8453


 13%|█▎        | 118/938 [00:01<00:12, 63.10it/s]

Step 4,800 (N samples: 307,200), Loss: 3.5840, Grad: 0.8259


 24%|██▍       | 223/938 [00:03<00:11, 62.90it/s]

Step 4,900 (N samples: 313,600), Loss: 3.4100, Grad: 3.3119


 34%|███▍      | 321/938 [00:05<00:09, 63.01it/s]

Step 5,000 (N samples: 320,000), Loss: 4.5350, Grad: 3.9542


 45%|████▍     | 419/938 [00:06<00:08, 63.69it/s]

Step 5,100 (N samples: 326,400), Loss: 4.5928, Grad: 0.4453


 55%|█████▌    | 517/938 [00:08<00:06, 62.43it/s]

Step 5,200 (N samples: 332,800), Loss: 4.6300, Grad: 66.3389


 66%|██████▋   | 622/938 [00:09<00:04, 63.42it/s]

Step 5,300 (N samples: 339,200), Loss: 5.2838, Grad: 0.4460


 77%|███████▋  | 720/938 [00:11<00:03, 63.19it/s]

Step 5,400 (N samples: 345,600), Loss: 5.3280, Grad: 1.6859


 87%|████████▋ | 818/938 [00:13<00:01, 62.78it/s]

Step 5,500 (N samples: 352,000), Loss: 4.0889, Grad: 10.5286


 98%|█████████▊| 923/938 [00:14<00:00, 63.31it/s]

Step 5,600 (N samples: 358,400), Loss: 3.2968, Grad: 0.4917


100%|██████████| 938/938 [00:14<00:00, 62.95it/s]


Epoch 7/20


  9%|▉         | 84/938 [00:01<00:13, 63.34it/s]

Step 5,700 (N samples: 364,800), Loss: 3.6241, Grad: 0.6037


 19%|█▉        | 182/938 [00:02<00:12, 62.08it/s]

Step 5,800 (N samples: 371,200), Loss: 4.4905, Grad: 0.4001


 30%|██▉       | 280/938 [00:04<00:10, 62.75it/s]

Step 5,900 (N samples: 377,600), Loss: 4.6319, Grad: 2.0915


 41%|████      | 385/938 [00:06<00:08, 62.84it/s]

Step 6,000 (N samples: 384,000), Loss: 4.2956, Grad: 3.1392


 51%|█████▏    | 483/938 [00:07<00:07, 63.42it/s]

Step 6,100 (N samples: 390,400), Loss: 4.5527, Grad: 34.5696


 62%|██████▏   | 581/938 [00:09<00:05, 62.71it/s]

Step 6,200 (N samples: 396,800), Loss: 4.1521, Grad: 1.8973


 72%|███████▏  | 679/938 [00:10<00:04, 61.84it/s]

Step 6,300 (N samples: 403,200), Loss: 4.7802, Grad: 4.9007


 84%|████████▎ | 784/938 [00:12<00:02, 61.09it/s]

Step 6,400 (N samples: 409,600), Loss: 2.3331, Grad: 1.7167


 94%|█████████▍| 882/938 [00:14<00:00, 62.34it/s]

Step 6,500 (N samples: 416,000), Loss: 3.7337, Grad: 0.8958


100%|██████████| 938/938 [00:14<00:00, 62.72it/s]


Epoch 8/20


  4%|▍         | 41/938 [00:00<00:14, 61.83it/s]

Step 6,600 (N samples: 422,400), Loss: 2.8906, Grad: 19.0600


 16%|█▌        | 146/938 [00:02<00:12, 62.86it/s]

Step 6,700 (N samples: 428,800), Loss: 2.8031, Grad: 0.9288


 26%|██▌       | 244/938 [00:03<00:11, 62.72it/s]

Step 6,800 (N samples: 435,200), Loss: 3.1556, Grad: 0.4888


 36%|███▋      | 342/938 [00:05<00:09, 60.42it/s]

Step 6,900 (N samples: 441,600), Loss: 4.4667, Grad: 1.7947


 47%|████▋     | 443/938 [00:07<00:07, 62.10it/s]

Step 7,000 (N samples: 448,000), Loss: 3.1242, Grad: 1.4888


 58%|█████▊    | 548/938 [00:08<00:06, 64.58it/s]

Step 7,100 (N samples: 454,400), Loss: 3.2385, Grad: 0.5351


 69%|██████▉   | 646/938 [00:10<00:04, 63.13it/s]

Step 7,200 (N samples: 460,800), Loss: 2.4469, Grad: 0.7687


 79%|███████▉  | 744/938 [00:11<00:03, 63.79it/s]

Step 7,300 (N samples: 467,200), Loss: 1.7489, Grad: 0.5218


 90%|████████▉ | 842/938 [00:13<00:01, 64.18it/s]

Step 7,400 (N samples: 473,600), Loss: 3.2219, Grad: 2.3013


100%|██████████| 938/938 [00:14<00:00, 63.54it/s]


Step 7,500 (N samples: 480,000), Loss: 2.6736, Grad: 0.4206
Epoch 9/20


 11%|█         | 105/938 [00:01<00:13, 63.85it/s]

Step 7,600 (N samples: 486,400), Loss: 2.4367, Grad: 2.8232


 22%|██▏       | 203/938 [00:03<00:11, 64.01it/s]

Step 7,700 (N samples: 492,800), Loss: 3.6664, Grad: 0.2749


 33%|███▎      | 308/938 [00:04<00:09, 63.90it/s]

Step 7,800 (N samples: 499,200), Loss: 3.1392, Grad: 8.3068


 43%|████▎     | 406/938 [00:06<00:08, 62.85it/s]

Step 7,900 (N samples: 505,600), Loss: 2.3260, Grad: 0.0366


 54%|█████▎    | 504/938 [00:07<00:07, 60.36it/s]

Step 8,000 (N samples: 512,000), Loss: 4.3814, Grad: 27.4325


 64%|██████▍   | 602/938 [00:09<00:05, 63.07it/s]

Step 8,100 (N samples: 518,400), Loss: 3.2630, Grad: 0.3645


 75%|███████▌  | 706/938 [00:11<00:03, 61.54it/s]

Step 8,200 (N samples: 524,800), Loss: 3.9503, Grad: 1.2619


 86%|████████▌ | 804/938 [00:12<00:02, 61.12it/s]

Step 8,300 (N samples: 531,200), Loss: 2.6088, Grad: 0.1828


 97%|█████████▋| 909/938 [00:14<00:00, 64.55it/s]

Step 8,400 (N samples: 537,600), Loss: 2.9118, Grad: 0.6662


100%|██████████| 938/938 [00:14<00:00, 63.50it/s]


Epoch 10/20


  7%|▋         | 70/938 [00:01<00:14, 58.29it/s]

Step 8,500 (N samples: 544,000), Loss: 3.9944, Grad: 3.6938


 18%|█▊        | 167/938 [00:02<00:13, 58.00it/s]

Step 8,600 (N samples: 550,400), Loss: 3.3017, Grad: 1.0981


 28%|██▊       | 266/938 [00:04<00:11, 58.50it/s]

Step 8,700 (N samples: 556,800), Loss: 2.6926, Grad: 0.2002


 40%|███▉      | 371/938 [00:06<00:09, 62.48it/s]

Step 8,800 (N samples: 563,200), Loss: 3.0443, Grad: 0.3062


 50%|█████     | 469/938 [00:07<00:07, 63.64it/s]

Step 8,900 (N samples: 569,600), Loss: 3.7474, Grad: 8.7850


 60%|██████    | 567/938 [00:09<00:06, 61.04it/s]

Step 9,000 (N samples: 576,000), Loss: 2.1953, Grad: 0.9667


 72%|███████▏  | 671/938 [00:11<00:04, 57.91it/s]

Step 9,100 (N samples: 582,400), Loss: 3.6002, Grad: 2.3359


 82%|████████▏ | 767/938 [00:12<00:02, 60.25it/s]

Step 9,200 (N samples: 588,800), Loss: 2.6540, Grad: 24.4786


 92%|█████████▏| 865/938 [00:14<00:01, 58.28it/s]

Step 9,300 (N samples: 595,200), Loss: 3.0182, Grad: 0.4210


100%|██████████| 938/938 [00:15<00:00, 60.51it/s]


Epoch 11/20


  3%|▎         | 27/938 [00:00<00:17, 50.77it/s]

Step 9,400 (N samples: 601,600), Loss: 3.1483, Grad: 0.2776


 14%|█▍        | 130/938 [00:02<00:13, 60.95it/s]

Step 9,500 (N samples: 608,000), Loss: 2.7225, Grad: 5.0967


 24%|██▍       | 228/938 [00:03<00:11, 63.26it/s]

Step 9,600 (N samples: 614,400), Loss: 3.0617, Grad: 21.2308


 35%|███▌      | 331/938 [00:05<00:09, 61.87it/s]

Step 9,700 (N samples: 620,800), Loss: 1.4669, Grad: 1.2709


 46%|████▌     | 429/938 [00:07<00:08, 60.94it/s]

Step 9,800 (N samples: 627,200), Loss: 3.6008, Grad: 1.3120


 56%|█████▌    | 527/938 [00:08<00:06, 60.95it/s]

Step 9,900 (N samples: 633,600), Loss: 3.7506, Grad: 1.9387


 67%|██████▋   | 631/938 [00:10<00:05, 59.47it/s]

Step 10,000 (N samples: 640,000), Loss: 2.2190, Grad: 0.3150


 78%|███████▊  | 729/938 [00:11<00:03, 62.78it/s]

Step 10,100 (N samples: 646,400), Loss: 3.5445, Grad: 97.6344


 88%|████████▊ | 827/938 [00:13<00:01, 58.68it/s]

Step 10,200 (N samples: 652,800), Loss: 2.9959, Grad: 0.4624


 99%|█████████▉| 932/938 [00:15<00:00, 60.47it/s]

Step 10,300 (N samples: 659,200), Loss: 4.3887, Grad: 0.0681


100%|██████████| 938/938 [00:15<00:00, 61.33it/s]


Epoch 12/20


 10%|▉         | 91/938 [00:01<00:13, 63.09it/s]

Step 10,400 (N samples: 665,600), Loss: 3.8654, Grad: 0.4257


 20%|██        | 189/938 [00:02<00:11, 63.24it/s]

Step 10,500 (N samples: 672,000), Loss: 3.1204, Grad: 35.9405


 31%|███▏      | 294/938 [00:04<00:10, 62.48it/s]

Step 10,600 (N samples: 678,400), Loss: 2.4848, Grad: 0.0215


 42%|████▏     | 392/938 [00:06<00:08, 61.31it/s]

Step 10,700 (N samples: 684,800), Loss: 3.3419, Grad: 1.3992


 52%|█████▏    | 489/938 [00:07<00:07, 58.71it/s]

Step 10,800 (N samples: 691,200), Loss: 2.4560, Grad: 0.7456


 63%|██████▎   | 594/938 [00:09<00:05, 63.30it/s]

Step 10,900 (N samples: 697,600), Loss: 2.4492, Grad: 4.3200


 74%|███████▎  | 690/938 [00:10<00:04, 57.30it/s]

Step 11,000 (N samples: 704,000), Loss: 2.3339, Grad: 0.4930


 85%|████████▍ | 793/938 [00:12<00:02, 60.98it/s]

Step 11,100 (N samples: 710,400), Loss: 3.2655, Grad: 1.9979


 95%|█████████▍| 890/938 [00:14<00:00, 60.19it/s]

Step 11,200 (N samples: 716,800), Loss: 3.1895, Grad: 1.9561


100%|██████████| 938/938 [00:15<00:00, 62.37it/s]


Epoch 13/20


  6%|▌         | 55/938 [00:00<00:14, 58.98it/s]

Step 11,300 (N samples: 723,200), Loss: 2.8872, Grad: 0.2862


 16%|█▌        | 151/938 [00:02<00:13, 60.41it/s]

Step 11,400 (N samples: 729,600), Loss: 3.4096, Grad: 7.7934


 27%|██▋       | 256/938 [00:04<00:11, 60.98it/s]

Step 11,500 (N samples: 736,000), Loss: 1.9562, Grad: 1.0393


 38%|███▊      | 354/938 [00:05<00:09, 63.93it/s]

Step 11,600 (N samples: 742,400), Loss: 2.3484, Grad: 0.1339


 48%|████▊     | 452/938 [00:07<00:08, 56.86it/s]

Step 11,700 (N samples: 748,800), Loss: 2.3615, Grad: 0.1774


 59%|█████▉    | 557/938 [00:09<00:06, 60.26it/s]

Step 11,800 (N samples: 755,200), Loss: 2.2285, Grad: 0.1215


 70%|██████▉   | 655/938 [00:10<00:04, 60.09it/s]

Step 11,900 (N samples: 761,600), Loss: 3.6729, Grad: 0.4528


 80%|████████  | 752/938 [00:12<00:02, 62.74it/s]

Step 12,000 (N samples: 768,000), Loss: 2.0655, Grad: 0.2471


 91%|█████████▏| 857/938 [00:13<00:01, 63.72it/s]

Step 12,100 (N samples: 774,400), Loss: 2.4704, Grad: 2.9812


100%|██████████| 938/938 [00:15<00:00, 61.78it/s]


Epoch 14/20


  1%|          | 7/938 [00:00<00:15, 58.57it/s]

Step 12,200 (N samples: 780,800), Loss: 3.3333, Grad: 0.7258


 12%|█▏        | 112/938 [00:01<00:13, 61.57it/s]

Step 12,300 (N samples: 787,200), Loss: 2.7410, Grad: 9.5727


 23%|██▎       | 217/938 [00:03<00:11, 60.50it/s]

Step 12,400 (N samples: 793,600), Loss: 2.5773, Grad: 1.3828


 34%|███▎      | 315/938 [00:05<00:11, 54.77it/s]

Step 12,500 (N samples: 800,000), Loss: 3.2918, Grad: 0.1351


 45%|████▍     | 418/938 [00:06<00:08, 62.02it/s]

Step 12,600 (N samples: 806,400), Loss: 3.3539, Grad: 0.3422


 55%|█████▌    | 519/938 [00:08<00:07, 59.40it/s]

Step 12,700 (N samples: 812,800), Loss: 3.2837, Grad: 0.7449


 65%|██████▌   | 614/938 [00:10<00:05, 59.51it/s]

Step 12,800 (N samples: 819,200), Loss: 2.9210, Grad: 0.2575


 77%|███████▋  | 719/938 [00:11<00:03, 61.24it/s]

Step 12,900 (N samples: 825,600), Loss: 2.8849, Grad: 0.1538


 87%|████████▋ | 817/938 [00:13<00:01, 60.87it/s]

Step 13,000 (N samples: 832,000), Loss: 2.4958, Grad: 1.4110


 98%|█████████▊| 915/938 [00:14<00:00, 63.19it/s]

Step 13,100 (N samples: 838,400), Loss: 2.2572, Grad: 2.3176


100%|██████████| 938/938 [00:15<00:00, 61.43it/s]


Epoch 15/20


  8%|▊         | 77/938 [00:01<00:14, 61.35it/s]

Step 13,200 (N samples: 844,800), Loss: 3.8639, Grad: 1.0238


 19%|█▊        | 175/938 [00:02<00:12, 61.41it/s]

Step 13,300 (N samples: 851,200), Loss: 2.3085, Grad: 0.0295


 30%|██▉       | 278/938 [00:04<00:10, 60.01it/s]

Step 13,400 (N samples: 857,600), Loss: 2.8849, Grad: 23.0609


 40%|████      | 376/938 [00:06<00:08, 63.15it/s]

Step 13,500 (N samples: 864,000), Loss: 2.4419, Grad: 0.5053


 51%|█████▏    | 481/938 [00:07<00:07, 64.16it/s]

Step 13,600 (N samples: 870,400), Loss: 4.4251, Grad: 24.5119


 62%|██████▏   | 578/938 [00:09<00:06, 57.77it/s]

Step 13,700 (N samples: 876,800), Loss: 3.9162, Grad: 9.5778


 72%|███████▏  | 675/938 [00:10<00:04, 60.34it/s]

Step 13,800 (N samples: 883,200), Loss: 3.6677, Grad: 20.2209


 83%|████████▎ | 780/938 [00:12<00:02, 59.53it/s]

Step 13,900 (N samples: 889,600), Loss: 4.4148, Grad: 4.1451


 93%|█████████▎| 876/938 [00:14<00:01, 60.37it/s]

Step 14,000 (N samples: 896,000), Loss: 2.6850, Grad: 0.0898


100%|██████████| 938/938 [00:15<00:00, 61.20it/s]


Epoch 16/20


  4%|▍         | 42/938 [00:00<00:14, 61.25it/s]

Step 14,100 (N samples: 902,400), Loss: 2.7749, Grad: 0.5948


 15%|█▍        | 139/938 [00:02<00:13, 58.27it/s]

Step 14,200 (N samples: 908,800), Loss: 3.0707, Grad: 0.5846


 26%|██▌       | 244/938 [00:03<00:10, 63.60it/s]

Step 14,300 (N samples: 915,200), Loss: 2.5445, Grad: 1.4460


 36%|███▌      | 338/938 [00:05<00:09, 60.14it/s]

Step 14,400 (N samples: 921,600), Loss: 3.1997, Grad: 3.1035


 47%|████▋     | 441/938 [00:07<00:08, 58.74it/s]

Step 14,500 (N samples: 928,000), Loss: 3.9784, Grad: 1.1408


 57%|█████▋    | 539/938 [00:08<00:06, 61.03it/s]

Step 14,600 (N samples: 934,400), Loss: 2.1882, Grad: 0.2957


 68%|██████▊   | 637/938 [00:10<00:04, 61.42it/s]

Step 14,700 (N samples: 940,800), Loss: 3.6189, Grad: 19.5231


 79%|███████▉  | 741/938 [00:12<00:03, 60.19it/s]

Step 14,800 (N samples: 947,200), Loss: 3.5010, Grad: 0.5937


 89%|████████▉ | 839/938 [00:13<00:01, 62.78it/s]

Step 14,900 (N samples: 953,600), Loss: 3.6705, Grad: 2.5502


100%|██████████| 938/938 [00:15<00:00, 61.03it/s]


Step 15,000 (N samples: 960,000), Loss: 3.6347, Grad: 0.1911
Epoch 17/20


 11%|█         | 105/938 [00:01<00:12, 64.26it/s]

Step 15,100 (N samples: 966,400), Loss: 2.6127, Grad: 12.9023


 22%|██▏       | 203/938 [00:03<00:11, 63.32it/s]

Step 15,200 (N samples: 972,800), Loss: 3.2516, Grad: 11.5361


 32%|███▏      | 301/938 [00:04<00:09, 64.50it/s]

Step 15,300 (N samples: 979,200), Loss: 2.8004, Grad: 0.6304


 43%|████▎     | 399/938 [00:06<00:08, 63.29it/s]

Step 15,400 (N samples: 985,600), Loss: 4.9900, Grad: 15.1133


 54%|█████▎    | 504/938 [00:07<00:07, 56.06it/s]

Step 15,500 (N samples: 992,000), Loss: 3.3419, Grad: 0.6100


 64%|██████▍   | 600/938 [00:09<00:05, 62.30it/s]

Step 15,600 (N samples: 998,400), Loss: 3.3496, Grad: 8.1235


 75%|███████▌  | 704/938 [00:11<00:03, 61.23it/s]

Step 15,700 (N samples: 1,004,800), Loss: 2.9707, Grad: 1.3768


 85%|████████▌ | 801/938 [00:12<00:02, 54.57it/s]

Step 15,800 (N samples: 1,011,200), Loss: 3.4606, Grad: 0.7245


 96%|█████████▋| 905/938 [00:14<00:00, 61.72it/s]

Step 15,900 (N samples: 1,017,600), Loss: 2.8317, Grad: 0.1283


100%|██████████| 938/938 [00:14<00:00, 62.88it/s]


Epoch 18/20


  7%|▋         | 63/938 [00:01<00:13, 63.05it/s]

Step 16,000 (N samples: 1,024,000), Loss: 4.1043, Grad: 9.2902


 17%|█▋        | 161/938 [00:02<00:14, 52.94it/s]

Step 16,100 (N samples: 1,030,400), Loss: 4.5844, Grad: 6.8695


 28%|██▊       | 265/938 [00:04<00:11, 61.12it/s]

Step 16,200 (N samples: 1,036,800), Loss: 4.2821, Grad: 107.0639


 39%|███▊      | 363/938 [00:06<00:09, 59.48it/s]

Step 16,300 (N samples: 1,043,200), Loss: 2.2899, Grad: 0.6259


 50%|████▉     | 465/938 [00:07<00:07, 61.59it/s]

Step 16,400 (N samples: 1,049,600), Loss: 3.2669, Grad: 1.4988


 60%|█████▉    | 562/938 [00:09<00:06, 58.38it/s]

Step 16,500 (N samples: 1,056,000), Loss: 3.0604, Grad: 0.8928


 71%|███████   | 666/938 [00:11<00:04, 59.56it/s]

Step 16,600 (N samples: 1,062,400), Loss: 3.9747, Grad: 0.7877


 81%|████████▏ | 763/938 [00:12<00:02, 61.60it/s]

Step 16,700 (N samples: 1,068,800), Loss: 4.0741, Grad: 0.0385


 92%|█████████▏| 861/938 [00:14<00:01, 63.09it/s]

Step 16,800 (N samples: 1,075,200), Loss: 3.3931, Grad: 2.2872


100%|██████████| 938/938 [00:15<00:00, 60.56it/s]


Epoch 19/20


  3%|▎         | 27/938 [00:00<00:14, 61.31it/s]

Step 16,900 (N samples: 1,081,600), Loss: 4.8900, Grad: 0.1342


 13%|█▎        | 125/938 [00:02<00:13, 61.24it/s]

Step 17,000 (N samples: 1,088,000), Loss: 2.5095, Grad: 11.9438


 24%|██▍       | 228/938 [00:03<00:12, 59.05it/s]

Step 17,100 (N samples: 1,094,400), Loss: 3.3955, Grad: 0.0096


 35%|███▍      | 325/938 [00:05<00:09, 64.33it/s]

Step 17,200 (N samples: 1,100,800), Loss: 3.1330, Grad: 1.6958


 45%|████▌     | 423/938 [00:06<00:08, 60.12it/s]

Step 17,300 (N samples: 1,107,200), Loss: 1.8756, Grad: 0.0019


 56%|█████▋    | 528/938 [00:08<00:06, 61.43it/s]

Step 17,400 (N samples: 1,113,600), Loss: 2.7264, Grad: 11.1842


 67%|██████▋   | 626/938 [00:10<00:04, 64.20it/s]

Step 17,500 (N samples: 1,120,000), Loss: 4.1639, Grad: 17.1147


 77%|███████▋  | 724/938 [00:11<00:03, 62.42it/s]

Step 17,600 (N samples: 1,126,400), Loss: 2.9140, Grad: 16.8963


 88%|████████▊ | 829/938 [00:13<00:01, 62.64it/s]

Step 17,700 (N samples: 1,132,800), Loss: 2.9580, Grad: 0.9509


 99%|█████████▉| 927/938 [00:14<00:00, 62.95it/s]

Step 17,800 (N samples: 1,139,200), Loss: 3.7197, Grad: 0.0038


100%|██████████| 938/938 [00:14<00:00, 62.68it/s]


Epoch 20/20


 10%|▉         | 91/938 [00:01<00:13, 64.58it/s]

Step 17,900 (N samples: 1,145,600), Loss: 4.3778, Grad: 8.2795


 20%|██        | 189/938 [00:02<00:12, 62.29it/s]

Step 18,000 (N samples: 1,152,000), Loss: 4.3515, Grad: 41.0684


 31%|███       | 287/938 [00:04<00:10, 61.60it/s]

Step 18,100 (N samples: 1,158,400), Loss: 4.4296, Grad: 51.9968


 41%|████      | 385/938 [00:06<00:09, 60.45it/s]

Step 18,200 (N samples: 1,164,800), Loss: 3.1301, Grad: 26.9604


 52%|█████▏    | 490/938 [00:07<00:07, 63.34it/s]

Step 18,300 (N samples: 1,171,200), Loss: 3.7025, Grad: 6.7914


 62%|██████▏   | 586/938 [00:09<00:05, 59.17it/s]

Step 18,400 (N samples: 1,177,600), Loss: 3.0807, Grad: 1.1452


 74%|███████▎  | 690/938 [00:11<00:03, 62.08it/s]

Step 18,500 (N samples: 1,184,000), Loss: 2.1908, Grad: 19.7727


 84%|████████▍ | 788/938 [00:12<00:02, 60.55it/s]

Step 18,600 (N samples: 1,190,400), Loss: 3.4305, Grad: 0.0801


 95%|█████████▍| 891/938 [00:14<00:00, 59.98it/s]

Step 18,700 (N samples: 1,196,800), Loss: 2.4168, Grad: 0.0323


100%|██████████| 938/938 [00:15<00:00, 62.15it/s]
Testing: 100%|██████████| 157/157 [00:01<00:00, 83.48it/s]

====> Test set loss: 2.9867





VAE(
  (encoder): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): SiLU()
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): SiLU()
    (4): Linear(in_features=256, out_features=128, bias=True)
    (5): SiLU()
    (6): Linear(in_features=128, out_features=64, bias=True)
    (7): SiLU()
    (8): Linear(in_features=64, out_features=32, bias=True)
  )
  (softplus): Softplus(beta=1.0, threshold=20.0)
  (decoder): Sequential(
    (0): Linear(in_features=16, out_features=64, bias=True)
    (1): SiLU()
    (2): Linear(in_features=64, out_features=128, bias=True)
    (3): SiLU()
    (4): Linear(in_features=128, out_features=256, bias=True)
    (5): SiLU()
    (6): Linear(in_features=256, out_features=512, bias=True)
    (7): SiLU()
    (8): Linear(in_features=512, out_features=784, bias=True)
    (9): Sigmoid()
  )
)

In [16]:
torch.save(vae.state_dict(), 'adversarial_vae.pth')

## Testing our defense

To test our defense, we first have to generate adversarial examples. In a previous lab, we used the Fast Gradient Sign Method (FGSM) to generate adversarial examples. This time, we will use a few more methods, and we will use the `foolbox` library to do so. The library is very easy to use and has a lot of built-in methods to generate adversarial examples.

Installing it is as easy as running `pip install foolbox`. You can find the documentation [here](https://foolbox.readthedocs.io/en/latest/). We've already installed it for you.

The testing procedure is as follows:
1. Generate adversarial examples using the `foolbox` library.
2. Compute the accuracy of the classifier on the adversarial examples.
3. Pass the adversarial examples through the VAE.
4. Pass the reconstructed examples through the classifier and compute the accuracy.
5. Compare the accuracy of the adversarial examples and the reconstructed adversarial examples.

### 1. Generating adversarial examples

In [17]:
%pip install foolbox
from foolbox import PyTorchModel, accuracy, samples
from foolbox.attacks import LinfPGD

# Wrap the classifier with Foolbox
fclf = PyTorchModel(clf, bounds=(-1, 1))

# Get a batch of test data
images, labels = next(iter(test_loader))

# Attack the classifier
attack = LinfPGD()
raw_advs, clipped_advs, success = attack(fclf, images, labels, epsilons=0.3)

Note: you may need to restart the kernel to use updated packages.


### 2. Compute the accuracy of the classifier on the adversarial examples
_Hint: You don't have to run the classifier again. Instead, you can use the `success` variable._

In [18]:
# TODO: Compute the accuracy of the classifier on the adversarial examples
accuracy_adv = 1 - success.to(torch.float32).mean(axis=-1)

### 3. Pass the adversarial examples through the VAE

In [19]:
# TODO: Pass the adversarial examples through the VAE
# Hint: Use `clipped_advs`
advs = clipped_advs.view(raw_advs.size(0), -1)
advs = advs.to(device)
advs_recon = vae(advs, compute_loss=False).x_recon

### 4. Pass the reconstructed examples through the classifier and compute the accuracy

In [23]:
# TODO: Pass the reconstructed examples through the classifier and compute the accuracy
# Hint: Use `advs_reconstructed`
advs_recon = advs_recon.view(-1, 1, 28, 28)
advs_recon = advs_recon.to(device)
clf.to(device)
outputs = clf(advs_recon)
_, predicted = torch.max(outputs.data, dim=1)
accuracy_recon = (predicted == labels.to(device)).sum().item() / len(labels)

### 5. Compare the accuracy of the adversarial examples and the reconstructed adversarial examples

In [25]:
# TODO: Print the accuracy of the classifier on the adversarial examples and the reconstructed examples
print(f'Accuracy on adversarial examples: {accuracy_adv:.4f}')
print(f'Accuracy on reconstructed examples: {accuracy_recon:.4f}')

Accuracy on adversarial examples: 0.5469
Accuracy on reconstructed examples: 0.6719


What conclusions can you draw from the results?