# Maximal Jacobian-based Saliency Map Attack

In this notebook, we will be implementing a M-JSMA, or Maximal Jacobian-based Saliency Map Attack.

Before reading through this notebook, please read through the regular JSMA notebook in `code/JSMA/attack.ipynb` first, as some concepts were discussed here that will be considered to be known to the user in this notebook.

As before, we will start by setting up our standard model.

In [26]:
# Importing the required packages
import torch
from torchvision import transforms
from torchvision.utils import save_image
from mnist_model_generator import Net
from PIL import Image
import math

In [27]:
# Loading our pre-trained MNIST model
model = Net()
model.load_state_dict(torch.load('../../data/models/mnist_cnn.pt', map_location=torch.device('cpu')))
model.eval();

In [28]:
# Prepare the 3 images we will be testing on.
# We will use different images, as this attack method is untargeted.
# Notice this attack is much more effective without prior normalization.
preprocess = transforms.Compose([
   transforms.Resize(28),
   transforms.ToTensor(),
])

three = Image.open("../../data/pictures/3.png")
three_tensor = preprocess(three)[0].reshape(1,1,28,28)

four = Image.open("../../data/pictures/4.png")
four_tensor = preprocess(four)[0].reshape(1,1,28,28)

eight = Image.open("../../data/pictures/8.png")
eight_tensor = preprocess(eight)[0].reshape(1,1,28,28)

The following images will be used:

<img src=../../data/pictures/3.png width=140>
<img src=../../data/pictures/4.png width=140>
<img src=../../data/pictures/8.png width=140>

Finally, let us test our model to make sure it correctly predicts these images.

In [29]:
print(f'The model predicted: {model(three_tensor).argmax().item()} with {model(three_tensor).max().item() * 100}% certainty. Should be 3')
print(f'The model predicted: {model(four_tensor).argmax().item()} with {model(four_tensor).max().item() * 100}% certainty. Should be 4')
print(f'The model predicted: {model(eight_tensor).argmax().item()} with {model(eight_tensor).max().item() * 100}% certainty. Should be 8')

The model predicted: 3 with 99.99912977218628% certainty. Should be 3
The model predicted: 4 with 91.91465973854065% certainty. Should be 4
The model predicted: 8 with 99.97344613075256% certainty. Should be 8


## Attacking the model

Let us start by discussing the changes that M-JSMA made to the original JSMA.
- First of all, the attack is now non-targeted, as it checks that the predicted class is not the actual class, rather than checking whether it is the target class.
- This attack combines the positive and negative saliency maps, using the positive whenever pixels that improve classes that are not the actual class are chosen and using the negative whenever pixels that improve the actual class are chosen.
- The modification value is chosen to be a lower value and therefore, pixels can be adjusted more than once. This leads to adversarial images that appear 'less attacked'

Though appearing as not many changes, this has tremendous effects on the attack as a whole.

The fact that the attack is able to create images that appear to be attacked less, as well as doing so without guidance from a human, means it becomes applicable in many more areas that previously possible.

The code for M-JSMA can be seen below.

