In [1]:
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 200)
        self.fc2 = nn.Linear(200, 10)

    def forward(self, x):
        x = x.view((-1, 28 * 28))
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [3]:
class Normalize(nn.Module):
    def forward(self, x):
        return (x - 0.1307) / 0.3081


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

batch_size = 512
seed = 42
learning_rate = 0.01
num_epochs = 10
eps = 0.1
k = 7
trades_lambda = 1.0

# Setting the random number generator
torch.manual_seed(seed)

<torch._C.Generator at 0x7fc7fdc7a9b0>

In [5]:
# Datasets
train_dataset = datasets.MNIST('mnist_data/', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))
test_dataset = datasets.MNIST('mnist_data/', train=False, download=True, transform=transforms.Compose([transforms.ToTensor()]))

# Data loaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [6]:
# Add data normalization as a first "layer" to the network
# This allows us to search for adversarial examples to the real image,
# rather than to the normalized image
model = nn.Sequential(Normalize(), Net())
model = model.to(device)

opt = optim.Adam(params=model.parameters(), lr=learning_rate, weight_decay=1e-5)
scheduler = optim.lr_scheduler.StepLR(opt, 15)
ce_loss = torch.nn.CrossEntropyLoss()
kl_loss = torch.nn.KLDivLoss(reduction='batchmean')

In [7]:
def pgd(model, x_batch, target, k, eps, eps_step, kl_loss: bool = False):
    if kl_loss:
        # Loss function for the case that target is a distribution rather than a label (used for TRADES)
        loss_fn = torch.nn.KLDivLoss(reduction='sum')
    else:
        # Standard PGD
        loss_fn = torch.nn.CrossEntropyLoss(reduction='sum')
    
    # Disable gradients here
    with torch.no_grad():
        # Initialize with a random point inside the considered perturbation region
        x_adv = x_batch.detach() + eps * (2 * torch.rand_like(x_batch) - 1)
        
       # Project back to the image domain
        x_adv.clamp(min=0.0, max=1.0)

        for step in range(k):
            # Make sure we don't have a previous compute graph and enable gradient computation
            x_adv.detach_().requires_grad_()

            # Re-enable gradients
            with torch.enable_grad():
                # Run the model and obtain the loss
                out = F.log_softmax(model(x_adv), dim=1)
                model.zero_grad()

                # Compute gradient
                loss_fn(out, target).backward()
            
            # Compute step
            step = eps_step * x_adv.grad.sign()

            # Project to eps ball
            x_adv = x_batch + (x_adv + step - x_batch).clamp(min=-eps, max=eps)

            # Clamp back to image domain: we clamp at each step
            x_adv.clamp_(min=0.0, max=1.0)
    
    return x_adv.detach()

In [19]:
def train_and_test_accuracies_using_defense(defense, num_epochs, train_loader, test_loader, k, eps):
    for epoch in range(1, num_epochs + 1):
        # Training
        for _, (x_batch, y_batch) in enumerate(tqdm(train_loader)):

            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            
            if defense == 'PGD':
                # PGD attack to generate adversarial examples
                
                # Switch model to eval mode, to ensure it is deterministic
                model.eval()

                x_adv = pgd(
                    model, 
                    x_batch=x_batch, 
                    target=y_batch,
                    eps=eps, 
                    k=k, 
                    eps_step = 2.5 * eps / k
                )

                # Switch back to training mode
                model.train()
                out_pgd = model(x_adv)

                # Compute loss
                loss = ce_loss(out_pgd, y_batch)

            elif defense == 'TRADES':
                # Switch to training mode
                model.train()
                out_nat = model(x_batch)
                target = F.softmax(out_nat.detach(), dim=1)

                # Do PGD attack to generate adversarial examples
                
                # Switch network to eval mode, to ensure it is deterministic
                model.eval()

                x_adv = pgd(
                    model, 
                    x_batch=x_batch, 
                    target=target, 
                    k=k, 
                    eps=eps,
                    eps_step=2.5 * eps / k,
                    kl_loss=True
                )

                # Calculate loss
                
                # Switch to training mode
                model.train()
                out_adv = F.log_softmax(model(x_adv), dim=1)
                
                loss_nat = ce_loss(out_nat, y_batch)
                loss_adv = kl_loss(out_adv, target)
                loss = loss_nat + trades_lambda * loss_adv
                
            elif defense == 'none':                
                model.train()
                out_nat = model(x_batch)
                loss = ce_loss(out_nat, y_batch)

            opt.zero_grad()
            loss.backward()
            opt.step()

        # Testing
        model.eval()

        tot_test, tot_acc, tot_adv_acc = 0.0, 0.0, 0.0

        for _, (x_batch, y_batch) in enumerate(tqdm(test_loader)):
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)

            # Prediction by the model on each batch
            out = model(x_batch)
            pred = torch.max(out, dim=1)[1]
            acc = pred.eq(y_batch).sum().item()

            x_adv = pgd(
                model,
                x_batch=x_batch,
                target=y_batch,
                k=k,
                eps=eps,
                eps_step=2.5 * eps / k
            )

            # Prediction of the model on the adversarial batch
            out_adv = model(x_adv)
            pred_adv = torch.max(out_adv, dim=1)[1]
            acc_adv = pred_adv.eq(y_batch).sum().item()

            # Add to total accuracies for both regular and adversarial accuracies
            tot_acc += acc
            tot_adv_acc += acc_adv
            tot_test += x_batch.size()[0]

        scheduler.step()

        print('Epoch %d: Accuracy %.5lf, Adv Accuracy %.5lf' %
            (epoch, tot_acc / tot_test, tot_adv_acc / tot_test))


In [16]:
# Evaluate model using standard training, no defense
train_and_test_accuracies_using_defense(
    defense='none', num_epochs=num_epochs, train_loader=train_loader, test_loader=test_loader, k=k, eps=eps)

100%|██████████| 118/118 [00:07<00:00, 16.62it/s]
100%|██████████| 20/20 [00:06<00:00,  3.28it/s]


Epoch 1: Accuracy 0.98690, Adv Accuracy 0.54680


100%|██████████| 118/118 [00:06<00:00, 18.04it/s]
100%|██████████| 20/20 [00:05<00:00,  3.68it/s]


Epoch 2: Accuracy 0.98690, Adv Accuracy 0.54730


100%|██████████| 118/118 [00:06<00:00, 18.93it/s]
100%|██████████| 20/20 [00:05<00:00,  3.66it/s]


Epoch 3: Accuracy 0.98680, Adv Accuracy 0.53310


100%|██████████| 118/118 [00:06<00:00, 19.01it/s]
100%|██████████| 20/20 [00:05<00:00,  3.62it/s]


Epoch 4: Accuracy 0.98670, Adv Accuracy 0.53690


100%|██████████| 118/118 [00:06<00:00, 18.68it/s]
100%|██████████| 20/20 [00:05<00:00,  3.42it/s]


Epoch 5: Accuracy 0.98720, Adv Accuracy 0.52480


100%|██████████| 118/118 [00:06<00:00, 18.92it/s]
100%|██████████| 20/20 [00:05<00:00,  3.34it/s]


Epoch 6: Accuracy 0.98690, Adv Accuracy 0.52470


100%|██████████| 118/118 [00:06<00:00, 18.30it/s]
100%|██████████| 20/20 [00:05<00:00,  3.54it/s]


Epoch 7: Accuracy 0.98620, Adv Accuracy 0.52730


100%|██████████| 118/118 [00:06<00:00, 19.65it/s]
100%|██████████| 20/20 [00:05<00:00,  3.68it/s]


Epoch 8: Accuracy 0.98660, Adv Accuracy 0.53170


100%|██████████| 118/118 [00:06<00:00, 19.44it/s]
100%|██████████| 20/20 [00:05<00:00,  3.76it/s]


Epoch 9: Accuracy 0.98690, Adv Accuracy 0.52450


100%|██████████| 118/118 [00:05<00:00, 20.00it/s]
100%|██████████| 20/20 [00:05<00:00,  3.85it/s]

Epoch 10: Accuracy 0.98690, Adv Accuracy 0.52230





In [17]:
# Evaluate model using PGD defense
train_and_test_accuracies_using_defense(
    defense='PGD', num_epochs=num_epochs, train_loader=train_loader, test_loader=test_loader, k=k, eps=eps)

100%|██████████| 118/118 [00:33<00:00,  3.56it/s]
100%|██████████| 20/20 [00:05<00:00,  3.61it/s]


Epoch 1: Accuracy 0.98460, Adv Accuracy 0.80610


100%|██████████| 118/118 [00:32<00:00,  3.61it/s]
100%|██████████| 20/20 [00:05<00:00,  3.67it/s]


Epoch 2: Accuracy 0.98520, Adv Accuracy 0.82600


100%|██████████| 118/118 [00:33<00:00,  3.56it/s]
100%|██████████| 20/20 [00:05<00:00,  3.59it/s]


Epoch 3: Accuracy 0.98520, Adv Accuracy 0.83650


100%|██████████| 118/118 [00:34<00:00,  3.45it/s]
100%|██████████| 20/20 [00:05<00:00,  3.59it/s]


Epoch 4: Accuracy 0.98540, Adv Accuracy 0.84520


100%|██████████| 118/118 [00:33<00:00,  3.48it/s]
100%|██████████| 20/20 [00:05<00:00,  3.61it/s]


Epoch 5: Accuracy 0.98550, Adv Accuracy 0.85170


100%|██████████| 118/118 [00:33<00:00,  3.51it/s]
100%|██████████| 20/20 [00:05<00:00,  3.65it/s]


Epoch 6: Accuracy 0.98530, Adv Accuracy 0.85380


100%|██████████| 118/118 [00:33<00:00,  3.55it/s]
100%|██████████| 20/20 [00:05<00:00,  3.67it/s]


Epoch 7: Accuracy 0.98520, Adv Accuracy 0.85780


100%|██████████| 118/118 [00:32<00:00,  3.60it/s]
100%|██████████| 20/20 [00:05<00:00,  3.60it/s]


Epoch 8: Accuracy 0.98500, Adv Accuracy 0.85950


100%|██████████| 118/118 [00:35<00:00,  3.34it/s]
100%|██████████| 20/20 [00:06<00:00,  3.05it/s]


Epoch 9: Accuracy 0.98540, Adv Accuracy 0.86080


100%|██████████| 118/118 [00:33<00:00,  3.52it/s]
100%|██████████| 20/20 [00:05<00:00,  3.70it/s]

Epoch 10: Accuracy 0.98530, Adv Accuracy 0.86350





In [20]:

# Evaluate model using PGD defense
train_and_test_accuracies_using_defense(
    defense='TRADES', num_epochs=num_epochs, train_loader=train_loader, test_loader=test_loader, k=k, eps=eps)


100%|██████████| 118/118 [00:33<00:00,  3.50it/s]
100%|██████████| 20/20 [00:05<00:00,  3.65it/s]


Epoch 1: Accuracy 0.98520, Adv Accuracy 0.86570


100%|██████████| 118/118 [00:33<00:00,  3.56it/s]
100%|██████████| 20/20 [00:05<00:00,  3.67it/s]


Epoch 2: Accuracy 0.98520, Adv Accuracy 0.86630


100%|██████████| 118/118 [00:35<00:00,  3.34it/s]
100%|██████████| 20/20 [00:05<00:00,  3.58it/s]


Epoch 3: Accuracy 0.98540, Adv Accuracy 0.86600


100%|██████████| 118/118 [00:34<00:00,  3.46it/s]
100%|██████████| 20/20 [00:05<00:00,  3.62it/s]


Epoch 4: Accuracy 0.98530, Adv Accuracy 0.86730


100%|██████████| 118/118 [00:33<00:00,  3.51it/s]
100%|██████████| 20/20 [00:05<00:00,  3.45it/s]


Epoch 5: Accuracy 0.98570, Adv Accuracy 0.86640


100%|██████████| 118/118 [00:35<00:00,  3.35it/s]
100%|██████████| 20/20 [00:05<00:00,  3.58it/s]


Epoch 6: Accuracy 0.98550, Adv Accuracy 0.86770


100%|██████████| 118/118 [00:42<00:00,  2.77it/s]
100%|██████████| 20/20 [00:07<00:00,  2.60it/s]


Epoch 7: Accuracy 0.98540, Adv Accuracy 0.86610


100%|██████████| 118/118 [00:39<00:00,  2.97it/s]
100%|██████████| 20/20 [00:06<00:00,  3.15it/s]


Epoch 8: Accuracy 0.98540, Adv Accuracy 0.86780


100%|██████████| 118/118 [00:35<00:00,  3.29it/s]
100%|██████████| 20/20 [00:05<00:00,  3.34it/s]


Epoch 9: Accuracy 0.98530, Adv Accuracy 0.86740


100%|██████████| 118/118 [00:35<00:00,  3.37it/s]
100%|██████████| 20/20 [00:05<00:00,  3.82it/s]

Epoch 10: Accuracy 0.98530, Adv Accuracy 0.86790



