Skip to content

Commit

Permalink
Merge pull request #100 from QData/refactor
Browse files Browse the repository at this point in the history
Refactor
  • Loading branch information
uvafan committed May 19, 2020
2 parents 26dbdb3 + 3293d31 commit e7741da
Show file tree
Hide file tree
Showing 62 changed files with 890 additions and 564 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,6 @@ wandb/

# checkpoints
checkpoints/

# vim
*.swp
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ A `Transformation` takes as input a `TokenizedText` and returns a list of possib

### Search Methods

A search method is currently implemented in an extension of the `Attack` class, through implementing the `attack_one` method. The `get_transformations` function takes as input a `TokenizedText` object and outputs a list of possible transformations filtered by meeting all of the attack’s constraints. A search consists of successive calls to `get_transformations` until the search succeeds or is exhausted.
A search method takes as input an initial `goal_function` result and returns a final `goal_function` result. The search is given access to the `get_transformations` function, which takes as input a `TokenizedText` object and outputs a list of possible transformations filtered by meeting all of the attack’s constraints. A search consists of successive calls to `get_transformations` until the search succeeds (determined through `get_goal_results`) or is exhausted.

## Contributing to TextAttack

Expand Down

Large diffs are not rendered by default.

81 changes: 46 additions & 35 deletions docs/examples/2_Constraints.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,19 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[nltk_data] Downloading package punkt to /u/jm8wx/nltk_data...\n",
"[nltk_data] Downloading package punkt to /u/edl9cy/nltk_data...\n",
"[nltk_data] Package punkt is already up-to-date!\n",
"[nltk_data] Downloading package maxent_ne_chunker to\n",
"[nltk_data] /u/jm8wx/nltk_data...\n",
"[nltk_data] /u/edl9cy/nltk_data...\n",
"[nltk_data] Package maxent_ne_chunker is already up-to-date!\n",
"[nltk_data] Downloading package words to /u/jm8wx/nltk_data...\n",
"[nltk_data] Downloading package words to /u/edl9cy/nltk_data...\n",
"[nltk_data] Package words is already up-to-date!\n"
]
},
Expand All @@ -74,7 +74,7 @@
"True"
]
},
"execution_count": 7,
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -97,7 +97,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 2,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -151,7 +151,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -188,7 +188,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -214,7 +214,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand All @@ -236,7 +236,7 @@
" ('.', '.')]"
]
},
"execution_count": 11,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -264,7 +264,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -273,7 +273,7 @@
"class NamedEntityConstraint(Constraint):\n",
" \"\"\" A constraint that ensures `x_adv` only substitutes named entities from `x` with other named entities.\n",
" \"\"\"\n",
" def __call__(self, x, x_adv, original_text=None):\n",
" def _check_constraint(self, x, x_adv, original_text=None):\n",
" x_entities = get_entities(x.text)\n",
" x_adv_entities = get_entities(x_adv.text)\n",
" # If there aren't named entities, let's return False (the attack\n",
Expand Down Expand Up @@ -316,17 +316,13 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[34;1mtextattack\u001b[0m: Downloading https://textattack.s3.amazonaws.com/models/classification/lstm/yelp_polarity.\n",
"100%|██████████| 297M/297M [00:06<00:00, 48.3MB/s] \n",
"\u001b[34;1mtextattack\u001b[0m: Unzipping file path_to_zip_file to unzipped_folder_path.\n",
"\u001b[34;1mtextattack\u001b[0m: Successfully saved models/classification/lstm/yelp_polarity to cache.\n",
"\u001b[34;1mtextattack\u001b[0m: Goal function <class 'textattack.goal_functions.classification.untargeted_classification.UntargetedClassification'> matches model LSTMForYelpSentimentClassification.\n"
]
}
Expand All @@ -344,43 +340,56 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GreedyWordSwap(\n",
"Attack(\n",
" (search_method): GreedySearch\n",
" (goal_function): UntargetedClassification\n",
" (transformation): WordSwapEmbedding(\n",
" (max_candidates): 15\n",
" (embedding_type): paragramcf\n",
" (replace_stopwords): False\n",
" )\n",
" (constraints): \n",
" (0): NamedEntityConstraint\n",
" (1): RepeatModification\n",
" (2): StopwordModification\n",
" (is_black_box): True\n",
")\n"
]
}
],
"source": [
"from textattack.transformations import WordSwapEmbedding\n",
"from textattack.search_methods import GreedyWordSwap\n",
"from textattack.search_methods import GreedySearch\n",
"from textattack.constraints.pre_transformation import RepeatModification, StopwordModification\n",
"from textattack.shared import Attack\n",
"\n",
"# We're going to the `WordSwapEmbedding` transformation. Using the default settings, this\n",
"# will try substituting words with their neighbors in the counter-fitted embedding space. \n",
"transformation = WordSwapEmbedding(max_candidates=15) \n",
"# Now, let's make the attack using these parameters. And add one constraint: our \n",
"# custom NamedEntityConstraint.\n",
"attack = GreedyWordSwap(goal_function, transformation, constraints=[NamedEntityConstraint()])\n",
"\n",
"# We'll use the greedy search method again\n",
"search_method = GreedySearch()\n",
"\n",
"# Our constraints will be the same as Tutorial 1, plus the named entity constraint\n",
"constraints = [RepeatModification(),\n",
" StopwordModification(),\n",
" NamedEntityConstraint()]\n",
"\n",
"# Now, let's make the attack using these parameters. \n",
"attack = Attack(goal_function, constraints, transformation, search_method)\n",
"\n",
"print(attack)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand All @@ -389,7 +398,7 @@
"True"
]
},
"execution_count": 41,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -409,22 +418,24 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"from tqdm import tqdm # tqdm provides us a nice progress bar.\n",
"from textattack.loggers import CSVLogger # tracks a dataframe for us.\n",
"from textattack.attack_results import FailedAttackResult, SkippedAttackResult\n",
"from textattack.attack_results import SuccessfulAttackResult\n",
"\n",
"results_iterable = attack.attack_dataset(YelpSentiment(), attack_n=True)\n",
"logger = CSVLogger(color_method='html')\n",
"\n",
"num_successes = 0\n",
"while num_successes < 10:\n",
" result = next(results_iterable)\n",
" if not (isinstance(result, FailedAttackResult) or isinstance(result, SkippedAttackResult)):\n",
" if isinstance(result, SuccessfulAttackResult):\n",
" logger.log_attack_result(result)\n",
" num_successes += 1"
" num_successes += 1\n",
" print(num_successes)"
]
},
{
Expand All @@ -446,8 +457,8 @@
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>passage_1</th>\n",
" <th>passage_2</th>\n",
" <th>original_text</th>\n",
" <th>perturbed_text</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
Expand Down Expand Up @@ -534,9 +545,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "torch",
"display_name": "Python 3",
"language": "python",
"name": "build_central"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand Down
23 changes: 17 additions & 6 deletions local_tests/command_line_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@ def register_test(command, name=None, output_file=None, desc=None):
## BEGIN TESTS ##
#######################################


#
# test: run_attack --interactive
#
register_test(('printf "All that glitters is not gold\nq\n"',
'python -m textattack --recipe textfooler --model bert-imdb --interactive'),
name='interactive_mode',
output_file='local_tests/sample_outputs/interactive_mode.txt',
desc='Runs textfooler attack on BERT trained on IMDB using interactive mode')

#
# test: run_attack_parallel textfooler attack on 10 samples from BERT MR
# (takes about 81s)
Expand All @@ -29,10 +39,10 @@ def register_test(command, name=None, output_file=None, desc=None):
# test: run_attack_parallel textfooler attack on 10 samples from BERT SNLI
# (takes about 51s)
#
register_test('python -m textattack --model bert-snli --recipe textfooler --num-examples 10',
register_test('python -m textattack --model bert-snli --recipe deepwordbug --num-examples 10',
name='run_attack_textfooler_bert_snli_10',
output_file='local_tests/sample_outputs/run_attack_textfooler_bert_snli_10.txt',
desc='Runs attack using TextFooler recipe on BERT using 10 examples from the SNLI dataset')
output_file='local_tests/sample_outputs/run_attack_deepwordbug_bert_snli_10.txt',
desc='Runs attack using DeepWordBug recipe on BERT using 10 examples from the SNLI dataset')

#
# test: run_attack deepwordbug attack on 10 samples from LSTM MR
Expand All @@ -51,7 +61,7 @@ def register_test(command, name=None, output_file=None, desc=None):
#
register_test(('python -m textattack --attack-n --goal-function targeted-classification:target_class=2 '
'--enable-csv --model bert-mnli --num-examples 4 --transformation word-swap-wordnet '
'--constraints lang-tool --attack beam-search:beam_width=2'),
'--constraints lang-tool repeat stopword --search beam-search:beam_width=2'),
name='run_attack_targeted2_bertmnli_wordnet_beamwidth_2_enablecsv_attackn',
output_file='local_tests/sample_outputs/run_attack_targetedclassification2_wordnet_langtool_enable_csv_beamsearch2_attack_n_4.txt',
desc=('Runs attack using targeted classification on class 2 on BERT MNLI with'
Expand All @@ -67,10 +77,11 @@ def register_test(command, name=None, output_file=None, desc=None):
#
register_test(('python -m textattack --attack-n --goal-function non-overlapping-output '
'--model t5-en2de --num-examples 6 --transformation word-swap-random-char-substitution '
'--constraints edit-distance:12 words-perturbed:max_percent=0.75 --attack greedy-word'),
'--constraints edit-distance:12 max-words-perturbed:max_percent=0.75 repeat stopword '
'--search greedy'),
name='run_attack_nonoverlapping_t5en2de_randomcharsub_editdistance_wordsperturbed_greedyword',
output_file='local_tests/sample_outputs/run_attack_nonoverlapping_t5ende_editdistance_bleu.txt',
desc=('Runs attack using targeted classification on class 2 on BERT MNLI with'
'enable_csv and attack_n set, using the WordNet transformation and beam '
'search with beam width 2, using language tool constraint, on 10 samples')
)
)
41 changes: 41 additions & 0 deletions local_tests/sample_outputs/interactive_mode.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
Attack(
(search_method): GreedyWordSwapWIR(
(wir_method): unk
)
(goal_function): UntargetedClassification
(transformation): WordSwapEmbedding(
(max_candidates): 50
(embedding_type): paragramcf
)
(constraints):
(0): WordEmbeddingDistance(
(embedding_type): paragramcf
(min_cos_sim): 0.5
(cased): False
(include_unknown_words): True
)
(1): PartOfSpeech(
(tagset): universal
(allow_verb_noun_swap): True
)
(2): UniversalSentenceEncoder(
(metric): angular
(threshold): 0.904458599
(compare_with_original): False
(window_size): 15
(skip_text_shorter_than_window): True
)
(3): RepeatModification
(4): StopwordModification
(is_black_box): True
)

Load time: /.*/s
Running in interactive mode
----------------------------
Enter a sentence to attack or "q" to quit:
Attacking...
1-->0
All that glitters is not gold
All that glisten is not gold
Enter a sentence to attack or "q" to quit:

0 comments on commit e7741da

Please sign in to comment.