In [1]:
# !pip install nlpaug torch>=1.6.0 transformers>=4.11.3 sentencepiece sacremoses

In [2]:
import os
import glob
import torch
from datasets import load_dataset, load_from_disk

import nlpaug.augmenter.word as naw
import nlpaug.augmenter.sentence as nas

2023-08-23 16:47:36.060394: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
os.environ["TOKENIZERS_PARALLELISM"] = "true"

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [5]:
text = 'The quick brown fox jumped over the lazy dog'
back_translation_aug = naw.BackTranslationAug(
    from_model_name='facebook/wmt19-en-de',
    to_model_name='facebook/wmt19-de-en',
    device=device
)
augmented_text = back_translation_aug.augment(text)
augmented_text

Some weights of FSMTForConditionalGeneration were not initialized from the model checkpoint at facebook/wmt19-en-de and are newly initialized: ['model.encoder.embed_positions.weight', 'model.decoder.embed_positions.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of FSMTForConditionalGeneration were not initialized from the model checkpoint at facebook/wmt19-de-en and are newly initialized: ['model.encoder.embed_positions.weight', 'model.decoder.embed_positions.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


['The speedy brown fox leapt over the lazy dog']

In [6]:
context_ins_aug = naw.ContextualWordEmbsAug(
    model_path='bert-base-uncased', action="insert", device=device)
augmented_text = context_ins_aug.augment(text)
augmented_text

['even the quick acting brown fox jumped over at the lazy dog']

In [7]:
context_sub_aug = naw.ContextualWordEmbsAug(
    model_path='bert-base-uncased', action="substitute", device=device)
augmented_text = context_sub_aug.augment(text)
augmented_text

['the quick tail fox variations on the lazy dog']

In [8]:
context_cont_aug = nas.ContextualWordEmbsForSentenceAug(model_path='gpt2', device=device)
augmented_texts = context_cont_aug.augment(text, n=1)
augmented_text

['the quick tail fox variations on the lazy dog']

In [9]:
class AugMapper:
    def __init__(self, augmenter):
        self.augmenter = augmenter # textattack augmenter recipe

    def apply_to_batch(self, batch):
        new_texts, new_labels = [], []
        for text, label in zip(batch['text'], batch['label']):
            augmented_text = self.augmenter.augment(text)
            new_texts.extend(augmented_text)
            new_labels.extend([label] * len(augmented_text))

        return {
            "text": new_texts,
            "label": new_labels,
            "idx": list(range(len(new_labels))),
        }

In [None]:
augs = [
    ("backtranslate", AugMapper(back_translation_aug)),
    ("context_ins", AugMapper(context_ins_aug)), 
    ("context_sub", AugMapper(context_sub_aug)), 
    ("context_cont", AugMapper(context_cont_aug))
]

dataset_paths = glob.glob("./fada/fadata/datasets/*original.*")
dataset_paths.sort()

for dataset_path in dataset_paths:

    print(dataset_path)

    dataset = load_from_disk(dataset_path)
    
    for aug_name, aug_fn in augs:

        aug_save_path = dataset_path.replace("original", aug_name)
        
        print(aug_save_path)
        
        if os.path.exists(aug_save_path):
            print(f"found existing dataset {aug_save_path}... skipping...") 
        else:
            aug_dataset = dataset.map(aug_fn.apply_to_batch, batched=True, batch_size=10)
            aug_dataset.save_to_disk(aug_save_path)

./fada/fadata/datasets/ag_news.default.original.10
./fada/fadata/datasets/ag_news.default.backtranslate.10
found existing dataset ./fada/fadata/datasets/ag_news.default.backtranslate.10... skipping...
./fada/fadata/datasets/ag_news.default.context_ins.10
found existing dataset ./fada/fadata/datasets/ag_news.default.context_ins.10... skipping...
./fada/fadata/datasets/ag_news.default.context_sub.10
found existing dataset ./fada/fadata/datasets/ag_news.default.context_sub.10... skipping...
./fada/fadata/datasets/ag_news.default.context_cont.10
found existing dataset ./fada/fadata/datasets/ag_news.default.context_cont.10... skipping...
./fada/fadata/datasets/ag_news.default.original.200
./fada/fadata/datasets/ag_news.default.backtranslate.200


Map:   0%|          | 0/800 [00:00<?, ? examples/s]