In [None]:
import pypinyin
import re
import pandas as pd
from pypinyin import Style
from datasets import load_dataset, Dataset
from transformers import MT5Tokenizer, MT5ForConditionalGeneration, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, EarlyStoppingCallback
from evaluate import load
import numpy as np
import torch
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


# Data Preprocessing (Example: Standard clean data)

In [None]:
# Dataset: https://huggingface.co/datasets/swaption2009/20k-en-zh-translation-pinyin-hsk

ds = load_dataset("swaption2009/20k-en-zh-translation-pinyin-hsk")
dataset = ds["train"]

contains_english = re.compile(r'[a-zA-Z]')

def clean_punctuations(p):
    """
    Remove common Chinese-style or western punctuation
    """
    return re.sub(r"[。.,，！？!?:：；;\"'‘’“”()（）《》【】＇｀……\-－／/、\[\]［］＂·—]", "", p)

def clean_spaces(text):
    """
    Remove all spaces from Chinese text
    """
    text = text.replace(" ", "").replace("\u00A0", "").replace("　", "")  # Remove regular and non-breaking spaces
    return text

def convert_fullwidth_to_normal(text):
    """
    Convert full-width digits (０１２３４５６７８９) to normal digits (0123456789).
    """
    return "".join(chr(ord(char) - 0xFEE0) if '０' <= char <= '９' else char for char in text)

def chinese_to_pinyin(text):
    return " ".join(pypinyin.lazy_pinyin(text, style=Style.NORMAL))

formatted_dataset, formatted_dataset_eval = [], []

for i in range(2, dataset.num_rows, 5):
    chinese = dataset[i]["text"][10:]
    pinyin = dataset[i + 1]["text"][8:]

    chinese = clean_punctuations(chinese)
    chinese = clean_spaces(chinese)
    chinese = convert_fullwidth_to_normal(chinese)

    pinyin = chinese_to_pinyin(chinese)

    if ((i+3) % 2000) == 0:
        if len(chinese) < 60 and not contains_english.search(chinese) and pinyin not in [entry["Pinyin"] for entry in formatted_dataset_eval]:
            formatted_dataset_eval.append({
                "Pinyin": pinyin,
                "Chinese": chinese
            })
    else:
        if len(chinese) < 60 and not contains_english.search(chinese) and chinese not in [entry["Chinese"] for entry in formatted_dataset]:
            formatted_dataset.append({
                "Chinese": chinese,
                "Pinyin": pinyin
            })

df = pd.DataFrame(formatted_dataset)
df_eval = pd.DataFrame(formatted_dataset_eval)

df.to_csv("train.csv", index=False, encoding="utf-8")
df_eval.to_csv("eval.csv", index=False, encoding="utf-8")

# Training (Example: mT5-base)

In [None]:
csv_path = "train.csv"
df = pd.read_csv(csv_path)

examples = []
for _, row in df.iterrows():
    if pd.notna(row.get("Pinyin")) and pd.notna(row.get("Chinese")):
        examples.append({
            "input": row["Pinyin"],
            "target": row["Chinese"]
        })
dataset = Dataset.from_pandas(pd.DataFrame(examples))


train_size = int(0.8 * len(dataset))
eval_size = len(dataset) - train_size
train_subset, eval_subset = torch.utils.data.random_split(dataset, [train_size, eval_size])

train_dataset = Dataset.from_list([dataset[i] for i in train_subset.indices])
eval_dataset = Dataset.from_list([dataset[i] for i in eval_subset.indices])
del dataset

model_name = "google/mt5-base"
tokenizer = MT5Tokenizer.from_pretrained(model_name)
model = MT5ForConditionalGeneration.from_pretrained(model_name)

def preprocess_function(examples):
    prefix = "拼音转中文："
    inputs = [prefix + text for text in examples["input"]]
    targets = examples["target"] # Chinese

    model_inputs = tokenizer(
        inputs,
        padding="longest",
        truncation=True,
    )

    labels = tokenizer(
        text_target=targets,
        truncation=True,
        padding="longest"
    )

    label_ids = [
        [-100 if token == tokenizer.pad_token_id else token for token in label]
        for label in labels["input_ids"]
    ]
    model_inputs["labels"] = label_ids

    return model_inputs

# Tokenize the entire dataset
tokenized_train_dataset = train_dataset.map(preprocess_function, batched=True)
tokenized_eval_dataset = eval_dataset.map(preprocess_function, batched=True)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

bleu = load("bleu")
chrf = load("chrf")
rouge = load("rouge")

def space_chars(text):
    return " ".join(list(text.strip()))

def compute_metrics(eval_preds):
    preds, labels = eval_preds

    print("Pred shape:", np.array(preds).shape)
    print("Labels shape:", np.array(labels).shape)

    labels = np.where(labels == -100, tokenizer.pad_token_id, labels)

    # clip predictions to valid token ID range
    preds = np.clip(preds, 0, tokenizer.vocab_size - 1)

    invalid = [(i, val) for i, row in enumerate(labels) for val in row if val < 0 or val >= tokenizer.vocab_size]
    if invalid:
        print("Invalid token IDs found:", invalid[:5])  # print just a few for now
        raise ValueError("Found token ids out of tokenizer vocab range.")

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Normalize whitespace, remove special tokens, etc. if needed
    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [label.strip() for label in decoded_labels]

    # Character-level accuracy
    char_correct = 0
    char_total = 0
    for pred, label in zip(decoded_preds, decoded_labels):
        char_total += len(label)
        char_correct += sum(p == l for p, l in zip(pred, label))

    char_accuracy = char_correct / char_total if char_total > 0 else 0.0

    spaced_preds = [space_chars(pred) for pred in decoded_preds]
    spaced_labels = [space_chars(label) for label in decoded_labels]

    # BLEU (optional, use for logging or reference)
    bleu_result = bleu.compute(predictions=spaced_preds, references=[[lbl] for lbl in spaced_labels])
    bleu_score = bleu_result["bleu"]

    # chrf
    chrf_score = chrf.compute(predictions=spaced_preds, references=[[lbl] for lbl in spaced_labels])["score"]

    # ROUGE (use spaced strings so it treats each char as a token)
    rouge_result = rouge.compute(predictions=spaced_preds, references=spaced_labels, use_stemmer=False)
    rouge1 = rouge_result["rouge1"]
    rouge2 = rouge_result["rouge2"]
    rougeL = rouge_result["rougeL"]

    return {
        "char_accuracy": char_accuracy,
        "chrf": chrf_score,
        "bleu": bleu_score,
        "rouge1": rouge1,
        "rouge2": rouge2,
        "rougeL": rougeL,
    }

training_args = Seq2SeqTrainingArguments(
    output_dir="./mt5_pinyin_to_chinese",  
    evaluation_strategy="steps",
    learning_rate=3e-5,
    per_device_train_batch_size=16, 
    per_device_eval_batch_size=16,   
    weight_decay=0.01,
    save_total_limit=2,     
    num_train_epochs=20,         
    predict_with_generate=True,
    load_best_model_at_end=True,
    greater_is_better=True,
    metric_for_best_model="char_accuracy",
    gradient_accumulation_steps=8,
    save_steps=200,                    
    logging_steps=100,         
    fp16=False,
    bf16=True,
    eval_steps=200,          
    save_strategy="steps",
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback()]
)

trainer.train()
trainer.evaluate()
trainer.save_model("./mt5-base")

# Evaluate (Example: Against standard clean data)

In [None]:
model_name = "./mt5-base"

model = MT5ForConditionalGeneration.from_pretrained(model_name)
tokenizer = MT5Tokenizer.from_pretrained(model_name)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

eval_csv_path = "./eval.csv"

eval_data = pd.read_csv(eval_csv_path)

preds = []
labels = []
for sample in tqdm(eval_data.to_dict(orient="records")):
    input_text = "拼音转中文：" + sample["Pinyin"]
    input_ids = tokenizer(input_text, return_tensors="pt", padding="longest", truncation=True)["input_ids"]
    input_ids = torch.LongTensor(input_ids).view(1, -1).to(model.device)
    generated_ids = model.generate(
        input_ids,
        max_new_tokens=64,
        num_beams=4,  
        early_stopping=True,  
        num_return_sequences=1,  
    )
    pred = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    print(f"Pinyin: " + sample["Pinyin"] + "\nChinese: " + pred + "\n\n")
    preds.append(pred.strip())
    labels.append(sample["Chinese"].strip())

def space_chars(text):
    return " ".join(list(text.strip()))

char_correct = 0
char_total = 0
for pred, label in zip(preds, labels):
    char_total += len(label)
    char_correct += sum(p == l for p, l in zip(pred, label))
char_accuracy = char_correct / char_total if char_total > 0 else 0.0

spaced_preds = [space_chars(p) for p in preds]
spaced_labels = [space_chars(l) for l in labels]

bleu_result = bleu.compute(predictions=spaced_preds, references=[[l] for l in spaced_labels])
chrf_result = chrf.compute(predictions=spaced_preds, references=[[l] for l in spaced_labels])
rouge_result = rouge.compute(predictions=spaced_preds, references=spaced_labels, use_stemmer=False)

results = {
    "char_accuracy": char_accuracy,
    "bleu": bleu_result["bleu"],
    "chrf": chrf_result["score"],
    "rouge1": rouge_result["rouge1"],
    "rouge2": rouge_result["rouge2"],
    "rougeL": rouge_result["rougeL"],
    "predictions": preds,
    "references": labels
}

results