In [2]:
# !pip install -U transformers 
# !pip install mosestokenizer

In [15]:
def translate(texts, model, tokenizer, language="fr"):
    """
    """
    # Prepare the text data into appropriate format for the model
    template = lambda text: f"{text}" if language == "en" else f">>{language}<< {text}"
    src_texts = [template(text) for text in texts]
    print(src_texts)
    # Tokenize the texts
    encoded = tokenizer.prepare_seq2seq_batch(src_texts)
    
    # Generate translation using model
    translated = model.generate(**encoded)

    # Convert the generated tokens indices back into text
    translated_texts = tokenizer.batch_decode(translated, skip_special_tokens=True)
    
    return translated_texts

In [32]:
def prep_model_tokenizer(romance_lang: str = "es"):
    ROMANCE = romance_lang
    target_model_name = f"Helsinki-NLP/opus-mt-en-{ROMANCE}"
    target_tokenizer = MarianTokenizer.from_pretrained(target_model_name)
    target_model = MarianMTModel.from_pretrained(target_model_name)

    en_model_name = f"Helsinki-NLP/opus-mt-{ROMANCE}-en"
    en_tokenizer = MarianTokenizer.from_pretrained(en_model_name)
    en_model = MarianMTModel.from_pretrained(en_model_name)
    return en_model, en_tokenizer, target_model, target_tokenizer

In [30]:
def back_translate(texts, source_lang="en", target_lang="fr"):
    """
    Paraphrasing via Back Translation
    """
    en_model, en_tokenizer, target_model, target_tokenizer = prep_model_tokenizer(
        romance_lang=target_lang
    )

    # Translate from source to target language
    fr_texts = translate(texts, target_model, target_tokenizer, language=target_lang)

    # Translate from target language back to source language
    back_translated_texts = translate(
        fr_texts, en_model, en_tokenizer, language=source_lang
    )

    return back_translated_texts

In [36]:
en_texts = [
    "cancel my card please",
    "What is my account balance?",
    "Where is my refund?",
]

aug_texts = back_translate(
    back_translate(en_texts, source_lang="en", target_lang="zh"),
    source_lang="en",
    target_lang="es",
)
print(aug_texts)

['>>zh<< cancel my card please', '>>zh<< What is my account balance?', '>>zh<< Where is my refund?']
['請取消我的名片', '我的账户余额是多少?', '我的退款呢?']
['>>es<< Please cancel my card.', ">>es<< What's my account balance?", ">>es<< Where's my refund?"]
['Por favor, cancela mi tarjeta.', '¿Cuál es el saldo de mi cuenta?', '¿Dónde está mi reembolso?']
['Please cancel my card.', "What's the balance on my account?", "Where's my refund?"]
