In [1]:
!pip install nlpaug

Defaulting to user installation because normal site-packages is not writeable


In [2]:
!pip install tqdm

Defaulting to user installation because normal site-packages is not writeable


In [3]:
import torch
import random
import pandas as pd
from datasets import load_dataset, Dataset
import nlpaug.augmenter.word as naw

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Set augmentation ratio
pct_augment = 0.3
num_augmented_per_sample = 3  # can be increased

In [5]:
# Load original AGNews
dataset = load_dataset("ag_news", split="train")
texts, labels = dataset["text"], dataset["label"]

In [6]:
# Initialize augmenter
aug = naw.ContextualWordEmbsAug(
    model_path='bert-base-uncased',
    action="substitute",
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

In [7]:
augmented_texts = []
augmented_labels = []

In [10]:
from tqdm import tqdm
BATCH_SIZE = 512
PCT_AUGMENT = 0.3
NUM_AUG_PER_SAMPLE = 7

for i in tqdm(range(0, len(texts), BATCH_SIZE)):
    batch_texts = texts[i:i + BATCH_SIZE]
    batch_labels = labels[i:i + BATCH_SIZE]

    # Append originals
    augmented_texts.extend(batch_texts)
    augmented_labels.extend(batch_labels)

    # Decide which to augment in this batch
    augment_indices = [j for j in range(len(batch_texts)) if random.random() < PCT_AUGMENT]
    texts_to_augment = [batch_texts[j] for j in augment_indices]
    labels_to_augment = [batch_labels[j] for j in augment_indices]

    if texts_to_augment:
        try:
            augmented_batch = aug.augment(texts_to_augment)
            for _ in range(NUM_AUG_PER_SAMPLE):
                augmented_texts.extend(augmented_batch)
                augmented_labels.extend(labels_to_augment)
        except Exception as e:
            print(f"Augmentation failed: {e}")
            continue

100%|██████████| 235/235 [15:40<00:00,  4.00s/it]


In [11]:
# Save to CSV
df = pd.DataFrame({'text': augmented_texts, 'label': augmented_labels})
df.to_csv("augmented_agnews.csv", index=False)
print("Augmented dataset saved to augmented_agnews.csv")

Augmented dataset saved to augmented_agnews.csv
