In [None]:
import pypinyin
import re
import pandas as pd
from pypinyin import Style
from datasets import load_dataset

In [None]:
# Dataset: https://huggingface.co/datasets/swaption2009/20k-en-zh-translation-pinyin-hsk

ds = load_dataset("swaption2009/20k-en-zh-translation-pinyin-hsk")
dataset = ds["train"]

contains_english = re.compile(r'[a-zA-Z]')

def clean_punctuations(p):
    """
    Remove common Chinese-style or western punctuation
    """
    return re.sub(r"[。.,，！？!?:：；;\"'‘’“”()（）《》【】＇｀……\-－／/、\[\]［］＂·—]", "", p)

def clean_spaces(text):
    """
    Remove all spaces from Chinese text
    """
    text = text.replace(" ", "").replace("\u00A0", "").replace("　", "")  # Remove regular and non-breaking spaces
    return text

def convert_fullwidth_to_normal(text):
    """
    Convert full-width digits (０１２３４５６７８９) to normal digits (0123456789).
    """
    return "".join(chr(ord(char) - 0xFEE0) if '０' <= char <= '９' else char for char in text)

def chinese_to_pinyin(text):
    return " ".join(pypinyin.lazy_pinyin(text, style=Style.NORMAL))

formatted_dataset, formatted_dataset_eval = [], []

for i in range(2, dataset.num_rows, 5):
    chinese = dataset[i]["text"][10:]
    pinyin = dataset[i + 1]["text"][8:]

    chinese = clean_punctuations(chinese)
    chinese = clean_spaces(chinese)
    chinese = convert_fullwidth_to_normal(chinese)

    pinyin = chinese_to_pinyin(chinese)

    if ((i+3) % 2000) == 0:
        if len(chinese) < 60 and not contains_english.search(chinese) and pinyin not in [entry["Pinyin"] for entry in formatted_dataset_eval]:
            formatted_dataset_eval.append({
                "Pinyin": pinyin,
                "Chinese": chinese
            })
    else:
        if len(chinese) < 60 and not contains_english.search(chinese) and chinese not in [entry["Chinese"] for entry in formatted_dataset]:
            formatted_dataset.append({
                "Chinese": chinese,
                "Pinyin": pinyin
            })

df = pd.DataFrame(formatted_dataset)
df_eval = pd.DataFrame(formatted_dataset_eval)

df.to_csv("../data/inputs/standard/train.csv", index=False, encoding="utf-8")
df_eval.to_csv("../data/inputs/standard/eval.csv", index=False, encoding="utf-8")

### Check data length and verify data alignment

In [None]:
max_length = max(len(sen) for sen in df["Pinyin"])
print(f"The length of the longest Pinyin sentence is: {max_length}")
max_length = max(len(sen) for sen in df["Chinese"])
print(f"The length of the longest Chinese sentence is: {max_length}")

max_length = max(len(sen) for sen in df_eval["Pinyin"])
print(f"The length of the longest Pinyin sentence is: {max_length}")

for i in range(len(df)):
    chinese_text = df["Chinese"][i]
    pinyin_text = df["Pinyin"][i]
    
    numbers_in_chinese = re.findall(r"\d+", chinese_text)
    symbols_in_chinese = re.findall(r"[%]", chinese_text)
    
    # Calculate the total length of numbers
    total_number_length = sum(len(num) for num in numbers_in_chinese)
    total_symbol_count = len(symbols_in_chinese)
    
    adjusted_chinese_length = len(chinese_text) - total_number_length - total_symbol_count + len(numbers_in_chinese)
    pinyin_word_count = len(pinyin_text.split(" "))
    
    if adjusted_chinese_length != pinyin_word_count:
        print(f"Line {i} has a mismatch in length")
        print(f"Chinese length (adjusted): {adjusted_chinese_length}, Pinyin word count: {pinyin_word_count}")
        print(f"Chinese: {chinese_text}")
        print(f"Pinyin: {pinyin_text}")
        print("\n\n")