# Implementation of "adversarial training" as a defence mechanism
The goal of this notebook is to showcase a simple implementation of adversarial training as a defence mechanism against adversarial attacks. We compare the performance of a model trained with and without adversarial training against adversarial attacks. We use the MNIST as a dataset, a simple CNN as the model, and the Fast Gradient Sign Method (FGSM) as the adversarial attack.

## 1. Training on augmented dataset
### 1.1 Training on regular data

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
import time
import torch
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from torch import nn, optim

In [2]:
#We take the MNIST dataset

mean, std = 0.1307, 0.3081
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((mean,), (std,)) 
])

train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform) #60k images
test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform) #10k images

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [3]:
def new_model():
    return nn.Sequential(
        nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),  # 1x28x28 -> 32x28x28
        nn.ReLU(),
        nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # 32x28x28 -> 64x28x28
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2), # 64x28x28 -> 64x14x14
        nn.Flatten(),  # 64x14x14 -> 12544
        nn.Linear(64 * 14 * 14, 128),  # 12544 -> 128
        nn.ReLU(),
        nn.Linear(128, 10)  # 128 -> 10
    )

criterion = nn.CrossEntropyLoss()

def train_model(model, train_loader, criterion, optimizer, epochs=5):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader):.4f}")

def test_model(model, test_loader, criterion, optimizer, epsilon=0.5):
    model.eval()
    correct, total, test_loss = 0, 0, 0
    correct_aug, test_loss_aug = 0, 0
    
    for inputs, labels in test_loader:
        inputs.requires_grad = True
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        loss.backward()
        inputs_aug = (torch.clamp(mean + std*(inputs + epsilon * torch.sign(inputs.grad)), 0, 1) - mean)/ std
        outputs_aug = model(inputs_aug)
        loss_aug = criterion(outputs_aug, labels)
        test_loss_aug += loss_aug.item()

        _, predicted_aug = torch.max(outputs_aug, 1)
        correct_aug += (predicted_aug == labels).sum().item()
    
    print(f"Test Loss: {test_loss / len(test_loader):.4f}")
    print(f"Accuracy: {100 * correct / total:.2f}%")
    print(f"Adversarial Test Loss: {test_loss_aug / len(test_loader):.4f}")
    print(f"Adversarial Accuracy: {100 * correct_aug / total:.2f}%")


In [26]:
model = new_model()
optimizer = optim.Adam(model.parameters(), lr=0.001)
train_model(model, train_loader, criterion, optimizer, epochs=5) #2 epochs decent, 10 would overfit, 5 is good
test_model(model, test_loader, criterion, optimizer, epsilon=2)

# epsilon=2
# Adversarial Test Loss: 6.8014
# Adversarial Accuracy: 15.87%

#epsilon = 1:
# Adversarial Test Loss: 5.0315
# Adversarial Accuracy: 21.07%

# epsilon=0.5
# Epoch 1/5, Loss: 0.1216
# Epoch 2/5, Loss: 0.0369
# Epoch 3/5, Loss: 0.0226
# Epoch 4/5, Loss: 0.0148
# Epoch 5/5, Loss: 0.0104
# Test Loss: 0.0332
# Accuracy: 99.03%
# Adversarial Test Loss: 1.4371
# Adversarial Accuracy: 69.06%

# epsilon=0.1
# Epoch 1/5, Loss: 0.1157
# Epoch 2/5, Loss: 0.0367
# Epoch 3/5, Loss: 0.0224
# Epoch 4/5, Loss: 0.0151
# Epoch 5/5, Loss: 0.0120
# Test Loss: 0.0392
# Accuracy: 98.90%
# Adversarial Test Loss: 0.1151
# Adversarial Accuracy: 96.86%


Test Loss: 0.0392
Accuracy: 98.90%
Adversarial Test Loss: 6.6889
Adversarial Accuracy: 27.06%


In [None]:
model_reg = new_model()
optimizer_reg = optim.Adam(model_reg.parameters(), lr=0.001, weight_decay=5e-4) #with L2 regularization
train_model(model_reg, train_loader, criterion, optimizer_reg, epochs=5)
test_model(model_reg, test_loader, criterion, optimizer_reg, epsilon=2)

# epsilon=2
# Adversarial Test Loss: 4.5166
# Adversarial Accuracy: 15.01%

# epsilon=1
# Adversarial Test Loss: 3.4740
# Adversarial Accuracy: 19.32%

