Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CarliniLInfMethod differs significantly from the original, resulting in lower attack success rate #1374

Closed
kaitokishi opened this issue Oct 20, 2021 · 6 comments · Fixed by #1380
Assignees
Labels
improvement Improve implementation
Milestone

Comments

@kaitokishi
Copy link
Contributor

Describe the bug
The loss used in CarliniLInfMethod differs from the paper. Therefore, CarliniLInfMethod fails to generate adversarial images in small eps while the original implementation by Carlini can generate adversarial images in the eps perturbation size.

To reproduce

  1. Make a simple neural network to classify MNIST.
  2. Generate adversarial images by the original implementation. We can get the images in small perturbation sizes like 0.1.
  3. Generate adversarial images by CarliniLInfMethod with eps=0.1 which is sufficient in the original implementation, and get lower attack success rate. I tried random 10 samples from MNIST and got 100% attack success rate (i.e. 10 samples out of 10) in the original implementation but got 60% attack success rate (i.e. 6 samples out of 10, which are the same as used in the original implementation) in CarliniLInfMethod.

Expected behavior
CarliniLInfMethod should have high attack success rate comparable with Carlini's original implementation even if eps is small as in the original implementation.

Cause of the bug
The loss of Carlini & Wagner is (see p.10 of the paper) while the loss implemented on CarliniLInfMethod is only , so they are inconsistent.

System information (please complete the following information):

  • OS: Ubuntu 18.04.5 LTS
  • Python version: 3.6.9
  • ART version or commit number: 1.8.0
  • To test Carlini's original implementation:
    • tensorflow-gpu version: 1.14.0 / Keras version: 2.3.0
  • To test CarliniLInfMethod:
    • TensorFlow version: 2.3.0
@beat-buesser
Copy link
Collaborator

Hi @kaitokishi Thank you very much for raising this issue! I would be very interested to take a closer look. Would you be able to share your code used for the comparisons? That way I could start with the exactly same experiments.

@kaitokishi
Copy link
Contributor Author

Thank you for your reply. I refactor my code and the result is slightly changed but the main result is not changed.

My code consists of two files: train_model.py, which makes a model to classify MNIST, and attack.py, which attacks the model by Carlini's original implementation or ART's CarliniLInfMethod. Also, two virtual environments are used: venv-original and venv-art.

1. Make a simple neural network to classify MNIST

Use venv-original with requirements-original.txt.

requirements-original.txt
absl-py==0.13.0
astor==0.8.1
cached-property==1.5.2
cycler==0.10.0
dataclasses==0.8
gast==0.5.2
google-pasta==0.2.0
grpcio==1.40.0
h5py==2.10.0
importlib-metadata==4.8.1
Keras==2.3.0
Keras-Applications==1.0.8
Keras-Preprocessing==1.1.2
kiwisolver==1.3.1
Markdown==3.3.4
matplotlib==3.3.4
numpy==1.19.5
Pillow==8.3.2
pkg_resources==0.0.0
protobuf==3.17.3
pyparsing==2.4.7
python-dateutil==2.8.2
PyYAML==5.4.1
scipy==1.5.4
six==1.16.0
tensorboard==1.14.0
tensorflow-estimator==1.14.0
tensorflow-gpu==1.14.0
termcolor==1.1.0
tqdm==4.62.2
typing-extensions==3.10.0.2
Werkzeug==2.0.1
wrapt==1.12.1
zipp==3.5.0

Run train_model.py to save a model to classify MNIST.

train_model.py
import numpy as np

import keras
from keras import layers
from tensorflow.keras.datasets import mnist


# Load a dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train[:,:,:,np.newaxis] / 255. - 0.5 # x_train.shape is (60000, 28, 28, 1)
x_test = x_test[:,:,:,np.newaxis] / 255. - 0.5 # x_test.shape is (10000, 28, 28, 1)


