Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ART does not work with Keras Embedding layers #33

Closed
peck94 opened this issue Feb 7, 2019 · 22 comments
Closed

ART does not work with Keras Embedding layers #33

peck94 opened this issue Feb 7, 2019 · 22 comments
Assignees

Comments

@peck94
Copy link

peck94 commented Feb 7, 2019

Describe the bug
I am unable to create any instances of art.classifiers.KerasClassifier whenever the underlying Keras model contains an Embedding layer. Using the TensorFlow backend, this invariably leads to a TypeError: Can not convert a NoneType into a Tensor or Operation.

To Reproduce
Steps to reproduce the behavior:

  1. Create any Keras model with an Embedding layer on the TensorFlow backend.
  2. Attempt to instantiate art.classifiers.KerasClassifier on it.
  3. Watch it fail.

Expected behavior
I expected ART to simply return an instance of KerasClassifier as it usually does.

Screenshots
N/A, but here's a minimal non-working example:

from keras.layers import Dense, Activation, Dropout, Embedding, LSTM
from art.classifiers import KerasClassifier

model = Sequential()
model.add(Embedding(100, 128, input_length=50))
model.add(LSTM(128))
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation('sigmoid'))

model.compile(loss='binary_crossentropy', optimizer='rmsprop', metrics=['binary_accuracy'])

classifier = KerasClassifier((0, 1), model=model)
classifier.fit(x_train, y_train, nb_epochs=10, batch_size=128)

System information (please complete the following information):

  • Ubuntu 18.04 LTS
  • Python version 3.6.5
  • ART version 0.5.0
  • TensorFlow version 1.12.0
  • Keras version 2.2.4
@ririnicolae
Copy link
Collaborator

@peck94 From your MWE, I take it you work with sequential data, not images? ART currently supports only image models, but you can expect an extension to other types of data in the upcoming months. FYI, embedding layers are not differentiable, so even if you could wrap the model under the KerasClassifier, you would not be able to apply most evasion attacks out of the box (as they need gradients). The changes required to have them work on embeddings is the reason why it will take us a bit longer to support this.

I take your point that you should still be able to train your Keras model through the ART wrapper and will look into it, see if we can provide some rapid support at least for training.

@ririnicolae ririnicolae self-assigned this Feb 7, 2019
@cr019283
Copy link

cr019283 commented May 2, 2019

I came across exactly the same problem and would be interested in finding out if there is a solution (or way around) for models with Embedding layers.

@ririnicolae
Copy link
Collaborator

@cr019283 The workaround would be to encapsulate in the model only the layers after the embedding. You would be able to run an attack end to end between the representation provided by the embedding and the output. For this setup, your embeddings would have to be treated as feature vectors. This is not currently supported in ART, so you would either have to tinker with the library code a bit or wait for feature vectors support from our side (WIP, see issue #49, expected completion date next week).

@step8p
Copy link
Contributor

step8p commented May 9, 2019

Hello, any news about this?

@ririnicolae
Copy link
Collaborator

@step8p PR for issue #49 will be in tomorrow, that would give you access to the same workaround that I suggested to @cr019283.

@ririnicolae
Copy link
Collaborator

@step8p @cr019283 Feature vectors support is now on dev branch, will be on master and pip release next Monday.

@step8p
Copy link
Contributor

step8p commented May 28, 2019

I saw that now this should be supported. However, I'm still trying to figure out how to exactly do it. Is there any example? Moreover, if I got it correct, this means that I can attack any of the layers, right?

@ririnicolae
Copy link
Collaborator

@step8p I'll try to put together an example for this. And yes, on principle you can attack any of the layers.

@step8p
Copy link
Contributor

step8p commented May 29, 2019

Given an already trained keras model "model" and a test dataset "x_test, y_test", i wrote this (and it seems working)

from keras import Model, Input

#Split the model in 2 parts
HL_model = Model(inputs=model.input, outputs=model.layers[1].output)

DL_input = Input(model.layers[2].input_shape[1:])
DL_model = DL_input
for layer in model.layers[2:]:
    DL_model = layer(DL_model)
DL_model = Model(inputs=DL_input, outputs=DL_model)

#Create the train dataset for model DL
pre_features = HL_model.predict(x_test)

# Evaluate the classifier on the test set
preds = np.argmax(DL_model.predict(pre_features), axis=1)
acc = np.sum(preds == np.argmax(y_test, axis=1)) / y_test.shape[0]
print("\nTest accuracy: %.2f%%" % (acc * 100))

# Craft adversarial samples with FGSM
ART_DL_model = KerasClassifier(clip_values=(0, 1), model=DL_model)
epsilon = .1  # Maximum perturbation
adv_crafter = FastGradientMethod(ART_DL_model, eps=epsilon)
pre_features_adv = adv_crafter.generate(x=pre_features)

# Evaluate the classifier on the adversarial examples
preds = np.argmax(DL_model.predict(pre_features_adv), axis=1)
acc = np.sum(preds == np.argmax(y_test, axis=1)) / y_test.shape[0]
print("\nTest accuracy on adversarial sample: %.2f%%" % (acc * 100))

obtaining

Test accuracy: 98.41%
Test accuracy on adversarial sample: 20.39%

Does this sounds correct to you or do you see something strange?

@ririnicolae
Copy link
Collaborator

@step8p Example looks good to me. The only thing I would change is completely remove the clip_values from the KerasClassifier: it's hard to tell if the features in the embedding space would actually be in the (0, 1) range, and adding this constraint will clip them during the attack. clip_values were mandatory so far in ART, but since the feature vectors extension you can omit them altogether.

Would you be interested in submitting your full example in a PR to be included in the ART examples folder?

@step8p
Copy link
Contributor

step8p commented May 29, 2019

I think you are totally right about the clip_values. I tried to remove it and I confirm that it works. About the example, I can not share the very same example because it's based on private data, but, if this works for you, I can realize an example showing how to adversarially perturb one of the conv layer in the middle for the mnist example

@step8p
Copy link
Contributor

step8p commented May 31, 2019

This should do the job

# -*- coding: utf-8 -*-
"""Trains a convolutional neural network on the MNIST dataset, then attacks it with the FGSM attack."""
from __future__ import absolute_import, division, print_function, unicode_literals

import sys
from os.path import abspath

sys.path.append(abspath('.'))

import keras.backend as k
from keras.models import Model, Sequential
from keras.layers import Input, Dense, Flatten, Conv2D, MaxPooling2D, Dropout
import numpy as np

from art.attacks.fast_gradient import FastGradientMethod
from art.classifiers import KerasClassifier
from art.utils import load_dataset

# Read MNIST dataset
(x_train, y_train), (x_test, y_test), min_, max_ = load_dataset(str('mnist'))

# Create Keras convolutional neural network - basic architecture from Keras examples
# Source here: https://github.com/keras-team/keras/blob/master/examples/mnist_cnn.py
k.set_learning_phase(1)
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=x_train.shape[1:]))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