# epsilon=0.5
# Epoch 1/5, Loss: 0.1243
# Epoch 2/5, Loss: 0.0510
# Epoch 3/5, Loss: 0.0421
# Epoch 4/5, Loss: 0.0371
# Epoch 5/5, Loss: 0.0305
# Test Loss: 0.0440
# Accuracy: 98.60%
# Adversarial Test Loss: 0.9340
# Adversarial Accuracy: 70.79%

# epsilon=0.1
# Adversarial Test Loss: 0.0932
# Adversarial Accuracy: 97.01%

# epsilon=0.01
# Adversarial Test Loss: 0.0470
# Adversarial Accuracy: 98.39%

Test Loss: 0.0430
Accuracy: 98.51%
Adversarial Test Loss: 4.5166
Adversarial Accuracy: 15.01%


### 1.2 Training on an augmented dataset

In [6]:
#We dynamically define images that are adversarial to the current model using fgsm, and train it to work on them
def train_model_aug(model, train_loader, criterion, optimizer, epochs=5, epsilon=0.5):
    model.train()
    for epoch in range(epochs):
        running_loss, running_loss_aug = 0, 0
        for inputs, labels in train_loader:
            inputs.requires_grad = True
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            
            inputs_aug = (torch.clamp(mean + std*(inputs + epsilon * torch.sign(inputs.grad)), 0, 1) - mean)/ std
            outputs_aug = model(inputs_aug)
            loss_aug = criterion(outputs_aug, labels)
            optimizer.zero_grad()
            loss_aug.backward()
            optimizer.step()
            running_loss_aug += loss_aug.item()
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader):.4f}, Adversarial Loss: {running_loss_aug / len(train_loader):.4f}")
        
def train_rand_epsilon(model, train_loader, criterion, optimizer, epochs=5):
    model.train()
    for epoch in range(epochs):
        running_loss, running_loss_aug = 0, 0
        for inputs, labels in train_loader:
            inputs.requires_grad = True
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            
            epsilon = np.logspace(-2, 0.8, 100)[np.random.randint(0, 100)] #0.8 emphasizes large epsilons more, as that's a weak point in training
            inputs_aug = (torch.clamp(mean + std*(inputs + epsilon * torch.sign(inputs.grad)), 0, 1) - mean)/ std
            outputs_aug = model(inputs_aug)
            loss_aug = criterion(outputs_aug, labels)
            optimizer.zero_grad()
            loss_aug.backward()
            optimizer.step()
            running_loss_aug += loss_aug.item()
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader):.4f}, Adversarial Loss: {running_loss_aug / len(train_loader):.4f}")

In [None]:
model_aug = new_model()
optimizer_aug = optim.Adam(model_aug.parameters(), lr=0.001)
train_model_aug(model_aug, train_loader, criterion, optimizer_aug, epochs=5, epsilon=0.1)
test_model(model_aug, test_loader, criterion, optimizer_aug, epsilon=0.5)

# epsilon=0.5
# Epoch 1/5, Loss: 0.0867, Adversarial Loss: 0.3047
# Epoch 2/5, Loss: 0.0292, Adversarial Loss: 0.1468
# Epoch 3/5, Loss: 0.0181, Adversarial Loss: 0.1009
# Epoch 4/5, Loss: 0.0113, Adversarial Loss: 0.0694
# Epoch 5/5, Loss: 0.0076, Adversarial Loss: 0.0601
# Test Loss: 0.0258
# Accuracy: 99.19%
# Adversarial Test Loss: 0.1369
# Adversarial Accuracy: 95.92%

# epsilon=0.1
# Epoch 1/5, Loss: 0.0975, Adversarial Loss: 0.1284
# Epoch 2/5, Loss: 0.0326, Adversarial Loss: 0.0520
# Epoch 3/5, Loss: 0.0182, Adversarial Loss: 0.0334
# Epoch 4/5, Loss: 0.0119, Adversarial Loss: 0.0246
# Epoch 5/5, Loss: 0.0080, Adversarial Loss: 0.0189
# Test Loss: 0.0468
# Accuracy: 98.99%
# Adversarial Test Loss: 0.1020
# Adversarial Accuracy: 97.80%
# testing against epsilon=0.5
# Adversarial Test Loss: 0.8311
# Adversarial Accuracy: 84.33%


Test Loss: 0.0468
Accuracy: 98.99%
Adversarial Test Loss: 0.8311
Adversarial Accuracy: 84.33%


