# XLS-R (Wav2Vec2) CTC Fine-tuning on BAAI/CS-Dialogue (Mandarin–English Code-Switching)

**Goal:** Fine-tune `facebook/wav2vec2-xls-r-300m` with a custom CTC vocabulary built from CS-Dialogue.

**Key points**
- Uses `datasets.Audio` + `cast_column(..., Audio(16000))`.
- Builds mixed EN (A–Z) + Chinese vocab; maps space→`|`.
- Correct padding & metrics (WER/CER).

References:
- CS-Dialogue dataset card (16kHz, structure): https://huggingface.co/datasets/BAAI/CS-Dialogue
- XLS-R-300M model card (16kHz input): https://huggingface.co/facebook/wav2vec2-xls-r-300m
- Datasets audio processing: https://huggingface.co/docs/datasets/en/audio_process
- WER / CER metrics: https://huggingface.co/spaces/evaluate-metric/wer , https://huggingface.co/spaces/evaluate-metric/cer


In [1]:
# installations
!pip -q install   "evaluate==0.4.3" "jiwer==3.0.4"
!pip install "torchcodec==0.7.*" --index-url https://download.pytorch.org/whl/cu126
!pip install -U "peft==0.17.1" "accelerate>=1.1.0"
!pip install -U "bitsandbytes==0.48.1"
!pip install -U "pyctcdecode==0.5.0"
# !pip install "regex==2024.9.11"
#===============================================================================
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:512"
# os.environ["HF_HUB_DISABLE_XET"] = "1"         # 强制不用 XET 后端
# os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"  # 使用更快更稳的 hf_transfer
from importlib.metadata import version, PackageNotFoundError
import transformers, datasets, evaluate, tokenizers, huggingface_hub, torch, torchaudio, jiwer
import peft
from peft import LoraConfig

def safe_ver(pkg_name):
    try:
        return version(pkg_name)
    except PackageNotFoundError:
        return "(missing)"

# 改用 metadata 查询
print("Transformers  :", safe_ver("transformers"))
print("Datasets      :", safe_ver("datasets"))
print("Evaluate      :", safe_ver("evaluate"))
print("JiWER         :", safe_ver("jiwer"))
print("Tokenizers    :", safe_ver("tokenizers"))
print("HF Hub        :", safe_ver("huggingface_hub"))
print("Torch         :", torch.__version__)
print("Torchaudio    :", torchaudio.__version__)
print("PEFT          :", safe_ver("peft"))
print("Accelerate    :", safe_ver("accelerate"))
print("Bitsandbytes  :", safe_ver("bitsandbytes"))
print("PyCTCDecode   :", safe_ver("pyctcdecode"))
print("Regex         :", safe_ver("regex"))
print("PYTORCH_CUDA_ALLOC_CONF =", os.environ.get("PYTORCH_CUDA_ALLOC_CONF"))
# 简单构建一个 LoRA 配置，验证 API 可用
test_cfg = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.05, bias="none",
                      target_modules=["q_proj","k_proj","v_proj","out_proj"])
print(test_cfg)


Looking in indexes: https://download.pytorch.org/whl/cu126
Transformers  : 4.57.1
Datasets      : 4.0.0
Evaluate      : 0.4.3
JiWER         : 3.0.4
Tokenizers    : 0.22.1
HF Hub        : 0.35.3
Torch         : 2.8.0+cu126
Torchaudio    : 2.8.0+cu126
PEFT          : 0.17.1
Accelerate    : 1.11.0
Bitsandbytes  : 0.48.1
PyCTCDecode   : 0.5.0
Regex         : 2024.11.6
PYTORCH_CUDA_ALLOC_CONF = expandable_segments:True,max_split_size_mb:512
LoraConfig(task_type=None, peft_type=<PeftType.LORA: 'LORA'>, auto_mapping=None, base_model_name_or_path=None, revision=None, inference_mode=False, r=8, target_modules={'v_proj', 'q_proj', 'out_proj', 'k_proj'}, exclude_modules=None, lora_alpha=16, lora_dropout=0.05, fan_in_fan_out=False, bias='none', use_rslora=False, modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={}, megatron_config=None, megatron_core='megatron.core', trainable_token_indices=None, loftq_config={}, eva_config=

In [2]:
!python -m bitsandbytes


Platform: Linux-6.6.105+-x86_64-with-glibc2.35
  libc: glibc-2.35
Python: 3.12.12
PyTorch: 2.8.0+cu126
  CUDA: 12.6
  HIP: N/A
  XPU: N/A
Related packages:
  accelerate: 1.11.0
  diffusers: 0.35.2
  numpy: 1.26.4
  pip: 24.1.2
  peft: 0.17.1
  safetensors: 0.6.2
  transformers: 4.57.1
  triton: 3.4.0
  trl: not found
PyTorch settings found: CUDA_VERSION=126, Highest Compute Capability: (8, 0).
Checking that the library is importable and CUDA is callable...
SUCCESS!


In [3]:
import os, re, tarfile, json
from pathlib import Path
from collections import Counter
from typing import Dict, List
import numpy as np, torch
from datasets import Dataset, DatasetDict, Audio, Features, Value
from huggingface_hub import hf_hub_download
from transformers import AutoFeatureExtractor, AutoModelForCTC, TrainingArguments, Trainer, Wav2Vec2Processor, Wav2Vec2CTCTokenizer
import evaluate

SEED=42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)

BASE_DIR = Path('/content') if Path('/content').exists() else Path.cwd()
DATA_ROOT = BASE_DIR / 'cs_dialogue'
AUDIO_DIR = DATA_ROOT / 'short_wav'
INDEX_DIR = DATA_ROOT / 'data' / 'index' / 'short_wav'
VOCAB_DIR = DATA_ROOT / 'custom_vocab'
for p in [DATA_ROOT, AUDIO_DIR, INDEX_DIR, VOCAB_DIR]: p.mkdir(parents=True, exist_ok=True)

NUM_SHARDS = int(os.environ.get('CS_NUM_SHARDS', 19))  # increase to 19 for full short_wav


## 1) Download index & audio shards (short_wav)

In [4]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [5]:
from pathlib import Path
import os, shutil, tarfile, glob

REPO_ID = "BAAI/CS-Dialogue"        # 数据集仓库
# 你已有：DATA_ROOT = Path('/content/cs_dialogue'); AUDIO_DIR = DATA_ROOT/'short_wav'; INDEX_DIR = DATA_ROOT/'index'/'short_wav'

def download_index():
    files = [
        "data/index/short_wav/train/text",
        "data/index/short_wav/train/wav.scp",
        "data/index/short_wav/dev/text",
        "data/index/short_wav/dev/wav.scp",
        "data/index/short_wav/test/text",
        "data/index/short_wav/test/wav.scp",
    ]
    local = []
    for fp in files:
        # ★ 关键：直接落到 INDEX_DIR/ 相对目录；禁用 symlink
        dst_dir = INDEX_DIR / Path(fp).parent.relative_to("data/index/short_wav")
        dst_dir.mkdir(parents=True, exist_ok=True)
        src = hf_hub_download(
            repo_id=REPO_ID,
            filename=fp,
            repo_type="dataset",
            local_dir=str(DATA_ROOT),   # 关键：交给 local_dir 来还原相对层级
        )
        local.append(Path(src))
    print("Index ready:", local)
    return local


def download_shards(n=19):
    got = []
    for i in range(n):
        rel = f"data/short_wav/short_wav.tar.gz{i:02d}"
        src = hf_hub_download(
            repo_id=REPO_ID,
            filename=rel,
            repo_type="dataset",
            local_dir=str(DATA_ROOT),   # 落到 DATA_ROOT/rel
        )
        dst = Path(src)
        print("Downloaded:", dst)
        assert dst.exists() and dst.stat().st_size > 0, f"missing or empty part: {dst}"
        got.append(dst)
    # 关键：分片实际在 DATA_ROOT/data/short_wav 下
    print("In data/short_wav:", sorted(p.name for p in (DATA_ROOT/'data'/'short_wav').glob('short_wav.tar.gz*')))
    return got





def concat_parts(parts, out_file: Path):
    """把 *.tar.gz00.. 拼接成一个完整 tar.gz"""
    parts = sorted(parts, key=lambda p: p.name)
    if not parts:
        raise RuntimeError("no parts provided to concat_parts")
    # 逐个断言存在
    for p in parts:
        if not Path(p).exists():
            raise FileNotFoundError(f"part not found on disk: {p}")
    out_file.parent.mkdir(parents=True, exist_ok=True)
    total = 0
    with open(out_file, "wb") as w:
        for p in parts:
            with open(p, "rb") as r:
                shutil.copyfileobj(r, w)
            sz = Path(p).stat().st_size
            total += sz
            print(f"  appended {Path(p).name} ({sz/1e6:.1f} MB)")
    print(f"==> concatenated -> {out_file} (~{total/1e9:.2f} GB)")

def extract_concatenated_tar_gz(out_dir: Path, parts=None):
    """
    解压流程：
    1) 如未给 parts，则自动从 DATA_ROOT 扫描 short_wav.tar.gz[0-9][0-9]
    2) 拼成 short_wav.tar.gz
    3) 正常 tar 解压到 out_dir
    """
    out_dir.mkdir(parents=True, exist_ok=True)

    # ★ 自动发现分片（避免用了旧的 tars 变量）
    if parts is None:
        # 同时在 DATA_ROOT 根 & DATA_ROOT/data/short_wav 下搜索
        candidates = sorted((DATA_ROOT).glob("short_wav.tar.gz[0-9][0-9]")) \
               + sorted((DATA_ROOT / "data" / "short_wav").glob("short_wav.tar.gz[0-9][0-9]"))
    else:
        candidates = list(map(Path, parts))

    if not candidates:
        raise RuntimeError(f"No split parts found in {DATA_ROOT}. Expected files like short_wav.tar.gz00")

    merged = DATA_ROOT / "short_wav.tar.gz"

    if merged.exists():
      merged.unlink()  # 清除此前拼接失败的半成品
    concat_parts(candidates, merged)

    # 解压
    with tarfile.open(merged, "r:gz") as tf:
        tf.extractall(out_dir, filter="data")  # 3.12+ 推荐
    print("Extracted ->", out_dir)

    # 可选：节省空间
    try:
        merged.unlink()
    except Exception:
        pass

# === 调用顺序 ===
_ = download_index()
_ = download_shards(NUM_SHARDS)                 # e.g. NUM_SHARDS=2

# ★ 完整性检查：必须 19 片全部到位
expected = [f"short_wav.tar.gz{i:02d}" for i in range(19)]
have = sorted(p.name for p in (DATA_ROOT/'data'/'short_wav').glob('short_wav.tar.gz*'))
missing = [x for x in expected if x not in have]
assert not missing, f"缺少分片：{missing}。请把 NUM_SHARDS 设为 19 或补齐后再解压。"

extract_concatenated_tar_gz(AUDIO_DIR, None)   # 让它自己扫描分片再拼接


text: 0.00B [00:00, ?B/s]

wav.scp: 0.00B [00:00, ?B/s]

text: 0.00B [00:00, ?B/s]

wav.scp: 0.00B [00:00, ?B/s]

text: 0.00B [00:00, ?B/s]

wav.scp: 0.00B [00:00, ?B/s]

Index ready: [PosixPath('/content/cs_dialogue/data/index/short_wav/train/text'), PosixPath('/content/cs_dialogue/data/index/short_wav/train/wav.scp'), PosixPath('/content/cs_dialogue/data/index/short_wav/dev/text'), PosixPath('/content/cs_dialogue/data/index/short_wav/dev/wav.scp'), PosixPath('/content/cs_dialogue/data/index/short_wav/test/text'), PosixPath('/content/cs_dialogue/data/index/short_wav/test/wav.scp')]


data/short_wav/short_wav.tar.gz00:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz00


data/short_wav/short_wav.tar.gz01:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz01


data/short_wav/short_wav.tar.gz02:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz02


data/short_wav/short_wav.tar.gz03:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz03


data/short_wav/short_wav.tar.gz04:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz04


data/short_wav/short_wav.tar.gz05:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz05


data/short_wav/short_wav.tar.gz06:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz06


data/short_wav/short_wav.tar.gz07:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz07


data/short_wav/short_wav.tar.gz08:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz08


data/short_wav/short_wav.tar.gz09:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz09


data/short_wav/short_wav.tar.gz10:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz10


data/short_wav/short_wav.tar.gz11:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz11


data/short_wav/short_wav.tar.gz12:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz12


data/short_wav/short_wav.tar.gz13:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz13


data/short_wav/short_wav.tar.gz14:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz14


data/short_wav/short_wav.tar.gz15:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz15


data/short_wav/short_wav.tar.gz16:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz16


data/short_wav/short_wav.tar.gz17:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz17


data/short_wav/short_wav.tar.gz18:   0%|          | 0.00/158M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz18
In data/short_wav: ['short_wav.tar.gz00', 'short_wav.tar.gz01', 'short_wav.tar.gz02', 'short_wav.tar.gz03', 'short_wav.tar.gz04', 'short_wav.tar.gz05', 'short_wav.tar.gz06', 'short_wav.tar.gz07', 'short_wav.tar.gz08', 'short_wav.tar.gz09', 'short_wav.tar.gz10', 'short_wav.tar.gz11', 'short_wav.tar.gz12', 'short_wav.tar.gz13', 'short_wav.tar.gz14', 'short_wav.tar.gz15', 'short_wav.tar.gz16', 'short_wav.tar.gz17', 'short_wav.tar.gz18']
  appended short_wav.tar.gz00 (524.3 MB)
  appended short_wav.tar.gz01 (524.3 MB)
  appended short_wav.tar.gz02 (524.3 MB)
  appended short_wav.tar.gz03 (524.3 MB)
  appended short_wav.tar.gz04 (524.3 MB)
  appended short_wav.tar.gz05 (524.3 MB)
  appended short_wav.tar.gz06 (524.3 MB)
  appended short_wav.tar.gz07 (524.3 MB)
  appended short_wav.tar.gz08 (524.3 MB)
  appended short_wav.tar.gz09 (524.3 MB)
  appended short_wav.tar.gz10 (524.3 MB)
  appended short_wav.tar.gz11 (524.3 MB)
  a

In [6]:
print("INDEX_DIR =", INDEX_DIR)
print("Has train/text?   ->", (INDEX_DIR/'train'/'text').exists())
print("Has train/wav.scp?->", (INDEX_DIR/'train'/'wav.scp').exists())

print("Split parts in DATA_ROOT/data/short_wav:",
      sorted(p.name for p in (DATA_ROOT/'data'/'short_wav').glob('short_wav.tar.gz*'))[:19])


INDEX_DIR = /content/cs_dialogue/data/index/short_wav
Has train/text?   -> True
Has train/wav.scp?-> True
Split parts in DATA_ROOT/data/short_wav: ['short_wav.tar.gz00', 'short_wav.tar.gz01', 'short_wav.tar.gz02', 'short_wav.tar.gz03', 'short_wav.tar.gz04', 'short_wav.tar.gz05', 'short_wav.tar.gz06', 'short_wav.tar.gz07', 'short_wav.tar.gz08', 'short_wav.tar.gz09', 'short_wav.tar.gz10', 'short_wav.tar.gz11', 'short_wav.tar.gz12', 'short_wav.tar.gz13', 'short_wav.tar.gz14', 'short_wav.tar.gz15', 'short_wav.tar.gz16', 'short_wav.tar.gz17', 'short_wav.tar.gz18']


In [7]:
from pathlib import Path

def print_tree(root: Path, max_depth: int = 3):
    root = Path(root)

    def walk(p: Path, depth: int, prefix: str = ""):
        if depth < 0:
            return
        items = sorted(p.iterdir(), key=lambda x: (x.is_file(), x.name.lower()))
        for i, it in enumerate(items):
            is_last = (i == len(items) - 1)
            connector = "└── " if is_last else "├── "
            name = it.name + ("/" if it.is_dir() else "")
            print(prefix + connector + name)
            if it.is_dir():
                next_prefix = prefix + ("    " if is_last else "│   ")
                walk(it, depth - 1, next_prefix)

    root = root.resolve()
    print(root.as_posix())
    walk(root, max_depth)

# 使用：打印 DATA_ROOT 下三层
print_tree(DATA_ROOT, max_depth=4)


/content/cs_dialogue
├── .cache/
│   └── huggingface/
│       ├── download/
│       │   └── data/
│       │       ├── index/
│       │       └── short_wav/
│       └── .gitignore
├── custom_vocab/
├── data/
│   ├── index/
│   │   └── short_wav/
│   │       ├── dev/
│   │       │   ├── text
│   │       │   └── wav.scp
│   │       ├── test/
│   │       │   ├── text
│   │       │   └── wav.scp
│   │       └── train/
│   │           ├── text
│   │           └── wav.scp
│   └── short_wav/
│       ├── short_wav.tar.gz00
│       ├── short_wav.tar.gz01
│       ├── short_wav.tar.gz02
│       ├── short_wav.tar.gz03
│       ├── short_wav.tar.gz04
│       ├── short_wav.tar.gz05
│       ├── short_wav.tar.gz06
│       ├── short_wav.tar.gz07
│       ├── short_wav.tar.gz08
│       ├── short_wav.tar.gz09
│       ├── short_wav.tar.gz10
│       ├── short_wav.tar.gz11
│       ├── short_wav.tar.gz12
│       ├── short_wav.tar.gz13
│       ├── short_wav.tar.gz14
│       ├── short_wav.tar.gz15
│       ├── sho

## 2) Build DatasetDict from wav.scp & text

In [8]:

from pathlib import Path

def resolve_local_audio_path(wavpath: str) -> Path | None:
    """
    把 wav.scp 的路径映射到本地真实存在的文件。
    兼容如下情况：
    - 路径前缀带不带 'data/'
    - 是否出现双重 'short_wav/short_wav'
    - 绝对路径/相对路径
    """
    p = Path(wavpath.strip())

    # 1) 若本身就是绝对路径且存在，直接返回
    if p.is_absolute() and p.exists():
        return p

    candidates: list[Path] = []

    # 2) wav.scp 一般包含 ".../short_wav/..." 这段；抽取从第一处 'short_wav' 之后的尾部
    parts = p.parts
    if "short_wav" in parts:
        i = parts.index("short_wav")
        tail = Path(*parts[i+1:])  # 去掉第一个 'short_wav' 及之前的前缀
        # 情况 A：你的磁盘上是 short_wav/short_wav/WAVE/...
        if (AUDIO_DIR / "short_wav").exists():
            candidates.append(AUDIO_DIR / "short_wav" / tail)  # /.../short_wav/short_wav/...
        # 情况 B：只有一层 short_wav/WAVE/...
        candidates.append(AUDIO_DIR / tail)                     # /.../short_wav/...
    else:
        # 没出现 short_wav 关键词时，尝试几种常见组合
        candidates.append(DATA_ROOT / p)            # /content/cs_dialogue/<wav.scp里的相对路径>
        candidates.append(DATA_ROOT / "data" / p)   # /content/cs_dialogue/data/<...>

    # 3) 再补充几种保守候选
    candidates.append(DATA_ROOT / p)
    if str(p).startswith("data/"):
        candidates.append(DATA_ROOT / str(p))                    # /content/cs_dialogue/data/short_wav/...
        candidates.append(DATA_ROOT / str(p).replace("data/", "", 1))  # /content/cs_dialogue/short_wav/...

    for c in candidates:
        if c.exists():
            return c.resolve()
    return None


def read_kv(fp: Path):
    d={}
    with open(fp, 'r', encoding='utf-8') as f:
        for line in f:
            line=line.strip()
            if not line:
                continue
            # Kaldi风格：<key><space><value...>
            k,v=line.split(' ',1)
            d[k]=v
    return d


def make_split(split: str):
    wavscp = read_kv(INDEX_DIR/split/'wav.scp')
    text   = read_kv(INDEX_DIR/split/'text')

    ids, paths, trans = [], [], []
    miss_audio, miss_text = 0, 0

    for uid, wavpath in wavscp.items():
        local = resolve_local_audio_path(wavpath)
        if local is None or not local.exists():
            miss_audio += 1
            continue
        if uid not in text:
            miss_text += 1
            continue

        ids.append(uid)
        paths.append(str(local))
        trans.append(text[uid])

    print(f"[{split}] matched {len(ids)} items "
          f"(missing audio: {miss_audio}, missing text: {miss_text})")

    feats = Features({
        'id': Value('string'),
        'audio': Audio(sampling_rate=16000),
        'transcription': Value('string')
    })
    return Dataset.from_dict(
        {'id': ids, 'audio': paths, 'transcription': trans},
        features=feats
    )


train_ds = make_split('train')
val_ds   = make_split('dev')
test_ds  = make_split('test')
minds = DatasetDict(train=train_ds, validation=val_ds, test=test_ds)
minds = minds.cast_column('audio', Audio(sampling_rate=16000))
minds


[train] matched 26239 items (missing audio: 0, missing text: 0)
[dev] matched 6186 items (missing audio: 0, missing text: 0)
[test] matched 6257 items (missing audio: 0, missing text: 0)


DatasetDict({
    train: Dataset({
        features: ['id', 'audio', 'transcription'],
        num_rows: 26239
    })
    validation: Dataset({
        features: ['id', 'audio', 'transcription'],
        num_rows: 6186
    })
    test: Dataset({
        features: ['id', 'audio', 'transcription'],
        num_rows: 6257
    })
})

In [9]:

MIN_SEC, MAX_SEC = 0.3, 6
def _keep_ok(ex):
    sec = ex["audio"]["array"].shape[0] / ex["audio"]["sampling_rate"]
    return (sec >= MIN_SEC) and (sec <= MAX_SEC)

minds = minds.filter(_keep_ok, num_proc=8)

Filter (num_proc=8):   0%|          | 0/26239 [00:00<?, ? examples/s]

Filter (num_proc=8):   0%|          | 0/6186 [00:00<?, ? examples/s]

Filter (num_proc=8):   0%|          | 0/6257 [00:00<?, ? examples/s]

In [10]:
minds

DatasetDict({
    train: Dataset({
        features: ['id', 'audio', 'transcription'],
        num_rows: 13001
    })
    validation: Dataset({
        features: ['id', 'audio', 'transcription'],
        num_rows: 2952
    })
    test: Dataset({
        features: ['id', 'audio', 'transcription'],
        num_rows: 3313
    })
})

In [11]:
from datasets import Audio
print("sr(train) =", minds['train'].features['audio'].sampling_rate)
print("sr(val)   =", minds['validation'].features['audio'].sampling_rate)
print("sr(test)  =", minds['test'].features['audio'].sampling_rate)

sr(train) = 16000
sr(val)   = 16000
sr(test)  = 16000


## 3) Normalize transcripts (EN upper + Chinese)

In [12]:
import unicodedata
import regex as re

# 大小写不敏感地移除这些标签
TAG_RE = re.compile(r"</?EN>|</?CN>|</?MIX>|<SPK\s*/?>|<FIL\s*/?>", re.IGNORECASE)

def normalize_text(ex):
    s = ex['transcription']
    if s is None:
        return {'transcription': ""}

    # 0) 统一 Unicode + 去标签
    s = unicodedata.normalize("NFKC", s)
    s = TAG_RE.sub(" ", s)

    # 1) 英文统一大写；仅保留：中文、A–Z、内部撇号'、空格
    s = s.upper()
    s = re.sub(r"[^\u4E00-\u9FFFA-Z' ]+", " ", s)

    # 2) 合并多空格并去首尾空格
    s = re.sub(r"\s+", " ", s).strip()

    return {'transcription': s}

# 注意：map 只调用一次，避免重复处理
minds = minds.map(normalize_text)

minds['train'][7]['transcription'][:120]



Map:   0%|          | 0/13001 [00:00<?, ? examples/s]

Map:   0%|          | 0/2952 [00:00<?, ? examples/s]

Map:   0%|          | 0/3313 [00:00<?, ? examples/s]

'哦 所以你平 你你喜欢看日本动漫是吗'

In [13]:
minds['train'][7]['audio']

<datasets.features._torchcodec.AudioDecoder at 0x7bd0839d98e0>

In [14]:
from datasets import Audio
from IPython.display import Audio as IPyAudio
import numpy as np

# 只挑出第 0 条，并把 audio 列转换为可解码，统一采样率到 16k
ex_ds = minds["train"].select([7]).cast_column(
    "audio",
    Audio(sampling_rate=16000, decode=True)   # 注意：去掉 mono
)

ex  = ex_ds[0]
wav = ex["audio"]["array"]                   # numpy.ndarray
sr  = ex["audio"]["sampling_rate"]           # 16000

# 如果是多声道，手动转单声道（多数数据本来就是一维，不需要这步）
if wav.ndim > 1:
    wav = np.mean(wav, axis=0)

IPyAudio(wav, rate=sr)


## 4) Build CTC vocab (space→`|`)

In [15]:

# 4) Build CTC vocab (train-only; keep VOCAB_DIR)
import os, json, unicodedata, regex as re
from collections import Counter

VOCAB_JSON = os.path.join(VOCAB_DIR, "vocab.json")

TAG_RE  = re.compile(r"</?EN>|</?CN>|</?MIX>|<SPK\s*/?>|<FIL\s*/?>", re.IGNORECASE)
_CJK_RE = re.compile(r"[\u3400-\u4DBF\u4E00-\u9FFF\uF900-\uFAFF]")

def _norm_for_vocab(s: str, delim="|"):
    s = s.replace(" ", delim)
    s = unicodedata.normalize("NFKC", s)
    s = TAG_RE.sub(" ", s)
    s = s.upper()
    s = re.sub(rf"[^\u4E00-\u9FFFA-Z0-9'{delim}]+", " ", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s

def _build_vocab_from_train(train_texts, min_freq=5, delim="|"):
    cnt = Counter()
    for raw in train_texts:
        s = _norm_for_vocab(raw, delim)
        for ch in s:
            if ch != " ":
                cnt[ch] += 1

    tokens = [delim, "'", "<unk>", "<pad>"] + list("ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")

    total_chars = sum(cnt.values())
    # With tiny subsets we cannot afford to drop rare CJK characters.
    cjk_threshold = min_freq if total_chars >= 200000 else 1

    zh_tokens = sorted({c for c, f in cnt.items() if _CJK_RE.match(c) and f >= cjk_threshold})

    if cjk_threshold == 1:
        rare_cjk = [c for c, f in cnt.items() if _CJK_RE.match(c) and f < min_freq]
        if rare_cjk:
            print(f"[vocab] promoted {len(rare_cjk)} rare CJK tokens (min_freq -> 1)")

    tokens += zh_tokens
    seen, ordered = set(), []
    for t in tokens:
        if t not in seen:
            ordered.append(t); seen.add(t)
    return {t:i for i,t in enumerate(ordered)}

os.makedirs(VOCAB_DIR, exist_ok=True)

if os.path.isfile(VOCAB_JSON):
    with open(VOCAB_JSON, "r", encoding="utf-8") as f:
        vocab_json = json.load(f)
    need = set(list("ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") + ["|", "'", "<unk>", "<pad>"])
    missing = [t for t in need if t not in vocab_json]
    print(f"[vocab] loaded from {VOCAB_JSON}, size={len(vocab_json)}, missing={missing}")
else:
    train_texts = [t for t in minds["train"]["transcription"]]
    vocab_json = _build_vocab_from_train(train_texts, min_freq=5, delim="|")
    with open(VOCAB_JSON, "w", encoding="utf-8") as f:
        json.dump(vocab_json, f, ensure_ascii=False, indent=2)
    print(f"[vocab] built & saved to {VOCAB_JSON}, size={len(vocab_json)}")


[vocab] built & saved to /content/cs_dialogue/custom_vocab/vocab.json, size=805


## 5) Init tokenizer/processor & XLS-R-300M model

In [16]:

from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor, Wav2Vec2ForCTC
import torch

tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(
    VOCAB_DIR,
    unk_token="<unk>",
    pad_token="<pad>",
    word_delimiter_token="|",
    lowercase=False
)
feature_extractor = Wav2Vec2FeatureExtractor(
    feature_size=1, sampling_rate=16000, padding_value=0.0,
    do_normalize=True, return_attention_mask=True
)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-xls-r-300m")

model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size   = processor.tokenizer.vocab_size
model.config.ctc_loss_reduction = "mean"
model.config.ctc_zero_infinity  = True

processor.tokenizer.bos_token = None
processor.tokenizer.eos_token = None
if getattr(model, "config", None) is not None:
    if hasattr(model.config, "bos_token_id"): model.config.bos_token_id = None
    if hasattr(model.config, "eos_token_id"): model.config.eos_token_id = None
gen_cfg = getattr(model, "generation_config", None)
if gen_cfg is not None:
    if hasattr(gen_cfg, "bos_token_id"): gen_cfg.bos_token_id = None
    if hasattr(gen_cfg, "eos_token_id"): gen_cfg.eos_token_id = None

if model.lm_head.out_features != processor.tokenizer.vocab_size:
    in_f = model.lm_head.in_features
    new_head = torch.nn.Linear(in_f, processor.tokenizer.vocab_size, bias=True)
    with torch.no_grad():
        old = model.lm_head.out_features
        copy_n = min(old, processor.tokenizer.vocab_size)
        new_head.weight[:copy_n].copy_(model.lm_head.weight[:copy_n])
        new_head.bias[:copy_n].copy_(model.lm_head.bias[:copy_n])
    model.lm_head = new_head.to(model.device)

with torch.no_grad():
    model.lm_head.bias.zero_()
    model.lm_head.bias[processor.tokenizer.pad_token_id] = -5.0

model.config.activation_dropout = 0.0
model.config.attention_dropout  = 0.0
model.config.feat_proj_dropout  = 0.0
model.config.hidden_dropout     = 0.0
model.config.final_dropout      = 0.0
model.config.mask_time_prob     = 0.0
model.config.layerdrop          = 0.0

print("[ok] vocab_size=", processor.tokenizer.vocab_size, "blank/pad id=", processor.tokenizer.pad_token_id)


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DiaTokenizer'. 
The class this function is called from is 'Wav2Vec2CTCTokenizer'.


config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.27G [00:00<?, ?B/s]

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-xls-r-300m and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[ok] vocab_size= 805 blank/pad id= 3


In [17]:
# # ==== （Partial Freezing） ====
# def partially_freeze_wav2vec2_ctc(model, train_last_n_layers=4, freeze_feature_encoder=True):
#     """
#     仅微调 Transformer 编码器的最后 N 层 + CTC 头，其他部分全部冻结。
#     适用于 Wav2Vec2ForCTC/HuBERTForCTC/WavLMForCTC。
#     """
#     # 1) 冻结特征提取器（Conv 特征编码器）
#     if freeze_feature_encoder and hasattr(model, "freeze_feature_encoder"):
#         model.freeze_feature_encoder()
#     # 2) 先冻结整个 encoder
#     # enc = model.wav2vec2.encoder
#     # 2) 兼容不同主干命名：wav2vec2 / hubert / wavlm
#     backbone = getattr(model, "wav2vec2", None) \
#             or getattr(model, "hubert", None) \
#             or getattr(model, "wavlm", None)
#     if backbone is None:
#         raise AttributeError(f"Unsupported model backbone. Expect wav2vec2 / hubert / wavlm, got: {type(model)}")

#     # 3) 先冻结整个 encoder
#     enc = backbone.encoder
#     for p in enc.parameters():
#         p.requires_grad = False
#     # 3) 只解冻最后 N 层
#     if hasattr(enc, "layers"):
#         num_layers = len(enc.layers)
#         n = min(train_last_n_layers, num_layers)
#         for i in range(num_layers - n, num_layers):
#             for p in enc.layers[i].parameters():
#                 p.requires_grad = True
#     # 4) 保持 encoder 的 LayerNorm 可训练（通常有助于稳定）
#     if hasattr(enc, "layer_norm"):
#         for p in enc.layer_norm.parameters():
#             p.requires_grad = True
#     # 5) CTC 头始终可训练
#     for p in model.lm_head.parameters():
#         p.requires_grad = True

# def print_trainable_parameters(model, max_show=20):
#     total, trainable = 0, 0
#     names = []
#     for n, p in model.named_parameters():
#         num = p.numel()
#         total += num
#         if p.requires_grad:
#             trainable += num
#             if len(names) < max_show:
#                 names.append(n)
#     print(f"Trainable params: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")
#     print("Some trainable tensors:", *names, sep="\n  - ")

# # ==== 调用：只训练最后 4 层 Transformer + lm_head ====
# # partially_freeze_wav2vec2_ctc(model, train_last_n_layers=4, freeze_feature_encoder=True)
# # 过拟合 64 条阶段：
# partially_freeze_wav2vec2_ctc(model, train_last_n_layers=12, freeze_feature_encoder=False)

# # 跑完过拟合确认“能学”后，再恢复正式训练：
# # partially_freeze_wav2vec2_ctc(model, train_last_n_layers=8, freeze_feature_encoder=True)


# print_trainable_parameters(model)
# assert any(p.requires_grad for n,p in model.named_parameters() if "lm_head" in n), "lm_head 没解冻"


## 6) Encode → input_values / attention_mask / labels

In [18]:

from datasets import Audio
for sp in ["train","validation","test"]:
    minds[sp] = minds[sp].cast_column("audio", Audio(sampling_rate=16000))

import numpy as np, unicodedata, regex as re
TAG_RE  = re.compile(r"</?EN>|</?CN>|</?MIX>|<SPK\s*/?>|<FIL\s*/?>", re.IGNORECASE)

def _normalize_for_train(s: str):
    s = s.replace(" ", processor.tokenizer.word_delimiter_token)
    s = unicodedata.normalize("NFKC", s)
    s = TAG_RE.sub(" ", s)
    s = s.upper()
    s = re.sub(rf"[^\u4E00-\u9FFFA-Z'{processor.tokenizer.word_delimiter_token}]+", " ", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s

def prepare_batch(ex):
    audio = ex["audio"]
    out = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_attention_mask=True)
    with processor.as_target_processor():
        text = _normalize_for_train(ex["transcription"])
        labels = processor(text).input_ids
    ex["input_values"]   = out.input_values[0]
    ex["attention_mask"] = out.attention_mask[0]
    ex["labels"]         = labels
    return ex

encoded = minds.map(prepare_batch, remove_columns=minds["train"].column_names, num_proc=1)
print({k: encoded[k].features for k in encoded})


Map:   0%|          | 0/13001 [00:00<?, ? examples/s]



model.safetensors:   0%|          | 0.00/1.27G [00:00<?, ?B/s]

Map:   0%|          | 0/2952 [00:00<?, ? examples/s]

Map:   0%|          | 0/3313 [00:00<?, ? examples/s]

{'train': {'input_values': List(Value('float32')), 'attention_mask': List(Value('int32')), 'labels': List(Value('int64'))}, 'validation': {'input_values': List(Value('float32')), 'attention_mask': List(Value('int32')), 'labels': List(Value('int64'))}, 'test': {'input_values': List(Value('float32')), 'attention_mask': List(Value('int32')), 'labels': List(Value('int64'))}}


In [19]:
# CTC blank 使用 tokenizer.pad_token_id（Wav2Vec2 官方约定）
assert model.config.pad_token_id == processor.tokenizer.pad_token_id, \
    f"pad/blank mismatch: model={model.config.pad_token_id}, tok={processor.tokenizer.pad_token_id}"
# 采样率检查（必须是 16,000 Hz）
print("sr(train) =", minds['train'].features['audio'].sampling_rate)
print("sr(val)   =", minds['validation'].features['audio'].sampling_rate)
print("sr(test)  =", minds['test'].features['audio'].sampling_rate)


sr(train) = 16000
sr(val)   = 16000
sr(test)  = 16000


## 7) Collator & Metrics (WER/CER) + Sanity check

In [20]:

from dataclasses import dataclass
from typing import Dict, List, Union
import torch

@dataclass
class DataCollatorCTCWithPadding:
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = "longest"
    def __call__(self, features: List[Dict]):
        inputs = [{"input_values": f["input_values"]} for f in features]
        batch  = self.processor.pad(inputs, padding=self.padding, return_tensors="pt", pad_to_multiple_of=8)
        labels = [{"input_ids": f["labels"]} for f in features]
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(labels, padding=self.padding, return_tensors="pt", pad_to_multiple_of=8)
        label_ids = labels_batch["input_ids"]
        attn = labels_batch.get("attention_mask")
        if attn is not None:
            label_ids = label_ids.masked_fill(attn.ne(1), -100)
        else:
            pad_id = self.processor.tokenizer.pad_token_id
            label_ids = label_ids.masked_fill(label_ids.eq(pad_id), -100)
        batch["labels"] = label_ids
        return batch

data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)


In [21]:

# ===== Metrics (WER / CER / MER) =====
import evaluate, numpy as np, unicodedata, regex as re

wer_metric = evaluate.load('wer')
cer_metric = evaluate.load('cer')

def _norm_eval(xs):
    out = []
    for s in xs:
        s = s.replace(processor.tokenizer.word_delimiter_token, " ")
        s = unicodedata.normalize("NFKC", s)
        s = TAG_RE.sub(" ", s)
        s = s.upper()
        s = re.sub(r"[^\u4E00-\u9FFFA-Z' ]+", " ", s)
        s = re.sub(r"\s+", " ", s).strip()
        out.append(s)
    return out

_CJK_RE = re.compile(r"[\u3400-\u4DBF\u4E00-\u9FFF\uF900-\uFAFF]")
_EN_RE  = re.compile(r"[A-Za-z]")

def _mixed_tokens(s: str):
    s = " ".join(s.strip().split())
    toks = []; i = 0; n = len(s)
    while i < n:
        ch = s[i]
        if _EN_RE.match(ch):
            j = i + 1
            while j < n and (s[j].isalpha() or s[j] == "'"):
                j += 1
            toks.append(s[i:j].upper()); i = j; continue
        if _CJK_RE.match(ch):
            toks.append(ch); i += 1; continue
        if ch.isdigit():
            j = i + 1
            while j < n and s[j].isdigit():
                j += 1
            toks.append(s[i:j]); i = j; continue
        i += 1
    return toks

def _edit_ops(ref_toks, hyp_toks):
    R, H = len(ref_toks), len(hyp_toks)
    dp = [[(0,0,0,0) for _ in range(H+1)] for _ in range(R+1)]
    for j in range(1, H+1):
        c,S,I,D = dp[0][j-1]; dp[0][j] = (c+1, S, I+1, D)
    for i in range(1, R+1):
        c,S,I,D = dp[i-1][0]; dp[i][0] = (c+1, S, I, D+1)
    for i in range(1, R+1):
        for j in range(1, H+1):
            r, h = ref_toks[i-1], hyp_toks[j-1]
            c1,S1,I1,D1 = dp[i-1][j-1]
            sub = (c1, S1, I1, D1) if r == h else (c1+1, S1+1, I1, D1)
            c2,S2,I2,D2 = dp[i][j-1]; ins = (c2+1, S2, I2+1, D2)
            c3,S3,I3,D3 = dp[i-1][j]; dele= (c3+1, S3, I3, D3+1)
            dp[i][j] = min([sub, dele, ins], key=lambda x: (x[0], x[1], x[3], x[2]))
    c,S,I,D = dp[R][H]
    return S,I,D,c

def compute_mer_list(hyps, refs):
    totS = totI = totD = totN = 0
    for hyp, ref in zip(hyps, refs):
        rt, ht = _mixed_tokens(ref), _mixed_tokens(hyp)
        S,I,D,_ = _edit_ops(rt, ht)
        totS += S; totI += I; totD += D; totN += len(rt)
    mer = (totS + totI + totD) / max(1, totN)
    return {"mer": mer, "S": totS, "I": totI, "D": totD, "N": totN}

def compute_metrics(pred):
    pred_ids = np.argmax(pred.predictions, axis=-1)
    labels   = pred.label_ids.copy()
    labels[labels==-100] = processor.tokenizer.pad_token_id

    hyp_raw = processor.batch_decode(pred_ids, skip_special_tokens=False)
    ref_raw = processor.batch_decode(labels,   skip_special_tokens=False)

    pad_tok = processor.tokenizer.pad_token
    unk_tok = processor.tokenizer.unk_token or "<unk>"
    delim_tok = processor.tokenizer.word_delimiter_token or "|"

    def _cleanup(decoded_list):
        cleaned = []
        unk_total = 0
        for s in decoded_list:
            if pad_tok:
                s = s.replace(pad_tok, " ")
            if delim_tok:
                s = s.replace(delim_tok, " ")
            if unk_tok:
                unk_total += s.count(unk_tok)
            s = re.sub(r"\s+", " ", s).strip()
            cleaned.append(s)
        return cleaned, unk_total

    hyp_texts, hyp_unk = _cleanup(hyp_raw)
    ref_texts, ref_unk = _cleanup(ref_raw)

    hyp_n = _norm_eval(hyp_texts)
    ref_n = _norm_eval(ref_texts)

    empty = sum(1 for s in hyp_n if s=="")
    pct = (100.0*empty/len(hyp_n)) if hyp_n else 0.0
    print(f"[metrics-debug] preds={len(hyp_n)}, empty_after_norm={empty} ({pct:.1f}%), unk_pred={hyp_unk}, unk_ref={ref_unk}")

    mer_stats = compute_mer_list(hyp_n, ref_n)

    return {
        "wer": wer_metric.compute(predictions=hyp_n, references=ref_n),
        "cer": cer_metric.compute(predictions=hyp_n, references=ref_n),
        "mer": mer_stats["mer"],
    }


Downloading builder script: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

In [22]:

# 探针 A：原始/规范化解码对照
from collections import Counter
import numpy as np, torch

device = model.device
model.eval()

sample = encoded["validation"].select(range(min(3, len(encoded["validation"]))))
if len(sample):
    feats = [{"input_values": iv} for iv in sample["input_values"]]
    ins = processor.pad(feats, padding=True, return_tensors="pt")
    ins = {k: v.to(device) for k, v in ins.items()}
    with torch.no_grad():
        logits = model(input_values=ins["input_values"],
                       attention_mask=ins.get("attention_mask", None)).logits
    hyp_ids = logits.argmax(dim=-1).cpu().numpy()

    lbl_ids = []
    for seq in sample["labels"]:
        arr = np.array(seq, dtype=np.int64)
        arr[arr == -100] = processor.tokenizer.pad_token_id
        lbl_ids.append(arr.tolist())

    hyp_raw = processor.batch_decode(hyp_ids, skip_special_tokens=False)
    ref_raw = processor.batch_decode(lbl_ids, skip_special_tokens=False)
    hyp = processor.batch_decode(hyp_ids, skip_special_tokens=True)
    ref = processor.batch_decode(lbl_ids,  skip_special_tokens=True)

    def _tidy(xs):
        return [x.replace(processor.tokenizer.word_delimiter_token, " ") for x in xs]

    for i, (r0, h0, r1, h1) in enumerate(zip(ref_raw, hyp_raw, _tidy(ref), _tidy(hyp)), 1):
        print(f"[{i}] REF_RAW: {r0}")
        print(f"[{i}] HYP_RAW: {h0}")
        print(f"[{i}] REF    : {r1}")
        print(f"[{i}] HYP    : {h1}")
        print("-"*60)


[1] REF_RAW: 哎 你好啊 你好啊 你可以叫我<unk>丽
[1] HYP_RAW: 饭岁方努秀由毛由努秀由努毛努毛由毛由毛由毛由努毛秀努毛秀努秀努方概秀努毛努毛努由努秀努秀毛努毛努由毛罗毛由毛秀由秀毛由毛由努毛努秀努毛方秀毛秀岁秀由毛由秀毛努毛努毛由毛秀毛秀毛秀毛努毛努由努秀努D
[1] REF    : 哎 你好啊 你好啊 你可以叫我丽
[1] HYP    : 饭岁方努秀由毛由努秀由努毛努毛由毛由毛由毛由努毛秀努毛秀努秀努方概秀努毛努毛努由努秀努秀毛努毛努由毛罗毛由毛秀由秀毛由毛由努毛努秀努毛方秀毛秀岁秀由毛由秀毛努毛努毛由毛秀毛秀毛秀毛努毛努由努秀努D
------------------------------------------------------------
[2] REF_RAW: 好的
[2] HYP_RAW: 概方努方努秀努秀方秀由秀方秀方秀由努秀努秀努概啥无概
[2] REF    : 好的
[2] HYP    : 概方努方努秀努秀方秀由秀方秀方秀由努秀努秀努概啥无概
------------------------------------------------------------
[3] REF_RAW: 那你大概 嗯 你说
[3] HYP_RAW: 概秀努修努由份努由努秀由毛由秀由秀岁由秀罗秀由秀罗秀毛罗毛秀努秀努秀努秀努秀努秀努秀罗方努秀努由秀努秀方罗由努罗毛秀概秀由毛由秀毛秀毛秀由秀努秀罗秀罗秀由努秀概喝概
[3] REF    : 那你大概 嗯 你说
[3] HYP    : 概秀努修努由份努由努秀由毛由秀由秀岁由秀罗秀由秀罗秀毛罗毛秀努秀努秀努秀努秀努秀努秀罗方努秀努由秀努秀方罗由努罗毛秀概秀由毛由秀毛秀毛秀由秀努秀罗秀罗秀由努秀概喝概
------------------------------------------------------------


In [23]:

# 探针 B：blank 比例
model.eval()
sample2 = encoded["validation"].select(range(min(2, len(encoded["validation"]))))
if len(sample2):
    feats = [{"input_values": iv} for iv in sample2["input_values"]]
    ins = processor.pad(feats, padding=True, return_tensors="pt")
    ins = {k: v.to(model.device) for k, v in ins.items()}
    with torch.no_grad():
        logits = model(input_values=ins["input_values"],
                       attention_mask=ins.get("attention_mask", None)).logits
    hyp_ids = logits.argmax(dim=-1).cpu().numpy()
    pad_id  = processor.tokenizer.pad_token_id
    print("CTC blank/pad id =", pad_id)
    for b, path in enumerate(hyp_ids):
        counts = Counter(path.tolist())
        total = len(path)
        top5 = counts.most_common(5)
        blank_ratio = 100.0 * counts.get(pad_id, 0) / total
        uniq = len(counts)
        print(f"[B{b}] steps={total}, unique_ids={uniq}, blank_ratio={blank_ratio:.1f}%")
        print("      top5 ids(freq):", top5)
        if uniq <= 3 or blank_ratio > 70:
            print("  ⚠️ 解码高度单一：可能是采样率/词表/分词符口径问题。")


CTC blank/pad id = 3
[B0] steps=270, unique_ids=10, blank_ratio=0.0%
      top5 ids(freq): [(163, 111), (492, 82), (554, 38), (579, 30), (434, 3)]
[B1] steps=270, unique_ids=7, blank_ratio=0.0%
      top5 ids(freq): [(477, 181), (579, 44), (163, 34), (434, 7), (554, 2)]


In [24]:

# 可选：pyctcdecode beam 搜索
try:
    from pyctcdecode import build_ctcdecoder
    id2tok = {i:s for s,i in processor.tokenizer.get_vocab().items()}
    alphabet = []
    blank_id = processor.tokenizer.pad_token_id
    for i in range(len(id2tok)):
        if i == blank_id: alphabet.append("")
        elif id2tok[i] == processor.tokenizer.word_delimiter_token: alphabet.append(" ")
        else: alphabet.append(id2tok[i])
    decoder = build_ctcdecoder(alphabet)

    model.eval()
    sm = encoded["validation"].select(range(1))
    fe = [{"input_values": iv} for iv in sm["input_values"]]
    ins = processor.pad(fe, padding=True, return_tensors="pt").to(model.device)
    import numpy as np, torch
    with torch.no_grad():
        logp = model(**ins).logits[0].cpu().numpy()
    beam_text = decoder.decode(logp, beam_width=50)
    greedy_ids = logp.argmax(-1)
    greedy_text = processor.decode(greedy_ids, skip_special_tokens=True).replace(processor.tokenizer.word_delimiter_token, " ")
    print("[beam ]", beam_text[:120])
    print("[greedy]", greedy_text[:120])
except Exception as e:
    print("pyctcdecode beam 对照跳过：", repr(e))




pyctcdecode beam 对照跳过： ValueError('Input logits shape is (270, 805), but vocabulary is size 807. Need logits of shape: (time, vocabulary)')


In [25]:

def partially_freeze_wav2vec2_ctc(model, train_last_n_layers=12, freeze_feature_encoder=False):
    backbone = getattr(model, "wav2vec2", None) or getattr(model, "hubert", None) or getattr(model, "wavlm", None)
    assert backbone is not None
    enc = backbone.encoder
    for p in enc.parameters(): p.requires_grad = False
    if hasattr(enc, "layers"):
        L = len(enc.layers); n = min(train_last_n_layers, L)
        for i in range(L-n, L):
            for p in enc.layers[i].parameters(): p.requires_grad = True
    if hasattr(enc, "layer_norm"):
        for p in enc.layer_norm.parameters(): p.requires_grad = True
    for p in model.lm_head.parameters(): p.requires_grad = True
    if freeze_feature_encoder and hasattr(model, "freeze_feature_encoder"):
        model.freeze_feature_encoder()

def print_trainable_parameters(model, k=30):
    total=train=0; names=[]
    for n,p in model.named_parameters():
        total+=p.numel()
        if p.requires_grad:
            train+=p.numel()
            if len(names)<k: names.append(n)
    print(f"Trainable params: {train:,}/{total:,} ({100*train/total:.2f}%)")
    print("Some trainables:", *['  - '+x for x in names], sep='\n')

partially_freeze_wav2vec2_ctc(model, train_last_n_layers=12, freeze_feature_encoder=False)
print_trainable_parameters(model)
assert any(p.requires_grad for n,p in model.named_parameters() if 'lm_head' in n), "lm_head 未解冻"


Trainable params: 156,719,397/316,263,845 (49.55%)
Some trainables:
  - wav2vec2.masked_spec_embed
  - wav2vec2.feature_extractor.conv_layers.0.conv.weight
  - wav2vec2.feature_extractor.conv_layers.0.conv.bias
  - wav2vec2.feature_extractor.conv_layers.0.layer_norm.weight
  - wav2vec2.feature_extractor.conv_layers.0.layer_norm.bias
  - wav2vec2.feature_extractor.conv_layers.1.conv.weight
  - wav2vec2.feature_extractor.conv_layers.1.conv.bias
  - wav2vec2.feature_extractor.conv_layers.1.layer_norm.weight
  - wav2vec2.feature_extractor.conv_layers.1.layer_norm.bias
  - wav2vec2.feature_extractor.conv_layers.2.conv.weight
  - wav2vec2.feature_extractor.conv_layers.2.conv.bias
  - wav2vec2.feature_extractor.conv_layers.2.layer_norm.weight
  - wav2vec2.feature_extractor.conv_layers.2.layer_norm.bias
  - wav2vec2.feature_extractor.conv_layers.3.conv.weight
  - wav2vec2.feature_extractor.conv_layers.3.conv.bias
  - wav2vec2.feature_extractor.conv_layers.3.layer_norm.weight
  - wav2vec2.featu

In [44]:

from transformers import TrainingArguments, Trainer, EarlyStoppingCallback

# 8.1 过拟合 64 条
small_idx = list(range(min(64, len(encoded["train"]))))
train_small = encoded["train"].select(small_idx)
eval_small  = encoded["validation"].select(range(min(64, len(encoded["validation"]))))

test_args = TrainingArguments(
    output_dir=str(DATA_ROOT/'outputs/xlsr_overfit64'),
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=1,
    learning_rate=1e-4,
    weight_decay=0.0,
    warmup_ratio=0.0,
    num_train_epochs=30,
    lr_scheduler_type="linear",
    gradient_checkpointing=False,
    bf16=False, fp16=False,
    group_by_length=False,
    eval_strategy="epoch",
    save_strategy="no",
    logging_steps=10,
    report_to='none',
    optim="adamw_torch",
    max_grad_norm=1.0,
)

trainer_small = Trainer(
    model=model,
    args=test_args,
    train_dataset=train_small,
    eval_dataset=eval_small,
    processing_class=processor,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    # callbacks=[EarlyStoppingCallback(early_stopping_patience=5, early_stopping_threshold=1e-3)]
)

# 如需运行，解除下一行注释：
trainer_small.train()






Epoch,Training Loss,Validation Loss,Wer,Cer,Mer
1,184.5201,138.816681,1.0,1.0,1.0
2,142.3595,125.947594,1.0,1.0,1.0
3,177.6275,117.27063,1.0,1.0,1.0
4,139.4254,110.300308,1.0,1.0,1.0
5,119.4282,104.063568,1.0,1.0,1.0
6,109.2868,99.01091,1.0,1.0,1.0
7,93.1055,93.688812,1.0,1.0,1.0
8,85.8053,89.580971,1.0,1.0,1.0
9,92.8683,86.211555,1.0,1.0,1.0
10,80.9916,85.417938,1.041667,1.006061,1.007317


[metrics-debug] preds=64, empty_after_norm=64 (100.0%)
[metrics-debug] preds=64, empty_after_norm=64 (100.0%)
[metrics-debug] preds=64, empty_after_norm=64 (100.0%)
[metrics-debug] preds=64, empty_after_norm=64 (100.0%)
[metrics-debug] preds=64, empty_after_norm=64 (100.0%)
[metrics-debug] preds=64, empty_after_norm=64 (100.0%)
[metrics-debug] preds=64, empty_after_norm=64 (100.0%)
[metrics-debug] preds=64, empty_after_norm=64 (100.0%)
[metrics-debug] preds=64, empty_after_norm=64 (100.0%)
[metrics-debug] preds=64, empty_after_norm=44 (68.8%)
[metrics-debug] preds=64, empty_after_norm=33 (51.6%)
[metrics-debug] preds=64, empty_after_norm=31 (48.4%)
[metrics-debug] preds=64, empty_after_norm=45 (70.3%)
[metrics-debug] preds=64, empty_after_norm=34 (53.1%)
[metrics-debug] preds=64, empty_after_norm=30 (46.9%)
[metrics-debug] preds=64, empty_after_norm=34 (53.1%)
[metrics-debug] preds=64, empty_after_norm=39 (60.9%)
[metrics-debug] preds=64, empty_after_norm=28 (43.8%)
[metrics-debug] pre

TrainOutput(global_step=480, training_loss=85.34016958872478, metrics={'train_runtime': 281.5825, 'train_samples_per_second': 6.819, 'train_steps_per_second': 1.705, 'total_flos': 2.611699747699872e+17, 'train_loss': 85.34016958872478, 'epoch': 30.0})

In [45]:
metrics_small = trainer_small.evaluate()
metrics_small

[metrics-debug] preds=64, empty_after_norm=32 (50.0%)


{'eval_loss': 72.07487487792969,
 'eval_wer': 1.1964285714285714,
 'eval_cer': 1.0896969696969696,
 'eval_mer': 1.2682926829268293,
 'eval_runtime': 4.3646,
 'eval_samples_per_second': 14.663,
 'eval_steps_per_second': 3.666,
 'epoch': 30.0}

## 8) Train

In [27]:

# 8.2 正式训练（按需开启 bf16 / grad_ckpt / 8bit 优化器等）
args = TrainingArguments(
    output_dir=str(DATA_ROOT/'outputs/xlsr_full'),
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=2,
    learning_rate=1e-4,
    weight_decay=0.01,
    warmup_ratio=0.1,
    num_train_epochs=24.0,
    lr_scheduler_type="cosine",
    gradient_checkpointing=False,
    bf16=False, fp16=False,
    group_by_length=False,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    logging_steps=200,
    load_best_model_at_end=True,
    metric_for_best_model='mer',
    greater_is_better=False,
    report_to='none',
    optim="adamw_torch",
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=encoded['train'],
    eval_dataset=encoded['validation'],
    processing_class=processor,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=1e-3)]
)

print("正式训练：已配置完成。如需运行，请在本单元格末尾调用 trainer.train()")
# trainer.train()


正式训练：已配置完成。如需运行，请在本单元格末尾调用 trainer.train()


In [28]:
# #=======================================Round 1=======================================

# from transformers import TrainingArguments, Trainer, EarlyStoppingCallback

# args = TrainingArguments(
#     output_dir=str(DATA_ROOT/'outputs/xlsr'),
#     dataloader_num_workers=4,      # >0 才会有多进程加载
#     dataloader_pin_memory=True,    # GPU 传输更快
#     per_device_train_batch_size=8,
#     per_device_eval_batch_size=8,
#     gradient_accumulation_steps=2,      # 有效 batch = 16
#     learning_rate=1e-4,                 # 若抖动大，改 5e-5
#     weight_decay=0.01,
#     warmup_ratio=0.1,
#     num_train_epochs=24.0,
#     lr_scheduler_type="cosine",         # 显式写出，虽然默认就是 linear
#     gradient_checkpointing=True,
#     bf16=True,
#     # fp16=torch.cuda.is_available(),
#     group_by_length=False,
#     eval_strategy="epoch",
#     save_strategy="epoch",
#     save_total_limit=2,
#     logging_steps=1000,
#     load_best_model_at_end=True,
#     metric_for_best_model='mer',
#     greater_is_better=False,
#     report_to='none',
#     optim="adamw_bnb_8bit",
# )

# trainer = Trainer(
#     model=model,
#     args=args,
#     train_dataset=encoded['train'],
#     eval_dataset=encoded['validation'],
#     # tokenizer=processor,
#     processing_class=processor,
#     data_collator=data_collator,
#     compute_metrics=compute_metrics,
#     callbacks=[EarlyStoppingCallback(
#         early_stopping_patience=3,               # 连续3次评估无显著提升就停
#         early_stopping_threshold=1e-3            # CER 改善 < 0.001 视为不显著
#     )]
# )
# trainer.train()


In [29]:
# #=======================================Round 2=======================================

# best_ckpt = trainer.state.best_model_checkpoint
# print("best ckpt:", best_ckpt)

# from transformers import Wav2Vec2ForCTC
# model = Wav2Vec2ForCTC.from_pretrained(best_ckpt)   # 仅加载权重
# # 重新构建新的 TrainingArguments / Trainer（优化器会重新初始化）

# new_args = TrainingArguments(
#     output_dir=str(DATA_ROOT/'outputs_v2'),
#     per_device_train_batch_size=4,
#     per_device_eval_batch_size=4,
#     gradient_accumulation_steps=4,        # 有效 batch 仍为 16
#     learning_rate=5e-5,                   # ← 从 1e-4 降一档（后半程更稳）
#     weight_decay=0.01,
#     # 用 epoch 续训 3~5 个；或改成 max_steps 也可
#     num_train_epochs=5.0,
#     lr_scheduler_type="cosine",           # 线性 warmup→线性衰减
#     warmup_ratio=0.1,                     # ← 用比例更方便（约 10% warmup）
#     gradient_checkpointing=True,
#     bf16=True,                            # A100 建议用 BF16（比 FP16 稳）
#     dataloader_num_workers=4,
#     dataloader_pin_memory=True,
#     group_by_length=False,                 # 续训阶段可开启，提升吞吐/省 pad
#     eval_strategy="steps",          # ← 官方参数名是 eval_strategy
#     eval_steps=1000,
#     save_steps=1000,
#     save_total_limit=2,                   # 只保留“最佳+最近”两个 ckpt
#     logging_steps=50,
#     load_best_model_at_end=True,
#     metric_for_best_model='cer',
#     greater_is_better=False,
#     report_to='none',
# )

# new_trainer = Trainer(
#     model=model,
#     args=new_args,
#     train_dataset=encoded["train"],
#     eval_dataset=encoded["validation"],
#     processing_class=processor,
#     data_collator=data_collator,
#     compute_metrics=compute_metrics,
# )

# new_trainer.train()

In [30]:
# !cp -r cs_dialogue/outputs_v2/checkpoint-5000/ '/content/drive/MyDrive/Project2-PRS/models'


In [31]:
# #=======================================Round 3=======================================
# from transformers import Wav2Vec2ForCTC

# CKPT = "/content/drive/MyDrive/Project2-PRS/models/round2-checkpoint-5000"
# model = Wav2Vec2ForCTC.from_pretrained(CKPT)

# # === 3) 训练参数：cosine + warmup 10% + 早停；注意把 eval/save 策略写成官方键名 ===
# from transformers import EarlyStoppingCallback

# new_args_round3 = TrainingArguments(
#     output_dir=str(DATA_ROOT/'outputs_round3'),  # 新输出目录
#     per_device_train_batch_size=4,
#     per_device_eval_batch_size=4,
#     gradient_accumulation_steps=4,               # 有效 batch=16
#     learning_rate=5e-5,                          # 二轮基础上再稳一点
#     weight_decay=0.01,
#     num_train_epochs=12.0,                        # 先续 3~5 个 epoch
#     lr_scheduler_type="cosine",                  # 余弦退火（Trainer 原生）
#     warmup_ratio=0.10,                           # 10% warmup
#     gradient_checkpointing=True,
#     bf16=True,                                   # A100 建议 bf16
#     dataloader_num_workers=4,
#     dataloader_pin_memory=True,

#     eval_strategy="epoch",
#     # eval_steps=1000,
#     save_strategy="epoch",
#     # save_steps=1000,
#     save_total_limit=2,

#     logging_steps=50,
#     load_best_model_at_end=True,
#     metric_for_best_model="cer",
#     greater_is_better=False,
#     report_to="none",

#     #8-bit optimizer
#     optim="adamw_bnb_8bit",
# )

# new_trainer = Trainer(
#     model=model,
#     args=new_args_round3,
#     train_dataset=encoded["train"],
#     eval_dataset=encoded["validation"],
#     processing_class=processor,                  # 你之前已改用 processing_class ✅
#     data_collator=data_collator,
#     compute_metrics=compute_metrics,
#     callbacks=[EarlyStoppingCallback(
#         early_stopping_patience=3,               # 连续3次评估无显著提升就停
#         early_stopping_threshold=1e-3            # CER 改善 < 0.001 视为不显著
#     )]
# )



In [32]:
# new_trainer.train()

In [33]:
# #=======================================Round 3 (continue)=======================================
# # BEST = "/content/cs_dialogue/outputs_round3/checkpoint-14292"
# CKPT = "/content/drive/MyDrive/Project2-PRS/models/checkpoint-14292"


# from transformers import Wav2Vec2ForCTC, TrainingArguments, Trainer, EarlyStoppingCallback
# # 1) 仅加载模型权重（不会带上一次的优化器/调度器状态）
# model = Wav2Vec2ForCTC.from_pretrained(CKPT)

# # 2) 新一轮参数（示例：把 LR 小降；按 epoch 评估/保存，早停照旧）
# round3_coutinue = TrainingArguments(
#     output_dir=str(DATA_ROOT/'outputs_round3_cont'),
#     per_device_train_batch_size=4,
#     per_device_eval_batch_size=4,
#     gradient_accumulation_steps=4,
#     learning_rate=1e-5,                 # ← 比 5e-5 再稳一点
#     weight_decay=0.01,
#     num_train_epochs=12.0,               # 再跑 1–2 个 epoch，交给早停兜底
#     lr_scheduler_type="cosine",
#     warmup_ratio=0,
#     gradient_checkpointing=True,
#     bf16=True,
#     dataloader_num_workers=4,
#     dataloader_pin_memory=True,
#     eval_strategy="epoch",
#     save_strategy="epoch",
#     save_total_limit=2,
#     logging_steps=1000,
#     load_best_model_at_end=True,
#     metric_for_best_model="cer",
#     greater_is_better=False,
#     report_to="none",
#     optim="adamw_bnb_8bit",             # 继续用 8-bit 优化器（已安装 bnb）
# )

# trainer = Trainer(
#     model=model,
#     args=round3_coutinue,
#     train_dataset=encoded["train"],
#     eval_dataset=encoded["validation"],
#     processing_class=processor,
#     data_collator=data_collator,
#     compute_metrics=compute_metrics,
#     callbacks=[EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=1e-3)]
# )

# trainer.train()  # ← 不要传 resume_from_checkpoint


In [34]:

# from pyctcdecode import build_ctcdecoder

# # 1) 严格按 id 构造 labels，避免 get_vocab() 的无序/多余项
# vocab_size = processor.tokenizer.vocab_size
# labels = [processor.tokenizer.convert_ids_to_tokens(i) for i in range(vocab_size)]

# # 2) CTC 规则：pad 作为 blank；word_delimiter 作为空格
# pad_id = processor.tokenizer.pad_token_id
# if pad_id is not None:
#     labels[pad_id] = ""   # CTC blank

# # 有的 tokenizer 定义了 word_delimiter_token / 其 id；常见是 "|"
# if hasattr(processor.tokenizer, "word_delimiter_token_id") and \
#    processor.tokenizer.word_delimiter_token_id is not None:
#     labels[processor.tokenizer.word_delimiter_token_id] = " "

# # 3) 保险起见再断言一次，必须与 logits 维度一致
# assert len(labels) == model.lm_head.out_features, \
#     f"labels={len(labels)} vs lm_head={model.lm_head.out_features}"

# # 4) 构建 decoder（无 LM；有 KenLM 时加 kenlm_model_path=...）
# decoder = build_ctcdecoder(labels)

# print("vocab_size =", processor.tokenizer.vocab_size)
# print("lm_head_out =", model.lm_head.out_features)
# assert processor.tokenizer.vocab_size == model.lm_head.out_features, "vocab != lm_head"

# wer_metric = evaluate.load('wer')
# cer_metric = evaluate.load('cer')

# def _norm(xs):
#     out = []
#     for s in xs:
#         s = " ".join(s.strip().split())
#         out.append(s)
#     return out

# def compute_metrics_beam(pred):
#     # pred.predictions: [B, T, V] 的 logits (numpy)
#     logits = pred.predictions
#     # 逐条 beam 解码（示例 beam_width=50）
#     pred_str = [decoder.decode(logit, beam_width=50) for logit in logits]

#     # 还原 -100 并 decode 参考
#     label_ids = pred.label_ids.copy()
#     label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
#     label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

#     pred_str  = _norm(pred_str)
#     label_str = _norm(label_str)
#     return {
#         "wer": wer_metric.compute(predictions=pred_str, references=label_str),
#         "cer": cer_metric.compute(predictions=pred_str, references=label_str),
#     }


# # 取一个很小的 batch（比如 test 的前 2~4 条）
# small = encoded["test"].select(range(min(4, len(encoded["test"]))))

# batch = data_collator([{k: v for k, v in small[i].items()} for i in range(len(small))])
# with torch.no_grad():
#     lg = model(input_values=batch["input_values"].to(model.device),
#                attention_mask=batch["attention_mask"].to(model.device)).logits
# lg = lg.cpu().numpy()  # [B, T, V]

# # ① labels 长度 = logits 末维 = lm_head_out
# print("logits.shape =", lg.shape)
# assert lg.shape[-1] == len(labels) == model.lm_head.out_features, "decoder labels dim mismatch"

# # ② 解码一条，确保不包含分隔符“|”或 pad 符号名
# out = decoder.decode(lg[0], beam_width=50)
# print("beam sample:", out[:120])
# assert "|" not in out, "decoded text still contains '|' (word delimiter not mapped to space)"
# if processor.tokenizer.pad_token is not None:
#     assert processor.tokenizer.pad_token not in out, "decoded text contains PAD token"
# print("[OK] labels/logits 对齐 & 空格/blank 映射正常")


In [35]:
# BEST = "/content/cs_dialogue/outputs_round3/checkpoint-14292"  # 你JSON里给的
# from transformers import Wav2Vec2ForCTC, Trainer, TrainingArguments

# model = Wav2Vec2ForCTC.from_pretrained(BEST)
# args = TrainingArguments(output_dir=str(DATA_ROOT/'final_eval'), report_to='none')
# tester = Trainer(
#     model=model, args=args,
#     eval_dataset=encoded["test"],        # ← 用 test split
#     processing_class=processor,
#     data_collator=data_collator,
#     compute_metrics=compute_metrics,     # 仍计算 CER/WER（主：CER）
# )
# test_metrics = tester.evaluate(metric_key_prefix="test")
# print(test_metrics)  # 建议保存到文件
# # 保存“可部署包”
# tester.save_model(str(DATA_ROOT/'final_best'))
# processor.save_pretrained(str(DATA_ROOT/'final_best'))


In [36]:
# BEST = "/content/cs_dialogue/outputs_round3_cont/checkpoint-8337"  # 你JSON里给的
# from transformers import Wav2Vec2ForCTC, Trainer, TrainingArguments

# model = Wav2Vec2ForCTC.from_pretrained(BEST)
# args = TrainingArguments(output_dir=str(DATA_ROOT/'final_eval_8337'), report_to='none')
# tester = Trainer(
#     model=model, args=args,
#     eval_dataset=encoded["test"],        # ← 用 test split
#     processing_class=processor,
#     data_collator=data_collator,
#     compute_metrics=compute_metrics_beam,     # 仍计算 CER/WER（主：CER）
# )
# test_metrics = tester.evaluate(metric_key_prefix="test")
# print(test_metrics)  # 建议保存到文件
# # 保存“可部署包”
# tester.save_model(str(DATA_ROOT/'final_best_8337'))
# processor.save_pretrained(str(DATA_ROOT/'final_best_8337'))


In [37]:
# from google.colab import drive
# drive.mount('/content/drive')

# BEST = "/content/drive/MyDrive/Project2-PRS/models/checkpoint-8337"  # 你JSON里给的
# from transformers import Wav2Vec2ForCTC, Trainer, TrainingArguments

# model = Wav2Vec2ForCTC.from_pretrained(BEST)
# args = TrainingArguments(output_dir=str(DATA_ROOT/'final_eval_8337_beam'), report_to='none')
# tester = Trainer(
#     model=model, args=args,
#     eval_dataset=encoded["test"],        # ← 用 test split
#     processing_class=processor,
#     data_collator=data_collator,
#     compute_metrics=compute_metrics_beam,     # 仍计算 CER/WER（主：CER）
# )
# test_metrics = tester.evaluate(metric_key_prefix="test_beam")
# print(test_metrics)  # 建议保存到文件
# # 保存“可部署包”
# tester.save_model(str(DATA_ROOT/'final_best_8337_beam'))
# processor.save_pretrained(str(DATA_ROOT/'final_best_8337_beam'))


In [38]:
# #=======================================Round 5(LoRA MER)=======================================
# from transformers import Wav2Vec2ForCTC


# # 从 Round 2 的“最佳 checkpoint”继续（仍然只加载权重）
# CKPT = "/content/drive/MyDrive/Project2-PRS/models/checkpoint-8337"

# from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, EarlyStoppingCallback
# from peft import LoraConfig
# from transformers import BitsAndBytesConfig

# model = Wav2Vec2ForCTC.from_pretrained(
#     CKPT,
#     # quantization_config=BitsAndBytesConfig(load_in_8bit=True),
#     # device_map="auto",
# )

# # === 1) 配置 LoRA（对自注意力投影层加 LoRA）===
# # *注意*：Wav2Vec2 编码器的注意力层常见投影名为 q_proj/k_proj/v_proj/out_proj
# lora_config = LoraConfig(
#     r=8, lora_alpha=16, lora_dropout=0.05, bias="none",
#     target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
# )

# ADAPTER_NAME = "lora_wav2vec2_round4"
# ADAPTER_BEST_DIR = str((DATA_ROOT / "outputs_round4_lora_adapter_best").resolve())
# model.add_adapter(lora_config, adapter_name=ADAPTER_NAME)  # ← 官方推荐调用
# model.set_adapter(ADAPTER_NAME)                          # 训练时启用该 adapter

# # model.print_trainable_parameters()  # 调试：应显示仅少量参数可训练

# from transformers import TrainerCallback
# import os, math

# class BestAdapterSaver(TrainerCallback):
#     def __init__(self, adapter_name: str, out_dir: str, metric_name="eval_cer", greater_is_better=False):
#         self.adapter_name = adapter_name
#         self.out_dir = out_dir
#         self.metric_name = metric_name
#         self.sign = 1.0 if greater_is_better else -1.0
#         self.best_score = -math.inf  # 用带符号的分数做“越大越好”的统一比较
#         os.makedirs(out_dir, exist_ok=True)

#     def on_evaluate(self, args, state, control, metrics=None, **kwargs):
#         if metrics is None or self.metric_name not in metrics:
#             return
#         score = self.sign * metrics[self.metric_name]
#         if score > self.best_score:
#             self.best_score = score
#             # 只保存 LoRA 适配器（不保存全模型）
#             adapter_dir = os.path.join(self.out_dir, f"best_adapter_step-{state.global_step}")
#             # kwargs["model"].save_adapter(adapter_dir, adapter_name=self.adapter_name)
#             kwargs["model"].save_pretrained(adapter_dir)
#             # 也可额外写一个“latest-best”软链接/拷贝，方便复用：
#             latest_dir = os.path.join(self.out_dir, "best_adapter_latest")
#             try:
#                 if os.path.islink(latest_dir) or os.path.exists(latest_dir):
#                     import shutil
#                     shutil.rmtree(latest_dir)
#                 import shutil
#                 shutil.copytree(adapter_dir, latest_dir)
#             except Exception:
#                 pass
#         return


# lora_args = TrainingArguments(
#     output_dir=str(DATA_ROOT/'outputs_round4_lora'),
#     per_device_train_batch_size=4,
#     per_device_eval_batch_size=4,
#     gradient_accumulation_steps=4,
#     learning_rate=5e-4,               # LoRA 常用更大的起始 LR（5e-4 ~ 1e-3）
#     weight_decay=0.01,
#     num_train_epochs=7.0,
#     lr_scheduler_type="cosine",
#     warmup_ratio=0.10,
#     gradient_checkpointing=False,
#     bf16=True,
#     dataloader_num_workers=4,
#     dataloader_pin_memory=True,

#     # ★ 改为按 epoch
#     eval_strategy="epoch",
#     save_strategy="no",
#     save_total_limit=2,

#     logging_steps=200,
#     load_best_model_at_end=False,
#     metric_for_best_model="mer",
#     # metric_for_best_model="cer",
#     greater_is_better=False,
#     report_to="none",
#     remove_unused_columns=False,
# )


# lora_trainer = Trainer(
#     model=model,                                # LoRA 包裹后的 PeftModel
#     args=lora_args,
#     train_dataset=encoded["train"],
#     eval_dataset=encoded["validation"],
#     processing_class=processor,
#     data_collator=data_collator,
#     compute_metrics=compute_metrics,
#     callbacks=[EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=1e-3),
#                BestAdapterSaver(adapter_name=ADAPTER_NAME,
#                                 out_dir=ADAPTER_BEST_DIR,
#                                 # metric_name="eval_cer",
#                                 metric_name="eval_mer",
#                                 greater_is_better=False),
#                ]
# )

# # 与 Round 2 一样：这是“从权重开始新一轮”，不要恢复旧 optimizer/scheduler
# lora_trainer.train()

# ADAPTER_OUT = str(DATA_ROOT/'outputs_round4_lora_adapter')
# # model.save_adapter(ADAPTER_OUT, adapter_name=ADAPTER_NAME)
# model.save_pretrained(ADAPTER_OUT)  # ← 同上，用 save_pretrained

# # 之后可在任意同架构模型上：
# # model.load_adapter(ADAPTER_OUT, adapter_name="lora_round3")
# # model.set_adapter("lora_round3")


In [39]:
# #=======================================Round 4(LoRA)=======================================
# from transformers import Wav2Vec2ForCTC


# # 从 Round 2 的“最佳 checkpoint”继续（仍然只加载权重）
# CKPT = "/content/drive/MyDrive/Project2-PRS/models/checkpoint-8337"

# from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, EarlyStoppingCallback
# from peft import LoraConfig
# from transformers import BitsAndBytesConfig

# model = Wav2Vec2ForCTC.from_pretrained(
#     CKPT,
#     # quantization_config=BitsAndBytesConfig(load_in_8bit=True),
#     # device_map="auto",
# )

# # === 1) 配置 LoRA（对自注意力投影层加 LoRA）===
# # *注意*：Wav2Vec2 编码器的注意力层常见投影名为 q_proj/k_proj/v_proj/out_proj
# lora_config = LoraConfig(
#     r=8, lora_alpha=16, lora_dropout=0.05, bias="none",
#     target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
# )

# ADAPTER_NAME = "lora_wav2vec2_round4"
# ADAPTER_BEST_DIR = str((DATA_ROOT / "outputs_round4_lora_adapter_best").resolve())
# model.add_adapter(lora_config, adapter_name=ADAPTER_NAME)  # ← 官方推荐调用
# model.set_adapter(ADAPTER_NAME)                          # 训练时启用该 adapter

# # model.print_trainable_parameters()  # 调试：应显示仅少量参数可训练

# from transformers import TrainerCallback
# import os, math

# class BestAdapterSaver(TrainerCallback):
#     def __init__(self, adapter_name: str, out_dir: str, metric_name="eval_cer", greater_is_better=False):
#         self.adapter_name = adapter_name
#         self.out_dir = out_dir
#         self.metric_name = metric_name
#         self.sign = 1.0 if greater_is_better else -1.0
#         self.best_score = -math.inf  # 用带符号的分数做“越大越好”的统一比较
#         os.makedirs(out_dir, exist_ok=True)

#     def on_evaluate(self, args, state, control, metrics=None, **kwargs):
#         if metrics is None or self.metric_name not in metrics:
#             return
#         score = self.sign * metrics[self.metric_name]
#         if score > self.best_score:
#             self.best_score = score
#             # 只保存 LoRA 适配器（不保存全模型）
#             adapter_dir = os.path.join(self.out_dir, f"best_adapter_step-{state.global_step}")
#             # kwargs["model"].save_adapter(adapter_dir, adapter_name=self.adapter_name)
#             kwargs["model"].save_pretrained(adapter_dir)
#             # 也可额外写一个“latest-best”软链接/拷贝，方便复用：
#             latest_dir = os.path.join(self.out_dir, "best_adapter_latest")
#             try:
#                 if os.path.islink(latest_dir) or os.path.exists(latest_dir):
#                     import shutil
#                     shutil.rmtree(latest_dir)
#                 import shutil
#                 shutil.copytree(adapter_dir, latest_dir)
#             except Exception:
#                 pass
#         return


# lora_args = TrainingArguments(
#     output_dir=str(DATA_ROOT/'outputs_round4_lora'),
#     per_device_train_batch_size=4,
#     per_device_eval_batch_size=4,
#     gradient_accumulation_steps=4,
#     learning_rate=5e-4,               # LoRA 常用更大的起始 LR（5e-4 ~ 1e-3）
#     weight_decay=0.01,
#     num_train_epochs=7.0,
#     lr_scheduler_type="cosine",
#     warmup_ratio=0.10,
#     gradient_checkpointing=False,
#     bf16=True,
#     dataloader_num_workers=4,
#     dataloader_pin_memory=True,

#     # ★ 改为按 epoch
#     eval_strategy="epoch",
#     save_strategy="no",
#     save_total_limit=2,

#     logging_steps=200,
#     load_best_model_at_end=False,
#     metric_for_best_model="cer",
#     greater_is_better=False,
#     report_to="none",
#     remove_unused_columns=False,
# )


# lora_trainer = Trainer(
#     model=model,                                # LoRA 包裹后的 PeftModel
#     args=lora_args,
#     train_dataset=encoded["train"],
#     eval_dataset=encoded["validation"],
#     processing_class=processor,
#     data_collator=data_collator,
#     compute_metrics=compute_metrics,
#     callbacks=[EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=1e-3),
#                BestAdapterSaver(adapter_name=ADAPTER_NAME,
#                                 out_dir=ADAPTER_BEST_DIR,
#                                 metric_name="eval_cer",
#                                 greater_is_better=False),
#                ]
# )

# # 与 Round 2 一样：这是“从权重开始新一轮”，不要恢复旧 optimizer/scheduler
# lora_trainer.train()

# ADAPTER_OUT = str(DATA_ROOT/'outputs_round4_lora_adapter')
# # model.save_adapter(ADAPTER_OUT, adapter_name=ADAPTER_NAME)
# model.save_pretrained(ADAPTER_OUT)  # ← 同上，用 save_pretrained

# # 之后可在任意同架构模型上：
# # model.load_adapter(ADAPTER_OUT, adapter_name="lora_round3")
# # model.set_adapter("lora_round3")


In [40]:
# BASE = "/content/drive/MyDrive/Project2-PRS/models/checkpoint-8337"
# BEST_ADAPTER_DIR = str((DATA_ROOT / "outputs_round4_lora_adapter_best" / "best_adapter_latest").resolve())

# model = Wav2Vec2ForCTC.from_pretrained(BASE)
# model.load_adapter(BEST_ADAPTER_DIR, adapter_name="lora_best")
# model.set_adapter("lora_best")
# # 然后用 Trainer 做 evaluate / 继续训练即可


In [41]:
# from transformers import Wav2Vec2ForCTC
# from google.colab import drive
# drive.mount('/content/drive')

# #载入你当前最佳权重（全量权重）
# CKPT = "/content/drive/MyDrive/Project2-PRS/models/checkpoint-8337"
# model = Wav2Vec2ForCTC.from_pretrained(CKPT)

# # 给自注意力投影层加 LoRA（按你的 Wav2Vec2 命名可能是 q_proj/k_proj/v_proj/out_proj）
# lora_cfg = LoraConfig(
#     r=8, lora_alpha=16, lora_dropout=0.05, bias="none",
#     target_modules=["q_proj","k_proj","v_proj","out_proj"]
# )
# ADAPTER_NAME = "lora_round4"
# model.add_adapter(lora_cfg, adapter_name=ADAPTER_NAME)   # 官方推荐 API
# model.set_adapter(ADAPTER_NAME)

# lora_args = TrainingArguments(
#     output_dir=str(DATA_ROOT/'outputs_round4_lora'),
#     per_device_train_batch_size=4,
#     per_device_eval_batch_size=4,
#     gradient_accumulation_steps=4,
#     learning_rate=5e-4,                # LoRA 可用更大 LR
#     weight_decay=0.01,
#     num_train_epochs=7.0,
#     lr_scheduler_type="cosine",
#     warmup_ratio=0.10,
#     bf16=True,
#     gradient_checkpointing=True,
#     dataloader_num_workers=4,
#     dataloader_pin_memory=True,
#     evaluation_strategy="epoch",
#     save_strategy="epoch",
#     save_total_limit=2,
#     load_best_model_at_end=True,
#     metric_for_best_model="cer",
#     greater_is_better=False,
#     logging_steps=50,
#     report_to="none",
# )

# lora_trainer = Trainer(
#     model=model,
#     args=lora_args,
#     train_dataset=encoded["train"], eval_dataset=encoded["validation"],
#     processing_class=processor, data_collator=data_collator,
#     compute_metrics=compute_metrics,   # 建议直接用上面“beam版”评估
# )
# lora_trainer.train()

# # 仅保存适配器，便于复用/分享
# model.save_adapter(str(DATA_ROOT/'outputs_round4_lora_adapter'), adapter_name=ADAPTER_NAME)


In [42]:
# # 1) 最终在 test 上做指标评估（会返回 metrics）
# test_metrics = trainer.evaluate(eval_dataset=encoded["test"])
# print("TEST (evaluate):", test_metrics)