# Make a model
model = keras.Sequential()
model.add(layers.Flatten())
model.add(layers.Dense(units=512, activation='relu', input_dim=28*28))
model.add(layers.Dense(units=512, activation='relu'))
model.add(layers.Dense(units=10))

loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(loss=loss, optimizer=keras.optimizers.SGD(), metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=128, epochs=3)

model.save('mnist_model.h5')

2. Generate adversarial images by the original implementation

Download li_attack.py from Carlini's GitHub repository. Then, run below command (now we are using venv-original)

python attack.py original

to attack the model by Carlini's original implementation.

attack.py
import sys

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.datasets import mnist


def get_correct_predicted_indices(model):
    indices = []
    for i in range(len(x_test)):
        if y_test[i] == np.argmax(model.predict(np.array([x_test[i]]))):
            indices.append(i)
            if len(indices) == 10:
                break
    print(f'Selected indices: {indices}')
    return indices


# Load a dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train[:,:,:,np.newaxis] / 255. - 0.5 # x_train.shape is (60000, 28, 28, 1)
x_test = x_test[:,:,:,np.newaxis] / 255. - 0.5 # x_test.shape is (10000, 28, 28, 1)


if sys.argv[1] == 'original':
    import keras
    from li_attack import CarliniLi
    class Model:
        def __init__(self, restore):
            self.num_channels = 1
            self.image_size = 28
            self.num_labels = 10
            self.model = model

        def predict(self, data):
            return self.model(data)

    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
        # load model
        model = keras.models.load_model('mnist_model.h5')
        indices = get_correct_predicted_indices(model)
        attack = CarliniLi(sess, Model(model), targeted=False, learning_rate=0.01, max_iterations=1000)
        x_adv = attack.attack(x_test[indices], np.identity(10)[y_test[indices]]) # one-hot encoding is required
        print(f'original, ASR: {np.sum(y_test[indices] != np.argmax(model.predict(x_adv), axis=1)) / 10}')
        print(f'original, LInf norm: {np.max(np.abs(x_adv - x_test[indices]), axis=(1,2,3))}')

        #for i, each_x_adv in zip(indices, x_adv): # Optional: we can save the images with `img` directory.
        #    plt.imshow(each_x_adv, cmap='gray')
        #    plt.savefig(f'img/original-{i}.png')
elif sys.argv[1] == 'art':
    from tensorflow import keras
    # load model
    model = keras.models.load_model('mnist_model.h5')
    indices = get_correct_predicted_indices(model)
    from art.attacks.evasion import CarliniLInfMethod
    from art.estimators.classification import TensorFlowV2Classifier
    attack = CarliniLInfMethod(TensorFlowV2Classifier(model=model, nb_classes=10, input_shape=(28, 28, 1), loss_object=model.loss, clip_values=(-0.5,0.5)), targeted=False, learning_rate=0.01, max_iter=1000, eps=0.16)
    x_adv = attack.generate(x_test[indices], y_test[indices])
    print(f'ART, ASR: {np.sum(y_test[indices] != np.argmax(model.predict(x_adv), axis=1)) / 10}')
    print(f'ART, LInf norm: {np.max(np.abs(x_adv - x_test[indices]), axis=(1,2,3))}')

    #for i, each_x_adv in zip(indices, x_adv): # Optional: we can save the images with `img` directory.
    #    plt.imshow(each_x_adv, cmap='gray')
    #    plt.savefig(f'img/art-{i}.png')

I got below result.

Selected indices: [0, 1, 2, 3, 4, 5, 6, 7, 9, 10]
...
original, ASR: 1.0
original, LInf norm: [0.13257101 0.06283894 0.07840255 0.15271175 0.04800587 0.09540931
 0.06005499 0.03468984 0.04028573 0.07987416]

This means the attack success rate is 100% and the perturbation size (i.e. eps) can be smaller than 0.16 in Carlini's original implementation.

3. Generate adversarial images by CarliniLInfMethod with eps=0.16

Use venv-art with requirements-art.txt.