In [None]:
model_reg_aug = new_model()
optimizer_reg_aug = optim.Adam(model_reg_aug.parameters(), lr=0.001, weight_decay=1e-4)
#train_model_aug(model_reg_aug, train_loader, criterion, optimizer_reg_aug, epochs=5, epsilon=3)
train_rand_epsilon(model_reg_aug, train_loader, criterion, optimizer_reg_aug, epochs=10)
for epsilon in [3, 2, 1, 0.5, 0.1, 0.01, 0.001]:
    print(f"\nepsilon = {epsilon}")
    test_model(model_reg_aug, test_loader, criterion, optimizer_reg_aug, epsilon=epsilon)

# training for epsilon=3 and testing against other epsilons:
# Epoch 1/5, Loss: 0.1654, Adversarial Loss: 0.2066
# Epoch 2/5, Loss: 0.0526, Adversarial Loss: 0.0651
# Epoch 3/5, Loss: 0.0382, Adversarial Loss: 0.0909
# Epoch 4/5, Loss: 0.0322, Adversarial Loss: 0.0733
# Epoch 5/5, Loss: 0.0255, Adversarial Loss: 0.0793
# epsilon = 3
# Test Loss: 0.0434
# Accuracy: 98.49%
# Adversarial Test Loss: 0.0567
# Adversarial Accuracy: 97.98%
# epsilon = 2
# Adversarial Test Loss: 0.1918
# Adversarial Accuracy: 94.38%
# epsilon = 1
# Adversarial Test Loss: 1.6209
# Adversarial Accuracy: 50.81%
# epsilon = 0.5
# Adversarial Test Loss: 1.0487
# Adversarial Accuracy: 70.71%
# epsilon = 0.1
# Adversarial Test Loss: 0.1126
# Adversarial Accuracy: 96.66%
# epsilon = 0.01
# Adversarial Test Loss: 0.0487
# Adversarial Accuracy: 98.35%
# epsilon = 0.001
# Adversarial Test Loss: 0.0440
# Adversarial Accuracy: 98.48%

# epsilon=1
# Epoch 1/5, Loss: 0.1095, Adversarial Loss: 0.5515
# Epoch 2/5, Loss: 0.0433, Adversarial Loss: 0.2005
# Epoch 3/5, Loss: 0.0345, Adversarial Loss: 0.1508
# Epoch 4/5, Loss: 0.0309, Adversarial Loss: 0.1205
# Epoch 5/5, Loss: 0.0265, Adversarial Loss: 0.0249
# Test Loss: 0.0290
# Accuracy: 98.99%
# Adversarial Test Loss: 0.0214
# Adversarial Accuracy: 99.33%
# testing against epsilon=0.5
# Adversarial Test Loss: 0.0755
# Adversarial Accuracy: 97.51%
# against epsilon=0.1
# Adversarial Test Loss: 0.0578
# Adversarial Accuracy: 98.10%
# against epsilon=0.01
# Adversarial Test Loss: 0.0313
# Adversarial Accuracy: 98.95%
# against epsilon=2:
# Adversarial Test Loss: 1.3907
# Adversarial Accuracy: 65.46%
# against epsilon=3
# Adversarial Test Loss: 3.4081
# Adversarial Accuracy: 40.27%

# epsilon=0.5
# Epoch 1/5, Loss: 0.0958, Adversarial Loss: 0.2417
# Epoch 2/5, Loss: 0.0422, Adversarial Loss: 0.0775
# Epoch 3/5, Loss: 0.0336, Adversarial Loss: 0.0624
# Epoch 4/5, Loss: 0.0273, Adversarial Loss: 0.0674
# Epoch 5/5, Loss: 0.0217, Adversarial Loss: 0.1037
# Test Loss: 0.0283
# Accuracy: 99.12%
# Adversarial Test Loss: 0.1300
# Adversarial Accuracy: 96.03%

# epsilon=0.1
# Epoch 1/5, Loss: 0.1004, Adversarial Loss: 0.1299
# Epoch 2/5, Loss: 0.0415, Adversarial Loss: 0.0607
# Epoch 3/5, Loss: 0.0321, Adversarial Loss: 0.0487
# Epoch 4/5, Loss: 0.0238, Adversarial Loss: 0.0392
# Epoch 5/5, Loss: 0.0222, Adversarial Loss: 0.0364
# Test Loss: 0.0341
# Accuracy: 99.05%
# Adversarial Test Loss: 0.0679
# Adversarial Accuracy: 97.94%
# testing against epsilon=0.5
# Adversarial Test Loss: 0.4774
# Adversarial Accuracy: 86.31%
# testing against epsilon=1
# Accuracy: 99.05%
# Adversarial Test Loss: 2.2030
# Adversarial Accuracy: 39.60%

