In [6]:
import warnings
warnings.filterwarnings("ignore")

Translation Functions

In [4]:
def translate(texts, model, tokenizer, language="fr"):
    # Prepare the text data into appropriate format for the model
    template = lambda text: f"{text}" if language == "en" else f">>{language}<< {text}"
    src_texts = [template(text) for text in texts]

    # Tokenize the texts
    encoded = tokenizer.prepare_seq2seq_batch(src_texts,
                                              return_tensors='pt')
    
    # Generate translation using model
    translated = model.generate(**encoded)

    # Convert the generated tokens indices back into text
    translated_texts = tokenizer.batch_decode(translated, skip_special_tokens=True)
    
    return translated_texts

def back_translate(texts, target_model, target_tokenizer, source_model, source_tokenizer, target_lang="fr", source_lang="en" ):
    # Translate to target language
    fr_texts = translate(texts, target_model, target_tokenizer, 
                         language=target_lang)

    # Translate from target language back to source language
    back_translated_texts = translate(fr_texts, source_model, source_tokenizer, 
                                      language=source_lang)
    
    return back_translated_texts

Load models and tokenizers

In [None]:
from transformers import MarianMTModel, MarianTokenizer

target_model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE'
target_tokenizer = MarianTokenizer.from_pretrained(target_model_name)
target_model = MarianMTModel.from_pretrained(target_model_name)


en_model_name = 'Helsinki-NLP/opus-mt-ROMANCE-en'
en_tokenizer = MarianTokenizer.from_pretrained(en_model_name)
en_model = MarianMTModel.from_pretrained(en_model_name)

Translation test

In [9]:
en_texts = ['This is so cool', 'I hated the food', 'They were very helpful']
source_lang="en"
target_lang="es"

aug_texts = back_translate(en_texts, 
                           target_model=target_model, target_tokenizer=target_tokenizer,
                           source_model=en_model, source_tokenizer=en_tokenizer, 
                           source_lang=source_lang, target_lang=target_lang)
print(en_texts,"\n",aug_texts)

['This is so cool', 'I hated the food', 'They were very helpful'] 
 ['This is so great.', 'I hated food.', 'They were very helpful.']
