Skip to content

Commit

Permalink
Merge pull request #515 from QData/keras-parallel-fix
Browse files Browse the repository at this point in the history
[CODE] Keras parallel attack fix - Issue #499
  • Loading branch information
qiyanjun committed Aug 21, 2021
2 parents f64df48 + f7a2218 commit a3d8742
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ The [documentation website](https://textattack.readthedocs.io/en/latest) contain

The easiest way to try out an attack is via the command-line interface, `textattack attack`.

> **Tip:** If your machine has multiple GPUs, you can distribute the attack across them using the `--parallel` option. For some attacks, this can really help performance.
> **Tip:** If your machine has multiple GPUs, you can distribute the attack across them using the `--parallel` option. For some attacks, this can really help performance. (If you want to attack Keras models in parallel, please check out `examples/attack/attack_keras_parallel.py` instead)
Here are some concrete examples:

Expand Down
10 changes: 7 additions & 3 deletions docs/1start/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,12 @@ textattack train --model-name-or-path lstm --dataset yelp_polarity --epochs 50
```


*Fine-Tune `bert-base` on the `CoLA` dataset for 5 epochs**:
*Fine-Tune `bert-base` on the `CoLA` dataset for 5 epochs*:
```bash
textattack train --model-name-or-path bert-base-uncased --dataset glue^cola --per-device-train-batch-size 8 --epochs 5
```




### 2. Use Custom Models

TextAttack is model-agnostic! You can use `TextAttack` to analyze any model that outputs IDs, tensors, or strings. To help users, TextAttack includes pre-trained models for different common NLP tasks. This makes it easier for
Expand Down Expand Up @@ -138,3 +136,9 @@ This modular design unifies adversarial attack methods into one system, enables



### 6. The attacking is too slow


- **Tip:** If your machine has multiple GPUs, you can distribute the attack across them using the `--parallel` option. For some attacks, this can really help performance.

- If you want to attack Keras models in parallel, please check out `examples/attack/attack_keras_parallel.py` instead. (This is a hotfix for issues caused by a recent update of Keras in TF)
142 changes: 142 additions & 0 deletions examples/attack/attack_keras_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""
Recent upgrade of keras versions in TF 2.5+, keras has been moved to tf.keras
This has resulted in certain exceptions when keras models are attacked in parallel
This script fixes this behavior by adding an official hotfix for this situation detailed here:
https://github.com/tensorflow/tensorflow/issues/34697
All models/dataset are similar to keras attack tutorial at :
https://textattack.readthedocs.io/en/latest/2notebook/Example_3_Keras.html#
NOTE: This fix might be deprecated in future TF releases
NOTE: This script is not designed to run in a Jupyter notebook due to conflicting namespace issues
We recommend running it as a script only
"""

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.utils import to_categorical
from tensorflow.python.keras.layers import deserialize, serialize
from tensorflow.python.keras.saving import saving_utils
import torch

from textattack import AttackArgs, Attacker
from textattack.attack_recipes import PWWSRen2019
from textattack.datasets import HuggingFaceDataset
from textattack.models.wrappers import ModelWrapper

NUM_WORDS = 1000


def unpack(model, training_config, weights):
restored_model = deserialize(model)
if training_config is not None:
restored_model.compile(
**saving_utils.compile_args_from_training_config(training_config)
)
restored_model.set_weights(weights)
return restored_model


# Hotfix function
def make_keras_picklable():
def __reduce__(self):
model_metadata = saving_utils.model_metadata(self)
training_config = model_metadata.get("training_config", None)
model = serialize(self)
weights = self.get_weights()
return (unpack, (model, training_config, weights))

cls = Model
cls.__reduce__ = __reduce__


# Run the function
make_keras_picklable()


def transform(x):
x_transform = []
for i, word_indices in enumerate(x):
BoW_array = np.zeros((NUM_WORDS,))
for index in word_indices:
if index < len(BoW_array):
BoW_array[index] += 1
x_transform.append(BoW_array)
return np.array(x_transform)


class CustomKerasModelWrapper(ModelWrapper):
def __init__(self, model):
self.model = model

def __call__(self, text_input_list):

x_transform = []
for i, review in enumerate(text_input_list):
tokens = [x.strip(",") for x in review.split()]
BoW_array = np.zeros((NUM_WORDS,))
for word in tokens:
if word in vocabulary:
if vocabulary[word] < len(BoW_array):
BoW_array[vocabulary[word]] += 1
x_transform.append(BoW_array)
x_transform = np.array(x_transform)
prediction = self.model.predict(x_transform)
return prediction


model = Sequential()
model.add(Dense(512, activation="relu", input_dim=NUM_WORDS))
model.add(Dropout(0.3))
model.add(Dense(100, activation="relu"))
model.add(Dense(2, activation="sigmoid"))
opt = tf.keras.optimizers.Adam(learning_rate=0.00001)

model.compile(optimizer=opt, loss="binary_crossentropy", metrics=["accuracy"])


(x_train_tokens, y_train), (x_test_tokens, y_test) = tf.keras.datasets.imdb.load_data(
path="imdb.npz",
num_words=NUM_WORDS,
skip_top=0,
maxlen=None,
seed=113,
start_char=1,
oov_char=2,
index_from=3,
)


index = int(0.9 * len(x_train_tokens))
x_train = transform(x_train_tokens)[:index]
x_test = transform(x_test_tokens)[index:]
y_train = np.array(y_train[:index])
y_test = np.array(y_test[index:])
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

vocabulary = tf.keras.datasets.imdb.get_word_index(path="imdb_word_index.json")

results = model.fit(
x_train, y_train, epochs=1, batch_size=512, validation_data=(x_test, y_test)
)


if __name__ == "__main__":
torch.multiprocessing.freeze_support()

model_wrapper = CustomKerasModelWrapper(model)
dataset = HuggingFaceDataset("rotten_tomatoes", None, "test", shuffle=True)

attack = PWWSRen2019.build(model_wrapper)

attack_args = AttackArgs(
num_examples=10,
checkpoint_dir="checkpoints",
parallel=True,
num_workers_per_device=2,
)

attacker = Attacker(attack, dataset, attack_args)

attacker.attack_dataset()

0 comments on commit a3d8742

Please sign in to comment.