In [30]:
def saliency_map(J, t, space, size, width):
    S = [(0, 0)] * size
    for p in space:
        alpha = J[t, p // width, p % width].item()
        beta = 0
        for i in range(J.size(0)):
            if not i == t:
                beta += J[i, p // width, p % width].item()
        S[p] = (alpha, beta)
    return S

In [31]:
def clip(orig, val, eps):
    return min([1, orig + eps, max([0, orig - eps, val])])

In [32]:
def jsmaM(original_image_tensor, actual_class, predictor, max_dist, perturbation, epsilon):
    img_tensor = original_image_tensor.clone()

    img_tensor = img_tensor.reshape(1,1,28,28)

    img_size = img_tensor.size(2) * img_tensor.size(3)
    width = img_tensor.size(3)
    search_space = list(range(img_size))
    i = 0
    max_iter = math.floor((img_size * max_dist) / (200))
    chosen_pixel_1 = -1
    chosen_pixel_2 = -1
    modifier = 0
    prediction = predictor(img_tensor)

    # Eta denotes the most recent change to each pixel, to make sure no cycles are formed.
    eta = [0] * img_size

    while prediction.argmax().item() == actual_class and i < max_iter and len(search_space) >= 2:
        max = 0

        # Calculate the jacobian.
        J = torch.autograd.functional.jacobian(predictor, img_tensor)[0, :, 0, 0, :, :]

        # Calculate the saliency map.
        S = [saliency_map(J, target, search_space, img_size, width) for target in range(10)]
        
        # For all possible classes
        for t in range(10):
            # Find the optimal pixel pair for class t.
            for pixel1 in search_space:
                for pixel2 in search_space:
                    if pixel1 == pixel2:
                        continue
                    
                    alpha = S[t][pixel1][0] + S[t][pixel2][0]
                    beta = S[t][pixel1][1] + S[t][pixel2][1]

                    if -alpha * beta > max:
                        chosen_pixel_1 = pixel1
                        chosen_pixel_2 = pixel2
                        max = -alpha * beta
                        # If the most influential pixels boost the chances of the actual class, we wish to use the negative JSMA
                        # Otherwise, we use the positive JSMA.
                        modifier = (-1 if t == actual_class else 1) * math.copysign(1, alpha) * perturbation

        # If no improvements can be made, quit.
        if max == 0:
            break
        
        # Make sure the change remains within some epsilon of the original value.
        # This helps to keep the image as close to the original as possible.
        new1 = clip(
            original_image_tensor[0, 0, chosen_pixel_1 // width, chosen_pixel_1 % width].item(),
            img_tensor[0, 0, chosen_pixel_1 // width, chosen_pixel_1 % width].item() + modifier,
            epsilon
        )
        diff1 = abs(new1 - img_tensor[0, 0, chosen_pixel_1 // width, chosen_pixel_1 % width])
        img_tensor[0, 0, chosen_pixel_1 // width, chosen_pixel_1 % width] = new1
        
        new2 = clip(
            original_image_tensor[0, 0, chosen_pixel_2 // width, chosen_pixel_2 % width].item(),
            img_tensor[0, 0, chosen_pixel_2 // width, chosen_pixel_2 % width].item() + modifier,
            epsilon
        )
        diff2 = abs(new2 - img_tensor[0, 0, chosen_pixel_2 // width, chosen_pixel_2 % width])
        img_tensor[0, 0, chosen_pixel_2 // width, chosen_pixel_2 % width] = new2

        # Remove the pixels from the search domain under some conditions:
        #   - The new value is 0 or 1
        #   - The new value is the same as the old value
        #   - A loop was found using eta.
        val = img_tensor[0, 0, chosen_pixel_1 // width, chosen_pixel_1 % width]
        if val <= 0 or val >= 1 or diff1 < 1e-06 or eta[chosen_pixel_1] == -1 * modifier:
            search_space.remove(chosen_pixel_1)
    
        val = img_tensor[0, 0, chosen_pixel_2 // width, chosen_pixel_2 % width]
        if val == 0 or val == 1 or diff2 < 1e-06 or eta[chosen_pixel_2] == -1 * modifier:
            search_space.remove(chosen_pixel_2)
        
        eta[chosen_pixel_1] = modifier
        eta[chosen_pixel_2] = modifier
        prediction = predictor(img_tensor)

        # Print some statements to show the attack's progress.
        topPredictions = torch.topk(prediction, 2).indices[0]
        closestIndex = topPredictions[1].item() if prediction.argmax() == actual_class else topPredictions[0].item()
        print(f'Actual class: {actual_class} at {prediction[0, actual_class] * 100}%')
        print(f'Closest attack: {closestIndex} at {prediction[0, closestIndex] * 100}%')

        i += 1
    return img_tensor

In [33]:
attacked_3 = jsmaM(three_tensor, 3, model, 20, 1, 0.5)
save_image(attacked_3[0,0], '../../results/JSMA-M/attacked-3.png')
print('')
attacked_4 = jsmaM(four_tensor, 4, model, 20, 1, 0.5)
save_image(attacked_4[0,0], '../../results/JSMA-M/attacked-4.png')
print('')
attacked_8 = jsmaM(eight_tensor, 8, model, 20, 1, 0.5)
save_image(attacked_8[0,0], '../../results/JSMA-M/attacked-8.png')

Actual class: 8 at 99.92290496826172%
Closest attack: 2 at 0.03664131462574005%
Actual class: 8 at 99.86494445800781%
Closest attack: 2 at 0.08285068720579147%
Actual class: 8 at 99.81197357177734%
Closest attack: 2 at 0.1156361997127533%
Actual class: 8 at 99.40975952148438%
Closest attack: 2 at 0.5095492601394653%
Actual class: 8 at 99.40975952148438%
Closest attack: 2 at 0.5095492601394653%
Actual class: 8 at 99.1590576171875%
Closest attack: 2 at 0.7626803517341614%
Actual class: 8 at 99.04500579833984%
Closest attack: 2 at 0.8696763515472412%
Actual class: 8 at 97.74250030517578%
Closest attack: 2 at 2.0979902744293213%
Actual class: 8 at 96.52853393554688%
Closest attack: 2 at 3.3512566089630127%
Actual class: 8 at 95.88655853271484%
Closest attack: 2 at 3.9759793281555176%
Actual class: 8 at 94.82962799072266%
Closest attack: 2 at 5.0015034675598145%
Actual class: 8 at 92.57537078857422%
Closest attack: 2 at 7.203100681304932%
Actual class: 8 at 88.50079345703125%
Closest attack

In [34]:
print(f'The model predicted: {model(attacked_3).argmax().item()} with {model(attacked_3).max().item() * 100}% certainty. Should be 3')
print(f'The model predicted: {model(attacked_4).argmax().item()} with {model(attacked_4).max().item() * 100}% certainty. Should be 4')
print(f'The model predicted: {model(attacked_8).argmax().item()} with {model(attacked_8).max().item() * 100}% certainty. Should be 8')

The model predicted: 9 with 53.7250816822052% certainty. Should be 3
The model predicted: 9 with 54.69772219657898% certainty. Should be 4
The model predicted: 2 with 49.734100699424744% certainty. Should be 8


## Analyzing the results

We obtained the following three results, as can be seen in the print statements above:

<figure>
    <img src=../../results/JSMA-M/attacked-3.png width=140>
    <figcaption>Classified as a 9</figcaption>
</figure>
<figure>
    <img src=../../results/JSMA-M/attacked-4.png width=140>
    <figcaption>Classified as a 9</figcaption>
</figure>
<figure>
    <img src=../../results/JSMA-M/attacked-8.png width=140>
    <figcaption>Classified as a 2</figcaption>
</figure>

In the first image, we see the same issue we had with JSMA's generated adversarial example.
A human could very well mistake this image for a 9 as well, though less so due to the area that 'closes' the 9 being less bright than the rest of the 3.

However, we notice that the other two images would without a doubt still be classified as the correct number by human observers.
Yet the model failed to properly classify them.
This means our attack was successful!

These results are certainly an improvement over those generated by JSMA.
Additionally, this method in a non-targeted attack, which means less human interaction is requires, which is a great bonus for an adversarial example generator.

Let us finalize by giving M-JSMA a chance against the robust model, as mentioned before in the JSMA notebook.

## The effectiveness of adversarial training

We have seen that this attack is capable of tricking our model into predicting the wrong class with a very high success rate.

Adversarial training is one of the methods that can be used to defend a model against incoming attacks.

This time, our attack is indeed non-targeted, so it could be used to train a model.
However, another issue comes up.

Namely, the execution time of this attack is not as low as it was for simpler attacks.
The training set consists of 60000 images, with an average time per image of about 80 seconds, the training would take 55.5 days.
We simply do not have the time to create this model, even though it would likely be a very robust model.

Therefore, we once again choose to borrow the robust model that was generated in the FGSM & similar attacks notebook.
While this will ofcourse not be trained specifically against our attack, we will explore whether it performs better than our original model either way.

In [35]:
# Loading the robust MNIST model.
def_model = Net()
def_model.load_state_dict(torch.load('../../data/models/mnist_robust.pt', map_location=torch.device('cpu')))
def_model.eval();

Once again, let us make sure this model functions properly for regular images:

In [36]:
print(f'The model predicted: {def_model(three_tensor).argmax().item()} with {def_model(three_tensor).max().item() * 100}% certainty. Should be 3')
print(f'The model predicted: {def_model(four_tensor).argmax().item()} with {def_model(four_tensor).max().item() * 100}% certainty. Should be 4')
print(f'The model predicted: {def_model(eight_tensor).argmax().item()} with {def_model(eight_tensor).max().item() * 100}% certainty. Should be 8')

The model predicted: 3 with 100.0% certainty. Should be 3
The model predicted: 4 with 98.1884241104126% certainty. Should be 4
The model predicted: 8 with 100.0% certainty. Should be 8


Good news, the robust model performs even better than our original model!

Next, let us run the exact same attack as before on this model and analyze the results.

In [37]:
defended_3 = jsmaM(three_tensor, 3, def_model, 20, 1, 0.5)
save_image(defended_3[0,0], '../../results/JSMA-M/defended-3.png')
print('')
defended_4 = jsmaM(four_tensor, 4, def_model, 20, 1, 0.5)
save_image(defended_4[0,0], '../../results/JSMA-M/defended-4.png')
print('')
defended_8 = jsmaM(eight_tensor, 8, def_model, 20, 1, 0.5)
save_image(defended_8[0,0], '../../results/JSMA-M/defended-8.png')

Actual class: 3 at 100.0%
Closest attack: 2 at 1.16762207541532e-11%
Actual class: 3 at 100.0%
Closest attack: 2 at 4.471123471461169e-11%
Actual class: 3 at 100.0%
Closest attack: 2 at 2.8891170011924316e-10%
Actual class: 3 at 100.0%
Closest attack: 2 at 1.3701034751179009e-09%
Actual class: 3 at 100.0%
Closest attack: 2 at 1.4401884129711107e-07%
Actual class: 3 at 100.0%
Closest attack: 2 at 1.4401884129711107e-07%
Actual class: 3 at 100.0%
Closest attack: 2 at 1.4657595102107734e-06%
Actual class: 3 at 99.99998474121094%
Closest attack: 2 at 6.159306394692976e-06%
Actual class: 3 at 99.99996185302734%
Closest attack: 2 at 3.345565710333176e-05%
Actual class: 3 at 99.99990844726562%
Closest attack: 2 at 9.704021067591384e-05%
Actual class: 3 at 99.99979400634766%
Closest attack: 2 at 0.00019738267292268574%
Actual class: 3 at 99.99881744384766%
Closest attack: 2 at 0.0011575249955058098%
Actual class: 3 at 99.99321746826172%
Closest attack: 2 at 0.006664319429546595%
Actual class: 

In [38]:
print(f'The model predicted: {def_model(defended_3).argmax().item()} with {def_model(defended_3).max().item() * 100}% certainty. Should be 3')
print(f'The model predicted: {def_model(defended_4).argmax().item()} with {def_model(defended_4).max().item() * 100}% certainty. Should be 4')
print(f'The model predicted: {def_model(defended_8).argmax().item()} with {def_model(defended_8).max().item() * 100}% certainty. Should be 8')

The model predicted: 2 with 68.77274513244629% certainty. Should be 3
The model predicted: 9 with 79.48253154754639% certainty. Should be 4
The model predicted: 2 with 50.52988529205322% certainty. Should be 8


The following figures are obtained:

<figure>
    <img src=../../results/JSMA-M/defended-3.png width=140>
    <figcaption>Classified as a 2</figcaption>
</figure>
<figure>
    <img src=../../results/JSMA-M/defended-4.png width=140>
    <figcaption>Classified as a 9</figcaption>
</figure>
<figure>
    <img src=../../results/JSMA-M/defended-8.png width=140>
    <figcaption>Classified as a 2</figcaption>
</figure>

We notice extremely similar results for the images of 4 and 8.
The robust model clearly was not prepared for an attack of this scale.

However, we notice the adversarial example that was generated from the 3 looks very different, yet is also classified as a digit other than 3.
This implies that the robust model protected against the first attack, yet M-JSMA found another way to attack the model.

Overall, we must conclude that this model was not powerful enough to properly stop a M-JSMA, which can be expected, as it was not trained against this attack.
Perhaps a model trained on M-JSMA could have defended this attack.

Overall, adversarial training did not have much of an effect on M-JSMA.

Please see the report for a final discussion and a comparison to other attacks.