<a href="https://colab.research.google.com/github/GiuliaLanzillotta/exercises/blob/master/Adversarial_defense.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Adversarial defense 

Today we'll experiment with adversarial training as an adversarial defense technique. <br>

More specifically we'll employ [PGD](https://arxiv.org/pdf/1706.06083.pdf) and [TRADES](https://arxiv.org/pdf/1901.08573.pdf) attack to make our net more robust during training. 


In [45]:
import os
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

We'll be again be working with the MNIST dataset 

  
  ## 1. Define a shallow ReLU network 

In [46]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 200)
        self.fc2 = nn.Linear(200, 10)

    def forward(self, x):
        x = x.view((-1, 28 * 28))
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

Let's also add a normalisation layer.<br>
It will be inserted as a first "layer" to the network. This allows us to search for adverserial examples to the real image, rather than to the normalized image. 

In [47]:
class Normalize(nn.Module):
    def forward(self, x):
        return (x - 0.1307) / 0.3081

And let's set a few hyperparameters

In [48]:
batch_size = 512
seed = 42
learning_rate = 0.01
num_epochs = 10 
eps = 0.1 #PGD parameter (defines the magnitude of the perturbation)
k = 7 #PGD steps
trades_fact = 1.0 #TRADES lambda 

A few more lines of preparatory code ...

In [49]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(seed)

<torch._C.Generator at 0x7f8fb0a52768>

In [50]:
model = nn.Sequential(Normalize(), Net())
model = model.to(device)


  ## 2. Load dataset (MNIST)

In [51]:
# Warning: running this will download the data locally
train_dataset = datasets.MNIST('mnist_data/', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))
test_dataset = datasets.MNIST('mnist_data/', train=False, download=True, transform=transforms.Compose([transforms.ToTensor()]))

In [52]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


## 3. Implement the defenses 

In [53]:
def fgsm_step(x,y, eps, net):
  """
  Implements an fgsm step.
  """
  input_ = x.clone().detach_()
  input_.requires_grad = True
  loss = ce_loss(net(input_), torch.tensor([y], dtype=torch.long))
  loss.backward()
  x_next = x + eps*torch.sign(input_.grad)
  return x_next

def fgsm_TRADES_step(x,x_next,eps,net):
  """
  Implements fgsm step for TRADES boundary loss optimization. 
  TRADES boundary loss encourages the network to have a smoother boundary, with 
  respect to the perturbation regions. 
  Basically we enforce the predefined perturbation regions to have a coherent 
  prediction. 
  """
  input_ = x_next.clone().detach_()
  input_.requires_grad = True
  # The loss is equivalent to the difference between the original prediction
  # and the adversarial prediction 
  logit1 = torch.nn.LogSoftmax(net(x), dim=1)
  logit2 = torch.nn.LogSoftmax(net(x_next), dim=1)
  loss = torch.nn.NLLLoss()
  out = loss(logit2,logit1)
  out.backward()
  x_next = x_next + eps*torch.sign(input_.grad)
  return x_next


def get_PGD_adversarial_example(x,y, eps, net, k, method="PGD"):
  """ 
  Returns adversarial example in epsilon infinity ball around x
  using untargeted PGD attack. 
  """ 
  eps_step = 2.5*(eps/k)
  x_next = x 
  # set the adversarial flag to signal whether 
  # we have found an adversarial example
  adv = torch.argmax(net(x), dim=1) != y
  if adv: return x_next
  for i in range(k): 
    # take an fgsm step 
    if method=="PGD":x_next = fgsm_step(x_next, y, eps_step, net)
    elif method=="TRADES":x_next = fgsm_step(x, x_next, eps_step, net)
    else: raise ValueError("Method inserted is not valid. Supported methods: PGD, TRADES.")
    # project back to L infinity ball
    delta = x - x_next 
    delta = torch.clamp(delta, min=-1*eps, max=eps)
    x_next = x - delta
    # check whether we have an adversarial example 
    adv = torch.argmax(net(x_next), dim=1) != y
    if adv: return x_next
  return None

def get_PGD_Bmax(x_batch, y_batch, eps, net, k):
  """ 
  Returns alternative set of points that maximise 
  the loss of the newtork 
  """
  Bmax = []
  for (x,y) in zip(x_batch,y_batch):
    x_adv = get_PGD_adversarial_example(x,y, eps, net, k)
    if x_adv is None: x_adv = x 
    Bmax +=[x_adv[0]]
  return torch.stack(Bmax)
    
def compute_adv_accuracy(x_batch, y_batch, eps, k, lam, net, method="PGD"):
  """
  Returns the adversarial accuracy on the given batch. 
  """
  # iterate through the batch:
  # for the correctly classified examples check whether there's 
  # an adversarial example 
  tot_acc = 0
  for (x,y) in zip(x_batch,y_batch):
    if torch.argmax(net(x), dim=1) == y and get_PGD_adversarial_example(x,y, eps, net, k, method=method) is None: 
      tot_acc+=1
  return tot_acc



