Skip to content

Commit

Permalink
Merge pull request #534 from QData/back-translation-transformation
Browse files Browse the repository at this point in the history
Back translation transformation
  • Loading branch information
qiyanjun committed Oct 15, 2021
2 parents d483b17 + 830de53 commit c5c10f5
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 5 deletions.
1 change: 1 addition & 0 deletions tests/sample_outputs/list_augmentation_recipes.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
back_trans (textattack.augmentation.BackTranslationAugmenter)
charswap (textattack.augmentation.CharSwapAugmenter)
checklist (textattack.augmentation.CheckListAugmenter)
clare (textattack.augmentation.CLAREAugmenter)
Expand Down
1 change: 1 addition & 0 deletions textattack/augment_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"eda": "textattack.augmentation.EasyDataAugmenter",
"checklist": "textattack.augmentation.CheckListAugmenter",
"clare": "textattack.augmentation.CLAREAugmenter",
"back_trans": "textattack.augmentation.BackTranslationAugmenter",
}


Expand Down
1 change: 1 addition & 0 deletions textattack/augmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
CheckListAugmenter,
DeletionAugmenter,
CLAREAugmenter,
BackTranslationAugmenter,
)
13 changes: 13 additions & 0 deletions textattack/augmentation/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,16 @@ def __init__(
constraints = DEFAULT_CONSTRAINTS + [use_constraint]

super().__init__(transformation, constraints=constraints, **kwargs)


class BackTranslationAugmenter(Augmenter):
"""Sentence level augmentation that uses MarianMTModel to back-translate.
https://huggingface.co/transformers/model_doc/marian.html
"""

def __init__(self, **kwargs):
from textattack.transformations.sentence_transformations import BackTranslation

transformation = BackTranslation(chained_back_translation=5)
super().__init__(transformation, **kwargs)
3 changes: 2 additions & 1 deletion textattack/commands/augment_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def run(self, args):
"\tAugmentation recipe name ('r' to see available recipes): "
)
if recipe == "r":
print("\n\twordnet, embedding, charswap, eda, checklist\n")
recipe_display = " ".join(AUGMENTATION_RECIPE_NAMES.keys())
print(f"\n\t{recipe_display}\n")
args.recipe = input("\tAugmentation recipe name: ")
else:
args.recipe = recipe
Expand Down
9 changes: 5 additions & 4 deletions textattack/shared/utils/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def s3_url(uri):


def download_from_s3(folder_name, skip_if_cached=True):
"""Folder name will be saved as `<cache_dir>/textattack/<folder_name>`. If it
doesn't exist on disk, the zip file will be downloaded and extracted.
"""Folder name will be saved as `<cache_dir>/textattack/<folder_name>`. If
it doesn't exist on disk, the zip file will be downloaded and extracted.
Args:
folder_name (str): path to folder or file in cache
Expand Down Expand Up @@ -68,8 +68,9 @@ def download_from_s3(folder_name, skip_if_cached=True):


def download_from_url(url, save_path, skip_if_cached=True):
"""Downloaded file will be saved under `<cache_dir>/textattack/<save_path>`. If it
doesn't exist on disk, the zip file will be downloaded and extracted.
"""Downloaded file will be saved under
`<cache_dir>/textattack/<save_path>`. If it doesn't exist on disk, the zip
file will be downloaded and extracted.
Args:
url (str): URL path from which to download.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .sentence_transformation import SentenceTransformation
from .back_translation import BackTranslation
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import random

from transformers import MarianMTModel, MarianTokenizer

from textattack.shared import AttackedText

from .sentence_transformation import SentenceTransformation


class BackTranslation(SentenceTransformation):
"""A type of sentence level transformation that takes in a text input,
translates it into target language and translates it back to source
language.
letters_to_insert (string): letters allowed for insertion into words
(used by some char-based transformations)
src_lang (string): source language
target_lang (string): target language, for the list of supported language check bottom of this page
src_model: translation model from huggingface that translates from source language to target language
target_model: translation model from huggingface that translates from target language to source language
chained_back_translation: run back translation in a chain for more perturbation (for example, en-es-en-fr-en)
"""

def __init__(
self,
src_lang="en",
target_lang="es",
src_model="Helsinki-NLP/opus-mt-ROMANCE-en",
target_model="Helsinki-NLP/opus-mt-en-ROMANCE",
chained_back_translation=0,
):
self.src_lang = src_lang
self.target_lang = target_lang
self.target_model = MarianMTModel.from_pretrained(target_model)
self.target_tokenizer = MarianTokenizer.from_pretrained(target_model)
self.src_model = MarianMTModel.from_pretrained(src_model)
self.src_tokenizer = MarianTokenizer.from_pretrained(src_model)
self.chained_back_translation = chained_back_translation

def translate(self, input, model, tokenizer, lang="es"):
# change the text to model's format
src_texts = []
if lang == "en":
src_texts.append(input[0])
else:
if ">>" and "<<" not in lang:
lang = ">>" + lang + "<< "
src_texts.append(lang + input[0])

# tokenize the input
encoded_input = tokenizer.prepare_seq2seq_batch(src_texts, return_tensors="pt")

# translate the input
translated = model.generate(**encoded_input)
translated_input = tokenizer.batch_decode(translated, skip_special_tokens=True)
return translated_input

def _get_transformations(self, current_text, indices_to_modify):
transformed_texts = []
current_text = current_text.text

# to perform chained back translation, a random list of target languages are selected from the provided model
if self.chained_back_translation:
list_of_target_lang = random.sample(
self.target_tokenizer.supported_language_codes,
self.chained_back_translation,
)
for target_lang in list_of_target_lang:
target_language_text = self.translate(
[current_text],
self.target_model,
self.target_tokenizer,
target_lang,
)
src_language_text = self.translate(
target_language_text,
self.src_model,
self.src_tokenizer,
self.src_lang,
)
current_text = src_language_text[0]
return [AttackedText(current_text)]

# translates source to target language and back to source language (single back translation)
target_language_text = self.translate(
[current_text], self.target_model, self.target_tokenizer, self.target_lang
)
src_language_text = self.translate(
target_language_text, self.src_model, self.src_tokenizer, self.src_lang
)
transformed_texts.append(AttackedText(src_language_text[0]))
return transformed_texts


"""
List of supported languages
['fr',
'es',
'it',
'pt',
'pt_br',
'ro',
'ca',
'gl',
'pt_BR<<',
'la<<',
'wa<<',
'fur<<',
'oc<<',
'fr_CA<<',
'sc<<',
'es_ES',
'es_MX',
'es_AR',
'es_PR',
'es_UY',
'es_CL',
'es_CO',
'es_CR',
'es_GT',
'es_HN',
'es_NI',
'es_PA',
'es_PE',
'es_VE',
'es_DO',
'es_EC',
'es_SV',
'an',
'pt_PT',
'frp',
'lad',
'vec',
'fr_FR',
'co',
'it_IT',
'lld',
'lij',
'lmo',
'nap',
'rm',
'scn',
'mwl']
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""https://github.com/makcedward/nlpaug."""

from textattack.transformations import Transformation


class SentenceTransformation(Transformation):
def _get_transformations(self, current_text, indices_to_modify):
raise NotImplementedError()

0 comments on commit c5c10f5

Please sign in to comment.