In [None]:
pip install evaluate



In [None]:
pip install rouge_score



In [None]:
pip install sacremoses sacrebleu



In [None]:
import os
import re
import html
from typing import Dict, List
from dataclasses import dataclass
import numpy as np
import evaluate
from datasets import load_dataset, DatasetDict
import nltk

In [None]:
try:
    nltk.data.find('tokenizers/punkt')
    print("NLTK 'punkt' data is already available.")
except LookupError:
    print("Downloading NLTK sentence tokenizer data ('punkt')...")
    nltk.download('punkt', quiet=True)
    print("NLTK 'punkt' data downloaded.")

try:
    nltk.data.find('tokenizers/punkt_tab') # Check for punkt_tab
    print("NLTK 'punkt_tab' data is already available.")
except LookupError:
    print("NLTK 'punkt_tab' data not found. Downloading...")
    nltk.download('punkt_tab', quiet=True) # Download punkt_tab
    print("NLTK 'punkt_tab' data downloaded.")

NLTK 'punkt' data is already available.
NLTK 'punkt_tab' data is already available.


In [None]:
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)

## Model and Dataset

In [None]:
model_name = "t5-small"
dataset_id = "bogdancazan/wikilarge-text-simplification"
dataset_subset = None

In [None]:
source_text_column = "Normal"
target_text_column = "Simple"

In [None]:
use_subset = True
train_subset_size = 20000
val_subset_size = 2000

In [None]:
max_source_length = 256
max_target_length = 128
min_source_length_chars = 50

In [None]:
per_device_batch_size = 4
grad_accum_steps = 4
num_train_epochs = 3
learning_rate = 1e-5
weight_decay = 0.01
warmup_ratio = 0.03
lr_scheduler = "cosine"

In [None]:
mixed_precision = "fp16"
NUM_PROC = None
gen_num_beams = 4

## Text Cleaning

In [None]:
URL_RE = re.compile(r"https?://\S+")

In [None]:
def source_len_filter(example):
    return len(example[source_text_column]) >= min_source_length_chars

In [None]:
def clean_text(t: str) -> str:
    if t is None: # Add check for None values
        return ""
    t = str(t) # Ensure text is string
    t = html.unescape(t)
    t = URL_RE.sub("", t)
    t = t.replace("\u00A0", " ")
    t = re.sub(r"\s+", " ", t).strip()
    return t

In [None]:
def cleaner_batch(batch):
    # Apply cleaning only to the relevant columns
    inputs = [clean_text(a) for a in batch[source_text_column]]
    targets = [clean_text(h) for h in batch[target_text_column]]
    return {source_text_column: inputs, target_text_column: targets}

## Load dataset + cleanup

In [None]:
print(f"Loading dataset: {dataset_id}...")
try:
    if dataset_subset:
        raw = load_dataset(dataset_id, dataset_subset)
    else:
        raw = load_dataset(dataset_id)
except Exception as e:
    print(f"Error loading dataset: {e}")
    print("Please check the dataset ID and ensure required libraries (like gem_metrics) are installed.")

Loading dataset: bogdancazan/wikilarge-text-simplification...


In [None]:
if "validation" not in raw or len(raw['validation']) < 0.1 * len(raw['train']):
    print("Insufficient validation split, creating one from train split (10%)...")
    # Take 10% of train for validation
    train_test_split = raw["train"].train_test_split(test_size=0.1, seed=42)
    raw = DatasetDict({
        'train': train_test_split['train'],
        'validation': train_test_split['test'],
        # Keep test split if it exists, otherwise ignore
        'test': raw.get('test')
    })

if "test" not in raw or len(raw['test']) < 0.01 * len(raw['train']):
    print("Insufficient test split, creating one from train split (1%)...")
    # Take 10% of train for validation
    train_test_split = raw["train"].train_test_split(test_size=0.01, seed=42)
    raw = DatasetDict({
        'train': train_test_split['train'],
        'validation': raw.get('validation'),
        # Keep test split if it exists, otherwise ignore
        'test': train_test_split['test']
    })

Insufficient validation split, creating one from train split (10%)...
Insufficient test split, creating one from train split (1%)...


In [None]:
raw

DatasetDict({
    train: Dataset({
        features: ['Normal', 'Simple'],
        num_rows: 132618
    })
    validation: Dataset({
        features: ['Normal', 'Simple'],
        num_rows: 14885
    })
    test: Dataset({
        features: ['Normal', 'Simple'],
        num_rows: 1340
    })
})

In [None]:
print("Cleaning text...")
raw = raw.map(cleaner_batch, batched=True, num_proc=NUM_PROC, desc="Cleaning text")

Cleaning text...


In [None]:
# --- Filter out empty examples after cleaning ---
def non_empty(ex):
    return (ex[source_text_column] and len(ex[source_text_column]) > 0) and \
           (ex[target_text_column] and len(ex[target_text_column]) > 0)

print("Filtering empty examples...")
raw = raw.filter(non_empty, num_proc=NUM_PROC)

Filtering empty examples...


In [None]:
print("Final dataset sizes:")
print(raw)

Final dataset sizes:
DatasetDict({
    train: Dataset({
        features: ['Normal', 'Simple'],
        num_rows: 132618
    })
    validation: Dataset({
        features: ['Normal', 'Simple'],
        num_rows: 14885
    })
    test: Dataset({
        features: ['Normal', 'Simple'],
        num_rows: 1340
    })
})


In [None]:
def sentence_count_filter(example):
    """Keeps examples where source and target have the same number of sentences."""
    try:
        source_sentences = nltk.sent_tokenize(example[source_text_column])
        target_sentences = nltk.sent_tokenize(example[target_text_column])
        return len(source_sentences) == len(target_sentences)
    except Exception as e:
        # Handle potential errors during tokenization (e.g., on empty strings after cleaning)
        print(f"Warning: Error tokenizing sentences, discarding example. Error: {e}")
        print(f"Source: {example.get(source_text_column, 'N/A')}")
        print(f"Target: {example.get(target_text_column, 'N/A')}")
        return False

In [None]:
# --- Apply source length and sentence count filter ---
print(f"Filtering examples shorter than {min_source_length_chars} characters...")
raw = raw.filter(source_len_filter, num_proc=NUM_PROC, desc=f"Filtering short sources")

print("Final dataset sizes after length filtering:")
print(raw)

print("Applying sentence count filter...")
raw = raw.filter(sentence_count_filter, num_proc=NUM_PROC, desc="Filtering by sentence count")

print("Final dataset sizes after sentence count filtering:")
print(raw)

Filtering examples shorter than 50 characters...
Final dataset sizes after length filtering:
DatasetDict({
    train: Dataset({
        features: ['Normal', 'Simple'],
        num_rows: 130240
    })
    validation: Dataset({
        features: ['Normal', 'Simple'],
        num_rows: 14626
    })
    test: Dataset({
        features: ['Normal', 'Simple'],
        num_rows: 1320
    })
})
Applying sentence count filter...
Final dataset sizes after sentence count filtering:
DatasetDict({
    train: Dataset({
        features: ['Normal', 'Simple'],
        num_rows: 102965
    })
    validation: Dataset({
        features: ['Normal', 'Simple'],
        num_rows: 11617
    })
    test: Dataset({
        features: ['Normal', 'Simple'],
        num_rows: 1028
    })
})


In [None]:
if use_subset:
    print(f"Selecting subset: {train_subset_size} train, {val_subset_size} validation...")
    raw["train"] = raw["train"].select(range(min(train_subset_size, len(raw["train"]))))
    raw["validation"] = raw["validation"].select(range(min(val_subset_size, len(raw["validation"]))))

Selecting subset: 2500 train, 200 validation...


In [None]:
print("Final dataset sizes")
print(raw)

Final dataset sizes
DatasetDict({
    train: Dataset({
        features: ['Normal', 'Simple'],
        num_rows: 2500
    })
    validation: Dataset({
        features: ['Normal', 'Simple'],
        num_rows: 200
    })
    test: Dataset({
        features: ['Normal', 'Simple'],
        num_rows: 1028
    })
})


## Tokenizer & Model

In [None]:
print(f"Loading model/tokenizer: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

prefix = "simplify: "

Loading model/tokenizer: t5-small


In [None]:
def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples[source_text_column]]
    model_inputs = tokenizer(
        inputs,
        max_length=max_source_length,
        truncation=True,
        # padding="max_length",
    )
    # Target tokenization
    labels = tokenizer(
        text_target=examples[target_text_column],
        max_length=max_target_length,
        truncation=True,
        # padding="max_length",
    )

    # Mask pad tokens in labels
    label_ids = []
    for label_input_ids in labels["input_ids"]:
        label_ids.append([lid if lid != tokenizer.pad_token_id else -100 for lid in label_input_ids])

    model_inputs["labels"] = label_ids
    return model_inputs

In [None]:
print("Tokenizing dataset...")
tokenized = raw.map(
    preprocess_function,
    batched=True,
    remove_columns=raw["train"].column_names, # Remove original text columns
    num_proc=NUM_PROC,
    desc="Tokenizing dataset",
)

Tokenizing dataset...


Tokenizing dataset:   0%|          | 0/200 [00:00<?, ? examples/s]

In [None]:
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    pad_to_multiple_of=8 if mixed_precision != "no" else None, # Pad for efficiency with AMP
)

In [None]:
rouge = evaluate.load("rouge")
bleu = evaluate.load("bleu")
sari = evaluate.load("sari")

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]

    # Replace -100 with pad_token_id
    preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # --- Clean text ---
    decoded_preds = [p.strip() for p in decoded_preds]
    decoded_labels = [l.strip() for l in decoded_labels]

    # --- ROUGE ---
    result = rouge.compute(predictions=decoded_preds,
                           references=decoded_labels,
                           use_stemmer=True)
    result = {k: round(v * 100, 2) for k, v in result.items()}

    # --- BLEU ---
    bleu_result = bleu.compute(predictions=decoded_preds,references=[[l] for l in decoded_labels])
    result["bleu"] = round(bleu_result["bleu"] * 100, 2)

    # --- SARI ---
    sources = raw["validation"][source_text_column][:len(decoded_preds)]

    sari_result = sari.compute(
        sources=sources,
        predictions=decoded_preds,
        references=[[l] for l in decoded_labels]  # list of list refs
    )
    result["sari"] = round(sari_result["sari"], 2)


    return result

## Training args

In [None]:
output_dir = f"t5-small-wikilarge-simplifier" # Changed output dir name
training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    eval_steps=1000,             # Evaluate every 1000 steps (adjust as needed)
    logging_steps=200,
    save_steps=2000,
    save_total_limit=2,

    learning_rate=learning_rate,
    per_device_train_batch_size=per_device_batch_size,
    per_device_eval_batch_size=per_device_batch_size, # Use same batch size for eval
    gradient_accumulation_steps=grad_accum_steps,
    num_train_epochs=num_train_epochs,
    weight_decay=weight_decay,
    warmup_ratio=warmup_ratio,
    lr_scheduler_type=lr_scheduler, # Renamed arg
    gradient_checkpointing=True,

    predict_with_generate=True, # Needed for Seq2Seq models to generate text during eval
    generation_max_length=max_target_length, # Use target length for generation
    generation_num_beams=gen_num_beams,      # Use beam search during eval

    fp16=(mixed_precision == "fp16"),
    bf16=(mixed_precision == "bf16"),

    report_to=["none"], # Disable default reporting like wandb/tensorboard if not used
)

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

  trainer = Seq2SeqTrainer(


In [None]:
print("Starting training...")
trainer.train()

Starting training...


Step,Training Loss
200,2.0136
400,1.8409


TrainOutput(global_step=471, training_loss=1.9151148704966163, metrics={'train_runtime': 467.3142, 'train_samples_per_second': 16.049, 'train_steps_per_second': 1.008, 'total_flos': 120352696958976.0, 'train_loss': 1.9151148704966163, 'epoch': 3.0})

In [None]:
print("Evaluating final model...")
metrics = trainer.evaluate(max_length=max_target_length, num_beams=gen_num_beams)

Evaluating final model...


In [None]:
print("Saving final model...")
trainer.save_model()
tokenizer.save_pretrained(output_dir)

print("Final eval metrics:", metrics)

Saving final model...
Final eval metrics: {'eval_loss': 1.6726874113082886, 'eval_rouge1': 62.16, 'eval_rouge2': 45.16, 'eval_rougeL': 57.95, 'eval_rougeLsum': 58.02, 'eval_bleu': 37.85, 'eval_runtime': 77.919, 'eval_samples_per_second': 2.567, 'eval_steps_per_second': 0.642, 'epoch': 3.0}


In [None]:
rouge = evaluate.load("rouge")
bleu = evaluate.load("bleu")
sari = evaluate.load("sari")

def evaluate_on_split(split_name="test"):

    dataset = tokenized[split_name]
    raw_split = raw[split_name]

    print(f"Generating predictions on {split_name.upper()} split...")
    output = trainer.predict(
        dataset,
        max_length=max_target_length,
        num_beams=gen_num_beams
    )

    pred_ids = output.predictions
    label_ids = output.label_ids

    if isinstance(pred_ids, tuple):
        pred_ids = pred_ids[0]

    pred_ids = np.where(pred_ids != -100, pred_ids, tokenizer.pad_token_id)
    label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)

    decoded_preds = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    decoded_preds = [p.strip() for p in decoded_preds]
    decoded_labels = [l.strip() for l in decoded_labels]

    #ROUGE
    rouge_result = rouge.compute(
        predictions=decoded_preds,
        references=decoded_labels,
        use_stemmer=True
    )
    rouge_result = {k: round(v * 100, 2) for k, v in rouge_result.items()}

    #BLEU
    bleu_result = bleu.compute(
        predictions=decoded_preds,
        references=[[l] for l in decoded_labels]
    )
    bleu_score = round(bleu_result["bleu"] * 100, 2)

    #SARI
    sources = raw_split[split_name][source_text_column][:len(decoded_preds)]
    sari_result = sari.compute(
        sources=sources,
        predictions=decoded_preds,
        references=[[l] for l in decoded_labels]
    )
    sari_score = round(sari_result["sari"], 2)

    print(f"\n===== {split_name.upper()} METRICS =====")
    print(f"ROUGE-1:     {rouge_result['rouge1']}")
    print(f"ROUGE-2:     {rouge_result['rouge2']}")
    print(f"ROUGE-L:     {rouge_result['rougeL']}")
    print(f"ROUGE-Lsum:  {rouge_result['rougeLsum']}")
    print(f"BLEU:        {bleu_score}")
    print(f"SARI:        {sari_score}")
    print("============================\n")

In [None]:
test_metrics = evaluate_on_split(split_name="test")

1028
1028
Generating predictions on TEST split...



===== TEST METRICS =====
ROUGE-1: 61.12
ROUGE-2: 43.49
ROUGE-L: 56.76
BLEU:    34.64
SARI:    48.64