requirements-art.txt
absl-py==0.12.0
adversarial-robustness-toolbox==1.8.0
astor==0.8.1
astunparse==1.6.3
cached-property==1.5.2
cachetools==4.2.2
certifi==2021.5.30
chardet==4.0.0
cycler==0.10.0
dataclasses==0.8
flatbuffers==1.12
gast==0.3.3
google-auth==1.30.1
google-auth-oauthlib==0.4.4
google-pasta==0.2.0
grpcio==1.34.1
h5py==2.10.0
idna==2.10
importlib-metadata==4.5.0
joblib==1.0.1
Keras-Applications==1.0.8
keras-nightly==2.5.0.dev2021032900
Keras-Preprocessing==1.1.2
kiwisolver==1.3.1
llvmlite==0.36.0
Markdown==3.3.4
matplotlib==3.3.4
numba==0.53.1
numpy==1.18.5
oauthlib==3.1.1
opencv-python==4.5.2.54
opt-einsum==3.3.0
Pillow==8.2.0
pkg-resources==0.0.0
protobuf==3.17.2
pyasn1==0.4.8
pyasn1-modules==0.2.8
pyparsing==2.4.7
python-dateutil==2.8.1
PyYAML==5.4.1
requests==2.25.1
requests-oauthlib==1.3.0
rsa==4.7.2
scikit-learn==0.24.2
scipy==1.4.1
six==1.15.0
tensorboard==2.5.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorflow==2.3.0
tensorflow-addons==0.11.2
tensorflow-estimator==2.3.0
termcolor==1.1.0
threadpoolctl==2.1.0
tqdm==4.61.0
typeguard==2.12.1
typing-extensions==3.7.4.3
urllib3==1.26.5
Werkzeug==2.0.1
wrapt==1.12.1
zipp==3.4.1

Run below command

python attack.py art

to attack the model by CaliniLInfMethod in ART. I got below result.

Selected indices: [0, 1, 2, 3, 4, 5, 6, 7, 9, 10]
...
ART, ASR: 0.7
ART, LInf norm: [1.44921097e-08 1.47258534e-08 1.60000011e-01 1.60000011e-01
 1.47258534e-08 1.60000010e-01 1.60000011e-01 1.59978892e-01
 1.59454521e-01 1.60000011e-01]

It means the attack success rate is only 70% in CarliniLInfMethod with eps=0.16.

@beat-buesser
Copy link
Collaborator

Hi @kaitokishi Thank you very much! This is a really great documentation of your experiment! I will take a look immediately.

@beat-buesser
Copy link
Collaborator

Hi @kaitokishi I have been able to reproduce your results and you are right that we should add the missing minimisation of the perturbation. I would propose that we support this with ART 1.9.

We'll definitely work on a solution, but if you would be interested to start working on it please let us know, you would be most welcome!

@beat-buesser beat-buesser added the improvement Improve implementation label Oct 21, 2021
@beat-buesser beat-buesser added this to the ART 1.9.0 milestone Oct 21, 2021
@beat-buesser beat-buesser self-assigned this Oct 23, 2021
@beat-buesser beat-buesser linked a pull request Oct 24, 2021 that will close this issue
12 tasks
@beat-buesser
Copy link
Collaborator

Hi @kaitokishi

I have pushed updates for CarliniLInfMethod to branch development_issue_1374 and have obtained the following results with the most recent commit 8c76c29:

ART, ASR: 1.0
ART, LInf norm: [0.14354881 0.04386727 0.09895719 0.15535574 0.05742465 0.110842
 0.07077612 0.0431579  0.06195071 0.0781028 ]
ART, adversarial: [ True  True  True  True  True  True  True  True  True  True]

The perturbations and success rates are now similar to the original implementation. The attack runs slower, but supports all deep learning frameworks. We will test if we can reduce the number of gradient and loss calculations, apply Numba, etc. to accelerate the attack and I think it would be good to use your test case to also compare the attacks for L2 and L0.

