-
Notifications
You must be signed in to change notification settings - Fork 375
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #534 from QData/back-translation-transformation
Back translation transformation
- Loading branch information
Showing
9 changed files
with
178 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,4 +14,5 @@ | |
CheckListAugmenter, | ||
DeletionAugmenter, | ||
CLAREAugmenter, | ||
BackTranslationAugmenter, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 2 additions & 0 deletions
2
textattack/transformations/sentence_transformations/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .sentence_transformation import SentenceTransformation | ||
from .back_translation import BackTranslation |
145 changes: 145 additions & 0 deletions
145
textattack/transformations/sentence_transformations/back_translation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import random | ||
|
||
from transformers import MarianMTModel, MarianTokenizer | ||
|
||
from textattack.shared import AttackedText | ||
|
||
from .sentence_transformation import SentenceTransformation | ||
|
||
|
||
class BackTranslation(SentenceTransformation): | ||
"""A type of sentence level transformation that takes in a text input, | ||
translates it into target language and translates it back to source | ||
language. | ||
letters_to_insert (string): letters allowed for insertion into words | ||
(used by some char-based transformations) | ||
src_lang (string): source language | ||
target_lang (string): target language, for the list of supported language check bottom of this page | ||
src_model: translation model from huggingface that translates from source language to target language | ||
target_model: translation model from huggingface that translates from target language to source language | ||
chained_back_translation: run back translation in a chain for more perturbation (for example, en-es-en-fr-en) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
src_lang="en", | ||
target_lang="es", | ||
src_model="Helsinki-NLP/opus-mt-ROMANCE-en", | ||
target_model="Helsinki-NLP/opus-mt-en-ROMANCE", | ||
chained_back_translation=0, | ||
): | ||
self.src_lang = src_lang | ||
self.target_lang = target_lang | ||
self.target_model = MarianMTModel.from_pretrained(target_model) | ||
self.target_tokenizer = MarianTokenizer.from_pretrained(target_model) | ||
self.src_model = MarianMTModel.from_pretrained(src_model) | ||
self.src_tokenizer = MarianTokenizer.from_pretrained(src_model) | ||
self.chained_back_translation = chained_back_translation | ||
|
||
def translate(self, input, model, tokenizer, lang="es"): | ||
# change the text to model's format | ||
src_texts = [] | ||
if lang == "en": | ||
src_texts.append(input[0]) | ||
else: | ||
if ">>" and "<<" not in lang: | ||
lang = ">>" + lang + "<< " | ||
src_texts.append(lang + input[0]) | ||
|
||
# tokenize the input | ||
encoded_input = tokenizer.prepare_seq2seq_batch(src_texts, return_tensors="pt") | ||
|
||
# translate the input | ||
translated = model.generate(**encoded_input) | ||
translated_input = tokenizer.batch_decode(translated, skip_special_tokens=True) | ||
return translated_input | ||
|
||
def _get_transformations(self, current_text, indices_to_modify): | ||
transformed_texts = [] | ||
current_text = current_text.text | ||
|
||
# to perform chained back translation, a random list of target languages are selected from the provided model | ||
if self.chained_back_translation: | ||
list_of_target_lang = random.sample( | ||
self.target_tokenizer.supported_language_codes, | ||
self.chained_back_translation, | ||
) | ||
for target_lang in list_of_target_lang: | ||
target_language_text = self.translate( | ||
[current_text], | ||
self.target_model, | ||
self.target_tokenizer, | ||
target_lang, | ||
) | ||
src_language_text = self.translate( | ||
target_language_text, | ||
self.src_model, | ||
self.src_tokenizer, | ||
self.src_lang, | ||
) | ||
current_text = src_language_text[0] | ||
return [AttackedText(current_text)] | ||
|
||
# translates source to target language and back to source language (single back translation) | ||
target_language_text = self.translate( | ||
[current_text], self.target_model, self.target_tokenizer, self.target_lang | ||
) | ||
src_language_text = self.translate( | ||
target_language_text, self.src_model, self.src_tokenizer, self.src_lang | ||
) | ||
transformed_texts.append(AttackedText(src_language_text[0])) | ||
return transformed_texts | ||
|
||
|
||
""" | ||
List of supported languages | ||
['fr', | ||
'es', | ||
'it', | ||
'pt', | ||
'pt_br', | ||
'ro', | ||
'ca', | ||
'gl', | ||
'pt_BR<<', | ||
'la<<', | ||
'wa<<', | ||
'fur<<', | ||
'oc<<', | ||
'fr_CA<<', | ||
'sc<<', | ||
'es_ES', | ||
'es_MX', | ||
'es_AR', | ||
'es_PR', | ||
'es_UY', | ||
'es_CL', | ||
'es_CO', | ||
'es_CR', | ||
'es_GT', | ||
'es_HN', | ||
'es_NI', | ||
'es_PA', | ||
'es_PE', | ||
'es_VE', | ||
'es_DO', | ||
'es_EC', | ||
'es_SV', | ||
'an', | ||
'pt_PT', | ||
'frp', | ||
'lad', | ||
'vec', | ||
'fr_FR', | ||
'co', | ||
'it_IT', | ||
'lld', | ||
'lij', | ||
'lmo', | ||
'nap', | ||
'rm', | ||
'scn', | ||
'mwl'] | ||
""" |
8 changes: 8 additions & 0 deletions
8
textattack/transformations/sentence_transformations/sentence_transformation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
"""https://github.com/makcedward/nlpaug.""" | ||
|
||
from textattack.transformations import Transformation | ||
|
||
|
||
class SentenceTransformation(Transformation): | ||
def _get_transformations(self, current_text, indices_to_modify): | ||
raise NotImplementedError() |