# XLS-R (Wav2Vec2) CTC Fine-tuning on BAAI/CS-Dialogue (Mandarin–English Code-Switching)

**Goal:** Fine-tune `facebook/wav2vec2-xls-r-300m` with a custom CTC vocabulary built from CS-Dialogue.

**Key points**
- Uses `datasets.Audio` + `cast_column(..., Audio(16000))`.
- Builds mixed EN (A–Z) + Chinese vocab; maps space→`|`.
- Correct padding & metrics (WER/CER).

References:
- CS-Dialogue dataset card (16kHz, structure): https://huggingface.co/datasets/BAAI/CS-Dialogue
- XLS-R-300M model card (16kHz input): https://huggingface.co/facebook/wav2vec2-xls-r-300m
- Datasets audio processing: https://huggingface.co/docs/datasets/en/audio_process
- WER / CER metrics: https://huggingface.co/spaces/evaluate-metric/wer , https://huggingface.co/spaces/evaluate-metric/cer


In [None]:
# installations
!pip -q install   "evaluate==0.4.3" "jiwer==3.0.4"
!pip install "torchcodec==0.7.*" --index-url https://download.pytorch.org/whl/cu126
!pip install -U "peft==0.17.1" "accelerate>=1.1.0"
!pip install -U "bitsandbytes==0.48.1"
!pip install -U "pyctcdecode==0.5.0"

#===============================================================================
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:512"
# os.environ["HF_HUB_DISABLE_XET"] = "1"         # 强制不用 XET 后端
# os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"  # 使用更快更稳的 hf_transfer
from importlib.metadata import version, PackageNotFoundError
import transformers, datasets, evaluate, tokenizers, huggingface_hub, torch, torchaudio, jiwer
import peft
from peft import LoraConfig

def safe_ver(pkg_name):
    try:
        return version(pkg_name)
    except PackageNotFoundError:
        return "(missing)"

# 改用 metadata 查询
print("Transformers  :", safe_ver("transformers"))
print("Datasets      :", safe_ver("datasets"))
print("Evaluate      :", safe_ver("evaluate"))
print("JiWER         :", safe_ver("jiwer"))
print("Tokenizers    :", safe_ver("tokenizers"))
print("HF Hub        :", safe_ver("huggingface_hub"))
print("Torch         :", torch.__version__)
print("Torchaudio    :", torchaudio.__version__)
print("PEFT          :", safe_ver("peft"))
print("Accelerate    :", safe_ver("accelerate"))
print("Bitsandbytes  :", safe_ver("bitsandbytes"))
print("PyCTCDecode   :", safe_ver("pyctcdecode"))
print("PYTORCH_CUDA_ALLOC_CONF =", os.environ.get("PYTORCH_CUDA_ALLOC_CONF"))
# 简单构建一个 LoRA 配置，验证 API 可用
test_cfg = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.05, bias="none",
                      target_modules=["q_proj","k_proj","v_proj","out_proj"])
print(test_cfg)


In [2]:
!python -m bitsandbytes


Platform: Linux-6.6.105+-x86_64-with-glibc2.35
  libc: glibc-2.35
Python: 3.12.12
PyTorch: 2.8.0+cu126
  CUDA: 12.6
  HIP: N/A
  XPU: N/A
Related packages:
  accelerate: 1.11.0
  diffusers: 0.35.2
  numpy: 1.26.4
  pip: 24.1.2
  peft: 0.17.1
  safetensors: 0.6.2
  transformers: 4.57.1
  triton: 3.4.0
  trl: not found
PyTorch settings found: CUDA_VERSION=126, Highest Compute Capability: (8, 0).
Checking that the library is importable and CUDA is callable...
SUCCESS!


In [3]:
import os, re, tarfile, json
from pathlib import Path
from collections import Counter
from typing import Dict, List
import numpy as np, torch
from datasets import Dataset, DatasetDict, Audio, Features, Value
from huggingface_hub import hf_hub_download
from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorSpeechSeq2SeqWithPadding,
    TrainingArguments,
    Trainer,
)
import evaluate

SEED=42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)

BASE_DIR = Path('/content') if Path('/content').exists() else Path.cwd()
DATA_ROOT = BASE_DIR / 'cs_dialogue'
AUDIO_DIR = DATA_ROOT / 'short_wav'
INDEX_DIR = DATA_ROOT / 'data' / 'index' / 'short_wav'
VOCAB_DIR = DATA_ROOT / 'custom_vocab'
for p in [DATA_ROOT, AUDIO_DIR, INDEX_DIR, VOCAB_DIR]: p.mkdir(parents=True, exist_ok=True)


NUM_SHARDS = int(os.environ.get('CS_NUM_SHARDS', 19))  # increase to 19 for full short_wav


## 1) Download index & audio shards (short_wav)

In [4]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [6]:
from pathlib import Path
import os, shutil, tarfile, glob

REPO_ID = "BAAI/CS-Dialogue"        # 数据集仓库
# 你已有：DATA_ROOT = Path('/content/cs_dialogue'); AUDIO_DIR = DATA_ROOT/'short_wav'; INDEX_DIR = DATA_ROOT/'index'/'short_wav'

def download_index():
    files = [
        "data/index/short_wav/train/text",
        "data/index/short_wav/train/wav.scp",
        "data/index/short_wav/dev/text",
        "data/index/short_wav/dev/wav.scp",
        "data/index/short_wav/test/text",
        "data/index/short_wav/test/wav.scp",
    ]
    local = []
    for fp in files:
        # ★ 关键：直接落到 INDEX_DIR/ 相对目录；禁用 symlink
        dst_dir = INDEX_DIR / Path(fp).parent.relative_to("data/index/short_wav")
        dst_dir.mkdir(parents=True, exist_ok=True)
        src = hf_hub_download(
            repo_id=REPO_ID,
            filename=fp,
            repo_type="dataset",
            local_dir=str(DATA_ROOT),   # 关键：交给 local_dir 来还原相对层级
        )
        local.append(Path(src))
    print("Index ready:", local)
    return local


def download_shards(n=19):
    got = []
    for i in range(n):
        rel = f"data/short_wav/short_wav.tar.gz{i:02d}"
        src = hf_hub_download(
            repo_id=REPO_ID,
            filename=rel,
            repo_type="dataset",
            local_dir=str(DATA_ROOT),   # 落到 DATA_ROOT/rel
        )
        dst = Path(src)
        print("Downloaded:", dst)
        assert dst.exists() and dst.stat().st_size > 0, f"missing or empty part: {dst}"
        got.append(dst)
    # 关键：分片实际在 DATA_ROOT/data/short_wav 下
    print("In data/short_wav:", sorted(p.name for p in (DATA_ROOT/'data'/'short_wav').glob('short_wav.tar.gz*')))
    return got





def concat_parts(parts, out_file: Path):
    """把 *.tar.gz00.. 拼接成一个完整 tar.gz"""
    parts = sorted(parts, key=lambda p: p.name)
    if not parts:
        raise RuntimeError("no parts provided to concat_parts")
    # 逐个断言存在
    for p in parts:
        if not Path(p).exists():
            raise FileNotFoundError(f"part not found on disk: {p}")
    out_file.parent.mkdir(parents=True, exist_ok=True)
    total = 0
    with open(out_file, "wb") as w:
        for p in parts:
            with open(p, "rb") as r:
                shutil.copyfileobj(r, w)
            sz = Path(p).stat().st_size
            total += sz
            print(f"  appended {Path(p).name} ({sz/1e6:.1f} MB)")
    print(f"==> concatenated -> {out_file} (~{total/1e9:.2f} GB)")

def extract_concatenated_tar_gz(out_dir: Path, parts=None):
    """
    解压流程：
    1) 如未给 parts，则自动从 DATA_ROOT 扫描 short_wav.tar.gz[0-9][0-9]
    2) 拼成 short_wav.tar.gz
    3) 正常 tar 解压到 out_dir
    """
    out_dir.mkdir(parents=True, exist_ok=True)

    # ★ 自动发现分片（避免用了旧的 tars 变量）
    if parts is None:
        # 同时在 DATA_ROOT 根 & DATA_ROOT/data/short_wav 下搜索
        candidates = sorted((DATA_ROOT).glob("short_wav.tar.gz[0-9][0-9]")) \
               + sorted((DATA_ROOT / "data" / "short_wav").glob("short_wav.tar.gz[0-9][0-9]"))
    else:
        candidates = list(map(Path, parts))

    if not candidates:
        raise RuntimeError(f"No split parts found in {DATA_ROOT}. Expected files like short_wav.tar.gz00")

    merged = DATA_ROOT / "short_wav.tar.gz"

    if merged.exists():
      merged.unlink()  # 清除此前拼接失败的半成品
    concat_parts(candidates, merged)

    # 解压
    with tarfile.open(merged, "r:gz") as tf:
        tf.extractall(out_dir, filter="data")  # 3.12+ 推荐
    print("Extracted ->", out_dir)

    # 可选：节省空间
    try:
        merged.unlink()
    except Exception:
        pass

# === 调用顺序 ===
_ = download_index()
_ = download_shards(NUM_SHARDS)                 # e.g. NUM_SHARDS=2

# ★ 完整性检查：必须 19 片全部到位
expected = [f"short_wav.tar.gz{i:02d}" for i in range(19)]
have = sorted(p.name for p in (DATA_ROOT/'data'/'short_wav').glob('short_wav.tar.gz*'))
missing = [x for x in expected if x not in have]
assert not missing, f"缺少分片：{missing}。请把 NUM_SHARDS 设为 19 或补齐后再解压。"

extract_concatenated_tar_gz(AUDIO_DIR, None)   # 让它自己扫描分片再拼接


Index ready: [PosixPath('/content/cs_dialogue/data/index/short_wav/train/text'), PosixPath('/content/cs_dialogue/data/index/short_wav/train/wav.scp'), PosixPath('/content/cs_dialogue/data/index/short_wav/dev/text'), PosixPath('/content/cs_dialogue/data/index/short_wav/dev/wav.scp'), PosixPath('/content/cs_dialogue/data/index/short_wav/test/text'), PosixPath('/content/cs_dialogue/data/index/short_wav/test/wav.scp')]


data/short_wav/short_wav.tar.gz00:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz00


data/short_wav/short_wav.tar.gz01:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz01


data/short_wav/short_wav.tar.gz02:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz02


data/short_wav/short_wav.tar.gz03:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz03


data/short_wav/short_wav.tar.gz04:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz04


data/short_wav/short_wav.tar.gz05:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz05


data/short_wav/short_wav.tar.gz06:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz06


data/short_wav/short_wav.tar.gz07:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz07


data/short_wav/short_wav.tar.gz08:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz08


data/short_wav/short_wav.tar.gz09:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz09


data/short_wav/short_wav.tar.gz10:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz10


data/short_wav/short_wav.tar.gz11:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz11


data/short_wav/short_wav.tar.gz12:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz12


data/short_wav/short_wav.tar.gz13:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz13


data/short_wav/short_wav.tar.gz14:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz14


data/short_wav/short_wav.tar.gz15:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz15


data/short_wav/short_wav.tar.gz16:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz16


data/short_wav/short_wav.tar.gz17:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz17


data/short_wav/short_wav.tar.gz18:   0%|          | 0.00/158M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz18
In data/short_wav: ['short_wav.tar.gz00', 'short_wav.tar.gz01', 'short_wav.tar.gz02', 'short_wav.tar.gz03', 'short_wav.tar.gz04', 'short_wav.tar.gz05', 'short_wav.tar.gz06', 'short_wav.tar.gz07', 'short_wav.tar.gz08', 'short_wav.tar.gz09', 'short_wav.tar.gz10', 'short_wav.tar.gz11', 'short_wav.tar.gz12', 'short_wav.tar.gz13', 'short_wav.tar.gz14', 'short_wav.tar.gz15', 'short_wav.tar.gz16', 'short_wav.tar.gz17', 'short_wav.tar.gz18']
  appended short_wav.tar.gz00 (524.3 MB)
  appended short_wav.tar.gz01 (524.3 MB)
  appended short_wav.tar.gz02 (524.3 MB)
  appended short_wav.tar.gz03 (524.3 MB)
  appended short_wav.tar.gz04 (524.3 MB)
  appended short_wav.tar.gz05 (524.3 MB)
  appended short_wav.tar.gz06 (524.3 MB)
  appended short_wav.tar.gz07 (524.3 MB)
  appended short_wav.tar.gz08 (524.3 MB)
  appended short_wav.tar.gz09 (524.3 MB)
  appended short_wav.tar.gz10 (524.3 MB)
  appended short_wav.tar.gz11 (524.3 MB)
  a

In [7]:
print("INDEX_DIR =", INDEX_DIR)
print("Has train/text?   ->", (INDEX_DIR/'train'/'text').exists())
print("Has train/wav.scp?->", (INDEX_DIR/'train'/'wav.scp').exists())

print("Split parts in DATA_ROOT/data/short_wav:",
      sorted(p.name for p in (DATA_ROOT/'data'/'short_wav').glob('short_wav.tar.gz*'))[:19])


INDEX_DIR = /content/cs_dialogue/data/index/short_wav
Has train/text?   -> True
Has train/wav.scp?-> True
Split parts in DATA_ROOT/data/short_wav: ['short_wav.tar.gz00', 'short_wav.tar.gz01', 'short_wav.tar.gz02', 'short_wav.tar.gz03', 'short_wav.tar.gz04', 'short_wav.tar.gz05', 'short_wav.tar.gz06', 'short_wav.tar.gz07', 'short_wav.tar.gz08', 'short_wav.tar.gz09', 'short_wav.tar.gz10', 'short_wav.tar.gz11', 'short_wav.tar.gz12', 'short_wav.tar.gz13', 'short_wav.tar.gz14', 'short_wav.tar.gz15', 'short_wav.tar.gz16', 'short_wav.tar.gz17', 'short_wav.tar.gz18']


In [9]:
from pathlib import Path

def print_tree(root: Path, max_depth: int = 3):
    root = Path(root)

    def walk(p: Path, depth: int, prefix: str = ""):
        if depth < 0:
            return
        items = sorted(p.iterdir(), key=lambda x: (x.is_file(), x.name.lower()))
        for i, it in enumerate(items):
            is_last = (i == len(items) - 1)
            connector = "└── " if is_last else "├── "
            name = it.name + ("/" if it.is_dir() else "")
            print(prefix + connector + name)
            if it.is_dir():
                next_prefix = prefix + ("    " if is_last else "│   ")
                walk(it, depth - 1, next_prefix)

    root = root.resolve()
    print(root.as_posix())
    walk(root, max_depth)

# 使用：打印 DATA_ROOT 下三层
print_tree(DATA_ROOT, max_depth=4)


/content/cs_dialogue
├── .cache/
│   └── huggingface/
│       ├── download/
│       │   └── data/
│       │       ├── index/
│       │       └── short_wav/
│       └── .gitignore
├── custom_vocab/
├── data/
│   ├── index/
│   │   └── short_wav/
│   │       ├── dev/
│   │       │   ├── text
│   │       │   └── wav.scp
│   │       ├── test/
│   │       │   ├── text
│   │       │   └── wav.scp
│   │       └── train/
│   │           ├── text
│   │           └── wav.scp
│   └── short_wav/
│       ├── short_wav.tar.gz00
│       ├── short_wav.tar.gz01
│       ├── short_wav.tar.gz02
│       ├── short_wav.tar.gz03
│       ├── short_wav.tar.gz04
│       ├── short_wav.tar.gz05
│       ├── short_wav.tar.gz06
│       ├── short_wav.tar.gz07
│       ├── short_wav.tar.gz08
│       ├── short_wav.tar.gz09
│       ├── short_wav.tar.gz10
│       ├── short_wav.tar.gz11
│       ├── short_wav.tar.gz12
│       ├── short_wav.tar.gz13
│       ├── short_wav.tar.gz14
│       ├── short_wav.tar.gz15
│       ├── sho

## 2) Build DatasetDict from wav.scp & text

In [10]:
# def read_kv(fp: Path):
#     d={}
#     with open(fp, 'r', encoding='utf-8') as f:
#         for line in f:
#             line=line.strip()
#             if not line: continue
#             k,v=line.split(' ',1)
#             d[k]=v
#     return d

# def make_split(split: str):
#     wavscp = read_kv(INDEX_DIR/split/'wav.scp')
#     text   = read_kv(INDEX_DIR/split/'text')
#     ids, paths, trans = [], [], []
#     for uid, wavpath in wavscp.items():

#         shard = Path(wavpath).parts[-2]
#         fname = Path(wavpath).name
#         local = AUDIO_DIR / shard / fname
#         if local.exists() and uid in text:
#             ids.append(uid); paths.append(str(local)); trans.append(text[uid])
#     feats = Features({'id': Value('string'), 'audio': Audio(sampling_rate=16000), 'transcription': Value('string')})
#     return Dataset.from_dict({'id': ids, 'audio': paths, 'transcription': trans}, features=feats)
from pathlib import Path

def resolve_local_audio_path(wavpath: str) -> Path | None:
    """
    把 wav.scp 的路径映射到本地真实存在的文件。
    兼容如下情况：
    - 路径前缀带不带 'data/'
    - 是否出现双重 'short_wav/short_wav'
    - 绝对路径/相对路径
    """
    p = Path(wavpath.strip())

    # 1) 若本身就是绝对路径且存在，直接返回
    if p.is_absolute() and p.exists():
        return p

    candidates: list[Path] = []

    # 2) wav.scp 一般包含 ".../short_wav/..." 这段；抽取从第一处 'short_wav' 之后的尾部
    parts = p.parts
    if "short_wav" in parts:
        i = parts.index("short_wav")
        tail = Path(*parts[i+1:])  # 去掉第一个 'short_wav' 及之前的前缀
        # 情况 A：你的磁盘上是 short_wav/short_wav/WAVE/...
        if (AUDIO_DIR / "short_wav").exists():
            candidates.append(AUDIO_DIR / "short_wav" / tail)  # /.../short_wav/short_wav/...
        # 情况 B：只有一层 short_wav/WAVE/...
        candidates.append(AUDIO_DIR / tail)                     # /.../short_wav/...
    else:
        # 没出现 short_wav 关键词时，尝试几种常见组合
        candidates.append(DATA_ROOT / p)            # /content/cs_dialogue/<wav.scp里的相对路径>
        candidates.append(DATA_ROOT / "data" / p)   # /content/cs_dialogue/data/<...>

    # 3) 再补充几种保守候选
    candidates.append(DATA_ROOT / p)
    if str(p).startswith("data/"):
        candidates.append(DATA_ROOT / str(p))                    # /content/cs_dialogue/data/short_wav/...
        candidates.append(DATA_ROOT / str(p).replace("data/", "", 1))  # /content/cs_dialogue/short_wav/...

    for c in candidates:
        if c.exists():
            return c.resolve()
    return None


def read_kv(fp: Path):
    d={}
    with open(fp, 'r', encoding='utf-8') as f:
        for line in f:
            line=line.strip()
            if not line:
                continue
            # Kaldi风格：<key><space><value...>
            k,v=line.split(' ',1)
            d[k]=v
    return d


def make_split(split: str):
    wavscp = read_kv(INDEX_DIR/split/'wav.scp')
    text   = read_kv(INDEX_DIR/split/'text')

    ids, paths, trans = [], [], []
    miss_audio, miss_text = 0, 0

    for uid, wavpath in wavscp.items():
        local = resolve_local_audio_path(wavpath)
        if local is None or not local.exists():
            miss_audio += 1
            continue
        if uid not in text:
            miss_text += 1
            continue

        ids.append(uid)
        paths.append(str(local))
        trans.append(text[uid])

    print(f"[{split}] matched {len(ids)} items "
          f"(missing audio: {miss_audio}, missing text: {miss_text})")

    feats = Features({
        'id': Value('string'),
        'audio': Audio(sampling_rate=16000),
        'transcription': Value('string')
    })
    return Dataset.from_dict(
        {'id': ids, 'audio': paths, 'transcription': trans},
        features=feats
    )


train_ds = make_split('train')
val_ds   = make_split('dev')
test_ds  = make_split('test')
minds = DatasetDict(train=train_ds, validation=val_ds, test=test_ds)
minds = minds.cast_column('audio', Audio(sampling_rate=16000))
minds


[train] matched 26239 items (missing audio: 0, missing text: 0)
[dev] matched 6186 items (missing audio: 0, missing text: 0)
[test] matched 6257 items (missing audio: 0, missing text: 0)


DatasetDict({
    train: Dataset({
        features: ['id', 'audio', 'transcription'],
        num_rows: 26239
    })
    validation: Dataset({
        features: ['id', 'audio', 'transcription'],
        num_rows: 6186
    })
    test: Dataset({
        features: ['id', 'audio', 'transcription'],
        num_rows: 6257
    })
})

In [11]:
MIN_SEC, MAX_SEC = 0.3, 6.0
def _keep_ok(ex):
    sec = ex["audio"]["array"].shape[0] / ex["audio"]["sampling_rate"]
    return (sec >= MIN_SEC) and (sec <= MAX_SEC)

minds = minds.filter(_keep_ok, num_proc=4)

Filter (num_proc=4):   0%|          | 0/26239 [00:00<?, ? examples/s]

Filter (num_proc=4):   0%|          | 0/6186 [00:00<?, ? examples/s]

Filter (num_proc=4):   0%|          | 0/6257 [00:00<?, ? examples/s]

## 3) Normalize transcripts (EN upper + Chinese)

In [12]:
CN_RANGE = r"\u4E00-\u9FFF"
def normalize_text(ex):
    t = ex['transcription'].strip().upper()
    t = re.sub(fr"[^{CN_RANGE}A-Z' ]+", "", t)
    t = re.sub(r"\s+", " ", t)
    return {'transcription': t}

minds = minds.map(normalize_text)
minds['train'][0]['transcription'][:120]


Map:   0%|          | 0/13001 [00:00<?, ? examples/s]

Map:   0%|          | 0/2952 [00:00<?, ? examples/s]

Map:   0%|          | 0/3313 [00:00<?, ? examples/s]

'嗨'

## 4) Build CTC vocab (space→`|`)

In [None]:
# from collections import Counter
# def collect_chars(ds, key='transcription'):
#     c=Counter()
#     for s in ds[key]: c.update(list(s))
#     return c
# cnt=Counter()
# for sp in ['train','validation','test']:
#     cnt.update(collect_chars(minds[sp]))
# chars=sorted([ch for ch in cnt if ch!=' '])
# vocab={ch:i for i,ch in enumerate(chars)}
# vocab['|']=len(vocab); vocab['<unk>']=len(vocab); vocab['<pad>']=len(vocab)
# VOCAB_DIR.mkdir(parents=True, exist_ok=True)
# with open(VOCAB_DIR/'vocab.json','w',encoding='utf-8') as f:
#     json.dump(vocab, f, ensure_ascii=False, indent=2)
# len(vocab)


## 5) Init tokenizer/processor & Whisper-Small model

In [15]:
from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    # DataCollatorSpeechSeq2SeqWithPadding,
    DataCollatorForSeq2Seq,
    TrainingArguments,
    Trainer,
)
CKPT = "openai/whisper-small"

processor = WhisperProcessor.from_pretrained(
    CKPT,
    language=None,
    task="transcribe",
)
processor.feature_extractor.return_attention_mask = True

model = WhisperForConditionalGeneration.from_pretrained(CKPT)
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
model.generation_config.forced_decoder_ids = None
model.generation_config.suppress_tokens = []
processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.generation_config.pad_token_id = processor.tokenizer.pad_token_id

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.config.use_cache = False
model.train()
device

preprocessor_config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

normalizer.json: 0.00B [00:00, ?B/s]

added_tokens.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/967M [00:00<?, ?B/s]

generation_config.json: 0.00B [00:00, ?B/s]

device(type='cuda')

## 6) Encode → input_values / attention_mask / labels

In [18]:
def prepare_batch(batch):
    # 1) 音频 -> log-mel
    audio = batch["audio"]
    out = processor(
        audio=audio["array"],
        sampling_rate=audio["sampling_rate"],
        text=batch["transcription"],   # 直接传文本
        return_attention_mask=True
    )
    # WhisperProcessor 会同时产出特征与labels
    batch["input_features"] = out["input_features"][0]
    batch["labels"] = out["labels"]            # 有些版本是 list[int]，有些是 tensor/list[list[int]]
    return batch


encoded = minds.map(
    prepare_batch,
    remove_columns=minds['train'].column_names,
    num_proc=4,
    desc='Preparing Whisper inputs',
)
encoded

Preparing Whisper inputs (num_proc=4):   0%|          | 0/13001 [00:00<?, ? examples/s]

Preparing Whisper inputs (num_proc=4):   0%|          | 0/2952 [00:00<?, ? examples/s]

Preparing Whisper inputs (num_proc=4):   0%|          | 0/3313 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 13001
    })
    validation: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 2952
    })
    test: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 3313
    })
})

## 7) Collator & Metrics (WER/CER) + Sanity check

In [25]:
from dataclasses import dataclass
from typing import Dict, List, Union
import torch
from transformers import DataCollatorForSeq2Seq

# 1) 专用 collator：对 input_features 用 feature_extractor.pad，对 labels 用 tokenizer.pad
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: any
    padding: Union[bool, str] = "longest"

    def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
        # 拆出声学输入与标签
        input_feats = [{"input_features": f["input_features"]} for f in features]
        label_feats = [{"input_ids": f["labels"]} for f in features]

        batch = {}
        # a) pad 声学特征
        batch_inputs = self.processor.feature_extractor.pad(
            input_feats, padding=self.padding, return_tensors="pt"
        )
        batch["input_features"] = batch_inputs["input_features"]

        # b) pad 文本标签（pad→-100 由 tokenizer.pad + Trainer 处理）
        labels_batch = self.processor.tokenizer.pad(
            label_feats, padding=self.padding, return_tensors="pt"
        )
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        batch["labels"] = labels
        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)



wer_metric = evaluate.load("wer")
cer_metric = evaluate.load("cer")

def compute_metrics(pred):
    pred_ids, label_ids = pred
    preds = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids = np.where(label_ids != -100, label_ids, processor.tokenizer.pad_token_id)
    refs = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    return {"wer": wer_metric.compute(predictions=preds, references=refs),
            "cer": cer_metric.compute(predictions=preds, references=refs)}

##Quick sanity checks

In [29]:
# ==== Quick sanity checks for Whisper + LoRA ====

import numpy as np
import torch

# 1) 采样率 + 单声道检查（原始数据集 minds）
TARGET_SR = processor.feature_extractor.sampling_rate
sample_audio = minds["train"][0]["audio"]                  # 若你的变量名不同，这里换成你的原始 Dataset 变量
assert sample_audio["sampling_rate"] == TARGET_SR and sample_audio["array"].ndim == 1
print(f"[OK] audio resampled to {TARGET_SR} Hz mono")

# 2) collator 输出形状检查（已预处理后的 encoded）
sample_feats = encoded["validation"].select(range(min(2, len(encoded["validation"]))))
batch = data_collator(sample_feats)
print("[OK] collator output keys:", list(batch.keys()))
print("[OK] collator output shapes:", {k: tuple(v.shape) for k, v in batch.items()})

# 3) 目标 labels 构造检查（不再使用 as_target_processor，直接 tokenizer）
ids = processor.tokenizer(
    [minds["train"][i]["transcription"] for i in range(min(2, len(minds["train"])))],
    padding=True,
    return_tensors="np",
).input_ids
print("[OK] target ids shape:", ids.shape)

# 4) 模型 generate 小样本自检（Whisper 用 input_features，而不是 input_ids）
model.eval()
with torch.no_grad():
    gen_kwargs = {"max_length": 225}
    generated_ids = model.generate(
        input_features=batch["input_features"].to(model.device),
        **gen_kwargs,
    )

# 5) 预测/参考文本对齐与打印
label_ids = batch["labels"].cpu().numpy()
label_ids = np.where(label_ids == -100, processor.tokenizer.pad_token_id, label_ids)

# 对于 Whisper，解码要用 tokenizer.batch_decode（不要用 processor.batch_decode）
pred_str = processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
ref_str  = processor.tokenizer.batch_decode(label_ids,    skip_special_tokens=True)

for i, (ref, hyp) in enumerate(zip(ref_str, pred_str), 1):
    print(f"[{i}] REF: {ref[:120]}")
    print(f"[{i}] HYP: {hyp[:120]}")


[OK] audio resampled to 16000 Hz mono


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


[OK] collator output keys: ['input_features', 'labels']
[OK] collator output shapes: {'input_features': (2, 80, 3000), 'labels': (2, 16)}
[OK] target ids shape: (2, 12)
[1] REF: 哎你好啊 你好啊你可以叫我玛丽
[1] HYP: 嗯我是 LUNA 嗯你现在
[2] REF: 好的
[2] HYP: 啊你现在在哪


In [26]:
tiny = encoded["train"].select(range(12))
args = TrainingArguments(output_dir="tmp_overfit", max_steps=300,
                         per_device_train_batch_size=3, learning_rate=3e-4,
                         logging_steps=20, save_steps=10_000, fp16=torch.cuda.is_available(),report_to='none',)
trainer = Trainer(model=model, args=args, train_dataset=tiny, data_collator=data_collator,
                  processing_class=processor)
trainer.train()
print("[OK] tiny overfit finished; loss should drop")


You're using a WhisperTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss
20,10.067
40,1.4843
60,0.687
80,0.5323
100,0.3685
120,0.3045
140,0.2735
160,0.2218
180,0.2373
200,0.1992




[OK] tiny overfit finished; loss should drop


In [36]:
# ==== Tiny Sanity Run: Whisper-small + LoRA ====
import numpy as np, torch
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

# 0) 基础自检：采样率与单声道（若你的 encoded 保存了原始 audio 列，可启用）
if "audio" in encoded["train"].column_names:
    sr = processor.feature_extractor.sampling_rate
    a0 = encoded["train"][0]["audio"]
    assert a0["sampling_rate"] == sr and a0["array"].ndim == 1, "audio必须是16k单声道"
    print(f"[OK] audio at {sr} Hz mono")

# 1) 抽样一个超小训练/验证子集（可调 n_train/n_val）
n_train = min(200, len(encoded["train"]))
n_val   = min(50,  len(encoded["validation"]))
rng = np.random.default_rng(42)
train_idx = rng.choice(len(encoded["train"]), size=n_train, replace=False).tolist()
val_idx   = rng.choice(len(encoded["validation"]), size=n_val, replace=False).tolist()

tiny_ds = {
    "train": encoded["train"].select(train_idx),
    "validation": encoded["validation"].select(val_idx),
}

# 2) collator 输出形状快速检查
batch_probe = data_collator(tiny_ds["validation"].select(range(min(2, len(tiny_ds["validation"])))))
print("[OK] collator keys:", list(batch_probe.keys()))
print("[OK] collator shapes:", {k: tuple(v.shape) for k,v in batch_probe.items()})

# 3) Tiny 训练参数（快速过拟合训练集为目标）
tiny_args = Seq2SeqTrainingArguments(
    output_dir=str(DATA_ROOT/"_tmp_whisper_lora_tiny"),
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=4,
    learning_rate=5e-4,            # LoRA 可用稍大 LR
    num_train_epochs=2,            # 1~2 就能观察趋势
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    weight_decay=0.01,

    eval_strategy="epoch",
    save_strategy="no",            # 不保存大权重
    logging_steps=20,
    report_to="none",

    predict_with_generate=True,    # ⭐ Whisper/Seq2Seq 评测基于 generate
    generation_max_length=128,
    generation_num_beams=1,

    gradient_checkpointing=True,   # 显存紧张就保留；宽裕想提速可关
    bf16=torch.cuda.is_available(),# 你的环境如支持就用
)

tiny_trainer = Seq2SeqTrainer(
    model=model,                         # 你已套 LoRA 的 PeftModel
    args=tiny_args,
    train_dataset=tiny_ds["train"],
    eval_dataset=tiny_ds["validation"],
    processing_class=processor,          # ✅ 代替 tokenizer（官方已弃用 tokenizer 参数）
    data_collator=data_collator,         # 你的 DataCollatorSpeechSeq2SeqWithPadding
    compute_metrics=compute_metrics,     # 你的 WER/CER 函数（用 tokenizer.batch_decode）
    # （可选）也能把早停/最佳适配器保存回调加上，但小样本一般不必
)

print("[Tiny] start training on", len(tiny_ds["train"]), "samples...")
tiny_trainer.train()

# 4) 评估 + 少量解码展示
metrics = tiny_trainer.evaluate()
print("[Tiny] eval metrics:", metrics)

# 单步 generate 验证：注意 Whisper 的输入是 input_features
model.eval()
with torch.no_grad():
    gen_ids = model.generate(
        input_features=batch_probe["input_features"].to(model.device),
        max_length=64
    )
pred_str = processor.tokenizer.batch_decode(gen_ids, skip_special_tokens=True)

lab = batch_probe["labels"].cpu().numpy()
lab = np.where(lab == -100, processor.tokenizer.pad_token_id, lab)
ref_str = processor.tokenizer.batch_decode(lab, skip_special_tokens=True)

for i, (r, h) in enumerate(zip(ref_str, pred_str), 1):
    print(f"[Tiny {i}] REF: {r[:120]}")
    print(f"[Tiny {i}] HYP: {h[:120]}")

print("[Tiny] sanity run finished.")


[OK] collator keys: ['input_features', 'labels']
[OK] collator shapes: {'input_features': (2, 80, 3000), 'labels': (2, 11)}
[Tiny] start training on 200 samples...




Epoch,Training Loss,Validation Loss,Wer,Cer
1,No log,4.447331,1.096774,0.926009
2,No log,4.06696,0.903226,0.932735


[Tiny] eval metrics: {'eval_loss': 4.066959857940674, 'eval_wer': 0.9032258064516129, 'eval_cer': 0.9327354260089686, 'eval_runtime': 8.6994, 'eval_samples_per_second': 5.748, 'eval_steps_per_second': 0.46, 'epoch': 2.0}
[Tiny 1] REF: 整体上的效果是怎么样
[Tiny 1] HYP: YEY
[Tiny 2] REF: JUST HORRIBLE
[Tiny 2] HYP: YEAH
[Tiny] sanity run finished.


## 8) Train

In [None]:
#=======================================Round 1(Whisper + LoRA)=======================================



# 从 Round 2 的“最佳 checkpoint”继续（仍然只加载权重）

from peft import LoraConfig
from transformers import BitsAndBytesConfig,EarlyStoppingCallback
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

# === 1) 配置 LoRA（对自注意力投影层加 LoRA）===
# *注意*：Wav2Vec2 编码器的注意力层常见投影名为 q_proj/k_proj/v_proj/out_proj
lora_config = LoraConfig(
    r=8, lora_alpha=16, lora_dropout=0.05, bias="none",
    # target_modules=["q_proj", "v_proj"]
    target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
)

ADAPTER_NAME = "run_LoRA_whisper"
ADAPTER_BEST_DIR = str((DATA_ROOT / "outputs_whisper_lora_adapter_best").resolve())
model.add_adapter(lora_config, adapter_name=ADAPTER_NAME)  # ← 官方推荐调用
model.set_adapter(ADAPTER_NAME)                          # 训练时启用该 adapter

# model.print_trainable_parameters()  # 调试：应显示仅少量参数可训练

from transformers import TrainerCallback
import os, math

class BestAdapterSaver(TrainerCallback):
    def __init__(self, adapter_name: str, out_dir: str, metric_name="eval_cer", greater_is_better=False):
        self.adapter_name = adapter_name
        self.out_dir = out_dir
        self.metric_name = metric_name
        self.sign = 1.0 if greater_is_better else -1.0
        self.best_score = -math.inf  # 用带符号的分数做“越大越好”的统一比较
        os.makedirs(out_dir, exist_ok=True)

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        if metrics is None or self.metric_name not in metrics:
            return
        score = self.sign * metrics[self.metric_name]
        if score > self.best_score:
            self.best_score = score
            # 只保存 LoRA 适配器（不保存全模型）
            adapter_dir = os.path.join(self.out_dir, f"best_adapter_step-{state.global_step}")
            # kwargs["model"].save_adapter(adapter_dir, adapter_name=self.adapter_name)
            kwargs["model"].save_pretrained(adapter_dir)
            # 也可额外写一个“latest-best”软链接/拷贝，方便复用：
            latest_dir = os.path.join(self.out_dir, "best_adapter_latest")
            try:
                if os.path.islink(latest_dir) or os.path.exists(latest_dir):
                    import shutil
                    shutil.rmtree(latest_dir)
                import shutil
                shutil.copytree(adapter_dir, latest_dir)
            except Exception:
                pass
        return


lora_args = Seq2SeqTrainingArguments(          # ← 原来是 TrainingArguments
    output_dir=str(DATA_ROOT/'outputs_whisper_lora'),
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    gradient_accumulation_steps=1,
    learning_rate=5e-4,               # LoRA 常见较大学习率
    weight_decay=0.01,
    num_train_epochs=24.0,
    lr_scheduler_type="cosine",
    warmup_ratio=0.10,

    gradient_checkpointing=False,
    bf16=True,
    dataloader_num_workers=16,
    dataloader_pin_memory=True,
    dataloader_persistent_workers=True,  # 复用 worker，连续迭代更快
    dataloader_prefetch_factor=4,        # 每 worker 预取 4 个 batch

    # 评估/保存策略
    eval_strategy="epoch",
    save_strategy="no",
    save_total_limit=2,

    logging_steps=200,
    report_to="none",
    optim="adamw_torch_fused",

    # ⭐ Seq2Seq 生成相关（在评估/预测时使用 generate）
    predict_with_generate=True,
    generation_max_length=225,        # 也可改成 generation_config 控制
    generation_num_beams=1,           # 如需更稳可设 2~4
    metric_for_best_model="cer",      # 你用 CER 作为指标
)


lora_trainer = Seq2SeqTrainer(        # ← 原来是 Trainer
    model=model,
    args=lora_args,
    train_dataset=encoded["train"],
    eval_dataset=encoded["validation"],
    processing_class=processor,       # ✅ 代替 tokenizer（已弃用）
    data_collator=data_collator,      # 你自定义的 SpeechSeq2Seq collator
    compute_metrics=compute_metrics,
    callbacks=[
        EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=1e-3),
        BestAdapterSaver(adapter_name=ADAPTER_NAME, out_dir=ADAPTER_BEST_DIR,
                         metric_name="eval_cer", greater_is_better=False),
    ],
)

# 与 Round 2 一样：这是“从权重开始新一轮”，不要恢复旧 optimizer/scheduler
lora_trainer.train()

ADAPTER_OUT = str(DATA_ROOT/'outputs_whisper_lora_adapter')
# model.save_adapter(ADAPTER_OUT, adapter_name=ADAPTER_NAME)
model.save_pretrained(ADAPTER_OUT)  # ← 同上，用 save_pretrained

# 之后可在任意同架构模型上：
# model.load_adapter(ADAPTER_OUT, adapter_name="lora_round3")
# model.set_adapter("lora_round3")


Using EarlyStoppingCallback without load_best_model_at_end=True. Once training is finished, the best model will not be loaded automatically.


Epoch,Training Loss,Validation Loss,Wer,Cer
1,No log,5.950615,1.338523,1.266835


In [None]:
# #=======================================Round 1=======================================

# from transformers import TrainingArguments, Trainer, EarlyStoppingCallback

# args = TrainingArguments(
#     output_dir=str(DATA_ROOT/'outputs'),
#     dataloader_num_workers=4,      # >0 才会有多进程加载
#     dataloader_pin_memory=True,    # GPU 传输更快
#     per_device_train_batch_size=4,
#     per_device_eval_batch_size=4,
#     gradient_accumulation_steps=4,      # 有效 batch = 16
#     learning_rate=1e-4,                 # 若抖动大，改 5e-5
#     weight_decay=0.01,
#     warmup_ratio=0.1,
#     num_train_epochs=24.0,
#     lr_scheduler_type="cosine",         # 显式写出，虽然默认就是 linear
#     gradient_checkpointing=True,
#     fp16=torch.cuda.is_available(),
#     group_by_length=False,
#     eval_strategy="epoch",
#     save_strategy="epoch",
#     save_total_limit=2,
#     logging_steps=1000,
#     load_best_model_at_end=True,
#     metric_for_best_model='cer',
#     greater_is_better=False,
#     report_to='none',
#     optim="adamw_bnb_8bit",
# )

# trainer = Trainer(
#     model=model,
#     args=args,
#     train_dataset=encoded['train'],
#     eval_dataset=encoded['validation'],
#     # tokenizer=processor,
#     processing_class=processor,
#     data_collator=data_collator,
#     compute_metrics=compute_metrics,
#     callbacks=[EarlyStoppingCallback(
#         early_stopping_patience=3,               # 连续3次评估无显著提升就停
#         early_stopping_threshold=1e-3            # CER 改善 < 0.001 视为不显著
#     )]
# )
# trainer.train()


In [None]:
# #=======================================Round 2=======================================

# best_ckpt = trainer.state.best_model_checkpoint
# print("best ckpt:", best_ckpt)

# from transformers import Wav2Vec2ForCTC
# model = Wav2Vec2ForCTC.from_pretrained(best_ckpt)   # 仅加载权重
# # 重新构建新的 TrainingArguments / Trainer（优化器会重新初始化）

# new_args = TrainingArguments(
#     output_dir=str(DATA_ROOT/'outputs_v2'),
#     per_device_train_batch_size=4,
#     per_device_eval_batch_size=4,
#     gradient_accumulation_steps=4,        # 有效 batch 仍为 16
#     learning_rate=5e-5,                   # ← 从 1e-4 降一档（后半程更稳）
#     weight_decay=0.01,
#     # 用 epoch 续训 3~5 个；或改成 max_steps 也可
#     num_train_epochs=5.0,
#     lr_scheduler_type="cosine",           # 线性 warmup→线性衰减
#     warmup_ratio=0.1,                     # ← 用比例更方便（约 10% warmup）
#     gradient_checkpointing=True,
#     bf16=True,                            # A100 建议用 BF16（比 FP16 稳）
#     dataloader_num_workers=4,
#     dataloader_pin_memory=True,
#     group_by_length=False,                 # 续训阶段可开启，提升吞吐/省 pad
#     eval_strategy="steps",          # ← 官方参数名是 eval_strategy
#     eval_steps=1000,
#     save_steps=1000,
#     save_total_limit=2,                   # 只保留“最佳+最近”两个 ckpt
#     logging_steps=50,
#     load_best_model_at_end=True,
#     metric_for_best_model='cer',
#     greater_is_better=False,
#     report_to='none',
# )

# new_trainer = Trainer(
#     model=model,
#     args=new_args,
#     train_dataset=encoded["train"],
#     eval_dataset=encoded["validation"],
#     processing_class=processor,
#     data_collator=data_collator,
#     compute_metrics=compute_metrics,
# )

# new_trainer.train()

In [None]:
# !cp -r cs_dialogue/outputs_v2/checkpoint-5000/ '/content/drive/MyDrive/Project2-PRS/models'


In [None]:
# #=======================================Round 3=======================================
# from transformers import Wav2Vec2ForCTC

# CKPT = "/content/drive/MyDrive/Project2-PRS/models/round2-checkpoint-5000"
# model = Wav2Vec2ForCTC.from_pretrained(CKPT)

# # === 3) 训练参数：cosine + warmup 10% + 早停；注意把 eval/save 策略写成官方键名 ===
# from transformers import EarlyStoppingCallback

# new_args_round3 = TrainingArguments(
#     output_dir=str(DATA_ROOT/'outputs_round3'),  # 新输出目录
#     per_device_train_batch_size=4,
#     per_device_eval_batch_size=4,
#     gradient_accumulation_steps=4,               # 有效 batch=16
#     learning_rate=5e-5,                          # 二轮基础上再稳一点
#     weight_decay=0.01,
#     num_train_epochs=12.0,                        # 先续 3~5 个 epoch
#     lr_scheduler_type="cosine",                  # 余弦退火（Trainer 原生）
#     warmup_ratio=0.10,                           # 10% warmup
#     gradient_checkpointing=True,
#     bf16=True,                                   # A100 建议 bf16
#     dataloader_num_workers=4,
#     dataloader_pin_memory=True,

#     eval_strategy="epoch",
#     # eval_steps=1000,
#     save_strategy="epoch",
#     # save_steps=1000,
#     save_total_limit=2,

#     logging_steps=50,
#     load_best_model_at_end=True,
#     metric_for_best_model="cer",
#     greater_is_better=False,
#     report_to="none",

#     #8-bit optimizer
#     optim="adamw_bnb_8bit",
# )

# new_trainer = Trainer(
#     model=model,
#     args=new_args_round3,
#     train_dataset=encoded["train"],
#     eval_dataset=encoded["validation"],
#     processing_class=processor,                  # 你之前已改用 processing_class ✅
#     data_collator=data_collator,
#     compute_metrics=compute_metrics,
#     callbacks=[EarlyStoppingCallback(
#         early_stopping_patience=3,               # 连续3次评估无显著提升就停
#         early_stopping_threshold=1e-3            # CER 改善 < 0.001 视为不显著
#     )]
# )



In [None]:
# new_trainer.train()

In [None]:
# #=======================================Round 3 (continue)=======================================
# # BEST = "/content/cs_dialogue/outputs_round3/checkpoint-14292"
# CKPT = "/content/drive/MyDrive/Project2-PRS/models/checkpoint-14292"


# from transformers import Wav2Vec2ForCTC, TrainingArguments, Trainer, EarlyStoppingCallback
# # 1) 仅加载模型权重（不会带上一次的优化器/调度器状态）
# model = Wav2Vec2ForCTC.from_pretrained(CKPT)

# # 2) 新一轮参数（示例：把 LR 小降；按 epoch 评估/保存，早停照旧）
# round3_coutinue = TrainingArguments(
#     output_dir=str(DATA_ROOT/'outputs_round3_cont'),
#     per_device_train_batch_size=4,
#     per_device_eval_batch_size=4,
#     gradient_accumulation_steps=4,
#     learning_rate=1e-5,                 # ← 比 5e-5 再稳一点
#     weight_decay=0.01,
#     num_train_epochs=12.0,               # 再跑 1–2 个 epoch，交给早停兜底
#     lr_scheduler_type="cosine",
#     warmup_ratio=0,
#     gradient_checkpointing=True,
#     bf16=True,
#     dataloader_num_workers=4,
#     dataloader_pin_memory=True,
#     eval_strategy="epoch",
#     save_strategy="epoch",
#     save_total_limit=2,
#     logging_steps=1000,
#     load_best_model_at_end=True,
#     metric_for_best_model="cer",
#     greater_is_better=False,
#     report_to="none",
#     optim="adamw_bnb_8bit",             # 继续用 8-bit 优化器（已安装 bnb）
# )

# trainer = Trainer(
#     model=model,
#     args=round3_coutinue,
#     train_dataset=encoded["train"],
#     eval_dataset=encoded["validation"],
#     processing_class=processor,
#     data_collator=data_collator,
#     compute_metrics=compute_metrics,
#     callbacks=[EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=1e-3)]
# )

# trainer.train()  # ← 不要传 resume_from_checkpoint


In [None]:
# BEST = "/content/cs_dialogue/outputs_round3/checkpoint-14292"  # 你JSON里给的
# from transformers import Wav2Vec2ForCTC, Trainer, TrainingArguments

# model = Wav2Vec2ForCTC.from_pretrained(BEST)
# args = TrainingArguments(output_dir=str(DATA_ROOT/'final_eval'), report_to='none')
# tester = Trainer(
#     model=model, args=args,
#     eval_dataset=encoded["test"],        # ← 用 test split
#     processing_class=processor,
#     data_collator=data_collator,
#     compute_metrics=compute_metrics,     # 仍计算 CER/WER（主：CER）
# )
# test_metrics = tester.evaluate(metric_key_prefix="test")
# print(test_metrics)  # 建议保存到文件
# # 保存“可部署包”
# tester.save_model(str(DATA_ROOT/'final_best'))
# processor.save_pretrained(str(DATA_ROOT/'final_best'))


In [None]:
# BEST = "/content/cs_dialogue/outputs_round3_cont/checkpoint-8337"  # 你JSON里给的
# from transformers import Wav2Vec2ForCTC, Trainer, TrainingArguments

# model = Wav2Vec2ForCTC.from_pretrained(BEST)
# args = TrainingArguments(output_dir=str(DATA_ROOT/'final_eval_8337'), report_to='none')
# tester = Trainer(
#     model=model, args=args,
#     eval_dataset=encoded["test"],        # ← 用 test split
#     processing_class=processor,
#     data_collator=data_collator,
#     compute_metrics=compute_metrics_beam,     # 仍计算 CER/WER（主：CER）
# )
# test_metrics = tester.evaluate(metric_key_prefix="test")
# print(test_metrics)  # 建议保存到文件
# # 保存“可部署包”
# tester.save_model(str(DATA_ROOT/'final_best_8337'))
# processor.save_pretrained(str(DATA_ROOT/'final_best_8337'))


In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

# BEST = "/content/drive/MyDrive/Project2-PRS/models/checkpoint-8337"  # 你JSON里给的
# from transformers import Wav2Vec2ForCTC, Trainer, TrainingArguments

# model = Wav2Vec2ForCTC.from_pretrained(BEST)
# args = TrainingArguments(output_dir=str(DATA_ROOT/'final_eval_8337_beam'), report_to='none')
# tester = Trainer(
#     model=model, args=args,
#     eval_dataset=encoded["test"],        # ← 用 test split
#     processing_class=processor,
#     data_collator=data_collator,
#     compute_metrics=compute_metrics_beam,     # 仍计算 CER/WER（主：CER）
# )
# test_metrics = tester.evaluate(metric_key_prefix="test_beam")
# print(test_metrics)  # 建议保存到文件
# # 保存“可部署包”
# tester.save_model(str(DATA_ROOT/'final_best_8337_beam'))
# processor.save_pretrained(str(DATA_ROOT/'final_best_8337_beam'))


In [None]:
# #=======================================Round 4(LoRA)=======================================
# from transformers import Wav2Vec2ForCTC


# # 从 Round 2 的“最佳 checkpoint”继续（仍然只加载权重）
# CKPT = "/content/drive/MyDrive/Project2-PRS/models/checkpoint-8337"

# from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, EarlyStoppingCallback
# from peft import LoraConfig
# from transformers import BitsAndBytesConfig

# model = Wav2Vec2ForCTC.from_pretrained(
#     CKPT,
#     # quantization_config=BitsAndBytesConfig(load_in_8bit=True),
#     # device_map="auto",
# )

# # === 1) 配置 LoRA（对自注意力投影层加 LoRA）===
# # *注意*：Wav2Vec2 编码器的注意力层常见投影名为 q_proj/k_proj/v_proj/out_proj
# lora_config = LoraConfig(
#     r=8, lora_alpha=16, lora_dropout=0.05, bias="none",
#     target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
# )

# ADAPTER_NAME = "lora_wav2vec2_round4"
# ADAPTER_BEST_DIR = str((DATA_ROOT / "outputs_round4_lora_adapter_best").resolve())
# model.add_adapter(lora_config, adapter_name=ADAPTER_NAME)  # ← 官方推荐调用
# model.set_adapter(ADAPTER_NAME)                          # 训练时启用该 adapter

# # model.print_trainable_parameters()  # 调试：应显示仅少量参数可训练

# from transformers import TrainerCallback
# import os, math

# class BestAdapterSaver(TrainerCallback):
#     def __init__(self, adapter_name: str, out_dir: str, metric_name="eval_cer", greater_is_better=False):
#         self.adapter_name = adapter_name
#         self.out_dir = out_dir
#         self.metric_name = metric_name
#         self.sign = 1.0 if greater_is_better else -1.0
#         self.best_score = -math.inf  # 用带符号的分数做“越大越好”的统一比较
#         os.makedirs(out_dir, exist_ok=True)

#     def on_evaluate(self, args, state, control, metrics=None, **kwargs):
#         if metrics is None or self.metric_name not in metrics:
#             return
#         score = self.sign * metrics[self.metric_name]
#         if score > self.best_score:
#             self.best_score = score
#             # 只保存 LoRA 适配器（不保存全模型）
#             adapter_dir = os.path.join(self.out_dir, f"best_adapter_step-{state.global_step}")
#             # kwargs["model"].save_adapter(adapter_dir, adapter_name=self.adapter_name)
#             kwargs["model"].save_pretrained(adapter_dir)
#             # 也可额外写一个“latest-best”软链接/拷贝，方便复用：
#             latest_dir = os.path.join(self.out_dir, "best_adapter_latest")
#             try:
#                 if os.path.islink(latest_dir) or os.path.exists(latest_dir):
#                     import shutil
#                     shutil.rmtree(latest_dir)
#                 import shutil
#                 shutil.copytree(adapter_dir, latest_dir)
#             except Exception:
#                 pass
#         return


# lora_args = TrainingArguments(
#     output_dir=str(DATA_ROOT/'outputs_round4_lora'),
#     per_device_train_batch_size=4,
#     per_device_eval_batch_size=4,
#     gradient_accumulation_steps=4,
#     learning_rate=5e-4,               # LoRA 常用更大的起始 LR（5e-4 ~ 1e-3）
#     weight_decay=0.01,
#     num_train_epochs=7.0,
#     lr_scheduler_type="cosine",
#     warmup_ratio=0.10,
#     gradient_checkpointing=False,
#     bf16=True,
#     dataloader_num_workers=4,
#     dataloader_pin_memory=True,

#     # ★ 改为按 epoch
#     eval_strategy="epoch",
#     save_strategy="no",
#     save_total_limit=2,

#     logging_steps=200,
#     load_best_model_at_end=False,
#     metric_for_best_model="cer",
#     greater_is_better=False,
#     report_to="none",
#     remove_unused_columns=False,
# )


# lora_trainer = Trainer(
#     model=model,                                # LoRA 包裹后的 PeftModel
#     args=lora_args,
#     train_dataset=encoded["train"],
#     eval_dataset=encoded["validation"],
#     processing_class=processor,
#     data_collator=data_collator,
#     compute_metrics=compute_metrics,
#     callbacks=[EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=1e-3),
#                BestAdapterSaver(adapter_name=ADAPTER_NAME,
#                                 out_dir=ADAPTER_BEST_DIR,
#                                 metric_name="eval_cer",
#                                 greater_is_better=False),
#                ]
# )

# # 与 Round 2 一样：这是“从权重开始新一轮”，不要恢复旧 optimizer/scheduler
# lora_trainer.train()

# ADAPTER_OUT = str(DATA_ROOT/'outputs_round4_lora_adapter')
# # model.save_adapter(ADAPTER_OUT, adapter_name=ADAPTER_NAME)
# model.save_pretrained(ADAPTER_OUT)  # ← 同上，用 save_pretrained

# # 之后可在任意同架构模型上：
# # model.load_adapter(ADAPTER_OUT, adapter_name="lora_round3")
# # model.set_adapter("lora_round3")


In [None]:
# BASE = "/content/drive/MyDrive/Project2-PRS/models/checkpoint-8337"
# BEST_ADAPTER_DIR = str((DATA_ROOT / "outputs_round4_lora_adapter_best" / "best_adapter_latest").resolve())

# model = Wav2Vec2ForCTC.from_pretrained(BASE)
# model.load_adapter(BEST_ADAPTER_DIR, adapter_name="lora_best")
# model.set_adapter("lora_best")
# # 然后用 Trainer 做 evaluate / 继续训练即可


In [None]:
# from transformers import Wav2Vec2ForCTC
# from google.colab import drive
# drive.mount('/content/drive')

# #载入你当前最佳权重（全量权重）
# CKPT = "/content/drive/MyDrive/Project2-PRS/models/checkpoint-8337"
# model = Wav2Vec2ForCTC.from_pretrained(CKPT)

# # 给自注意力投影层加 LoRA（按你的 Wav2Vec2 命名可能是 q_proj/k_proj/v_proj/out_proj）
# lora_cfg = LoraConfig(
#     r=8, lora_alpha=16, lora_dropout=0.05, bias="none",
#     target_modules=["q_proj","k_proj","v_proj","out_proj"]
# )
# ADAPTER_NAME = "lora_round4"
# model.add_adapter(lora_cfg, adapter_name=ADAPTER_NAME)   # 官方推荐 API
# model.set_adapter(ADAPTER_NAME)

# lora_args = TrainingArguments(
#     output_dir=str(DATA_ROOT/'outputs_round4_lora'),
#     per_device_train_batch_size=4,
#     per_device_eval_batch_size=4,
#     gradient_accumulation_steps=4,
#     learning_rate=5e-4,                # LoRA 可用更大 LR
#     weight_decay=0.01,
#     num_train_epochs=7.0,
#     lr_scheduler_type="cosine",
#     warmup_ratio=0.10,
#     bf16=True,
#     gradient_checkpointing=True,
#     dataloader_num_workers=4,
#     dataloader_pin_memory=True,
#     evaluation_strategy="epoch",
#     save_strategy="epoch",
#     save_total_limit=2,
#     load_best_model_at_end=True,
#     metric_for_best_model="cer",
#     greater_is_better=False,
#     logging_steps=50,
#     report_to="none",
# )

# lora_trainer = Trainer(
#     model=model,
#     args=lora_args,
#     train_dataset=encoded["train"], eval_dataset=encoded["validation"],
#     processing_class=processor, data_collator=data_collator,
#     compute_metrics=compute_metrics,   # 建议直接用上面“beam版”评估
# )
# lora_trainer.train()

# # 仅保存适配器，便于复用/分享
# model.save_adapter(str(DATA_ROOT/'outputs_round4_lora_adapter'), adapter_name=ADAPTER_NAME)


In [None]:
# # 1) 最终在 test 上做指标评估（会返回 metrics）
# test_metrics = trainer.evaluate(eval_dataset=encoded["test"])
# print("TEST (evaluate):", test_metrics)