In [None]:
import random
import numpy as np
import pandas as pd
from time import sleep
from textblob import TextBlob
from textblob.translate import NotTranslated
import torch
from nlpaug.augmenter.word import ContextualWordEmbsAug
from tqdm.notebook import tqdm

In [None]:
data = pd.read_csv('../data/headlines_generation_corpus/train.csv')

In [None]:
## Aug1: Sentence Shuffle
def apply_sentenceShuffle(df, shuffle=False, random_state=42):

    def random_sentence_shuffle(text, random_state=random_state):
        '''
        文章をスプリットし、無作為にシャッフル
        '''
        random.seed(random_state)
        text = text.split('。')
        random.shuffle(text)
        text = ' '.join(text)
        return text

    def sentence_shuffle(text):
        '''
        example:                           
        [sentenceA, sentenceB, sentenceC] --apply--> [sentenceC, sentenceA, sentenceB] 
        '''
        text_arr = np.array(text.split('。')[:-1])
        shuffled = np.roll(text_arr, 1).tolist()
        shuffled_text = ' '.join(shuffled)
        return shuffled_text

    if shuffle == True:
        df['input_text'] = pd.Series([str(random_sentence_shuffle(value)) for value in df['input_text']])
    else:
        df['input_text'] = pd.Series([str(sentence_shuffle(value)) for value in df['input_text']])
    return df


## Aug2: Back Translation
def apply_backTranslation(df):

    def back_translation(text):
        '''
        訓練データを
        日本語 → 英語 → 日本語の流れで
        逆翻訳する
        '''
        textblob = TextBlob(text)
        try:
            textblob = textblob.translate(to='en')
            sleep(0.4)
            textblob = textblob.translate(to='ja')
            sleep(0.4)
            return textblob
        except NotTranslated:
            pass

    df['input_text'] = pd.Series([str(back_translation(value)) for value in tqdm(df['input_text'])])
    return df

# not use
# ## Aug3: Contextual Word Embedded Augmentation by BERT
# def word_embedded_aug(df):
#     '''
#     BERTによる類似単語埋め込み（置き換え）増強
#     '''
#     params = {
#         'model_path': 'cl-tohoku/bert-base-japanese-char-whole-word-masking',
#         'aug_p': 0.1,
#         'batch_size': 32,
#         'device': 'cuda' if torch.cuda.is_available() else 'cpu'
#     }
#     aug_df = df.copy()
#     aug = ContextualWordEmbsAug(**params)
#     aug_df['input_text'] = [aug.augment(text) for text in tqdm(df['input_text'])]
#     # denoising
#     aug_df['input_text'] = aug_df['input_text'].apply(lambda x: x.replace(' ',''))
#     aug_df['include_unk'] = aug_df['input_text'].str.contains('[UNK]')
#     aug_df = aug_df.query('include_unk == False')
#     aug_df = aug_df.drop(['include_unk'], axis=1)
#     return aug_df

# not use
# # Aug4: Random Character Deletion
# def random_char_deletion(text, random_state=42):
#     '''
#     無作為に1センテンスにつき、1文字を削除
#     '''
#     random.seed(random_state)
#     new_text = []
#     for sentence in text.split('。')[:-1]:
#         sentence += '。'
#         chars = list(sentence)
#         while True:
#             del_char = random.choice(chars)
#             if del_char not in ['、', '。']:
#                 break
#         chars.remove(del_char)
#         new_sentence = ''.join(chars)
#         new_text.append(new_sentence)
#     new_text = ''.join(new_text)
#     return new_text

# def apply_randomCharDeletion(df):
#     new_df = df.copy()
#     new_df['input_text'] = df['input_text'].map(random_char_deletion)
#     return new_df

In [None]:
ss_aug_data = apply_sentenceShuffle(data, shuffle=False)
bt_aug_data = apply_backTranslation(data)

In [None]:
ss_aug_data.to_csv('../data/headlines_generation_corpus/sentence_shuffle_aug_data.csv', index=False)
bt_aug_data.to_csv('../data/headlines_generation_corpus/back_translation_aug_data.csv', index=False)