Skip to content

Commit

Permalink
Merge pull request #89 from QData/augmentation
Browse files Browse the repository at this point in the history
augmentation in TextAttack
  • Loading branch information
jxmorris12 committed May 8, 2020
2 parents d072fa7 + d5df80b commit 983498b
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 3 deletions.
22 changes: 21 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ TextAttack provides pretrained models and datasets for user convenience. By defa

## Usage

### Basic Usage
### Running Attacks

The [`examples/`](examples/) folder contains notebooks walking through examples of basic usage of TextAttack, including building a custom transformation and a custom constraint.

Expand All @@ -71,6 +71,26 @@ The first are for classification and entailment attacks:
The final is for translation attacks:
- **seq2sick**: Greedy attack with goal of changing every word in the output translation. Currently implemented as black-box with plans to change to white-box as done in paper (["Seq2Sick: Evaluating the Robustness of Sequence-to-Sequence Models with Adversarial Examples"](https://arxiv.org/abs/1803.01128)).

### Augmenting Text

Many of the components of TextAttack are useful for data augmentation. The `textattack.Augmenter` class
uses a transformation and a list of constraints to augment data. We also offer three built-in recipes
for data augmentation:
- `textattack.WordNetAugmenter` augments text by replacing words with WordNet synonyms
- `textattack.EmbeddingAugmenter` augments text by replacing words with neighbors in the counter-fitted embedding space, with a constraint to ensure their cosine similarity is at least 0.8
- `textattack.CharSwapAugmenter` augments text by substituting, deleting, inserting, and swapping adjacent characters

All `Augmenter` objects implement `augment` and `augment_many` to generate augmentations
of a string or a list of strings. Here's an example of how to use the `EmbeddingAugmenter`:

```
>>> from textattack.augmentation import EmbeddingAugmenter
>>> augmenter = EmbeddingAugmenter()
>>> s = 'What I cannot create, I do not understand.'
>>> augmenter.augment(s)
['What I notable create, I do not understand.', 'What I significant create, I do not understand.', 'What I cannot engender, I do not understand.', 'What I cannot creating, I do not understand.', 'What I cannot creations, I do not understand.', 'What I cannot create, I do not comprehend.', 'What I cannot create, I do not fathom.', 'What I cannot create, I do not understanding.', 'What I cannot create, I do not understands.', 'What I cannot create, I do not understood.', 'What I cannot create, I do not realise.']
```

## Design

### TokenizedText
Expand Down
17 changes: 16 additions & 1 deletion local_tests/python_function_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,19 @@ def import_textattack():

register_test(import_textattack, name='import textattack',
output_file='local_tests/sample_outputs/empty_file.txt',
desc='Makes sure the textattack module can be imported')
desc='Makes sure the textattack module can be imported')
#
# test: import augmenter
#
def use_embedding_augmenter():
from textattack.augmentation import EmbeddingAugmenter
augmenter = EmbeddingAugmenter()
s = 'There is nothing either good or bad, but thinking makes it so.'
augmented_text_list = augmenter.augment(s)
augmented_s = 'There is nothing either good or unfavourable, but thinking makes it so.'
assert augmented_s in augmented_text_list


register_test(use_embedding_augmenter, name='use EmbeddingAugmenter',
output_file='local_tests/sample_outputs/empty_file.txt',
desc='Imports EmbeddingAugmenter and augments a single sentence')
1 change: 1 addition & 0 deletions textattack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from . import attack_recipes
from . import attack_results
from . import attack_methods
from . import augmentation
from . import constraints
from . import datasets
from . import goal_functions
Expand Down
2 changes: 2 additions & 0 deletions textattack/augmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .augmenter import Augmenter
from .recipes import WordNetAugmenter, EmbeddingAugmenter, CharSwapAugmenter
50 changes: 50 additions & 0 deletions textattack/augmentation/augmenter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from textattack.shared.tokenized_text import TokenizedText

class Augmenter:
""" A class for performing data augmentation using TextAttack.
Returns all possible transformations for a given string.
Args:
transformation (textattack.Transformation): the transformation
that suggests new texts from an input.
constraints: (list(textattack.Constraint)): constraints
that each transformation must meet
"""
def __init__(self, transformation, constraints=[]):
self.transformation = transformation
self.constraints = constraints

def _filter_transformations(self, tokenized_text, transformations):
""" Filters a list of `TokenizedText` objects to include only the ones
that pass `self.constraints`.
"""
for C in self.constraints:
if len(transformations) == 0: break
transformations = C.call_many(tokenized_text, transformations, original_text=tokenized_text)
return transformations

def augment(self, text):
""" Returns all possible augmentations of `text` according to
`self.transformation`.
"""
tokenized_text = TokenizedText(text, DummyTokenizer())
# Get potential transformations for text.
transformations = self.transformation(tokenized_text)
# Filter out transformations that don't match the constraints.
transformations = self._filter_transformations(tokenized_text, transformations)
return [t.clean_text() for t in transformations]

def augment_many(self, text_list):
""" Returns all possible augmentations of a list of strings according to
`self.transformation`.
"""
return [self.augment(text) for text in text_list]

class DummyTokenizer:
""" A dummy tokenizer class. Data augmentation applies a transformation
without querying a model, which means that tokenization is unnecessary.
In this case, we pass a dummy tokenizer to `TokenizedText`.
"""
def encode(self, _):
return []
45 changes: 45 additions & 0 deletions textattack/augmentation/recipes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from . import Augmenter

import textattack

class WordNetAugmenter(Augmenter):
""" Augments text by replacing with synonyms from the WordNet thesaurus. """
def __init__(self):
from textattack.transformations.black_box import WordSwapWordNet
transformation = WordSwapWordNet()
super().__init__(transformation, constraints=[])


class EmbeddingAugmenter(Augmenter):
""" Augments text by transforming words with their embeddings. """
def __init__(self):
from textattack.transformations.black_box import WordSwapEmbedding
transformation = WordSwapEmbedding(
max_candidates=50, embedding_type='paragramcf'
)
from textattack.constraints.semantics import WordEmbeddingDistance
constraints = [
WordEmbeddingDistance(min_cos_sim=0.8)
]
super().__init__(transformation, constraints=constraints)


class CharSwapAugmenter(Augmenter):
""" Augments words by swapping characters out for other characters. """
def __init__(self):
from textattack.transformations import CompositeTransformation
from textattack.transformations.black_box import \
WordSwapNeighboringCharacterSwap, \
WordSwapRandomCharacterDeletion, WordSwapRandomCharacterInsertion, \
WordSwapRandomCharacterSubstitution, WordSwapNeighboringCharacterSwap
transformation = CompositeTransformation([
# (1) Swap: Swap two adjacent letters in the word.
WordSwapNeighboringCharacterSwap(),
# (2) Substitution: Substitute a letter in the word with a random letter.
WordSwapRandomCharacterSubstitution(),
# (3) Deletion: Delete a random letter from the word.
WordSwapRandomCharacterDeletion(),
# (4) Insertion: Insert a random letter in the word.
WordSwapRandomCharacterInsertion()
])
super().__init__(transformation, constraints=[])
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(self, embedding_type='paragramcf', include_unknown_words=True,
self.cos_sim_mat = {}


def call_many(self, x, x_adv_list, original_word=None):
def call_many(self, x, x_adv_list, original_text=None):
""" Returns each `x_adv` from `x_adv_list` where `C(x,x_adv)` is True.
"""
return [x_adv for x_adv in x_adv_list if self(x, x_adv)]
Expand Down

0 comments on commit 983498b

Please sign in to comment.