In [1]:
import torch
from torchvision import transforms
from torchvision.utils import save_image
from mnist_model_generator import Net
from PIL import Image
import math

In [2]:
model = Net()
model.load_state_dict(torch.load('../../data/weights/mnist_cnn.pt', map_location=torch.device('cpu')))
model.eval();

In [3]:
three = Image.open("../../data/pictures/3.png")
preprocess = transforms.Compose([
   transforms.Resize(28),
   transforms.ToTensor(),
   transforms.Normalize((0.1307,), (0.3081,))
])
three_tensor = preprocess(three)[0].reshape(1,1,28,28)

In [4]:
model(three_tensor)

tensor([[3.8002e-37, 1.9888e-26, 1.2790e-24, 1.0000e+00, 3.6991e-38, 1.3000e-19,
         5.5069e-37, 2.9996e-23, 2.1760e-25, 7.5329e-23]],
       grad_fn=<SoftmaxBackward0>)

# Start of JSMA

In [5]:
def saliency_map(J, t, space, size, width):
    S = [(0, 0)] * size
    for p in space:
        alpha = J[t, p // width, p % width].item()
        beta = 0
        for i in range(J.size(0)):
            if not i == t:
                beta += J[i, p // width, p % width].item()
        S[p] = (alpha, beta)
    return S

In [6]:
def jsma(original_image_tensor, target, predictor, max_dist, increase):
    img_tensor = original_image_tensor.clone()
    img_size = img_tensor.size(2) * img_tensor.size(3)
    width = img_tensor.size(3)
    search_space = list(range(img_size))
    i = 0
    max_iter = math.floor((img_size * max_dist) / (200))
    chosen_pixel_1 = -1
    chosen_pixel_2 = -1
    prediction = predictor(img_tensor)

    while not prediction.argmax().item() == target and i < max_iter and len(search_space) >= 2:
        max = 0
        J = torch.autograd.functional.jacobian(predictor, img_tensor)[0, :, 0, 0, :, :]
        S = saliency_map(J, target, search_space, img_size, width)
        for pixel1 in search_space:
            for pixel2 in search_space:
                if pixel1 == pixel2:
                    continue
                
                alpha = S[pixel1][0] + S[pixel2][0]
                beta = S[pixel1][1] + S[pixel2][1]

                sign_check = alpha > 0 and beta < 0 if increase else alpha < 0 and beta > 0
                if sign_check and -alpha * beta > max:
                    chosen_pixel_1 = pixel1
                    chosen_pixel_2 = pixel2
                    max = -alpha * beta

        if max == 0:
            break

        img_tensor[0, 0, chosen_pixel_1 // width, chosen_pixel_1 % width] = 1 if increase else 0
        img_tensor[0, 0, chosen_pixel_2 // width, chosen_pixel_2 % width] = 1 if increase else 0

        search_space.remove(chosen_pixel_1)
        search_space.remove(chosen_pixel_2)
        
        prediction = predictor(img_tensor)
        i += 1
    return img_tensor

Image of 3 seems to lead to positive not being very effective.
Perhaps a 1 would perform better on positive?

On the other hand, 3 performs pretty well on negative

The experiment on an empty image is pretty interesting

In [7]:
attacked_models_positive = []
for i in range(10):
    attacked_models_positive.append(jsma(three_tensor, i, model, 20, True))
    print(f'Classified as {model(attacked_models_positive[i]).argmax().item()} with goal {i}')
    save_image(attacked_models_positive[i][0,0], f'../../results/JSMA/positive-{i}.png')

Classified as 3 with goal 0
Classified as 3 with goal 1
Classified as 2 with goal 2
Classified as 3 with goal 3
Classified as 3 with goal 4
Classified as 3 with goal 5
Classified as 3 with goal 6
Classified as 3 with goal 7
Classified as 8 with goal 8
Classified as 3 with goal 9


In [8]:
attacked_models_negative = []
for i in range(10):
    attacked_models_negative.append(jsma(three_tensor, i, model, 20, False))
    print(f'Classified as {model(attacked_models_negative[i]).argmax().item()} with goal {i}')
    save_image(attacked_models_negative[i][0,0], f'../../results/JSMA/negative-{i}.png')

Classified as 0 with goal 0
Classified as 1 with goal 1
Classified as 2 with goal 2
Classified as 3 with goal 3
Classified as 4 with goal 4
Classified as 5 with goal 5
Classified as 3 with goal 6
Classified as 7 with goal 7
Classified as 8 with goal 8
Classified as 3 with goal 9


In [9]:
attacked_models_empty = []
for i in range(10):
    attacked_models_empty.append(jsma(torch.zeros_like(three_tensor), i, model, 20, True))
    print(f'Classified as {model(attacked_models_empty[i]).argmax().item()} with goal {i}')
    save_image(attacked_models_empty[i][0,0], f'../../results/JSMA/empty-{i}.png')

Classified as 0 with goal 0
Classified as 1 with goal 1
Classified as 2 with goal 2
Classified as 3 with goal 3
Classified as 4 with goal 4
Classified as 5 with goal 5
Classified as 6 with goal 6
Classified as 7 with goal 7
Classified as 8 with goal 8
Classified as 9 with goal 9