## 4. Train and evaluate 

In [54]:
opt = optim.Adam(params=model.parameters(), lr=learning_rate, weight_decay=1e-5)
scheduler = optim.lr_scheduler.StepLR(opt, 15)
ce_loss = torch.nn.CrossEntropyLoss()
kl_loss = torch.nn.KLDivLoss(reduction='batchmean')

In [55]:
defense = "PGD"

In [None]:

for epoch in range(1, num_epochs + 1):
    # Training
    for batch_idx, (x_batch, y_batch) in enumerate(tqdm(train_loader)):

        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        model.train()  

        if defense == 'PGD':
          # get Bmax 
          x_batch_max = get_PGD_Bmax(x_batch, y_batch, eps, model, k)
          # compute the loss 
          out = model(x_batch_max)
          loss = ce_loss(out, y_batch)

        elif defense == 'TRADES':
          # get Bmax 
          x_batch_max = get_PGD_Bmax(x_batch, y_batch, eps, model, k, method="TRADES")
          # compute the loss 
          out = model(x_batch)
          out_max = model(x_batch_max)
          # boundary loss 
          b_loss = torch.nn.NLLLoss()
          b_out = b_loss(torch.nn.LogSoftmax(out_max, dim=1),
                         torch.nn.LogSoftmax(out, dim=1))
          loss = ce_loss(out, y_batch) + trades_fact*b_out

        elif defense == 'none':
            # standard training
            out_nat = model(x_batch)
            loss = ce_loss(out_nat, y_batch)
        
        opt.zero_grad()
        loss.backward()
        opt.step()
        
    # Testing
    model.eval()
    tot_test, tot_acc, tot_adv_acc = 0.0, 0.0, 0.0
    for batch_idx, (x_batch, y_batch) in enumerate(tqdm(test_loader)):
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        
        out = model(x_batch)
        pred = torch.max(out, dim=1)[1]
        acc = pred.eq(y_batch).sum().item()

        
        acc_adv = compute_adv_accuracy(x_batch, y_batch, eps, k, trades_fact, model, method="PGD")
        
        tot_acc += acc
        tot_adv_acc += acc_adv
        tot_test += x_batch.size()[0]
    scheduler.step()
    print()
    print('Epoch %d: Accuracy %.5lf, Adv Accuracy %.5lf' % (epoch, tot_acc / tot_test, tot_adv_acc / tot_test))


100%|██████████| 118/118 [03:28<00:00,  1.77s/it]
100%|██████████| 20/20 [00:42<00:00,  2.11s/it]
  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 1: Accuracy 0.94200, Adv Accuracy 0.59200


100%|██████████| 118/118 [04:21<00:00,  2.22s/it]
100%|██████████| 20/20 [00:46<00:00,  2.30s/it]
  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 2: Accuracy 0.96210, Adv Accuracy 0.67770


100%|██████████| 118/118 [04:28<00:00,  2.27s/it]
100%|██████████| 20/20 [00:46<00:00,  2.30s/it]
  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 3: Accuracy 0.96510, Adv Accuracy 0.66470


100%|██████████| 118/118 [04:35<00:00,  2.34s/it]
100%|██████████| 20/20 [00:47<00:00,  2.39s/it]
  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 4: Accuracy 0.96780, Adv Accuracy 0.71550


100%|██████████| 118/118 [04:38<00:00,  2.36s/it]
100%|██████████| 20/20 [00:46<00:00,  2.33s/it]
  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 5: Accuracy 0.96690, Adv Accuracy 0.69060


100%|██████████| 118/118 [04:39<00:00,  2.37s/it]
100%|██████████| 20/20 [00:47<00:00,  2.38s/it]
  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 6: Accuracy 0.96980, Adv Accuracy 0.72700


100%|██████████| 118/118 [04:37<00:00,  2.35s/it]
100%|██████████| 20/20 [00:47<00:00,  2.36s/it]
  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 7: Accuracy 0.96770, Adv Accuracy 0.70520


100%|██████████| 118/118 [04:40<00:00,  2.37s/it]
100%|██████████| 20/20 [00:47<00:00,  2.39s/it]
  0%|          | 0/118 [00:00<?, ?it/s]

Epoch 8: Accuracy 0.96920, Adv Accuracy 0.72520


 50%|█████     | 59/118 [02:21<02:19,  2.37s/it]


## 5. (Optional) save the model

In [None]:
os.makedirs("models", exist_ok=True)
torch.save(model.state_dict(), f"models/Net_{num_epochs}_{defense}")