Back-Translation

In [1]:
from datasets import load_dataset

squad = load_dataset("squad", split="train")
squad = squad.train_test_split(test_size=0.2, seed=42)

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased")

from transformers import DefaultDataCollator

data_collator = DefaultDataCollator()

In [2]:
import copy
augmented_squad = copy.deepcopy(squad)

In [3]:
import torch
from transformers import MarianMTModel, MarianTokenizer

# --- CONFIGURATION ---

device = "cuda" if torch.cuda.is_available() else "cpu"

# --- LOAD MODELS + TOKENIZERS ---

src_to_tgt_model_name = "Helsinki-NLP/opus-mt-en-de"
tgt_to_src_model_name = "Helsinki-NLP/opus-mt-de-en"

tokenizer_src2tgt = MarianTokenizer.from_pretrained(src_to_tgt_model_name, use_fast=True)
model_src2tgt = MarianMTModel.from_pretrained(src_to_tgt_model_name).to(device)

tokenizer_tgt2src = MarianTokenizer.from_pretrained(tgt_to_src_model_name, use_fast=True)
model_tgt2src = MarianMTModel.from_pretrained(tgt_to_src_model_name).to(device)

def translate_batch(sentences, tokenizer, model, max_length=256):
    """
    Translates a batch of sentences using a MarianMT model.
    """
    encoded = tokenizer(
    sentences,
    return_tensors="pt",
    padding=True,
    truncation=True,
    max_length=max_length
    ).to(device)


    generated_tokens = model.generate(
        **encoded,
        max_length=max_length
    )

    return [tokenizer.decode(t, skip_special_tokens=True) for t in generated_tokens]


def back_translate(sentences_list):
    """
    #Performs back-translation English → German → English.
    #Returns paraphrased English sentences.
    """
    try:
        # 1. English → German
        translated_batch = translate_batch(sentences_list, tokenizer_src2tgt, model_src2tgt)


        # 2. German → English
        back_translated_batch = translate_batch(translated_batch, tokenizer_tgt2src, model_tgt2src)

        return back_translated_batch

    except Exception as e:
        print(f"Error during back-translation: {e}")
        return [None] * len(sentences_list)


In [None]:
import copy
from datasets import Dataset, concatenate_datasets
from tqdm import tqdm
import time
import os

# Configuration
BATCH_SIZE = 8
LANG = "de"
SAVE_EVERY = 100   # save after every 10 batches

# Prepare
original_train_dataset = squad['train']
original_list = original_train_dataset.to_list()
total_records = len(original_list)

new_augmented_datasets = []   # temporary storage
batch_cache = []              # in-memory batch cache

print(f"Starting batch augmentation on {total_records} samples...")

for b, i in enumerate(tqdm(range(0, total_records, BATCH_SIZE), desc="Batch Translating")):
    batch_records = original_list[i:i+BATCH_SIZE]
    batch_questions = [r['question'] for r in batch_records]
    
    paraphrased_questions = back_translate(batch_questions)
    
    for j, record in enumerate(batch_records):
        p = paraphrased_questions[j]
        if p and isinstance(p, str):
            new_record = dict(record)  # shallow copy only
            global_index = i + j
            new_record['id'] = f"{record['id']}-aug-{LANG}-{global_index}"
            new_record['question'] = p
            batch_cache.append(new_record)
    
    # Flush to Dataset every SAVE_EVERY batches
    if (b + 1) % SAVE_EVERY == 0 or i + BATCH_SIZE >= total_records:
        temp_dataset = Dataset.from_list(batch_cache)
        new_augmented_datasets.append(temp_dataset)
        batch_cache.clear()  # free memory
        #time.sleep(2)  # optional cooldown

print("Merging all temporary datasets...")
new_augmented_dataset = concatenate_datasets(new_augmented_datasets)
print("Final size:", len(new_augmented_dataset))


In [None]:
augmented_squad['train'] = new_augmented_dataset
print(augmented_squad)