# Load and Clean Datasets

This notebook loads datasets from HuggingFace, cleans them (removes spaced hyphens), and saves them to the `hpc_datasets/` directory.

In [None]:
import re
import os
import pandas as pd
from datasets import load_dataset

# Create hpc_datasets directory if it doesn't exist
os.makedirs("hpc_datasets", exist_ok=True)

# Function to clean spaced hyphens from text
def clean_spaced_hyphens(text):
    """Replace ' - ' with '-' to fix tokenization artifacts"""
    if pd.isna(text):
        return text
    return text.replace(' - ', '-')

# Parse an "A ..." line from .m2
A_RE = re.compile(r"^A (\d+) (\d+)\|\|\|[^|]*\|\|\|([^|]*)\|\|\|")

def apply_edits(src):
    toks = src.split()
    # apply collected edits (right→left so indices stay valid)
    for s,e,repl in sorted(apply_edits.edits, key=lambda x: x[0], reverse=True):
        repl_toks = [] if repl in ("", "-NONE-") else repl.split()
        toks[s:e] = repl_toks
    return " ".join(toks)
apply_edits.edits = []  # static holder

def m2_to_pairs(path):
    pairs, src = [], None
    with open(path, encoding="utf8") as f:
        for line in f:
            line = line.rstrip("\n")
            if line.startswith("S "):
                # flush previous
                if src is not None:
                    tgt = apply_edits(src)
                    pairs.append((src, tgt))
                src = line[2:]
                apply_edits.edits = []
            elif line.startswith("A "):
                m = A_RE.match(line)
                if m:
                    s, e, repl = int(m.group(1)), int(m.group(2)), m.group(3).strip()
                    apply_edits.edits.append((s, e, repl))
            elif line == "":  # sentence boundary
                if src is not None:
                    tgt = apply_edits(src)
                    pairs.append((src, tgt))
                    src = None
                    apply_edits.edits = []
    # tail
    if src is not None:
        tgt = apply_edits(src)
        pairs.append((src, tgt))
    return pairs

print("Loading and cleaning datasets...")
print("=" * 70)

# 1. BEA-2019 train and validation sets
print("\nLoading BEA-2019 dataset...")
# Check multiple possible locations
possible_dirs = [
    "wi+locness_v2.1.bea19/wi+locness/m2",
    "data/wi_locness/m2",
    "wi_locness/m2"
]
m2_dir = None
for dir_path in possible_dirs:
    if os.path.exists(dir_path):
        m2_dir = dir_path
        break

if m2_dir:
    train, dev = [], []
    for fname in os.listdir(m2_dir): 
        if fname.endswith(".m2"):
            path = os.path.join(m2_dir, fname)
            if "train" in fname.lower():
                train += m2_to_pairs(path)
            elif "dev" in fname.lower():
                dev += m2_to_pairs(path)
    
    if train:
        bea_train_df = pd.DataFrame(train, columns=["input_text", "target_text"])
        # Clean spaced hyphens
        for col in bea_train_df.columns:
            bea_train_df[col] = bea_train_df[col].apply(clean_spaced_hyphens)
        bea_train_df.to_csv("hpc_datasets/bea_train.csv", index=False)
        print(f"✓ Saved BEA-2019 train set ({len(bea_train_df)} examples) to hpc_datasets/bea_train.csv")
    else:
        print("⚠ No BEA train files found in .m2 files")
    
    if dev:
        bea_dev_df = pd.DataFrame(dev, columns=["input_text", "target_text"])
        # Clean spaced hyphens
        for col in bea_dev_df.columns:
            bea_dev_df[col] = bea_dev_df[col].apply(clean_spaced_hyphens)
        bea_dev_df.to_csv("hpc_datasets/bea_dev.csv", index=False)
        print(f"✓ Saved BEA-2019 validation set ({len(bea_dev_df)} examples) to hpc_datasets/bea_dev.csv")
    else:
        print("⚠ No BEA dev files found in .m2 files")
