Skip to content

Commit

Permalink
Merge pull request #82 from QData/seq2sick
Browse files Browse the repository at this point in the history
seq2sick
  • Loading branch information
jxmorris12 committed May 3, 2020
2 parents a8eeb73 + 6ce0ab1 commit c7eea91
Show file tree
Hide file tree
Showing 55 changed files with 686 additions and 220 deletions.
11 changes: 5 additions & 6 deletions scripts/benchmark_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

import textattack.models as models

def _cb(s): return textattack.shared.utils.color(str(s), color='blue', method='stdout')
def _cg(s): return textattack.shared.utils.color(str(s), color='green', method='stdout')
def _cr(s): return textattack.shared.utils.color(str(s), color='red', method='stdout')
def _cb(s): return textattack.shared.utils.color_text(str(s), color='blue', method='stdout')
def _cg(s): return textattack.shared.utils.color_text(str(s), color='green', method='stdout')
def _cr(s): return textattack.shared.utils.color_text(str(s), color='red', method='stdout')
def _pb(): print(_cg('-' * 60))

from collections import Counter
Expand Down Expand Up @@ -39,7 +39,7 @@ def test_model_on_dataset(model, dataset, batch_size=16, num_examples=100):
batch_labels = []
all_true_labels = []
all_guess_labels = []
for i, (label, text) in enumerate(dataset):
for i, (text, label) in enumerate(dataset):
if i >= num_examples: break
ids = model.tokenizer.encode(text)
batch_ids.append(ids)
Expand Down Expand Up @@ -69,10 +69,9 @@ def test_model_on_dataset(model, dataset, batch_size=16, num_examples=100):
def test_all_models(num_examples):
_pb()
for model_name in MODEL_CLASS_NAMES:
if model_name != 'bert-mr': continue
model = eval(MODEL_CLASS_NAMES[model_name])()
dataset = DATASET_BY_MODEL[model_name]()
print(f'\nTesting {_cr(model_name)} on {_cr(type(dataset))}...')
print(f'Testing {_cr(model_name)} on {_cr(type(dataset))}...')
test_model_on_dataset(model, dataset, num_examples=num_examples)
_pb()
# @TODO print the grid of models/dataset names with results in a nice table :)
Expand Down
2 changes: 0 additions & 2 deletions scripts/run_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
A command line parser to run an attack from user specifications.
"""



from run_attack_args_helper import get_args
from run_attack_parallel import run as run_parallel
from run_attack_single_threaded import run as run_single_threaded
Expand Down
20 changes: 15 additions & 5 deletions scripts/run_attack_args_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
'alzantot': 'textattack.attack_recipes.Alzantot2018GeneticAlgorithm',
'alz-adjusted': 'textattack.attack_recipes.Alzantot2018GeneticAlgorithmAdjusted',
'deepwordbug': 'textattack.attack_recipes.Gao2018DeepWordBug',
'seq2sick': 'textattack.attack_recipes.Cheng2018Seq2SickBlackBox',
'textfooler': 'textattack.attack_recipes.Jin2019TextFooler',
'tf-adjusted': 'textattack.attack_recipes.Jin2019TextFoolerAdjusted',
}
Expand All @@ -18,7 +19,6 @@
#
# Text classification models
#

# BERT models - default uncased
'bert-ag-news': 'textattack.models.classification.bert.BERTForAGNewsClassification',
'bert-imdb': 'textattack.models.classification.bert.BERTForIMDBSentimentClassification',
Expand All @@ -34,21 +34,28 @@
'lstm-imdb': 'textattack.models.classification.lstm.LSTMForIMDBSentimentClassification',
'lstm-mr': 'textattack.models.classification.lstm.LSTMForMRSentimentClassification',
'lstm-yelp-sentiment': 'textattack.models.classification.lstm.LSTMForYelpSentimentClassification',

#
# Textual entailment models
#

# BERT models
'bert-mnli': 'textattack.models.entailment.bert.BERTForMNLI',
'bert-snli': 'textattack.models.entailment.bert.BERTForSNLI',
#
# Translation models
#
't5-en2fr': 'textattack.models.translation.t5.T5EnglishToFrench',
't5-en2de': 'textattack.models.translation.t5.T5EnglishToGerman',
't5-en2ro': 'textattack.models.translation.t5.T5EnglishToRomanian',
#
# Summarization models
#
't5-summ': 'textattack.models.summarization.T5Summarization',
}

DATASET_BY_MODEL = {
#
# Text classification datasets
#

# AG News
'bert-ag-news': textattack.datasets.classification.AGNews,
'cnn-ag-news': textattack.datasets.classification.AGNews,
Expand All @@ -65,12 +72,15 @@
'bert-yelp-sentiment': textattack.datasets.classification.YelpSentiment,
'cnn-yelp-sentiment': textattack.datasets.classification.YelpSentiment,
'lstm-yelp-sentiment': textattack.datasets.classification.YelpSentiment,

#
# Textual entailment datasets
#
'bert-mnli': textattack.datasets.entailment.MNLI,
'bert-snli': textattack.datasets.entailment.SNLI,
#
# Translation datasets
#
't5-en2de': textattack.datasets.translation.NewsTest2013EnglishToGerman,
}

TRANSFORMATION_CLASS_NAMES = {
Expand Down
4 changes: 2 additions & 2 deletions scripts/run_attack_single_threaded.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def run(args):

tokenized_text = textattack.shared.tokenized_text.TokenizedText(text, goal_function.model.tokenizer)

result = goal_function.get_results([tokenized_text], goal_function.get_output(tokenized_text))[0]
result = goal_function.get_result(tokenized_text, goal_function.get_output(tokenized_text))
print('Attacking...')

result = next(attack.attack_dataset([(result.output, text)]))
result = next(attack.attack_dataset([(text, result.output)]))
print(result.__str__(color_method='stdout'))

else:
Expand Down
5 changes: 3 additions & 2 deletions textattack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from . import attack_methods
from . import constraints
from . import datasets
from . import goal_functions
from . import goal_function_results
from . import loggers
from . import models
from . import shared
from . import tokenizers
from . import transformations
from . import goal_functions
from . import transformations
36 changes: 20 additions & 16 deletions textattack/attack_methods/attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ class Attack:
An attack generates adversarial examples on text.
This is an abstract class that contains main helper functionality for
attacks. An attack is comprised of a search method and a transformation, as
well as one or more linguistic constraints that successful examples must
meet.
attacks. An attack is comprised of a search method, a goal function, and a
transformation, as well as one or more linguistic constraints that
successful examples must meet.
Args:
goal_function: A function for determining how well a perturbation is doing at achieving the attack's goal.
Expand All @@ -29,7 +29,7 @@ def __init__(self, goal_function, transformation, constraints=[], is_black_box=T
"""
self.goal_function = goal_function
if not self.goal_function:
raise NameError('Cannot instantiate attack without self.goal_function for prediction scores')
raise NameError('Cannot instantiate attack without self.goal_function for predictions')
if not hasattr(self, 'tokenizer'):
if hasattr(self.goal_function.model, 'tokenizer'):
self.tokenizer = self.goal_function.model.tokenizer
Expand Down Expand Up @@ -123,33 +123,33 @@ def _get_examples_from_dataset(self, dataset, num_examples=None, shuffle=False,
Gets examples from a dataset and tokenizes them.
Args:
dataset: An iterable of (label, text) pairs
dataset: An iterable of (text, ground_truth_output) pairs
num_examples (int): the number of examples to return
shuffle (:obj:`bool`, optional): Whether to shuffle the data
attack_n (bool): If `True`, returns `num_examples` non-skipped
examples. If `False`, returns `num_examples` total examples.
Returns:
results (List[Tuple[Int, TokenizedText, Boolean]]): a list of
objects containing (label, text, was_skipped)
results (Iterable[Tuple[GoalFunctionResult, Boolean]]): a list of
objects containing (text, ground_truth_output, was_skipped)
"""
examples = []
n = 0

if shuffle:
random.shuffle(dataset.examples)

for true_output, text in dataset:
for text, ground_truth_output in dataset:
tokenized_text = TokenizedText(text, self.tokenizer)
goal_function_result = self.goal_function.get_result(tokenized_text, true_output)
goal_function_result = self.goal_function.get_result(tokenized_text, ground_truth_output)
# We can skip examples for which the goal is already succeeded,
# unless `attack_skippable_examples` is True.
if (not attack_skippable_examples) and (goal_function_result.succeeded):
if not attack_n:
n += 1
# Store the true output on the goal function so that the
# SkippedAttackResult has the correct output, not the incorrect.
goal_function_result.output = true_output
goal_function_result.output = ground_truth_output
yield (goal_function_result, True)
else:
n += 1
Expand All @@ -159,13 +159,14 @@ def _get_examples_from_dataset(self, dataset, num_examples=None, shuffle=False,

def attack_dataset(self, dataset, num_examples=None, shuffle=False, attack_n=False):
"""
Runs an attack on the given dataset and outputs the results to the console and the output file.
Runs an attack on the given dataset and outputs the results to the
console and the output file.
Args:
dataset: An iterable of (label, text) pairs
dataset: An iterable of (text, ground_truth_output) pairs
shuffle (:obj:`bool`, optional): Whether to shuffle the data. Defaults to False.
"""

examples = self._get_examples_from_dataset(dataset,
num_examples=num_examples, shuffle=shuffle, attack_n=attack_n)

Expand Down Expand Up @@ -203,9 +204,12 @@ def __repr__(self):
)
# self.constraints
constraints_lines = []
for i, constraint in enumerate(self.constraints):
constraints_lines.append(utils.add_indent(f'({i}): {constraint}', 2))
constraints_str = utils.add_indent('\n' + '\n'.join(constraints_lines), 2)
if len(self.constraints):
for i, constraint in enumerate(self.constraints):
constraints_lines.append(utils.add_indent(f'({i}): {constraint}', 2))
constraints_str = utils.add_indent('\n' + '\n'.join(constraints_lines), 2)
else:
constraints_str = 'None'
lines.append(utils.add_indent(f'(constraints): {constraints_str}', 2))
# self.is_black_box
lines.append(utils.add_indent(f'(is_black_box): {self.is_black_box}', 2))
Expand Down
14 changes: 10 additions & 4 deletions textattack/attack_methods/greedy_word_swap_wir.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def attack_one(self, tokenized_text, correct_output):

leave_one_texts = \
[tokenized_text.replace_word_at_index(i,self.replacement_str) for i in range(len_text)]
leave_one_scores = np.array([result.score for result in self.goal_function.get_results(leave_one_texts, correct_output)])
leave_one_scores = np.array([result.score for result in \
self.goal_function.get_results(leave_one_texts, correct_output)])
index_order = (-leave_one_scores).argsort()

new_tokenized_text = None
Expand All @@ -59,7 +60,8 @@ def attack_one(self, tokenized_text, correct_output):
if len(transformed_text_candidates) == 0:
continue
num_words_changed += 1
results = sorted(self.goal_function.get_results(transformed_text_candidates, correct_output), key=lambda x: -x.score)
results = sorted(self.goal_function.get_results(transformed_text_candidates, correct_output),
key=lambda x: -x.score)
# Skip swaps which don't improve the score
if results[0].score > cur_score:
cur_score = results[0].score
Expand Down Expand Up @@ -89,6 +91,10 @@ def attack_one(self, tokenized_text, correct_output):
original_result,
best_result
)
tokenized_text = results[0].tokenized_text
else:
tokenized_text = results[0].tokenized_text

return FailedAttackResult(original_result, results[0])
if len(results):
return FailedAttackResult(original_result, results[0])
else:
return FailedAttackResult(original_result)
1 change: 1 addition & 0 deletions textattack/attack_recipes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .alzantot_2018_genetic_algorithm import Alzantot2018GeneticAlgorithm
from .alzantot_2018_genetic_algorithm_adjusted import Alzantot2018GeneticAlgorithmAdjusted
from .cheng_2018_seq2sick_blackbox import Cheng2018Seq2SickBlackBox
from .jin_2019_textfooler import Jin2019TextFooler
from .jin_2019_textfooler_adjusted import Jin2019TextFoolerAdjusted
from .gao_2018_deepwordbug import Gao2018DeepWordBug
38 changes: 38 additions & 0 deletions textattack/attack_recipes/cheng_2018_seq2sick_blackbox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
Cheng, Minhao, et al.
Seq2Sick: Evaluating the Robustness of Sequence-to-Sequence Models with
Adversarial Examples
ArXiv, abs/1803.01128.
This is a greedy re-implementation of the seq2sick attack method. It does
not use gradient descent.
"""

from textattack.attack_methods import GreedyWordSwapWIR
from textattack.constraints.overlap import LevenshteinEditDistance
from textattack.goal_functions import NonOverlappingOutput
from textattack.transformations import WordSwapEmbedding

def Cheng2018Seq2SickBlackBox(model, goal_function='non_overlapping'):
#
# Goal is non-overlapping output.
#
goal_function = NonOverlappingOutput(model)
# @TODO implement transformation / search method just like they do in
# seq2sick.
transformation = WordSwapEmbedding(max_candidates=50)
#
# In these experiments, we hold the maximum difference
# on edit distance (系) to a constant 30 for each sample.
#
#
# Greedily swap words with "Word Importance Ranking".
#
attack = GreedyWordSwapWIR(goal_function, transformation=transformation,
constraints=[], max_depth=10)

return attack
2 changes: 1 addition & 1 deletion textattack/attack_results/attack_result.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from textattack.goal_functions import GoalFunctionResult
from textattack.goal_function_results import GoalFunctionResult
from textattack.shared import utils

class AttackResult:
Expand Down
3 changes: 2 additions & 1 deletion textattack/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .dataset import TextAttackDataset

from . import classification
from . import entailment
from . import entailment
from . import translation
6 changes: 3 additions & 3 deletions textattack/datasets/classification/ag_news.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from textattack.shared import utils
from textattack.datasets import TextAttackDataset
from .classification_dataset import ClassificationDataset

class AGNews(TextAttackDataset):
class AGNews(ClassificationDataset):
"""
Loads samples from the AG News Dataset.
Expand Down Expand Up @@ -35,4 +35,4 @@ class AGNews(TextAttackDataset):
DATA_PATH = 'datasets/classification/ag_news.txt'
def __init__(self, offset=0):
""" Loads a full dataset from disk. """
self._load_text_file(AGNews.DATA_PATH, offset=offset)
self._load_classification_text_file(AGNews.DATA_PATH, offset=offset)
10 changes: 10 additions & 0 deletions textattack/datasets/classification/classification_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from textattack.datasets import TextAttackDataset

class ClassificationDataset(TextAttackDataset):
""" A generic class for loading classification data
"""
def _process_example_from_file(self, raw_line):
tokens = raw_line.strip().split()
label = int(tokens[0])
text = ' '.join(tokens[1:])
return (text, label)
6 changes: 3 additions & 3 deletions textattack/datasets/classification/imdb_sentiment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from textattack.datasets import TextAttackDataset
from .classification_dataset import ClassificationDataset

class IMDBSentiment(TextAttackDataset):
class IMDBSentiment(ClassificationDataset):
"""
Loads samples from the IMDB Movie Review Sentiment dataset.
Expand All @@ -16,4 +16,4 @@ class IMDBSentiment(TextAttackDataset):
DATA_PATH = 'datasets/classification/imdb.txt'
def __init__(self, offset=0):
""" Loads a full dataset from disk. """
self._load_text_file(IMDBSentiment.DATA_PATH, offset=offset)
self._load_classification_text_file(IMDBSentiment.DATA_PATH, offset=offset)
6 changes: 3 additions & 3 deletions textattack/datasets/classification/kaggle_fake_news.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from textattack.datasets import TextAttackDataset
from .classification_dataset import ClassificationDataset

class KaggleFakeNews(TextAttackDataset):
class KaggleFakeNews(ClassificationDataset):
"""
Loads samples from the Kaggle Fake News dataset. https://www.kaggle.com/mrisdal/fake-news
Expand All @@ -16,4 +16,4 @@ class KaggleFakeNews(TextAttackDataset):
DATA_PATH = 'datasets/classification/fake'
def __init__(self, offset=0):
""" Loads a full dataset from disk. """
self._load_text_file(KaggleFakeNews.DATA_PATH, offset=offset)
self._load_classification_text_file(KaggleFakeNews.DATA_PATH, offset=offset)
6 changes: 3 additions & 3 deletions textattack/datasets/classification/movie_review_sentiment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from textattack.datasets import TextAttackDataset
from .classification_dataset import ClassificationDataset

class MovieReviewSentiment(TextAttackDataset):
class MovieReviewSentiment(ClassificationDataset):
"""
Loads samples from the Movie Review Dataset. The "MR" dataset is comprised
of sentence-level sentiment classification on positive and negative movie
Expand All @@ -18,4 +18,4 @@ class MovieReviewSentiment(TextAttackDataset):
DATA_PATH = 'datasets/classification/mr.txt'
def __init__(self, offset=0):
""" Loads a full dataset from disk. """
self._load_text_file(MovieReviewSentiment.DATA_PATH, offset=offset)
self._load_classification_text_file(MovieReviewSentiment.DATA_PATH, offset=offset)

0 comments on commit c7eea91

Please sign in to comment.