# TensorFlow and TextAttack

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/QData/TextAttack/blob/master/docs/2notebook/Example_0_tensorflow.ipynb)

[![View Source on GitHub](https://img.shields.io/badge/github-view%20source-black.svg)](https://github.com/QData/TextAttack/blob/master/docs/2notebook/Example_0_tensorflow.ipynb)

Please remember to run  **pip3 install textattack[tensorflow]**  in your notebook enviroment before the following codes:

## Run textattack on a trained tensorflow model: 

### First: Training

The following is code for training a text classification model using TensorFlow (and on top of it, the Keras API). This comes from the Tensorflow documentation ([see here](https://www.tensorflow.org/tutorials/keras/text_classification_with_hub)).

This cell loads the IMDB dataset (using `tensorflow_datasets`, not `datasets`), initializes a simple classifier, and trains it using Keras.

In [1]:
import numpy as np

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds

import matplotlib.pyplot as plt

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print("Hub version: ", hub.__version__)
print(
    "GPU is", "available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE"
)

train_data, test_data = tfds.load(
    name="imdb_reviews", split=["train", "test"], batch_size=-1, as_supervised=True
)

train_examples, train_labels = tfds.as_numpy(train_data)
test_examples, test_labels = tfds.as_numpy(test_data)

model = "https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim/1"
hub_layer = hub.KerasLayer(
    model, output_shape=[20], input_shape=[], dtype=tf.string, trainable=True
)
hub_layer(train_examples[:3])

model = tf.keras.Sequential()
model.add(hub_layer)
model.add(tf.keras.layers.Dense(16, activation="relu"))
model.add(tf.keras.layers.Dense(1))

model.summary()

x_val = train_examples[:10000]
partial_x_train = train_examples[10000:]

y_val = train_labels[:10000]
partial_y_train = train_labels[10000:]

model.compile(
    optimizer="adam",
    loss=tf.losses.BinaryCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

history = model.fit(
    partial_x_train,
    partial_y_train,
    epochs=40,
    batch_size=512,
    validation_data=(x_val, y_val),
    verbose=1,
)

Version:  2.3.2
Eager mode:  True
Hub version:  0.12.0
GPU is NOT AVAILABLE
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
keras_layer (KerasLayer)     (None, 20)                400020    
_________________________________________________________________
dense (Dense)                (None, 16)                336       
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 17        
Total params: 400,373
Trainable params: 400,373
Non-trainable params: 0
_________________________________________________________________
Epoch 1/40
Epoch 2/40
Epoch 3/40
Epoch 4/40
Epoch 5/40
Epoch 6/40
Epoch 7/40
Epoch 8/40
Epoch 9/40
Epoch 10/40
Epoch 11/40
Epoch 12/40
Epoch 13/40
Epoch 14/40
Epoch 15/40
Epoch 16/40
Epoch 17/40
Epoch 18/40
Epoch 19/40
Epoch 20/40
Epoch 21/40
Epoch 22/40
Epoch 23/40
Epoch 24/40
Epoch 25/40
Epoch 26/40
E

### Attacking

For each input, our classifier outputs a single number that indicates how positive or negative the model finds the input. For binary classification, TextAttack expects two numbers for each input (a score for each class, positive and negative). We have to post-process each output to fit this TextAttack format. To add this post-processing we need to implement a custom model wrapper class (instead of using the built-in `textattack.models.wrappers.TensorFlowModelWrapper`).

Each `ModelWrapper` must implement a single method, `__call__`, which takes a list of strings and returns a `List`, `np.ndarray`, or `torch.Tensor` of predictions.

In [2]:
import numpy as np
import torch

from textattack.models.wrappers import ModelWrapper


class CustomTensorFlowModelWrapper(ModelWrapper):
    def __init__(self, model):
        self.model = model

    def __call__(self, text_input_list):
        text_array = np.array(text_input_list)
        preds = self.model(text_array).numpy()
        logits = torch.exp(-torch.tensor(preds))
        logits = 1 / (1 + logits)
        logits = logits.squeeze(dim=-1)
        # Since this model only has a single output (between 0 or 1),
        # we have to add the second dimension.
        final_preds = torch.stack((1 - logits, logits), dim=1)
        return final_preds

Let's test our model wrapper out to make sure it can use our model to return predictions in the correct format.

In [3]:
CustomTensorFlowModelWrapper(model)(["I hate you so much", "I love you"])

tensor([[0.1409, 0.8591],
        [0.0213, 0.9787]])

Looks good! Now we can initialize our model wrapper with the model we trained and pass it to an instance of `textattack.attack.Attack`. 

We'll use the `PWWSRen2019` recipe as our attack, and attack 10 samples.

In [4]:
model_wrapper = CustomTensorFlowModelWrapper(model)

from textattack.datasets import HuggingFaceDataset
from textattack.attack_recipes import PWWSRen2019
from textattack import Attacker

dataset = HuggingFaceDataset("rotten_tomatoes", None, "test", shuffle=True)
attack = PWWSRen2019.build(model_wrapper)

attacker = Attacker(attack, dataset)
attacker.attack_dataset()

textattack: Loading [94mdatasets[0m dataset [94mrotten_tomatoes[0m, split [94mtest[0m.
textattack: Unknown if model of class <class 'tensorflow.python.keras.engine.sequential.Sequential'> compatible with goal function <class 'textattack.goal_functions.classification.untargeted_classification.UntargetedClassification'>.
[Succeeded / Failed / Skipped / Total] 2 / 0 / 3 / 5:  50%|█████     | 5/10 [00:00<00:00, 43.42it/s] 

Attack(
  (search_method): GreedyWordSwapWIR(
    (wir_method):  weighted-saliency
  )
  (goal_function):  UntargetedClassification
  (transformation):  WordSwapWordNet
  (constraints): 
    (0): RepeatModification
    (1): StopwordModification
  (is_black_box):  True
) 

--------------------------------------------- Result 1 ---------------------------------------------
[91mNegative (90%)[0m --> [37m[SKIPPED][0m

lovingly photographed in the manner of a golden book sprung to life , stuart little 2 manages sweetness largely without stickiness .


--------------------------------------------- Result 2 ---------------------------------------------
[92mPositive (52%)[0m --> [91mNegative (97%)[0m

consistently clever and [92msuspenseful[0m .

consistently clever and [91mcliff-hanging[0m .


--------------------------------------------- Result 3 ---------------------------------------------
[92mPositive (89%)[0m --> [91mNegative (86%)[0m

it's like a " big chill " reunion of

[Succeeded / Failed / Skipped / Total] 4 / 0 / 3 / 7:  70%|███████   | 7/10 [00:00<00:00, 18.97it/s]

--------------------------------------------- Result 6 ---------------------------------------------
[92mPositive (99%)[0m --> [91mNegative (85%)[0m

fresnadillo has something serious to say about the [92mways[0m in which [92mextravagant[0m chance can distort our perspective and throw us off the path of [92mgood[0m sense .

fresnadillo has something serious to say about the [91mmanner[0m in which [91mexuberant[0m chance can distort our perspective and throw us off the path of [91mripe[0m sense .


--------------------------------------------- Result 7 ---------------------------------------------
[92mPositive (99%)[0m --> [91mNegative (73%)[0m

[92mthrows[0m in enough clever and unexpected [92mtwists[0m to make the formula feel fresh .

[91mflip[0m in enough clever and unexpected [91mconstruction[0m to make the formula feel fresh .




[Succeeded / Failed / Skipped / Total] 6 / 0 / 4 / 10: 100%|██████████| 10/10 [00:00<00:00, 17.90it/s]

--------------------------------------------- Result 8 ---------------------------------------------
[92mPositive (96%)[0m --> [91mNegative (93%)[0m

weighty and ponderous but every [92mbit[0m as filling as the [92mtreat[0m of the title .

weighty and ponderous but every [91mbite[0m as filling as the [91mcover[0m of the title .


--------------------------------------------- Result 9 ---------------------------------------------
[92mPositive (84%)[0m --> [91mNegative (70%)[0m

a [92mreal[0m audience-pleaser that will strike a chord with anyone who's ever waited in a doctor's office , emergency room , hospital [92mbed[0m or insurance company office .

a [91mmaterial[0m audience-pleaser that will strike a chord with anyone who's ever waited in a doctor's office , emergency room , hospital [91mscrew[0m or insurance company office .


--------------------------------------------- Result 10 ---------------------------------------------
[91mNegative (99%)[0m --> [3




[<textattack.attack_results.skipped_attack_result.SkippedAttackResult at 0x7f74758a3e50>,
 <textattack.attack_results.successful_attack_result.SuccessfulAttackResult at 0x7f74758a3c70>,
 <textattack.attack_results.successful_attack_result.SuccessfulAttackResult at 0x7f74758a3490>,
 <textattack.attack_results.skipped_attack_result.SkippedAttackResult at 0x7f74758a3fd0>,
 <textattack.attack_results.skipped_attack_result.SkippedAttackResult at 0x7f74758a39d0>,
 <textattack.attack_results.successful_attack_result.SuccessfulAttackResult at 0x7f7475903100>,
 <textattack.attack_results.successful_attack_result.SuccessfulAttackResult at 0x7f7475284460>,
 <textattack.attack_results.successful_attack_result.SuccessfulAttackResult at 0x7f74758a3910>,
 <textattack.attack_results.successful_attack_result.SuccessfulAttackResult at 0x7f74758a3790>,
 <textattack.attack_results.skipped_attack_result.SkippedAttackResult at 0x7f7489297c40>]

## Conclusion 

Looks good! We successfully loaded a model, adapted it for TextAttack's `ModelWrapper`, and used that object in an attack. This is basically how you would adapt any model, using TensorFlow or any other library, for use with TextAttack.