I have used the modified version of attack.py below. I changed the arguments of the ART classifier to preprocessing=(0.5, 1.0)) and clip_values=(0.0, 1.0) to provide input images in range [0, 1], I think this was not necessary as your setup is correct, but it was one of my debugging steps.

attack.py

import sys

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.datasets import mnist

import logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler()
formatter = logging.Formatter("[%(levelname)s] %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)


def get_correct_predicted_indices(model):
    indices = []
    for i in range(len(x_test)):
        if y_test[i] == np.argmax(model.predict(np.array([x_test[i]]))):
            indices.append(i)
            if len(indices) == 10:
                break
    print(f'Selected indices: {indices}')
    return indices


# Load a dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# x_train = x_train[:,:,:,np.newaxis] / 255. - 0.5 # x_train.shape is (60000, 28, 28, 1)
# x_test = x_test[:,:,:,np.newaxis] / 255. - 0.5 # x_test.shape is (10000, 28, 28, 1)
x_train = x_train[:,:,:,np.newaxis] / 255. #- 0.5 # x_train.shape is (60000, 28, 28, 1)
x_test = x_test[:,:,:,np.newaxis] / 255. #- 0.5 # x_test.shape is (10000, 28, 28, 1)


if sys.argv[1] == 'original':
    import keras
    from li_attack import CarliniLi

    x_train = x_train - 0.5
    x_test = x_test - 0.5

    class Model:
        def __init__(self, restore):
            self.num_channels = 1
            self.image_size = 28
            self.num_labels = 10
            self.model = model

        def predict(self, data):
            return self.model(data)

    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
        # load model
        model = keras.models.load_model('mnist_model.h5')
        indices = get_correct_predicted_indices(model)
        attack = CarliniLi(sess, Model(model), targeted=False, learning_rate=0.01, max_iterations=1000)
        x_adv = attack.attack(x_test[indices], np.identity(10)[y_test[indices]]) # one-hot encoding is required
        print(f'original, ASR: {np.sum(y_test[indices] != np.argmax(model.predict(x_adv), axis=1)) / 10}')
        print(f'original, LInf norm: {np.max(np.abs(x_adv - x_test[indices]), axis=(1,2,3))}')

        #for i, each_x_adv in zip(indices, x_adv): # Optional: we can save the images with `img` directory.
        #    plt.imshow(each_x_adv, cmap='gray')
        #    plt.savefig(f'img/original-{i}.png')
elif sys.argv[1] == 'art':
    from tensorflow import keras
    # load model
    model = keras.models.load_model('mnist_model.h5')
    indices = get_correct_predicted_indices(model)
    from art.attacks.evasion import CarliniLInfMethod
    from art.estimators.classification import TensorFlowV2Classifier
    classifier = TensorFlowV2Classifier(model=model,
                                        nb_classes=10,
                                        input_shape=(28, 28, 1),
                                        loss_object=model.loss,
                                        clip_values=(0.0, 1.0),
                                        preprocessing=(0.5, 1.0))
    attack = CarliniLInfMethod(classifier=classifier,
                               targeted=False,
                               learning_rate=0.01,
                               max_iter=1000,
                               confidence=0.0)

    x_adv = attack.generate(x_test[indices], y_test[indices])
    print(f'ART, ASR: {np.sum(y_test[indices] != np.argmax(classifier.predict(x_adv), axis=1)) / 10}')
    print(f'ART, LInf norm: {np.max(np.abs(x_adv - x_test[indices]), axis=(1,2,3))}')
    print(f'ART, adversarial: {y_test[indices] != np.argmax(classifier.predict(x_adv), axis=1)}')

What do you think?

@kaitokishi
Copy link
Contributor Author

The result and the modified code seems great! The original implementation only supports TensorFlow v1, so supporting all deep learning frameworks by ART is brilliant.

I agree to use this attack.py to compare the attacks for L2 and L0.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
improvement Improve implementation
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants