In [None]:
from datasets import load_dataset, Dataset
import os

def load_bea_pairs():
    bea = load_dataset("juancavallotti/bea-19-corruption", split="train")
    sources, targets = [], []
    for ex in bea:
        src = ex["broken"].strip()
        tgt = ex["sentence"].strip()
        if not src or not tgt:
            continue
        if src == tgt:
            continue
        sources.append(src)
        targets.append(tgt)

    return sources, targets

def load_lang8_pairs():
    lang8 = load_dataset("rahuln2002/GED-lang8-cleaned", split="train")

    sources, targets = [], []
    for ex in lang8:
        src = ex["0"].strip()
        tgt = ex["1"].strip()
        if not src or not tgt:
            continue
        if src == tgt:
            continue
        sources.append(src)
        targets.append(tgt)

    return sources, targets

def load_jfleg_pairs():
    jfleg_val = load_dataset("jhu-clsp/jfleg", split="validation")
    jfleg_test = load_dataset("jhu-clsp/jfleg", split="test")

    sources, targets = [], []

    def add_split(split_ds):
        for ex in split_ds:
            src = ex["sentence"].strip()
            if not src:
                continue
            for tgt in ex["corrections"]:
                tgt = tgt.strip()
                if not tgt:
                    continue
                if src == tgt:
                    continue
                sources.append(src)
                targets.append(tgt)

    add_split(jfleg_val)
    add_split(jfleg_test)

    return sources, targets

In [None]:
def build_combined_dataset():
    all_sources = []
    all_targets = []

    for loader in (load_bea_pairs, load_lang8_pairs, load_jfleg_pairs):
        src, tgt = loader()
        all_sources.extend(src)
        all_targets.extend(tgt)


    unique = {}
    for s, t in zip(all_sources, all_targets):
        key = (s, t)
        if key not in unique:
            unique[key] = None

    src_unique, tgt_unique = zip(*unique.keys())
    src_unique = list(src_unique)
    tgt_unique = list(tgt_unique)


    ds = Dataset.from_dict({"source": src_unique, "target": tgt_unique})
    ds = ds.shuffle(seed=42)

    dataset_dict = ds.train_test_split(test_size=0.1, seed=42)

    return dataset_dict

In [None]:
os.makedirs("data", exist_ok=True)

dataset_dict = build_combined_dataset()

out_dir = "grammar_correction_pairs"
dataset_dict.save_to_disk(out_dir)

train_df = dataset_dict["train"].to_pandas()
valid_df = dataset_dict["test"].to_pandas()

train_df.to_csv("grammar_train.csv", index=False)
valid_df.to_csv("grammar_valid.csv", index=False)

README.md:   0%|          | 0.00/255 [00:00<?, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.


dataset_infos.json: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/7.41M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/84106 [00:00<?, ? examples/s]

README.md:   0%|          | 0.00/196 [00:00<?, ?B/s]

Cleaned_Lang8.csv:   0%|          | 0.00/25.6M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/200000 [00:00<?, ? examples/s]

README.md: 0.00B [00:00, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/148k [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/141k [00:00<?, ?B/s]

Generating validation split:   0%|          | 0/755 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/748 [00:00<?, ? examples/s]

train size: 244462
valid size: 27163


Saving the dataset (0/1 shards):   0%|          | 0/244462 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/27163 [00:00<?, ? examples/s]