In [None]:
import os, re, time, unicodedata, csv
from pathlib import Path

def u_nfc(s: str) -> str:
    return unicodedata.normalize("NFC", s)

# 간단 정규화: 한글 보존, 구두점 제거, 공백 정리
_re_punct = re.compile(r"[^\w\s]", flags=re.UNICODE)
_re_ws = re.compile(r"\s+")

def normalize_for_wer(s: str) -> str:
    s = u_nfc(s).lower()
    s = _re_punct.sub(" ", s)
    s = _re_ws.sub(" ", s).strip()
    return s

def normalize_for_cer(s: str) -> str:
    s = normalize_for_wer(s)
    s = s.replace(" ", "")
    return s

def levenshtein(seq_a, seq_b):
    # 문자열 또는 토큰 리스트 모두 지원
    n, m = len(seq_a), len(seq_b)
    if n == 0: return m
    if m == 0: return n
    dp = list(range(m+1))
    for i in range(1, n+1):
        prev, dp[0] = dp[0], i
        for j in range(1, m+1):
            cur = dp[j]
            cost = 0 if seq_a[i-1] == seq_b[j-1] else 1
            dp[j] = min(dp[j] + 1, dp[j-1] + 1, prev + cost)
            prev = cur
    return dp[m]

def cer_score(ref: str, hyp: str):
    r = normalize_for_cer(ref)
    h = normalize_for_cer(hyp)
    if len(r) == 0:
        return 0.0, 0, 0  # (CER, dist, N)
    dist = levenshtein(r, h)
    return dist / len(r), dist, len(r)

def wer_score(ref: str, hyp: str):
    r = normalize_for_wer(ref).split()
    h = normalize_for_wer(hyp).split()
    if len(r) == 0:
        return 0.0, 0, 0
    dist = levenshtein(r, h)
    return dist / len(r), dist, len(r)

In [2]:
import sherpa_onnx
from pathlib import Path

# 모델 폴더 지정 (네 환경 그대로 사용하거나 필요한 경로로 바꿔도 됨)
MODEL_DIR = Path("/Users/leejeje/Desktop/DSL/25-1/Modeling/model/sherpa-onnx-streaming-zipformer-korean-2024-06-16")

def pick_model_files(model_dir: Path):
    def pick(prefix: str):
        int8 = sorted(model_dir.glob(f"{prefix}-epoch-*.int8.onnx"))
        fp32 = sorted(p for p in model_dir.glob(f"{prefix}-epoch-*.onnx") if ".int8." not in p.name)
        return (int8[-1] if int8 else (fp32[-1] if fp32 else None))
    tokens = model_dir / "tokens.txt"
    enc = pick("encoder"); dec = pick("decoder"); join = pick("joiner")
    assert tokens.exists(), "tokens.txt 없음"
    assert enc and dec and join, "encoder/decoder/joiner onnx를 찾지 못함"
    return tokens, enc, dec, join

TOKENS, ENC, DEC, JOIN = pick_model_files(MODEL_DIR)

recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
    tokens=str(TOKENS),
    encoder=str(ENC),
    decoder=str(DEC),
    joiner=str(JOIN),
    decoding_method="greedy_search",
    num_threads=2,
    provider="cpu",   # 가능하면 'coreml'도 시도 가능 (onnxruntime에서 CoreML EP가 보일 때)
)

print("✅ 모델 로드 완료")
print(" - tokens :", TOKENS.name)
print(" - encoder:", ENC.name)
print(" - decoder:", DEC.name)
print(" - joiner :", JOIN.name)

✅ 모델 로드 완료
 - tokens : tokens.txt
 - encoder: encoder-epoch-99-avg-1.int8.onnx
 - decoder: decoder-epoch-99-avg-1.int8.onnx
 - joiner : joiner-epoch-99-avg-1.int8.onnx


In [11]:
import os, re, time, unicodedata, csv
from pathlib import Path

def u_nfc(s: str) -> str:
    return unicodedata.normalize("NFC", s)

# 간단 정규화: 한글 보존, 구두점 제거, 공백 정리
import re
_re_punct = re.compile(r"[^\w\s]", flags=re.UNICODE)
_re_ws = re.compile(r"\s+")

def normalize_for_wer(s: str) -> str:
    s = u_nfc(s).lower()
    s = _re_punct.sub(" ", s)
    s = _re_ws.sub(" ", s).strip()
    return s

def normalize_for_cer(s: str) -> str:
    s = normalize_for_wer(s)
    s = s.replace(" ", "")
    return s

def levenshtein(seq_a, seq_b):
    n, m = len(seq_a), len(seq_b)
    if n == 0: return m
    if m == 0: return n
    dp = list(range(m+1))
    for i in range(1, n+1):
        prev, dp[0] = dp[0], i
        for j in range(1, m+1):
            cur = dp[j]
            cost = 0 if seq_a[i-1] == seq_b[j-1] else 1
            dp[j] = min(dp[j] + 1, dp[j-1] + 1, prev + cost)
            prev = cur
    return dp[m]

def cer_score(ref: str, hyp: str):
    r = normalize_for_cer(ref)
    h = normalize_for_cer(hyp)
    if len(r) == 0: return 0.0, 0, 0
    dist = levenshtein(r, h)
    return dist / len(r), dist, len(r)

def wer_score(ref: str, hyp: str):
    r = normalize_for_wer(ref).split()
    h = normalize_for_wer(hyp).split()
    if len(r) == 0: return 0.0, 0, 0
    dist = levenshtein(r, h)
    return dist / len(r), dist, len(r)

def load_trn(trn_path) -> dict:
    """
    TRN 라인 예시 여러 형태를 모두 허용:
      1) '문장 텍스트 ... (KsponSpeech_E00001)'
      2) 'KsponSpeech_E00001 문장 텍스트 ...'
      3) 'KsponSpeech_E00001.wav\t문장 텍스트 ...'
    반환: { 'KsponSpeech_E00001': '문장 텍스트 ...', ... }
    """
    trn_path = Path(trn_path)
    mapping = {}
    with trn_path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue

            # 케이스 1) 마지막 괄호의 ID
            m = re.search(r"\(([^)]+)\)\s*$", line)
            if m:
                utt = m.group(1)
                text = line[:m.start()].strip()
            else:
                # 케이스 2/3) <utt>[.wav] <sep> <text>
                parts = re.split(r"[\t ]+", line, maxsplit=1)
                if len(parts) == 2:
                    utt, text = parts[0], parts[1]
                else:
                    # 파싱 실패 시 스킵
                    continue

            utt = os.path.basename(utt)
            utt = os.path.splitext(utt)[0]  # .wav 제거
            mapping[utt] = text
    return mapping


In [9]:
import soundfile as sf
import numpy as np

def resample_to_16k(wave: np.ndarray, sr: int) -> np.ndarray:
    if sr == 16000:
        return wave.astype("float32", copy=False)
    # 의존성 없이 간단 선형보간 리샘플
    import numpy as np
    x = np.arange(len(wave))
    new_len = int(round(len(wave) * 16000 / sr))
    new_x = np.linspace(0, len(wave)-1, new_len)
    out = np.interp(new_x, x, wave).astype("float32")
    return out

def decode_once(audio: np.ndarray, sr: int) -> str:
    if audio.ndim == 2:
        audio = audio.mean(axis=1)
    if sr != 16000:
        audio = resample_to_16k(audio, sr)
        sr = 16000
    stream = recognizer.create_stream()
    stream.accept_waveform(sr, audio.astype("float32", copy=False))
    stream.input_finished()
    while recognizer.is_ready(stream):
        recognizer.decode_stream(stream)
    res = recognizer.get_result(stream)
    # sherpa-onnx 1.12.10은 str을 반환
    return res if isinstance(res, str) else (res.text if hasattr(res, "text") else str(res))
    
def evaluate_split(wav_dir, trn_path, csv_out=None):
    from pathlib import Path
    wav_dir  = Path(wav_dir)
    trn_path = Path(trn_path)
    csv_out  = Path(csv_out) if csv_out is not None else None
    refs = load_trn(trn_path)
    wavs = sorted(wav_dir.glob("*.wav"))
    assert wavs, f"WAV 없음: {wav_dir}"
    print(f"Files: {len(wavs)} | TRN keys: {len(refs)}")

    rows = []
    tot_cdist = tot_cN = 0
    tot_wdist = tot_wN = 0
    tot_secs  = tot_infer = 0.0
    miss_ref = 0

    for i, wav in enumerate(wavs, 1):
        utt = wav.stem
        ref = refs.get(utt)
        if ref is None:
            miss_ref += 1
            continue

        audio, sr = sf.read(str(wav), dtype="float32", always_2d=False)
        dur = float(len(audio) / sr)
        t0 = time.perf_counter()
        hyp = decode_once(audio, sr)
        t1 = time.perf_counter()
        infer = t1 - t0
        rtf = infer / max(dur, 1e-9)

        cer, cdist, cN = cer_score(ref, hyp)
        wer, wdist, wN = wer_score(ref, hyp)

        rows.append({
            "utt": utt,
            "dur_s": round(dur, 3),
            "infer_s": round(infer, 3),
            "rtf": round(rtf, 3),
            "CER": round(cer, 4),
            "WER": round(wer, 4),
            "ref": ref,
            "hyp": hyp,
        })

        tot_cdist += cdist; tot_cN += cN
        tot_wdist += wdist; tot_wN += wN
        tot_secs  += dur;   tot_infer += infer

        if i % 20 == 0:
            print(f"[{i}/{len(wavs)}] RTF~{rtf:.2f} | CER~{cer:.3f} | WER~{wer:.3f}")

    overall_cer = (tot_cdist / tot_cN) if tot_cN else 0.0
    overall_wer = (tot_wdist / tot_wN) if tot_wN else 0.0
    avg_rtf     = (tot_infer / tot_secs) if tot_secs else 0.0
    print("\n=== Summary ===")
    print(f"Files scored     : {len(rows)} (missing refs: {miss_ref})")
    print(f"Total audio (s)  : {tot_secs:.1f}")
    print(f"Total infer (s)  : {tot_infer:.1f}")
    print(f"Avg RTF          : {avg_rtf:.3f}")
    print(f"CER (char)       : {overall_cer:.4f}")
    print(f"WER (word)       : {overall_wer:.4f}")

    if csv_out:
        csv_out.parent.mkdir(parents=True, exist_ok=True)
        with open(csv_out, "w", newline="", encoding="utf-8") as f:
            w = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
            w.writeheader()
            w.writerows(rows)
        print(f"Saved: {csv_out}")

    # 상위 5개 오차 큰 샘플(참고)
    rows_sorted = sorted(rows, key=lambda r: (-r["CER"], -r["WER"], -r["rtf"]))
    print("\nTop-5 by CER:")
    for r in rows_sorted[:5]:
        print(f"- {r['utt']} | CER {r['CER']:.3f} WER {r['WER']:.3f} RTF {r['rtf']:.2f} | ref: {r['ref']} | hyp: {r['hyp']}")
    return {
        "rows": rows,
        "summary": dict(files=len(rows), miss_ref=miss_ref, total_audio_s=tot_secs,
                        total_infer_s=tot_infer, avg_rtf=avg_rtf,
                        cer=overall_cer, wer=overall_wer)
    }


In [12]:
# 데이터 루트 지정 (네 스샷 구조 기준)
DATA_ROOT = Path("data")  # 필요하면 절대경로로 바꿔도 됨

# 1) eval_clean
wav_dir = "/Users/leejeje/Desktop/DSL/25-1/Modeling/data/KsponSpeech_eval/eval_clean"
trn_path = "/Users/leejeje/Desktop/DSL/25-1/Modeling/data/KsponSpeech_scripts/eval_clean.trn"
out_csv = Path("results_eval") / "eval_clean_results.csv"
res_clean = evaluate_split(wav_dir, trn_path, csv_out=out_csv)

# # 2) (선택) eval_other
# wav_dir = DATA_ROOT / "KsponSpeech_eval" / "eval_other"
# trn_path = DATA_ROOT / "KsponSpeech_scripts" / "eval_other.trn"
# out_csv = Path("results_eval") / "eval_other_results.csv"
# res_other = evaluate_split(wav_dir, trn_path, csv_out=out_csv)


Files: 3000 | TRN keys: 3000
[20/3000] RTF~0.05 | CER~0.000 | WER~0.105
[40/3000] RTF~0.04 | CER~0.545 | WER~0.600
[60/3000] RTF~0.04 | CER~0.143 | WER~0.333
[80/3000] RTF~0.04 | CER~0.714 | WER~0.750
[100/3000] RTF~0.05 | CER~0.322 | WER~0.519
[120/3000] RTF~0.09 | CER~0.750 | WER~0.500
[140/3000] RTF~0.05 | CER~0.115 | WER~0.231
[160/3000] RTF~0.06 | CER~0.182 | WER~0.500
[180/3000] RTF~0.04 | CER~0.500 | WER~0.400
[200/3000] RTF~0.07 | CER~0.000 | WER~0.667
[220/3000] RTF~0.04 | CER~1.000 | WER~1.000
[240/3000] RTF~0.06 | CER~0.085 | WER~0.167
[260/3000] RTF~0.03 | CER~0.667 | WER~0.500
[280/3000] RTF~0.06 | CER~0.182 | WER~0.300
[300/3000] RTF~0.04 | CER~0.667 | WER~0.500
[320/3000] RTF~0.05 | CER~0.000 | WER~0.000
[340/3000] RTF~0.05 | CER~0.200 | WER~0.500
[360/3000] RTF~0.06 | CER~0.184 | WER~0.412
[380/3000] RTF~0.04 | CER~0.667 | WER~0.667
[400/3000] RTF~0.07 | CER~0.231 | WER~0.400
[420/3000] RTF~0.14 | CER~0.200 | WER~0.200
[440/3000] RTF~0.06 | CER~0.021 | WER~0.056
[460/30