<a href="https://colab.research.google.com/github/GiuliaLanzillotta/exercises/blob/master/Adversarial_defense.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Adversarial defense 

Today we'll experiment with adversarial training as an adversarial defense technique. <br>

More specifically we'll employ [PGD](https://arxiv.org/pdf/1706.06083.pdf) and [TRADES](https://arxiv.org/pdf/1901.08573.pdf) attack to make our net more robust during training. 


In [1]:
import os
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

We'll be again be working with the MNIST dataset 

1. Define a shallow ReLU network 

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 200)
        self.fc2 = nn.Linear(200, 10)

    def forward(self, x):
        x = x.view((-1, 28 * 28))
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

Let's also add a normalisation layer.<br>
It will be inserted as a first "layer" to the network. This allows us to search for adverserial examples to the real image, rather than to the normalized image. 

In [3]:
class Normalize(nn.Module):
    def forward(self, x):
        return (x - 0.1307) / 0.3081

And let's set a few hyperparameters

In [4]:
batch_size = 512
seed = 42
learning_rate = 0.01
num_epochs = 10 
eps = 0.1 #PGD parameter (defines the magnitude of the perturbation)
k = 7 #PGD steps
trades_fact = 1.0 #TRADES lambda 

A few more lines of preparatory code ...

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(seed)

<torch._C.Generator at 0x7f8fb0a52768>

In [6]:
model = nn.Sequential(Normalize(), Net())
model = model.to(device)

2. Load dataset (MNIST)

In [None]:
# Warning: running this will download the data locally
train_dataset = datasets.MNIST('mnist_data/', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))
test_dataset = datasets.MNIST('mnist_data/', train=False, download=True, transform=transforms.Compose([transforms.ToTensor()]))

In [8]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

3. Implement the defenses 

In [None]:
def get_PGD_adversarial_example(x,y, eps, net, k):
  """ 
  Returns adversarial example in epsilon infinity ball around x
  using untargeted PGD attack. 
  """ 
  pass

def get_PGD_Bmax(x_batch, y_batch, eps, net, k):
  """ 
  Returns alternative set of points that maximise 
  the loss of the newtork 
  """
  pass

def compute_adv_accuracy(x_batch, y_batch, eps, k, lambda, method="PGD"):
  """
  Returns the adversarial accuracy on the given batch. 
  """
  # iterate through the batch:
  # for the correctly classified examples check whether there's 
  # an adversarial example 

4. Train and evaluate 

In [9]:
opt = optim.Adam(params=model.parameters(), lr=learning_rate, weight_decay=1e-5)
scheduler = optim.lr_scheduler.StepLR(opt, 15)
ce_loss = torch.nn.CrossEntropyLoss()
kl_loss = torch.nn.KLDivLoss(reduction='batchmean')

In [10]:
defense = "PGD"

In [None]:

for epoch in range(1, num_epochs + 1):
    # Training
    for batch_idx, (x_batch, y_batch) in enumerate(tqdm(train_loader)):

        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        model.train()  

        if defense == 'PGD':
          # get Bmax 
          x_batch_max = get_PGD_Bmax(x_batch, y_batch, eps, model, k)
          # compute the loss 
          out = model(x_batch_max)
          loss = ce_loss(out, y_batch)

        elif args.defense == 'TRADES':
            raise NotImplementedError
            #Problem 1.2 implement TRADES training
            
        elif defense == 'none':
            # standard training
            out_nat = model(x_batch)
            loss = ce_loss(out_nat, y_batch)
        
        opt.zero_grad()
        loss.backward()
        opt.step()
        
    # Testing
    model.eval()
    tot_test, tot_acc, tot_adv_acc = 0.0, 0.0, 0.0
    for batch_idx, (x_batch, y_batch) in enumerate(tqdm(test_loader)):
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        
        out = model(x_batch)
        pred = torch.max(out, dim=1)[1]
        acc = pred.eq(y_batch).sum().item()

        # Problem 1.1 calculate accuracy under PGD attack
        acc_adv = 0
        
        tot_acc += acc
        tot_adv_acc += acc_adv
        tot_test += x_batch.size()[0]
    scheduler.step()

    print('Epoch %d: Accuracy %.5lf, Adv Accuracy %.5lf' % (epoch, tot_acc / tot_test, tot_adv_acc / tot_test))


5. (Optional) save the model

In [None]:
os.makedirs("models", exist_ok=True)
torch.save(model.state_dict(), f"models/Net_{num_epochs}_{defense}")