# training on random epsilon, 5 epochs
# Epoch 1/5, Loss: 0.1390, Adversarial Loss: 0.3017
# Epoch 2/5, Loss: 0.0492, Adversarial Loss: 0.2308
# Epoch 3/5, Loss: 0.0383, Adversarial Loss: 0.2195
# Epoch 4/5, Loss: 0.0343, Adversarial Loss: 0.2027
# Epoch 5/5, Loss: 0.0293, Adversarial Loss: 0.1626
# Test Loss: 0.0397
# Accuracy: 98.62%
# epsilon = 3
# Adversarial Test Loss: 0.2398
# Adversarial Accuracy: 92.25%
# epsilon = 2
# Adversarial Test Loss: 0.4328
# Adversarial Accuracy: 86.25%
# epsilon = 1
# Adversarial Test Loss: 0.6709
# Adversarial Accuracy: 78.08%
# epsilon = 0.5
# Adversarial Test Loss: 0.6339
# Adversarial Accuracy: 78.81%
# epsilon = 0.1
# Adversarial Test Loss: 0.1116
# Adversarial Accuracy: 96.38%
# epsilon = 0.01
# Adversarial Test Loss: 0.0451
# Adversarial Accuracy: 98.43%
# epsilon = 0.001
# Adversarial Test Loss: 0.0403
# Adversarial Accuracy: 98.60%

# with further training, 10 epochs
# Epoch 1/5, Loss: 0.0278, Adversarial Loss: 0.1316
# Epoch 2/5, Loss: 0.0241, Adversarial Loss: 0.1251
# Epoch 3/5, Loss: 0.0265, Adversarial Loss: 0.1395
# Epoch 4/5, Loss: 0.0216, Adversarial Loss: 0.0960
# Epoch 5/5, Loss: 0.0206, Adversarial Loss: 0.1150
# Test Loss: 0.0504
# Accuracy: 98.34%
# epsilon = 3
# Adversarial Test Loss: 0.0751
# Adversarial Accuracy: 97.66%
# epsilon=2
# Adversarial Test Loss: 0.0884
# Adversarial Accuracy: 97.29%
# epsilon = 1
# Adversarial Test Loss: 0.1864
# Adversarial Accuracy: 94.28%
# epsilon = 0.5
# Adversarial Test Loss: 0.4242
# Adversarial Accuracy: 86.82%
# epsilon = 0.1
# Adversarial Test Loss: 0.1524
# Adversarial Accuracy: 95.27%
# epsilon = 0.01
# Adversarial Test Loss: 0.0579
# Adversarial Accuracy: 98.09%
# epsilon = 0.001
# Adversarial Test Loss: 0.0512
# Adversarial Accuracy: 98.32%

# training on random epsilon with the right distribution
# Epoch 1/10, Loss: 0.1023, Adversarial Loss: 0.4621
# Epoch 2/10, Loss: 0.0380, Adversarial Loss: 0.2354
# Epoch 3/10, Loss: 0.0282, Adversarial Loss: 0.2139
# Epoch 4/10, Loss: 0.0224, Adversarial Loss: 0.1783
# Epoch 5/10, Loss: 0.0201, Adversarial Loss: 0.1849
# Epoch 6/10, Loss: 0.0179, Adversarial Loss: 0.1596
# Epoch 7/10, Loss: 0.0174, Adversarial Loss: 0.1691
# Epoch 8/10, Loss: 0.0156, Adversarial Loss: 0.1478
# Epoch 9/10, Loss: 0.0160, Adversarial Loss: 0.1528
# Epoch 10/10, Loss: 0.0141, Adversarial Loss: 0.1320
# Test Loss: 0.0316
# Accuracy: 99.04%
# epsilon = 3
# Adversarial Test Loss: 0.6880
# Adversarial Accuracy: 78.42%
# epsilon = 2
# Adversarial Test Loss: 0.3256
# Adversarial Accuracy: 89.73%
# epsilon = 1
# Adversarial Test Loss: 0.3114
# Adversarial Accuracy: 90.27%
# epsilon = 0.5
# Adversarial Test Loss: 0.1965
# Adversarial Accuracy: 93.85%
# epsilon = 0.1
# Adversarial Test Loss: 0.0614
# Adversarial Accuracy: 97.99%
# epsilon = 0.01
# Adversarial Test Loss: 0.0345
# Adversarial Accuracy: 98.97%
# epsilon = 0.001
# Adversarial Test Loss: 0.0319
# Adversarial Accuracy: 99.04%

# training with high-biased random epsilon:
# Epoch 1/10, Loss: 0.1135, Adversarial Loss: 0.4447
# Epoch 2/10, Loss: 0.0409, Adversarial Loss: 0.2355
# Epoch 3/10, Loss: 0.0307, Adversarial Loss: 0.2273
# Epoch 4/10, Loss: 0.0251, Adversarial Loss: 0.1964
# Epoch 5/10, Loss: 0.0228, Adversarial Loss: 0.2062
# Epoch 6/10, Loss: 0.0198, Adversarial Loss: 0.2004
# Epoch 7/10, Loss: 0.0177, Adversarial Loss: 0.2125
# Epoch 8/10, Loss: 0.0173, Adversarial Loss: 0.1965
# Epoch 9/10, Loss: 0.0162, Adversarial Loss: 0.2077
# Epoch 10/10, Loss: 0.0155, Adversarial Loss: 0.2014
# Test Loss: 0.0293
# Accuracy: 98.90%
# epsilon = 3
# Adversarial Test Loss: 0.6358
# Adversarial Accuracy: 78.93%
# epsilon = 2
# Adversarial Test Loss: 0.4361
# Adversarial Accuracy: 86.04%
# epsilon = 1
# Adversarial Test Loss: 0.4720
# Adversarial Accuracy: 84.24%
# epsilon = 0.5
# Adversarial Test Loss: 0.2509
# Adversarial Accuracy: 91.87%
# epsilon = 0.1
# Adversarial Test Loss: 0.0613
# Adversarial Accuracy: 98.16%
# epsilon = 0.01
# Adversarial Test Loss: 0.0322
# Adversarial Accuracy: 98.84%
# epsilon = 0.001
# Adversarial Test Loss: 0.0296
# Adversarial Accuracy: 98.90%


Epoch 1/10, Loss: 0.1135, Adversarial Loss: 0.4447
Epoch 2/10, Loss: 0.0409, Adversarial Loss: 0.2355
Epoch 3/10, Loss: 0.0307, Adversarial Loss: 0.2273
Epoch 4/10, Loss: 0.0251, Adversarial Loss: 0.1964
Epoch 5/10, Loss: 0.0228, Adversarial Loss: 0.2062
Epoch 6/10, Loss: 0.0198, Adversarial Loss: 0.2004
Epoch 7/10, Loss: 0.0177, Adversarial Loss: 0.2125
Epoch 8/10, Loss: 0.0173, Adversarial Loss: 0.1965
Epoch 9/10, Loss: 0.0162, Adversarial Loss: 0.2077
Epoch 10/10, Loss: 0.0155, Adversarial Loss: 0.2014

epsilon = 3
Test Loss: 0.0293
Accuracy: 98.90%
Adversarial Test Loss: 0.6358
Adversarial Accuracy: 78.93%

epsilon = 2
Test Loss: 0.0293
Accuracy: 98.90%
Adversarial Test Loss: 0.4361
Adversarial Accuracy: 86.04%

epsilon = 1
Test Loss: 0.0293
Accuracy: 98.90%
Adversarial Test Loss: 0.4720
Adversarial Accuracy: 84.24%

epsilon = 0.5
Test Loss: 0.0293
Accuracy: 98.90%
Adversarial Test Loss: 0.2509
Adversarial Accuracy: 91.87%

epsilon = 0.1
Test Loss: 0.0293
Accuracy: 98.90%
Adversari

## Conclusion

Typical training is very weak to adversarial attacks, and quickly mislabels adversarial attacks.

Regularization hardly improves robustness.

Training on very adversarial data creates incredibly robust models to that attack. However training against a constant attack strength does not grant robustness against all strengths. This is achieved fairly well by training against attacks of random, variable strength. With this training, we achieve nearly 95% accuracy on adversarial data while maintaining 99% accuracy on real data. We remain slightly weak to super adversarial attacks. However, very strong attacks are trivial to recognize for an appropriately trained network (see null-labeling.ipynb), and we have over 98% accuracy on weak attacks which can't be recognized with absolute certainty.

Overall, training against fgsm appears very simple and a model modified to be robust against it will resist these attacks very well. It is noteworthy that the cost in accuracy on real data is surprisingly nonexistent: this might be due to the model we took being quite large, so it didn't need all the neurons to recognize MNIST anyway. What would be more interesting is to work with multi-step fgsm, increasing the number of steps needed to lead to misclassification. We did not explore this for computational purposes.