classifier = KerasClassifier(clip_values=(min_, max_), model=model)
classifier.fit(x_train, y_train, nb_epochs=5, batch_size=128)

# Attack one of the inner layers, instead of the input one. In this example
# we are going to attack the second convolutional layer. To this aim, we need to
# split the network in 2 sub-nets, in order to have has input layer of the 
# second network, the layer we want to attack

HL_model = Model(inputs=model.input, outputs=model.layers[2].output)

DL_input = Input(model.layers[3].input_shape[1:])
DL_model = DL_input
for layer in model.layers[3:]:
    DL_model = layer(DL_model)
DL_model = Model(inputs=DL_input, outputs=DL_model)

classifier = KerasClassifier(model=DL_model)

# Now we need to create the dataset for the DL_model, since the original one is 
# suited only for the "model" network (and thus for the HL_model). Note that it
# is not needed to change the labels
x_test_inner = HL_model.predict(x_test)

# Evaluate the classifier on the test set
preds = np.argmax(classifier.predict(x_test_inner), axis=1)
acc = np.sum(preds == np.argmax(y_test, axis=1)) / y_test.shape[0]
print("\nTest accuracy: %.2f%%" % (acc * 100))

# Craft adversarial samples with FGSM
epsilon = .1  # Maximum perturbation
adv_crafter = FastGradientMethod(classifier, eps=epsilon)
x_test_adv = adv_crafter.generate(x=x_test_inner)

# Evaluate the classifier on the adversarial examples
preds = np.argmax(classifier.predict(x_test_adv), axis=1)
acc = np.sum(preds == np.argmax(y_test, axis=1)) / y_test.shape[0]
print("\nTest accuracy on adversarial sample: %.2f%%" % (acc * 100))

@ririnicolae
Copy link
Collaborator

@step8p Thanks for taking the time to adapt the example. Do you want to PR it yourself or do you prefer I do so?

@step8p
Copy link
Contributor

step8p commented Jun 10, 2019

Sorry @ririnicolae, I read this just now. Yes, I can do it

@step8p
Copy link
Contributor

step8p commented Jun 10, 2019

I'm not totally sure of how to do it. Should I first create the file in the examples folder and then PR it, or open a PR in the main folder?

@ririnicolae
Copy link
Collaborator

@step8p The easiest is probably to fork the repo and add your commits to the dev branch on your fork. Once you push them to your fork on GitHub, you'll get the option to open a PR on the main repo. Use the GitHub interface to open your PR to the dev branch. :) Oh, and don't forget to sign your commits with -s. Let me know how it goes!

@step8p
Copy link
Contributor

step8p commented Jun 10, 2019

Did it (but not sure if in the correct way XD)

@ririnicolae
Copy link
Collaborator

@step8p Did you submit the PR? I can't say I see it. :)

@step8p
Copy link
Contributor

step8p commented Jun 11, 2019

Probably I did it in the wrong way. This is the link. Can I fix it somehow? Sorry for that :(

@ririnicolae
Copy link
Collaborator

@step8p I can see the PR through the link that you provided, it is directed from a feature branch of your fork to the dev branch of the fork as well. Could you change the target of the PR from your dev to the dev in this repo?

@step8p
Copy link
Contributor

step8p commented Jun 13, 2019

Did it (hope in the correct way this time).

Btw, going back to the "clip values" point, a colleague of mine (IMHO correctly) suggested that the lower bound for the perturbation should be 0, since it has been injected (at least in the example) after a relu, so the perturbation should not add negative values.

@ririnicolae
Copy link
Collaborator

@step8p PR looks good, please take a look at the review comments! About the clip_values, you would normally not know much about the values in the embedding space, but I guess for the given example it makes sense to set the lower bound on clip_values at 0 if this comes just after the relu.

imolloy pushed a commit to imolloy/adversarial-robustness-toolbox that referenced this issue Aug 5, 2019
Optimize FGSM for minimal perturbation (closes Trusted-AI#26 )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants