In [2]:
import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.optim as optim
import torch.nn.functional as functional
import torch.nn.init as init

In [3]:
class SimpleCNN(nn.Module):

    def __init__(self, width_multiplier: int):
        super(SimpleCNN, self).__init__()
        self.width_multiplier = width_multiplier
        self.conv1: nn.Conv2d = nn.Conv2d(1, 2 * width_multiplier, 5, padding=2)
        self.conv2: nn.Conv2d = nn.Conv2d(2 * width_multiplier, 4 * width_multiplier, 5, padding=2)
        self.linear: nn.Linear = nn.Linear(4 * width_multiplier * 28 * 28, 64)
        self.output: nn.Linear = nn.Linear(64, 10)
    
    def forward(self, x: autograd.Variable) -> autograd.Variable:
        y = functional.relu(self.conv1(x))
        y = functional.relu(self.conv2(y))
        y = functional.relu(self.linear(y.view(-1, 4*self.width_multiplier*28*28)))
        y = self.output(y)
        return y
    
    
# Now, we test the network to see if it works. 
net = SimpleCNN(1)
print(net(autograd.Variable(torch.rand((2, 1, 28, 28)))))

Variable containing:
-0.0571 -0.1623 -0.0022  0.1721 -0.1188  0.0514  0.0331  0.0368  0.0556  0.0452
-0.0579 -0.1702  0.0068  0.1977 -0.0826  0.0469  0.0274  0.0131  0.0441  0.0451
[torch.FloatTensor of size 2x10]



In [4]:
import typing


def fsgm(image_batch: torch.FloatTensor,
         label_batch: torch.LongTensor,
         model: typing.Callable[[autograd.Variable], autograd.Variable],
         objective: typing.Callable[[autograd.Variable, autograd.Variable], autograd.Variable],
         eps: float):
    """Takes a batch of images, and modifies each image using the FGSM attack."""
    for i in range(image_batch.shape[0]):
        x = autograd.Variable(torch.unsqueeze(image_batch[i], 0), requires_grad=True)
        label = autograd.Variable(label_batch[i:i+1])
        output = model(x)
        loss = objective(output, label)
        loss.backward()
        x.data += eps*torch.sign(x.grad.data)
        torch.clamp(x.data, min=0.0, max=1.0)


# Now, we test to see there are no obvious errors. 
net = SimpleCNN(1)
image = torch.rand((1, 1, 28, 28))
label = torch.LongTensor([2])
print(image[0, 0, 1:3, 1:3])
net.zero_grad()
fsgm(image, label, net, nn.CrossEntropyLoss(), 0.3)
print(image[0, 0, 1:3, 1:3])


 0.3063  0.3201
 0.3940  0.7272
[torch.FloatTensor of size 2x2]




 0.6063  0.6201
 0.0940  0.4272
[torch.FloatTensor of size 2x2]



In [18]:
def pgd(image_batch: torch.FloatTensor,
        label_batch: torch.LongTensor,
        model: typing.Callable[[autograd.Variable], autograd.Variable],
        objective: typing.Callable[[autograd.Variable, autograd.Variable], autograd.Variable],
        eps: float,
        alpha: float,
        num_steps: int,
        num_restarts: int):
    """Runs PGD on the negative of the given loss function with the given parameters on the given image."""
    
    def pgd_without_restarts(sample_index: int):
        """PGD on negative of the loss function. This has no random restarts."""
        image = image_batch[sample_index]
        x_min = torch.clamp(image - eps, min=0.0)
        x_max = torch.clamp(image + eps, max=1.0)
        random_start = torch.clamp(image + torch.rand(image.shape)*eps, min=0.0, max=1.0)
        x = autograd.Variable(torch.unsqueeze(random_start, 0),
            requires_grad=True)
        for i in range(num_steps):
            output = model(x)
            label = autograd.Variable(label_batch[sample_index:sample_index + 1])
            loss = objective(output, label)
            loss.backward()
            x.data += alpha*torch.sign(x.grad.data)
            x.data = torch.min(torch.max(x.data, x_min), x_max)
            x.grad.data.fill_(0)
        return x.data, loss.data[0]
    
    max_loss = -1.0
    best_perturbed_image = None
    for i in range(image_batch.shape[0]):
        perturbed_image, loss = pgd_without_restarts(i)
        if loss > max_loss:
            max_loss = loss
            best_perturbed_image = perturbed_image
        image_batch[i] = perturbed_image

# Now, we test to see there are no obvious errors. 
net = SimpleCNN(1)
image = torch.rand((1, 1, 28, 28))
label = torch.LongTensor([2])
print(image[0, 0, 1:3, 1:3])
net.zero_grad()
pgd(image, label, net, nn.CrossEntropyLoss(), 0.3, 0.6, 4, 2)
print(image[0, 0, 1:3, 1:3])


 0.4000  0.2835
 0.3640  0.5573
[torch.FloatTensor of size 2x2]


 0.7000  0.0000
 0.6640  0.2573
[torch.FloatTensor of size 2x2]

