In [1]:
import torch
from torchvision import transforms
from torchvision.utils import save_image
from mnist_model_generator import Net
from PIL import Image
import math

In [2]:
model = Net()
model.load_state_dict(torch.load('../mnist_cnn.pt'))
model.eval();

In [3]:
three = Image.open("../3.png")
preprocess = transforms.Compose([
   transforms.Resize(28),
   transforms.ToTensor(),
   transforms.Normalize((0.1307,), (0.3081,))
])
three_tensor = preprocess(three)[0].reshape(1,1,28,28)

In [4]:
model(three_tensor)

tensor([[3.8002e-37, 1.9888e-26, 1.2790e-24, 1.0000e+00, 3.6991e-38, 1.3000e-19,
         5.5068e-37, 2.9996e-23, 2.1760e-25, 7.5328e-23]],
       grad_fn=<SoftmaxBackward0>)

# Start of JSMA-M

In [5]:
def saliency_map(J, t, space, size, width):
    S = [(0, 0)] * size
    for p in space:
        alpha = J[t, p // width, p % width].item()
        beta = 0
        for i in range(J.size(0)):
            if not i == t:
                beta += J[i, p // width, p % width].item()
        S[p] = (alpha, beta)
    return S

In [6]:
def clip(orig, val, eps):
    return min([1, orig + eps, max([0, orig - eps, val])])

In [7]:
def jsmaM(original_image_tensor, actual_class, predictor, max_dist, perturbation, epsilon):
    img_tensor = original_image_tensor.clone()
    img_size = img_tensor.size(2) * img_tensor.size(3)
    width = img_tensor.size(3)
    search_space = list(range(img_size))
    i = 0
    max_iter = math.floor((img_size * max_dist) / (200))
    chosen_pixel_1 = -1
    chosen_pixel_2 = -1
    modifier = 0
    prediction = predictor(img_tensor)
    eta = [0] * img_size

    while prediction.argmax().item() == actual_class and i < max_iter and len(search_space) >= 2:
        max = 0
        J = torch.autograd.functional.jacobian(predictor, img_tensor)[0, :, 0, 0, :, :]
        S = [saliency_map(J, target, search_space, img_size, width) for target in range(10)]
        
        for t in range(10):
            for pixel1 in search_space:
                for pixel2 in search_space:
                    if pixel1 == pixel2:
                        continue
                    
                    alpha = S[t][pixel1][0] + S[t][pixel2][0]
                    beta = S[t][pixel1][1] + S[t][pixel2][1]

                    if -alpha * beta > max:
                        chosen_pixel_1 = pixel1
                        chosen_pixel_2 = pixel2
                        max = -alpha * beta
                        modifier = (-1 if t == actual_class else 1) * math.copysign(1, alpha) * perturbation

        if max == 0:
            break
        
        print(f'Before: {img_tensor[0, 0, chosen_pixel_1 // width, chosen_pixel_1 % width]}, {img_tensor[0, 0, chosen_pixel_2 // width, chosen_pixel_2 % width]}')

        img_tensor[0, 0, chosen_pixel_1 // width, chosen_pixel_1 % width] = clip(
            original_image_tensor[0, 0, chosen_pixel_1 // width, chosen_pixel_1 % width].item(),
            img_tensor[0, 0, chosen_pixel_1 // width, chosen_pixel_1 % width].item() + modifer,
            epsilon
        )
        
        img_tensor[0, 0, chosen_pixel_2 // width, chosen_pixel_2 % width] = clip(
            original_image_tensor[0, 0, chosen_pixel_2 // width, chosen_pixel_2 % width].item(),
            img_tensor[0, 0, chosen_pixel_2 // width, chosen_pixel_2 % width].item() + modifier,
            epsilon
        )

        val = img_tensor[0, 0, chosen_pixel_1 // width, chosen_pixel_1 % width]
        if val <= 0 or val >= 1 or eta[chosen_pixel_1] == -1 * modifier:
            search_space.remove(chosen_pixel_1)
    
        val = img_tensor[0, 0, chosen_pixel_2 // width, chosen_pixel_1 % width]
        if val <= 0 or val >= 1 or eta[chosen_pixel_1] == -1 * modifier:
            search_space.remove(chosen_pixel_2)
        
        eta[chosen_pixel_1] = modifier
        eta[chosen_pixel_2] = modifier
        prediction = predictor(img_tensor)
        i += 1
        print(f'After: {img_tensor[0, 0, chosen_pixel_1 // width, chosen_pixel_1 % width]}, {img_tensor[0, 0, chosen_pixel_2 // width, chosen_pixel_2 % width]}')
    return img_tensor

In [8]:
attacked = jsmaM(three_tensor, 3, model, 20, 1, 1)

Before: 2.770573854446411, 2.770573854446411
After: 1.0, 1.0
Before: 1.701402187347412, 2.68147611618042
After: 1.0, 1.0
Before: 2.185075044631958, 2.732388973236084
After: 1.0, 1.0
Before: 1.981423258781433, 1.8668692111968994
After: 1.0, 1.0
Before: 0.04673170670866966, 0.6576869487762451
After: 0.04673170670866966, 0.6576869487762451
Before: 0.04673170670866966, 0.6576869487762451
After: 0.04673170670866966, 0.6576869487762451
Before: 0.04673170670866966, 0.6576869487762451
After: 0.04673170670866966, 0.6576869487762451
Before: 0.04673170670866966, 0.6576869487762451
After: 0.04673170670866966, 0.6576869487762451
Before: 0.04673170670866966, 0.6576869487762451
After: 0.04673170670866966, 0.6576869487762451
Before: 0.04673170670866966, 0.6576869487762451
After: 0.04673170670866966, 0.6576869487762451
Before: 0.04673170670866966, 0.6576869487762451
After: 0.04673170670866966, 0.6576869487762451
Before: 0.04673170670866966, 0.6576869487762451
After: 0.04673170670866966, 0.6576869487762

KeyboardInterrupt: 

In [None]:
attacked.size()