# XLS-R (Wav2Vec2) CTC Fine-tuning on BAAI/CS-Dialogue (Mandarin–English Code-Switching)

**Goal:** Fine-tune `facebook/wav2vec2-xls-r-300m` with a custom CTC vocabulary built from CS-Dialogue.

**Key points**
- Uses `datasets.Audio` + `cast_column(..., Audio(16000))`.
- Builds mixed EN (A–Z) + Chinese vocab; maps space→`|`.
- Correct padding & metrics (WER/CER).

References:
- CS-Dialogue dataset card (16kHz, structure): https://huggingface.co/datasets/BAAI/CS-Dialogue
- XLS-R-300M model card (16kHz input): https://huggingface.co/facebook/wav2vec2-xls-r-300m
- Datasets audio processing: https://huggingface.co/docs/datasets/en/audio_process
- WER / CER metrics: https://huggingface.co/spaces/evaluate-metric/wer , https://huggingface.co/spaces/evaluate-metric/cer


In [1]:

# !pip -q install  "datasets[audio]" "evaluate==0.4.3" "jiwer==3.0.4"
!pip -q install   "evaluate==0.4.3" "jiwer==3.0.4"
!pip install "torchcodec==0.7.*" --index-url https://download.pytorch.org/whl/cu126

# import transformers, datasets, evaluate, jiwer, tokenizers, soundfile, torch, torchaudio, huggingface_hub
# print("Transformers  :", transformers.__version__)
# print("Datasets      :", datasets.__version__)
# print("Evaluate      :", evaluate.__version__)
# print("JiWER         :", jiwer.__version__)
# print("Tokenizers    :", tokenizers.__version__)
# print("HF Hub        :", huggingface_hub.__version__)
# print("Torch         :", torch.__version__)
# print("Torchaudio    :", torchaudio.__version__)
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:512"

from importlib.metadata import version, PackageNotFoundError
import transformers, datasets, evaluate, tokenizers, huggingface_hub, torch, torchaudio, jiwer

def safe_ver(pkg_name):
    try:
        return version(pkg_name)
    except PackageNotFoundError:
        return "(missing)"

print("Transformers  :", safe_ver("transformers"))
print("Datasets      :", safe_ver("datasets"))
print("Evaluate      :", safe_ver("evaluate"))
print("JiWER         :", safe_ver("jiwer"))          # ✅ 改用 metadata 查询
print("Tokenizers    :", safe_ver("tokenizers"))
print("HF Hub        :", safe_ver("huggingface_hub"))
print("Torch         :", torch.__version__)
print("Torchaudio    :", torchaudio.__version__)
print("PYTORCH_CUDA_ALLOC_CONF =", os.environ.get("PYTORCH_CUDA_ALLOC_CONF"))


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/84.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/3.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m111.0 MB/s[0m eta [36m0:00:00[0m
[?25hLooking in indexes: https://download.pytorch.org/whl/cu126
Collecting torchcodec==0.7.*
  Downloading https://download.pytorch.org/whl/cu126/torchcodec-0.7.0%2Bcu126-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (9.4 kB)
Downloading https://download.pytorch.org/whl/cu126/torchcodec-0.7.0%2Bcu126-cp312-cp312-manylinux_2_28_x86_64.whl (1.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m66.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torchcodec
Successfully instal

In [None]:
import os
import re
import tarfile
import json
from pathlib import Path
from collections import Counter
from typing import Dict, List
import numpy as np
import torch
from datasets import Dataset, DatasetDict, Audio, Features, Value
from huggingface_hub import hf_hub_download
from transformers import AutoFeatureExtractor, AutoModelForCTC, TrainingArguments, Trainer, Wav2Vec2Processor, Wav2Vec2CTCTokenizer
import evaluate

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

BASE_DIR = Path('/content') if Path('/content').exists() else Path.cwd()
DATA_ROOT = BASE_DIR / 'cs_dialogue'
AUDIO_DIR = DATA_ROOT / 'short_wav'
INDEX_DIR = DATA_ROOT / 'data' / 'index' / 'short_wav'
VOCAB_DIR = DATA_ROOT / 'custom_vocab'
for p in [DATA_ROOT, AUDIO_DIR, INDEX_DIR, VOCAB_DIR]:
    p.mkdir(parents=True, exist_ok=True)

CKPT = 'facebook/wav2vec2-xls-r-300m'
# increase to 19 for full short_wav
NUM_SHARDS = int(os.environ.get('CS_NUM_SHARDS', 19))

## 1) Download index & audio shards (short_wav)

In [3]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
from pathlib import Path
import os
import shutil
import tarfile
import glob

REPO_ID = "BAAI/CS-Dialogue"        # 数据集仓库
# 你已有：DATA_ROOT = Path('/content/cs_dialogue'); AUDIO_DIR = DATA_ROOT/'short_wav'; INDEX_DIR = DATA_ROOT/'index'/'short_wav'


def download_index():
    files = [
        "data/index/short_wav/train/text",
        "data/index/short_wav/train/wav.scp",
        "data/index/short_wav/dev/text",
        "data/index/short_wav/dev/wav.scp",
        "data/index/short_wav/test/text",
        "data/index/short_wav/test/wav.scp",
    ]
    local = []
    for fp in files:
        # ★ 关键：直接落到 INDEX_DIR/ 相对目录；禁用 symlink
        dst_dir = INDEX_DIR / \
            Path(fp).parent.relative_to("data/index/short_wav")
        dst_dir.mkdir(parents=True, exist_ok=True)
        src = hf_hub_download(
            repo_id=REPO_ID,
            filename=fp,
            repo_type="dataset",
            local_dir=str(DATA_ROOT),   # 关键：交给 local_dir 来还原相对层级
        )
        local.append(Path(src))
    print("Index ready:", local)
    return local


def download_shards(n=19):
    got = []
    for i in range(n):
        rel = f"data/short_wav/short_wav.tar.gz{i:02d}"
        src = hf_hub_download(
            repo_id=REPO_ID,
            filename=rel,
            repo_type="dataset",
            local_dir=str(DATA_ROOT),   # 落到 DATA_ROOT/rel
        )
        dst = Path(src)
        print("Downloaded:", dst)
        assert dst.exists() and dst.stat(
        ).st_size > 0, f"missing or empty part: {dst}"
        got.append(dst)
    # 关键：分片实际在 DATA_ROOT/data/short_wav 下
    print("In data/short_wav:", sorted(p.name for p in (DATA_ROOT /
          'data'/'short_wav').glob('short_wav.tar.gz*')))
    return got


def concat_parts(parts, out_file: Path):
    """把 *.tar.gz00.. 拼接成一个完整 tar.gz"""
    parts = sorted(parts, key=lambda p: p.name)
    if not parts:
        raise RuntimeError("no parts provided to concat_parts")
    # 逐个断言存在
    for p in parts:
        if not Path(p).exists():
            raise FileNotFoundError(f"part not found on disk: {p}")
    out_file.parent.mkdir(parents=True, exist_ok=True)
    total = 0
    with open(out_file, "wb") as w:
        for p in parts:
            with open(p, "rb") as r:
                shutil.copyfileobj(r, w)
            sz = Path(p).stat().st_size
            total += sz
            print(f"  appended {Path(p).name} ({sz/1e6:.1f} MB)")
    print(f"==> concatenated -> {out_file} (~{total/1e9:.2f} GB)")


def extract_concatenated_tar_gz(out_dir: Path, parts=None):
    """
    解压流程：
    1) 如未给 parts，则自动从 DATA_ROOT 扫描 short_wav.tar.gz[0-9][0-9]
    2) 拼成 short_wav.tar.gz
    3) 正常 tar 解压到 out_dir
    """
    out_dir.mkdir(parents=True, exist_ok=True)

    # ★ 自动发现分片（避免用了旧的 tars 变量）
    if parts is None:
        # 同时在 DATA_ROOT 根 & DATA_ROOT/data/short_wav 下搜索
        candidates = sorted((DATA_ROOT).glob("short_wav.tar.gz[0-9][0-9]")) \
            + sorted((DATA_ROOT / "data" /
                     "short_wav").glob("short_wav.tar.gz[0-9][0-9]"))
    else:
        candidates = list(map(Path, parts))

    if not candidates:
        raise RuntimeError(
            f"No split parts found in {DATA_ROOT}. Expected files like short_wav.tar.gz00")

    merged = DATA_ROOT / "short_wav.tar.gz"

    if merged.exists():
        merged.unlink()  # 清除此前拼接失败的半成品
    concat_parts(candidates, merged)

    # 解压
    with tarfile.open(merged, "r:gz") as tf:
        tf.extractall(out_dir, filter="data")  # 3.12+ 推荐
    print("Extracted ->", out_dir)

    # 可选：节省空间
    try:
        merged.unlink()
    except Exception:
        pass


# === 调用顺序 ===
_ = download_index()
_ = download_shards(NUM_SHARDS)                 # e.g. NUM_SHARDS=2

# ★ 完整性检查：必须 19 片全部到位
expected = [f"short_wav.tar.gz{i:02d}" for i in range(19)]
have = sorted(p.name for p in (DATA_ROOT/'data' /
              'short_wav').glob('short_wav.tar.gz*'))
missing = [x for x in expected if x not in have]
assert not missing, f"缺少分片：{missing}。请把 NUM_SHARDS 设为 19 或补齐后再解压。"

extract_concatenated_tar_gz(AUDIO_DIR, None)   # 让它自己扫描分片再拼接

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


text: 0.00B [00:00, ?B/s]

wav.scp: 0.00B [00:00, ?B/s]

text: 0.00B [00:00, ?B/s]

wav.scp: 0.00B [00:00, ?B/s]

text: 0.00B [00:00, ?B/s]

wav.scp: 0.00B [00:00, ?B/s]

Index ready: [PosixPath('/content/cs_dialogue/data/index/short_wav/train/text'), PosixPath('/content/cs_dialogue/data/index/short_wav/train/wav.scp'), PosixPath('/content/cs_dialogue/data/index/short_wav/dev/text'), PosixPath('/content/cs_dialogue/data/index/short_wav/dev/wav.scp'), PosixPath('/content/cs_dialogue/data/index/short_wav/test/text'), PosixPath('/content/cs_dialogue/data/index/short_wav/test/wav.scp')]


data/short_wav/short_wav.tar.gz00:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz00


data/short_wav/short_wav.tar.gz01:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz01


data/short_wav/short_wav.tar.gz02:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz02


data/short_wav/short_wav.tar.gz03:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz03


data/short_wav/short_wav.tar.gz04:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz04


data/short_wav/short_wav.tar.gz05:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz05


data/short_wav/short_wav.tar.gz06:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz06


data/short_wav/short_wav.tar.gz07:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz07


data/short_wav/short_wav.tar.gz08:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz08


data/short_wav/short_wav.tar.gz09:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz09


data/short_wav/short_wav.tar.gz10:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz10


data/short_wav/short_wav.tar.gz11:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz11


data/short_wav/short_wav.tar.gz12:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz12


data/short_wav/short_wav.tar.gz13:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz13


data/short_wav/short_wav.tar.gz14:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz14


data/short_wav/short_wav.tar.gz15:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz15


data/short_wav/short_wav.tar.gz16:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz16


data/short_wav/short_wav.tar.gz17:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz17


data/short_wav/short_wav.tar.gz18:   0%|          | 0.00/158M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz18
In data/short_wav: ['short_wav.tar.gz00', 'short_wav.tar.gz01', 'short_wav.tar.gz02', 'short_wav.tar.gz03', 'short_wav.tar.gz04', 'short_wav.tar.gz05', 'short_wav.tar.gz06', 'short_wav.tar.gz07', 'short_wav.tar.gz08', 'short_wav.tar.gz09', 'short_wav.tar.gz10', 'short_wav.tar.gz11', 'short_wav.tar.gz12', 'short_wav.tar.gz13', 'short_wav.tar.gz14', 'short_wav.tar.gz15', 'short_wav.tar.gz16', 'short_wav.tar.gz17', 'short_wav.tar.gz18']
  appended short_wav.tar.gz00 (524.3 MB)
  appended short_wav.tar.gz01 (524.3 MB)
  appended short_wav.tar.gz02 (524.3 MB)
  appended short_wav.tar.gz03 (524.3 MB)
  appended short_wav.tar.gz04 (524.3 MB)
  appended short_wav.tar.gz05 (524.3 MB)
  appended short_wav.tar.gz06 (524.3 MB)
  appended short_wav.tar.gz07 (524.3 MB)
  appended short_wav.tar.gz08 (524.3 MB)
  appended short_wav.tar.gz09 (524.3 MB)
  appended short_wav.tar.gz10 (524.3 MB)
  appended short_wav.tar.gz11 (524.3 MB)
  a

In [None]:
print("INDEX_DIR =", INDEX_DIR)
print("Has train/text?   ->", (INDEX_DIR/'train'/'text').exists())
print("Has train/wav.scp?->", (INDEX_DIR/'train'/'wav.scp').exists())

print("Split parts in DATA_ROOT/data/short_wav:",
      sorted(p.name for p in (DATA_ROOT/'data'/'short_wav').glob('short_wav.tar.gz*'))[:19])

INDEX_DIR = /content/cs_dialogue/data/index/short_wav
Has train/text?   -> True
Has train/wav.scp?-> True
Split parts in DATA_ROOT/data/short_wav: ['short_wav.tar.gz00', 'short_wav.tar.gz01', 'short_wav.tar.gz02', 'short_wav.tar.gz03', 'short_wav.tar.gz04', 'short_wav.tar.gz05', 'short_wav.tar.gz06', 'short_wav.tar.gz07', 'short_wav.tar.gz08', 'short_wav.tar.gz09', 'short_wav.tar.gz10', 'short_wav.tar.gz11', 'short_wav.tar.gz12', 'short_wav.tar.gz13', 'short_wav.tar.gz14', 'short_wav.tar.gz15', 'short_wav.tar.gz16', 'short_wav.tar.gz17', 'short_wav.tar.gz18']


In [None]:
from pathlib import Path


def print_tree(root: Path, max_depth: int = 3):
    root = Path(root)

    def walk(p: Path, depth: int, prefix: str = ""):
        if depth < 0:
            return
        items = sorted(p.iterdir(), key=lambda x: (
            x.is_file(), x.name.lower()))
        for i, it in enumerate(items):
            is_last = (i == len(items) - 1)
            connector = "└── " if is_last else "├── "
            name = it.name + ("/" if it.is_dir() else "")
            print(prefix + connector + name)
            if it.is_dir():
                next_prefix = prefix + ("    " if is_last else "│   ")
                walk(it, depth - 1, next_prefix)

    root = root.resolve()
    print(root.as_posix())
    walk(root, max_depth)


# 使用：打印 DATA_ROOT 下三层
print_tree(DATA_ROOT, max_depth=4)

/content/cs_dialogue
├── .cache/
│   └── huggingface/
│       ├── download/
│       │   └── data/
│       │       ├── index/
│       │       └── short_wav/
│       └── .gitignore
├── custom_vocab/
├── data/
│   ├── index/
│   │   └── short_wav/
│   │       ├── dev/
│   │       │   ├── text
│   │       │   └── wav.scp
│   │       ├── test/
│   │       │   ├── text
│   │       │   └── wav.scp
│   │       └── train/
│   │           ├── text
│   │           └── wav.scp
│   └── short_wav/
│       ├── short_wav.tar.gz00
│       ├── short_wav.tar.gz01
│       ├── short_wav.tar.gz02
│       ├── short_wav.tar.gz03
│       ├── short_wav.tar.gz04
│       ├── short_wav.tar.gz05
│       ├── short_wav.tar.gz06
│       ├── short_wav.tar.gz07
│       ├── short_wav.tar.gz08
│       ├── short_wav.tar.gz09
│       ├── short_wav.tar.gz10
│       ├── short_wav.tar.gz11
│       ├── short_wav.tar.gz12
│       ├── short_wav.tar.gz13
│       ├── short_wav.tar.gz14
│       ├── short_wav.tar.gz15
│       ├── sho

## 2) Build DatasetDict from wav.scp & text

In [None]:
# def read_kv(fp: Path):
#     d={}
#     with open(fp, 'r', encoding='utf-8') as f:
#         for line in f:
#             line=line.strip()
#             if not line: continue
#             k,v=line.split(' ',1)
#             d[k]=v
#     return d

# def make_split(split: str):
#     wavscp = read_kv(INDEX_DIR/split/'wav.scp')
#     text   = read_kv(INDEX_DIR/split/'text')
#     ids, paths, trans = [], [], []
#     for uid, wavpath in wavscp.items():

#         shard = Path(wavpath).parts[-2]
#         fname = Path(wavpath).name
#         local = AUDIO_DIR / shard / fname
#         if local.exists() and uid in text:
#             ids.append(uid); paths.append(str(local)); trans.append(text[uid])
#     feats = Features({'id': Value('string'), 'audio': Audio(sampling_rate=16000), 'transcription': Value('string')})
#     return Dataset.from_dict({'id': ids, 'audio': paths, 'transcription': trans}, features=feats)
from pathlib import Path


def resolve_local_audio_path(wavpath: str) -> Path | None:
    """
    把 wav.scp 的路径映射到本地真实存在的文件。
    兼容如下情况：
    - 路径前缀带不带 'data/'
    - 是否出现双重 'short_wav/short_wav'
    - 绝对路径/相对路径
    """
    p = Path(wavpath.strip())

    # 1) 若本身就是绝对路径且存在，直接返回
    if p.is_absolute() and p.exists():
        return p

    candidates: list[Path] = []

    # 2) wav.scp 一般包含 ".../short_wav/..." 这段；抽取从第一处 'short_wav' 之后的尾部
    parts = p.parts
    if "short_wav" in parts:
        i = parts.index("short_wav")
        tail = Path(*parts[i+1:])  # 去掉第一个 'short_wav' 及之前的前缀
        # 情况 A：你的磁盘上是 short_wav/short_wav/WAVE/...
        if (AUDIO_DIR / "short_wav").exists():
            # /.../short_wav/short_wav/...
            candidates.append(AUDIO_DIR / "short_wav" / tail)
        # 情况 B：只有一层 short_wav/WAVE/...
        # /.../short_wav/...
        candidates.append(AUDIO_DIR / tail)
    else:
        # 没出现 short_wav 关键词时，尝试几种常见组合
        # /content/cs_dialogue/<wav.scp里的相对路径>
        candidates.append(DATA_ROOT / p)
        # /content/cs_dialogue/data/<...>
        candidates.append(DATA_ROOT / "data" / p)

    # 3) 再补充几种保守候选
    candidates.append(DATA_ROOT / p)
    if str(p).startswith("data/"):
        # /content/cs_dialogue/data/short_wav/...
        candidates.append(DATA_ROOT / str(p))
        # /content/cs_dialogue/short_wav/...
        candidates.append(DATA_ROOT / str(p).replace("data/", "", 1))

    for c in candidates:
        if c.exists():
            return c.resolve()
    return None


def read_kv(fp: Path):
    d = {}
    with open(fp, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            # Kaldi风格：<key><space><value...>
            k, v = line.split(' ', 1)
            d[k] = v
    return d


def make_split(split: str):
    wavscp = read_kv(INDEX_DIR/split/'wav.scp')
    text = read_kv(INDEX_DIR/split/'text')

    ids, paths, trans = [], [], []
    miss_audio, miss_text = 0, 0

    for uid, wavpath in wavscp.items():
        local = resolve_local_audio_path(wavpath)
        if local is None or not local.exists():
            miss_audio += 1
            continue
        if uid not in text:
            miss_text += 1
            continue

        ids.append(uid)
        paths.append(str(local))
        trans.append(text[uid])

    print(f"[{split}] matched {len(ids)} items "
          f"(missing audio: {miss_audio}, missing text: {miss_text})")

    feats = Features({
        'id': Value('string'),
        'audio': Audio(sampling_rate=16000),
        'transcription': Value('string')
    })
    return Dataset.from_dict(
        {'id': ids, 'audio': paths, 'transcription': trans},
        features=feats
    )


train_ds = make_split('train')
val_ds = make_split('dev')
test_ds = make_split('test')
minds = DatasetDict(train=train_ds, validation=val_ds, test=test_ds)
minds = minds.cast_column('audio', Audio(sampling_rate=16000))
minds

[train] matched 26239 items (missing audio: 0, missing text: 0)
[dev] matched 6186 items (missing audio: 0, missing text: 0)
[test] matched 6257 items (missing audio: 0, missing text: 0)


DatasetDict({
    train: Dataset({
        features: ['id', 'audio', 'transcription'],
        num_rows: 26239
    })
    validation: Dataset({
        features: ['id', 'audio', 'transcription'],
        num_rows: 6186
    })
    test: Dataset({
        features: ['id', 'audio', 'transcription'],
        num_rows: 6257
    })
})

In [None]:
MIN_SEC, MAX_SEC = 0.3, 12.0


def _keep_ok(ex):
    sec = ex["audio"]["array"].shape[0] / ex["audio"]["sampling_rate"]
    return (sec >= MIN_SEC) and (sec <= MAX_SEC)


minds = minds.filter(_keep_ok, num_proc=4)

Filter (num_proc=4):   0%|          | 0/26239 [00:00<?, ? examples/s]

Filter (num_proc=4):   0%|          | 0/6186 [00:00<?, ? examples/s]

Filter (num_proc=4):   0%|          | 0/6257 [00:00<?, ? examples/s]

## 3) Normalize transcripts (EN upper + Chinese)

In [None]:
CN_RANGE = r"\u4E00-\u9FFF"


def normalize_text(ex):
    t = ex['transcription'].strip().upper()
    t = re.sub(fr"[^{CN_RANGE}A-Z' ]+", "", t)
    t = re.sub(r"\s+", " ", t)
    return {'transcription': t}


minds = minds.map(normalize_text)
minds['train'][0]['transcription'][:120]

Map:   0%|          | 0/19044 [00:00<?, ? examples/s]

Map:   0%|          | 0/4212 [00:00<?, ? examples/s]

Map:   0%|          | 0/4624 [00:00<?, ? examples/s]

'嗨'

## 4) Build CTC vocab (space→`|`)

In [None]:
from collections import Counter


def collect_chars(ds, key='transcription'):
    c = Counter()
    for s in ds[key]:
        c.update(list(s))
    return c


cnt = Counter()
for sp in ['train', 'validation', 'test']:
    cnt.update(collect_chars(minds[sp]))
chars = sorted([ch for ch in cnt if ch != ' '])
vocab = {ch: i for i, ch in enumerate(chars)}
vocab['|'] = len(vocab)
vocab['<unk>'] = len(vocab)
vocab['<pad>'] = len(vocab)
VOCAB_DIR.mkdir(parents=True, exist_ok=True)
with open(VOCAB_DIR/'vocab.json', 'w', encoding='utf-8') as f:
    json.dump(vocab, f, ensure_ascii=False, indent=2)
len(vocab)

2579

## 5) Init tokenizer/processor & XLS-R-300M model

In [None]:
feature_extractor = AutoFeatureExtractor.from_pretrained(
    CKPT, return_attention_mask=True)
tokenizer = Wav2Vec2CTCTokenizer(str(
    VOCAB_DIR/'vocab.json'), unk_token='<unk>', pad_token='<pad>', word_delimiter_token='|')
processor = Wav2Vec2Processor(
    feature_extractor=feature_extractor, tokenizer=tokenizer)

model = AutoModelForCTC.from_pretrained(
    CKPT,
    ctc_loss_reduction='mean',
    mask_time_prob=0.05,    # 轻量时间掩码
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),

)
model.freeze_feature_extractor()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.train()
device

preprocessor_config.json:   0%|          | 0.00/212 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.27G [00:00<?, ?B/s]

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-xls-r-300m and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


device(type='cuda')

## 6) Encode → input_values / attention_mask / labels

In [None]:
def prepare_batch(batch):
    audio = batch['audio']
    ins = processor(
        audio['array'], sampling_rate=audio['sampling_rate'], return_attention_mask=True)
    batch['input_values'] = ins['input_values'][0]
    batch['attention_mask'] = ins['attention_mask'][0]
    with processor.as_target_processor():
        batch['labels'] = processor(batch['transcription']).input_ids
    return batch


encoded = minds.map(
    prepare_batch, remove_columns=minds['train'].column_names, num_proc=1)
encoded

Map:   0%|          | 0/19044 [00:00<?, ? examples/s]



model.safetensors:   0%|          | 0.00/1.27G [00:00<?, ?B/s]

Map:   0%|          | 0/4212 [00:00<?, ? examples/s]

Map:   0%|          | 0/4624 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_values', 'attention_mask', 'labels'],
        num_rows: 19044
    })
    validation: Dataset({
        features: ['input_values', 'attention_mask', 'labels'],
        num_rows: 4212
    })
    test: Dataset({
        features: ['input_values', 'attention_mask', 'labels'],
        num_rows: 4624
    })
})

## 7) Collator & Metrics (WER/CER) + Sanity check

In [None]:
from dataclasses import dataclass
from typing import Union


@dataclass
# class DataCollatorCTCWithPadding:
#     processor: Wav2Vec2Processor
#     padding: Union[bool,str]='longest'
#     def __call__(self, features: List[Dict]):
#         inf = [{'input_values': f['input_values']} for f in features]
#         lab = [{'input_ids': f['labels']} for f in features]
#         batch = self.processor.pad(inf, padding=self.padding, return_tensors='pt')
#         with self.processor.as_target_processor():
#             labels_batch = self.processor.pad(lab, padding=self.padding, return_tensors='pt')
#         labels = labels_batch['input_ids'].masked_fill(labels_batch['attention_mask'].ne(1), -100)
#         batch['labels']=labels
#         return batch
class DataCollatorCTCWithPadding:
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = "longest"   # True 等价于 'longest'，写字符串更直观

    def __call__(self, features: List[Dict]):
        # 1) 音频特征
        inputs = [{"input_values": f["input_values"]} for f in features]
        batch = self.processor.pad(
            inputs,
            padding=self.padding,
            return_tensors="pt",
            pad_to_multiple_of=8,
        )
        # 2) 文本 labels
        labels = [{"input_ids": f["labels"]} for f in features]
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                labels,
                padding=self.padding,
                return_tensors="pt",
                pad_to_multiple_of=8,
            )
        # 3) 把 label 的 padding 位改成 -100
        label_mask = labels_batch.get("attention_mask", None)
        label_ids = labels_batch["input_ids"]
        if label_mask is not None:
            label_ids = label_ids.masked_fill(label_mask.ne(1), -100)
        else:
            # 极少数情况下没有 attention_mask，就把 pad_token_id 直接置 -100
            pad_id = self.processor.tokenizer.pad_token_id
            label_ids = label_ids.masked_fill(label_ids.eq(pad_id), -100)

        batch["labels"] = label_ids
        return batch


data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

wer_metric = evaluate.load('wer')
cer_metric = evaluate.load('cer')


def _norm_str_list(xs):
    out = []
    for s in xs:
        s = s.strip()
        # 若是中文任务，通常不强制 lower；英文可视需要加：s = s.lower()
        s = " ".join(s.split())  # 合并多空格
        out.append(s)
    return out


def compute_metrics(pred):
    pred_ids = np.argmax(pred.predictions, axis=-1)

    # 把 -100 还原为 pad_token_id 才能正常 decode
    label_ids = pred.label_ids.copy()
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids,  skip_special_tokens=True)
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    pred_str = _norm_str_list(pred_str)
    label_str = _norm_str_list(label_str)

    return {
        "wer": wer_metric.compute(predictions=pred_str, references=label_str),
        "cer": cer_metric.compute(predictions=pred_str, references=label_str),
    }
# def compute_metrics(pred):
#     pred_ids = np.argmax(pred.predictions, axis=-1)
#     label_ids = pred.label_ids.copy()
#     label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
#     pred_str  = processor.batch_decode(pred_ids,  skip_special_tokens=True)
#     label_str = processor.batch_decode(label_ids, skip_special_tokens=True)
#     return {'wer': wer_metric.compute(predictions=pred_str, references=label_str),
#             'cer': cer_metric.compute(predictions=pred_str, references=label_str)}

# # sanity check
# sample = encoded['validation'].select(range(min(3, len(encoded['validation']))))
# if len(sample):
#     ins = processor.pad({'input_values': sample['input_values']}, padding=True, return_tensors='pt')
#     with torch.no_grad():
#         lg = model(input_values=ins['input_values'].to(device), attention_mask=ins['attention_mask'].to(device)).logits
#     hyp_ids = lg.argmax(dim=-1).cpu().numpy()
#     lbl_ids=[]
#     for seq in sample['labels']:
#         arr=np.array(seq, dtype=np.int64); arr[arr==-100]=processor.tokenizer.pad_token_id; lbl_ids.append(arr.tolist())
#     hyp = processor.batch_decode(hyp_ids, skip_special_tokens=True)
#     ref = processor.batch_decode(lbl_ids, skip_special_tokens=True)
#     for i,(r,h) in enumerate(zip(ref,hyp),1):
#         print(f'[{i}] REF: {r[:80]}')
#         print(f'[{i}] HYP: {h[:80]}')


# sanity check（确保 model 已 .eval() 且在正确 device）
model.eval()

sample = encoded["validation"].select(
    range(min(3, len(encoded["validation"]))))
if len(sample):
    # 正确的打包方式：列表[{"input_values": ...}, ...]
    feats = [{"input_values": iv} for iv in sample["input_values"]]
    ins = processor.pad(feats, padding=True, return_tensors="pt")

    ins = {k: v.to(device) for k, v in ins.items()}  # 送去同一 device
    with torch.no_grad():
        logits = model(input_values=ins["input_values"],
                       attention_mask=ins.get("attention_mask", None)).logits

    hyp_ids = logits.argmax(dim=-1).cpu().numpy()

    # 处理 labels：把 -100 还原为 pad_token_id 再 decode
    lbl_ids = []
    for seq in sample["labels"]:
        arr = np.array(seq, dtype=np.int64)
        arr[arr == -100] = processor.tokenizer.pad_token_id
        lbl_ids.append(arr.tolist())

    hyp = processor.batch_decode(hyp_ids, skip_special_tokens=True)
    ref = processor.batch_decode(lbl_ids, skip_special_tokens=True)

    for i, (r, h) in enumerate(zip(ref, hyp), 1):
        print(f"[{i}] REF: {r[:80]}")
        print(f"[{i}] HYP: {h[:80]}")

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

[1] REF: 哎你好啊 你好啊你可以叫我玛丽
[1] HYP: 冰展胺坎拖坎拖坎拖坎拖坎拖坎拖坎嫌办坎胺坎胺坎拖坎拖坎胺拖坎拖坎冰拖仰拖坎拖冰拖坎胺坎拖坎缩
[2] REF: 好的
[2] HYP: 究择坎胺坎亭胺坎亭择胺坎亭择胺拖坎择胺择坎拖坎胺坎亭坎胺亭坎亭坎亭坎亭坎亭缩辈究冰傲
[3] REF: 那你大概 嗯你说
[3] HYP: 傲胺坎择坎拖坎拖坎亭胺缩坎拖坎拖坎拖坎拖亭坎胺坎拖坎亭坎胺坎胺坎择坎胺坎胺坎胺坎胺拖坎拖坎拖坎亭办胺坎亭坎胺坎择胺坎胺坎胺坎亭究粮究


In [None]:
vocab_size = processor.tokenizer.vocab_size
in_features = model.lm_head.in_features
old_out = model.lm_head.out_features

if old_out != vocab_size:
    print(f"[Fix] Rebuilding lm_head: {old_out} -> {vocab_size}")
    new_head = torch.nn.Linear(in_features, vocab_size, bias=True)
    # 可选：若 old_out > vocab_size，保留前 vocab_size 行的权重做“部分继承”
    with torch.no_grad():
        if old_out >= vocab_size:
            new_head.weight[:vocab_size].copy_(
                model.lm_head.weight[:vocab_size])
            new_head.bias[:vocab_size].copy_(model.lm_head.bias[:vocab_size])
        else:
            # old_out < vocab_size，拷贝已有部分，其余随机初始化即可
            new_head.weight[:old_out].copy_(model.lm_head.weight)
            new_head.bias[:old_out].copy_(model.lm_head.bias)
    model.lm_head = new_head.to(model.device)
    model.config.vocab_size = vocab_size

# 这些最好也设置好（一次即可）
if processor.tokenizer.pad_token is None and "<pad>" in processor.tokenizer.get_vocab():
    processor.tokenizer.pad_token = "<pad>"
if getattr(model.config, "pad_token_id", None) is None:
    model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.ctc_zero_infinity = True

# 再次强断言
head_out = model.lm_head.out_features
assert head_out == processor.tokenizer.vocab_size, f"mismatch: head={head_out}, vocab={processor.tokenizer.vocab_size}"
print(
    f"[OK] vocab_size={processor.tokenizer.vocab_size}, lm_head_out={head_out}")

[Fix] Rebuilding lm_head: 2581 -> 2579
[OK] vocab_size=2579, lm_head_out=2579


In [None]:
if processor.tokenizer.pad_token is None and "<pad>" in processor.tokenizer.get_vocab():
    processor.tokenizer.pad_token = "<pad>"
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.ctc_zero_infinity = True
print("[OK] pad_token_id & ctc_zero_infinity set")

[OK] pad_token_id & ctc_zero_infinity set


In [None]:

TARGET_SR = getattr(processor.feature_extractor, "sampling_rate", 16000)
x = minds["train"][0]["audio"]
assert x["sampling_rate"] == TARGET_SR and x["array"].ndim == 1
print("[OK] audio resampled to 16k mono")

[OK] audio resampled to 16k mono


In [None]:
feats = [{"input_values": minds["validation"][i]["audio"]["array"]}
         for i in range(2)]
bat = processor.pad(feats, padding=True, return_tensors="pt")
assert bat["input_values"].ndim == 2 and "attention_mask" in bat
print("[OK] collator input structure & mask")

[OK] collator input structure & mask


In [None]:
def clean_zh(s):
    return " ".join(s.strip().split())  # 你的规范化策略


with processor.as_target_processor():
    ids = processor([clean_zh(minds["train"][i]["transcription"])
                    for i in range(2)]).input_ids
assert isinstance(ids[0][0], int)
print("[OK] labels from as_target_processor & normalized")

[OK] labels from as_target_processor & normalized


In [None]:
model.eval()
with torch.no_grad():
    lg = model(**{k: v.to(model.device)
               for k, v in bat.items() if k != "labels"}).logits
hyp_ids = lg.argmax(dim=-1).cpu().numpy()
_ = processor.batch_decode(hyp_ids, skip_special_tokens=True)
print("[OK] greedy decode runs")

[OK] greedy decode runs


In [None]:
blank_id = getattr(model.config, "blank_token_id",
                   processor.tokenizer.vocab_size - 1)
br = (hyp_ids[0] == blank_id).mean()
print(f"[OK] blank_ratio≈{br:.2f}")

[OK] blank_ratio≈0.00


In [None]:
import evaluate
import numpy as np
wer = evaluate.load("wer")
cer = evaluate.load("cer")
lbl = np.array([[processor.tokenizer.pad_token_id if x == -
               100 else x for x in y] for y in ids], dtype=object)
pred = processor.batch_decode(hyp_ids, skip_special_tokens=True)
ref = processor.batch_decode(lbl,     skip_special_tokens=True)
_wer, _cer = wer.compute(predictions=pred, references=ref), cer.compute(
    predictions=pred, references=ref)
print(f"[OK] metrics: WER={_wer:.3f}, CER={_cer:.3f}")

[OK] metrics: WER=1.000, CER=10.214


In [None]:
tiny = encoded["train"].select(range(12))
args = TrainingArguments(output_dir="tmp_overfit", max_steps=300,
                         per_device_train_batch_size=3, learning_rate=3e-4,
                         logging_steps=20, save_steps=10_000, fp16=torch.cuda.is_available(), report_to='none',)
trainer = Trainer(model=model, args=args, train_dataset=tiny, data_collator=data_collator,
                  tokenizer=processor.feature_extractor)
trainer.train()
print("[OK] tiny overfit finished; loss should drop")

  trainer = Trainer(model=model, args=args, train_dataset=tiny, data_collator=data_collator,


Step,Training Loss
20,151.2279
40,75.6841
60,34.9433
80,11.2105
100,5.4711
120,4.8224
140,4.7141
160,4.696
180,4.6581
200,4.6401


[OK] tiny overfit finished; loss should drop


## 8) Train

In [None]:
from transformers import TrainingArguments, Trainer

TOTAL_STEPS = 12000      # 或者用 num_train_epochs=7.5（≈ 12000 / 1640）
WARMUP = 800             # 约 6.5% warmup（500~1200 都可试）
# WARMUP = 1200

args = TrainingArguments(
    output_dir=str(DATA_ROOT/'outputs'),
    dataloader_num_workers=4,      # >0 才会有多进程加载
    dataloader_pin_memory=True,    # GPU 传输更快
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,      # 有效 batch = 16
    learning_rate=1e-4,                 # 若抖动大，改 5e-5
    weight_decay=0.01,
    warmup_steps=WARMUP,
    # max_steps=TOTAL_STEPS,              # 或注释掉，用 num_train_epochs=6~8
    num_train_epochs=7.0,
    lr_scheduler_type="linear",         # 显式写出，虽然默认就是 linear
    gradient_checkpointing=True,
    fp16=torch.cuda.is_available(),
    # tf32=True,
    # group_by_length=True,
    group_by_length=False,
    eval_strategy='steps',
    eval_steps=1000,
    save_steps=1000,
    logging_steps=25,
    load_best_model_at_end=True,
    metric_for_best_model='cer',
    greater_is_better=False,
    report_to='none',
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=encoded['train'],
    eval_dataset=encoded['validation'],
    # tokenizer=processor,
    processing_class=processor,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)
trainer.train()



Step,Training Loss,Validation Loss,Wer,Cer
1000,4.5864,4.498214,1.0,0.996347
2000,4.2939,3.894025,1.0,0.992421
3000,3.5688,3.496095,0.999804,0.969878
4000,2.7794,2.394094,0.99915,0.731286
5000,2.1295,1.875605,0.909174,0.4929
6000,1.7799,1.62389,0.839861,0.423194
7000,1.6593,1.473105,0.802197,0.397592
8000,1.5752,1.414354,0.788138,0.382114




TrainOutput(global_step=8337, training_loss=3.044629129128618, metrics={'train_runtime': 7606.7445, 'train_samples_per_second': 17.525, 'train_steps_per_second': 1.096, 'total_flos': 3.316441567386576e+19, 'train_loss': 3.044629129128618, 'epoch': 7.0})

In [None]:
from transformers import Wav2Vec2ForCTC
best_ckpt = trainer.state.best_model_checkpoint
print("best ckpt:", best_ckpt)

model = Wav2Vec2ForCTC.from_pretrained(best_ckpt)   # 仅加载权重
# 重新构建新的 TrainingArguments / Trainer（优化器会重新初始化）

new_args = TrainingArguments(
    output_dir=str(DATA_ROOT/'outputs_v2'),
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,        # 有效 batch 仍为 16
    learning_rate=5e-5,                   # ← 从 1e-4 降一档（后半程更稳）
    weight_decay=0.01,
    # 用 epoch 续训 3~5 个；或改成 max_steps 也可
    num_train_epochs=5.0,
    lr_scheduler_type="cosine",           # 线性 warmup→线性衰减
    warmup_ratio=0.1,                     # ← 用比例更方便（约 10% warmup）
    gradient_checkpointing=True,
    bf16=True,                            # A100 建议用 BF16（比 FP16 稳）
    dataloader_num_workers=4,
    dataloader_pin_memory=True,
    group_by_length=False,                 # 续训阶段可开启，提升吞吐/省 pad
    eval_strategy="steps",          # ← 官方参数名是 eval_strategy
    eval_steps=1000,
    save_steps=1000,
    save_total_limit=2,                   # 只保留“最佳+最近”两个 ckpt
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model='cer',
    greater_is_better=False,
    report_to='none',
)

new_trainer = Trainer(
    model=model,
    args=new_args,
    train_dataset=encoded["train"],
    eval_dataset=encoded["validation"],
    processing_class=processor,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

new_trainer.train()

best ckpt: /content/cs_dialogue/outputs/checkpoint-8000




Step,Training Loss,Validation Loss,Wer,Cer
1000,1.6035,1.381998,0.790689,0.37474
2000,1.4833,1.317766,0.77493,0.35658
3000,1.2319,1.173648,0.702217,0.321002
4000,1.0916,1.133475,0.694108,0.307819
5000,1.0358,1.117216,0.681881,0.298162




TrainOutput(global_step=5955, training_loss=1.3140891110366577, metrics={'train_runtime': 6336.7348, 'train_samples_per_second': 15.027, 'train_steps_per_second': 0.94, 'total_flos': 2.367422348232796e+19, 'train_loss': 1.3140891110366577, 'epoch': 5.0})

In [None]:
# 1) 最终在 test 上做指标评估（会返回 metrics）
test_metrics = trainer.evaluate(eval_dataset=encoded["test"])
print("TEST (evaluate):", test_metrics)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [45]:
!cp -r cs_dialogue/outputs_v2/checkpoint-5000/ '/content/drive/MyDrive/Project2-PRS/models'