else:
    print(f"⚠ BEA-2019 data directory not found. Checked:")
    for dir_path in possible_dirs:
        print(f"  - {dir_path}")
    print("  Please download W&I+LOCNESS v2.1 from:")
    print("  https://www.cl.cam.ac.uk/research/nl/bea2019st/#data")
    print("  and extract to the project root directory")

# 2. JFLEG test set
print("\nLoading JFLEG dataset...")
jfleg = load_dataset("jfleg", split="test")
jfleg_df = pd.DataFrame({
    "input_text": jfleg["sentence"],
    "target_text": [refs[0] for refs in jfleg["corrections"]]  # Using first correction
})
# Clean spaced hyphens
for col in jfleg_df.columns:
    jfleg_df[col] = jfleg_df[col].apply(clean_spaced_hyphens)
jfleg_df.to_csv("hpc_datasets/jfleg_test.csv", index=False)
print(f"✓ Saved JFLEG test set ({len(jfleg_df)} examples) to hpc_datasets/jfleg_test.csv")

# 3. WikiAuto train set
print("\nLoading WikiAuto dataset...")
wiki_auto = load_dataset(
    "chaojiang06/wiki_auto",
    "default",
    revision="refs/convert/parquet"
)
wiki_train_df = pd.DataFrame({
    "input_text": wiki_auto["train"]["normal_sentence"],
    "target_text": wiki_auto["train"]["simple_sentence"]
})
# Clean spaced hyphens
for col in wiki_train_df.columns:
    wiki_train_df[col] = wiki_train_df[col].apply(clean_spaced_hyphens)
wiki_train_df.to_csv("hpc_datasets/wikiauto_train.csv", index=False)
print(f"✓ Saved WikiAuto train set ({len(wiki_train_df)} examples) to hpc_datasets/wikiauto_train.csv")

# 4. ASSET validation set
print("\nLoading ASSET dataset...")
asset = load_dataset("asset")
asset_val_df = pd.DataFrame({
    "input_text": asset["validation"]["original"],
    "target_text": [refs[0] for refs in asset["validation"]["simplifications"]]  # Using first simplification
})
# Clean spaced hyphens
for col in asset_val_df.columns:
    asset_val_df[col] = asset_val_df[col].apply(clean_spaced_hyphens)
asset_val_df.to_csv("hpc_datasets/asset_validation.csv", index=False)
print(f"✓ Saved ASSET validation set ({len(asset_val_df)} examples) to hpc_datasets/asset_validation.csv")

# 5. ASSET test set
asset_test_df = pd.DataFrame({
    "input_text": asset["test"]["original"],
    "target_text": [refs[0] for refs in asset["test"]["simplifications"]]  # Using first simplification
})
# Clean spaced hyphens
for col in asset_test_df.columns:
    asset_test_df[col] = asset_test_df[col].apply(clean_spaced_hyphens)
asset_test_df.to_csv("hpc_datasets/asset_test.csv", index=False)
print(f"✓ Saved ASSET test set ({len(asset_test_df)} examples) to hpc_datasets/asset_test.csv")

print("\n" + "=" * 70)
print("All datasets saved successfully to hpc_datasets/ directory!")
print("=" * 70)

Loading and cleaning datasets...

Loading BEA-2019 dataset...
✓ Saved BEA-2019 train set (68616 examples) to hpc_datasets/bea_train.csv
✓ Saved BEA-2019 validation set (8768 examples) to hpc_datasets/bea_dev.csv

Loading JFLEG dataset...
✓ Saved JFLEG test set (748 examples) to hpc_datasets/jfleg_test.csv

Loading WikiAuto dataset...
✓ Saved WikiAuto train set (373801 examples) to hpc_datasets/wikiauto_train.csv

Loading ASSET dataset...
✓ Saved ASSET validation set (2000 examples) to hpc_datasets/asset_validation.csv
✓ Saved ASSET test set (359 examples) to hpc_datasets/asset_test.csv

All datasets saved successfully to hpc_datasets/ directory!
