In [121]:
import torch
from torchvision import transforms
from torchvision.utils import save_image
from mnist_model_generator import Net
from PIL import Image
import math

In [122]:
model = Net()
model.load_state_dict(torch.load('../mnist_cnn.pt'))
model.eval();

In [123]:
three = Image.open("../8.png")
preprocess = transforms.Compose([
   transforms.Resize(28),
   transforms.ToTensor(),
])
three_tensor = preprocess(three)[0].reshape(1,1,28,28)

In [124]:
model(three_tensor)

tensor([[6.7077e-05, 1.2055e-06, 8.3040e-05, 3.1091e-06, 1.1007e-06, 5.2987e-05,
         1.5825e-05, 7.9039e-08, 9.9973e-01, 4.1002e-05]],
       grad_fn=<SoftmaxBackward0>)

# Start of JSMA-M

In [125]:
def saliency_map(J, t, space, size, width):
    S = [(0, 0)] * size
    for p in space:
        alpha = J[t, p // width, p % width].item()
        beta = 0
        for i in range(J.size(0)):
            if not i == t:
                beta += J[i, p // width, p % width].item()
        S[p] = (alpha, beta)
    return S

In [126]:
def clip(orig, val, eps):
    return min([1, orig + eps, max([0, orig - eps, val])])

In [127]:
def jsmaM(original_image_tensor, actual_class, predictor, max_dist, perturbation, epsilon):
    img_tensor = original_image_tensor.clone()
    img_size = img_tensor.size(2) * img_tensor.size(3)
    width = img_tensor.size(3)
    search_space = list(range(img_size))
    i = 0
    max_iter = math.floor((img_size * max_dist) / (200))
    chosen_pixel_1 = -1
    chosen_pixel_2 = -1
    modifier = 0
    prediction = predictor(img_tensor)
    eta = [0] * img_size

    while prediction.argmax().item() == actual_class and i < max_iter and len(search_space) >= 2:
        max = 0
        J = torch.autograd.functional.jacobian(predictor, img_tensor)[0, :, 0, 0, :, :]
        S = [saliency_map(J, target, search_space, img_size, width) for target in range(10)]
        
        for t in range(10):
            for pixel1 in search_space:
                for pixel2 in search_space:
                    if pixel1 == pixel2:
                        continue
                    
                    alpha = S[t][pixel1][0] + S[t][pixel2][0]
                    beta = S[t][pixel1][1] + S[t][pixel2][1]

                    if -alpha * beta > max:
                        chosen_pixel_1 = pixel1
                        chosen_pixel_2 = pixel2
                        max = -alpha * beta
                        modifier = (-1 if t == actual_class else 1) * math.copysign(1, alpha) * perturbation

        if max == 0:
            break
        
        new1 = clip(
            original_image_tensor[0, 0, chosen_pixel_1 // width, chosen_pixel_1 % width].item(),
            img_tensor[0, 0, chosen_pixel_1 // width, chosen_pixel_1 % width].item() + modifier,
            epsilon
        )
        diff1 = abs(new1 - img_tensor[0, 0, chosen_pixel_1 // width, chosen_pixel_1 % width])
        img_tensor[0, 0, chosen_pixel_1 // width, chosen_pixel_1 % width] = new1
        
        new2 = clip(
            original_image_tensor[0, 0, chosen_pixel_2 // width, chosen_pixel_2 % width].item(),
            img_tensor[0, 0, chosen_pixel_2 // width, chosen_pixel_2 % width].item() + modifier,
            epsilon
        )
        diff2 = abs(new2 - img_tensor[0, 0, chosen_pixel_2 // width, chosen_pixel_2 % width])
        img_tensor[0, 0, chosen_pixel_2 // width, chosen_pixel_2 % width] = new2

        val = img_tensor[0, 0, chosen_pixel_1 // width, chosen_pixel_1 % width]
        if val <= 0 or val >= 1 or diff1 < 1e-06 or eta[chosen_pixel_1] == -1 * modifier:
            search_space.remove(chosen_pixel_1)
    
        val = img_tensor[0, 0, chosen_pixel_2 // width, chosen_pixel_2 % width]
        if val == 0 or val == 1 or diff2 < 1e-06 or eta[chosen_pixel_2] == modifier:
            search_space.remove(chosen_pixel_2)
        
        eta[chosen_pixel_1] = modifier
        eta[chosen_pixel_2] = modifier
        prediction = predictor(img_tensor)

        topPredictions = torch.topk(prediction, 2).indices[0]
        closestIndex = topPredictions[1].item() if prediction.argmax() == actual_class else topPredictions[0].item()
        print(f'Actual class: {prediction[0, actual_class]}%')
        print(f'Closest attack: {closestIndex} at {prediction[0, closestIndex]}%')

        i += 1
    return img_tensor

In [131]:
attacked = jsmaM(three_tensor, 8, model, 20, 1, 0.3)

Actual class: 0.9994644522666931%
Closest attack: 2 at 0.0002299167972523719%
Actual class: 0.9994644522666931%
Closest attack: 2 at 0.0002299167972523719%
Actual class: 0.9990882873535156%
Closest attack: 2 at 0.00044256687397137284%
Actual class: 0.998862624168396%
Closest attack: 2 at 0.0005515737575478852%
Actual class: 0.9985792636871338%
Closest attack: 2 at 0.0006578587926924229%
Actual class: 0.998210072517395%
Closest attack: 2 at 0.0007347500650212169%
Actual class: 0.9975599050521851%
Closest attack: 2 at 0.0008847895660437644%
Actual class: 0.9967947602272034%
Closest attack: 2 at 0.000931522692553699%
Actual class: 0.9965640902519226%
Closest attack: 9 at 0.000983938341960311%
Actual class: 0.9944281578063965%
Closest attack: 9 at 0.002133168512955308%
Actual class: 0.9942368268966675%
Closest attack: 9 at 0.0022463698405772448%
Actual class: 0.9909203052520752%
Closest attack: 9 at 0.00413804454728961%
Actual class: 0.9894407391548157%
Closest attack: 9 at 0.0052473028190

In [129]:
save_image(attacked[0][0], f'attack.png')

In [130]:
model(attacked).argmax()

tensor(2)