## Data Augmentation with MarianMT using Back-Translation

### Initialize the models for English <-> Foreign Languages


In [1]:
import warnings
warnings.filterwarnings('ignore')

import pandas as pd
from tqdm import trange
from transformers import MarianMTModel, MarianTokenizer

target_model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE'
target_tokenizer = MarianTokenizer.from_pretrained(target_model_name)
target_model = MarianMTModel.from_pretrained(target_model_name)


en_model_name = 'Helsinki-NLP/opus-mt-ROMANCE-en'
en_tokenizer = MarianTokenizer.from_pretrained(en_model_name)
en_model = MarianMTModel.from_pretrained(en_model_name)

Some weights of MarianMTModel were not initialized from the model checkpoint at Helsinki-NLP/opus-mt-en-ROMANCE and are newly initialized: ['lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of MarianMTModel were not initialized from the model checkpoint at Helsinki-NLP/opus-mt-ROMANCE-en and are newly initialized: ['lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [2]:
def translate(texts, model, tokenizer, language="fr"):
    """Prepare the text data into appropriate format for the model"""
    template = lambda text: f"{text}" if language == "en" else f">>{language}<< {text}"
    src_texts = [template(text) for text in texts]

    # Tokenize the texts
    encoded = tokenizer.prepare_seq2seq_batch(src_texts,
                                              return_tensors='pt')
    
    # Generate translation using model
    translated = model.generate(**encoded)

    # Convert the generated tokens indices back into text
    translated_texts = tokenizer.batch_decode(translated, skip_special_tokens=True)
    
    return translated_texts

In [3]:
def back_translate(texts, target_lang="fr", source_lang="en"):
    """translate to target language and back to source language"""
    fr_texts = translate(texts, target_model, target_tokenizer, 
                         language=target_lang)

    # Translate from target language back to source language
    back_translated_texts = translate(fr_texts, en_model, en_tokenizer, 
                                      language=source_lang)
    
    back_translated_texts = [t for t in back_translated_texts if t not in texts]
    return back_translated_texts

### Perform Augmentation using English <-> Spanish


In [4]:
en_texts = ['Cannot access website', 'I hated the food', "I can't login to my vpn"]

In [5]:
aug_texts = back_translate(en_texts, source_lang="en", target_lang="es")
print(aug_texts)

['Cannot access the website', 'I hated food.', "I can't access my vpn"]


### Perform Augmentation using English <-> French



In [6]:
aug_texts = back_translate(en_texts, source_lang="en", target_lang="fr")
print(aug_texts)

['Unable to access website', 'I hated food.', "I can't connect to my vpn"]


In [7]:
dataset = pd.read_csv('./data/preprocessed_data_l123.csv')

In [8]:
minority_class_descr = dataset[dataset.label == 1].translated_description.tolist()

In [9]:
len(minority_class_descr)

2514

In [10]:
size = 3
augmented = list()
for i in trange(len(minority_class_descr)//size):
    subset = minority_class_descr[i*size:(i+1)*size]
    augmented.append(back_translate(subset, source_lang="en", target_lang="es"))
    augmented.append(back_translate(subset, source_lang="en", target_lang="es"))

len(augmented)

100%|███████████████████████████████████████████████████████████████████████████| 838/838 [3:23:39<00:00, 14.58s/it]


1676

In [15]:
augmented = [i for j in augmented for i in j]
augmented = [i for i in augmented if i not in minority_class_descr]
len(augmented)

4834

In [20]:
dataset.label.value_counts()

0    5985
1    2514
Name: label, dtype: int64

In [30]:
pd.Series(augmented)

0       event critical hostname company with value mou...
1       duplicate soft network two devices try sharing...
2       problem solving printer work printer replaceme...
3       event critical hostname company with value mou...
4       duplicate soft network two devices try sharing...
                              ...                        
4829    no it's working you can not access macne finis...
4830      multiple pcs cannot open prgramdntyme cnc range
4831            come receive e-mail send zz e-mail advise
4832    no it's working you can not access macne finis...
4833      multiple pcs cannot open prgramdntyme cnc range
Length: 4834, dtype: object

In [41]:
augmented_df = pd.DataFrame(columns=dataset.columns)
augmented_df.translated_description = augmented
augmented_df.label = 1
augmented_df

Unnamed: 0,translated_description,keywords,short_description,description,group,cleaned_description,cleaned_short_description,merged_description,char_length,word_length,short_char_length,short_word_length,language,language_confidence,label
0,event critical hostname company with value mou...,,,,,,,,,,,,,,1
1,duplicate soft network two devices try sharing...,,,,,,,,,,,,,,1
2,problem solving printer work printer replaceme...,,,,,,,,,,,,,,1
3,event critical hostname company with value mou...,,,,,,,,,,,,,,1
4,duplicate soft network two devices try sharing...,,,,,,,,,,,,,,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4829,no it's working you can not access macne finis...,,,,,,,,,,,,,,1
4830,multiple pcs cannot open prgramdntyme cnc range,,,,,,,,,,,,,,1
4831,come receive e-mail send zz e-mail advise,,,,,,,,,,,,,,1
4832,no it's working you can not access macne finis...,,,,,,,,,,,,,,1


In [42]:
augmented_df = pd.concat([dataset, augmented_df])

In [43]:
augmented_df.shape

(13333, 15)

In [45]:
augmented_df.label.value_counts()

1    7348
0    5985
Name: label, dtype: int64

In [46]:
augmented_df.to_csv('./data/augmented_data.csv')