-
Notifications
You must be signed in to change notification settings - Fork 375
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #89 from QData/augmentation
augmentation in TextAttack
- Loading branch information
Showing
7 changed files
with
136 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .augmenter import Augmenter | ||
from .recipes import WordNetAugmenter, EmbeddingAugmenter, CharSwapAugmenter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from textattack.shared.tokenized_text import TokenizedText | ||
|
||
class Augmenter: | ||
""" A class for performing data augmentation using TextAttack. | ||
Returns all possible transformations for a given string. | ||
Args: | ||
transformation (textattack.Transformation): the transformation | ||
that suggests new texts from an input. | ||
constraints: (list(textattack.Constraint)): constraints | ||
that each transformation must meet | ||
""" | ||
def __init__(self, transformation, constraints=[]): | ||
self.transformation = transformation | ||
self.constraints = constraints | ||
|
||
def _filter_transformations(self, tokenized_text, transformations): | ||
""" Filters a list of `TokenizedText` objects to include only the ones | ||
that pass `self.constraints`. | ||
""" | ||
for C in self.constraints: | ||
if len(transformations) == 0: break | ||
transformations = C.call_many(tokenized_text, transformations, original_text=tokenized_text) | ||
return transformations | ||
|
||
def augment(self, text): | ||
""" Returns all possible augmentations of `text` according to | ||
`self.transformation`. | ||
""" | ||
tokenized_text = TokenizedText(text, DummyTokenizer()) | ||
# Get potential transformations for text. | ||
transformations = self.transformation(tokenized_text) | ||
# Filter out transformations that don't match the constraints. | ||
transformations = self._filter_transformations(tokenized_text, transformations) | ||
return [t.clean_text() for t in transformations] | ||
|
||
def augment_many(self, text_list): | ||
""" Returns all possible augmentations of a list of strings according to | ||
`self.transformation`. | ||
""" | ||
return [self.augment(text) for text in text_list] | ||
|
||
class DummyTokenizer: | ||
""" A dummy tokenizer class. Data augmentation applies a transformation | ||
without querying a model, which means that tokenization is unnecessary. | ||
In this case, we pass a dummy tokenizer to `TokenizedText`. | ||
""" | ||
def encode(self, _): | ||
return [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from . import Augmenter | ||
|
||
import textattack | ||
|
||
class WordNetAugmenter(Augmenter): | ||
""" Augments text by replacing with synonyms from the WordNet thesaurus. """ | ||
def __init__(self): | ||
from textattack.transformations.black_box import WordSwapWordNet | ||
transformation = WordSwapWordNet() | ||
super().__init__(transformation, constraints=[]) | ||
|
||
|
||
class EmbeddingAugmenter(Augmenter): | ||
""" Augments text by transforming words with their embeddings. """ | ||
def __init__(self): | ||
from textattack.transformations.black_box import WordSwapEmbedding | ||
transformation = WordSwapEmbedding( | ||
max_candidates=50, embedding_type='paragramcf' | ||
) | ||
from textattack.constraints.semantics import WordEmbeddingDistance | ||
constraints = [ | ||
WordEmbeddingDistance(min_cos_sim=0.8) | ||
] | ||
super().__init__(transformation, constraints=constraints) | ||
|
||
|
||
class CharSwapAugmenter(Augmenter): | ||
""" Augments words by swapping characters out for other characters. """ | ||
def __init__(self): | ||
from textattack.transformations import CompositeTransformation | ||
from textattack.transformations.black_box import \ | ||
WordSwapNeighboringCharacterSwap, \ | ||
WordSwapRandomCharacterDeletion, WordSwapRandomCharacterInsertion, \ | ||
WordSwapRandomCharacterSubstitution, WordSwapNeighboringCharacterSwap | ||
transformation = CompositeTransformation([ | ||
# (1) Swap: Swap two adjacent letters in the word. | ||
WordSwapNeighboringCharacterSwap(), | ||
# (2) Substitution: Substitute a letter in the word with a random letter. | ||
WordSwapRandomCharacterSubstitution(), | ||
# (3) Deletion: Delete a random letter from the word. | ||
WordSwapRandomCharacterDeletion(), | ||
# (4) Insertion: Insert a random letter in the word. | ||
WordSwapRandomCharacterInsertion() | ||
]) | ||
super().__init__(transformation, constraints=[]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters