# XLS-R (Wav2Vec2) CTC Fine-tuning on BAAI/CS-Dialogue (Mandarin–English Code-Switching)

**Goal:** Fine-tune `facebook/wav2vec2-xls-r-300m` with a custom CTC vocabulary built from CS-Dialogue.

**Key points**
- Uses `datasets.Audio` + `cast_column(..., Audio(16000))`.
- Builds mixed EN (A–Z) + Chinese vocab; maps space→`|`.
- Correct padding & metrics (WER/CER).

References:
- CS-Dialogue dataset card (16kHz, structure): https://huggingface.co/datasets/BAAI/CS-Dialogue
- XLS-R-300M model card (16kHz input): https://huggingface.co/facebook/wav2vec2-xls-r-300m
- Datasets audio processing: https://huggingface.co/docs/datasets/en/audio_process
- WER / CER metrics: https://huggingface.co/spaces/evaluate-metric/wer , https://huggingface.co/spaces/evaluate-metric/cer


In [1]:
!pip list

Package                                  Version
---------------------------------------- --------------------
absl-py                                  1.4.0
absolufy-imports                         0.3.1
accelerate                               1.10.1
aiofiles                                 24.1.0
aiohappyeyeballs                         2.6.1
aiohttp                                  3.13.0
aiosignal                                1.4.0
alabaster                                1.0.0
albucore                                 0.0.24
albumentations                           2.0.8
ale-py                                   0.11.2
alembic                                  1.17.0
altair                                   5.5.0
annotated-types                          0.7.0
antlr4-python3-runtime                   4.9.3
anyio                                    4.11.0
anywidget                                0.9.18
argon2-cffi                              25.1.0
argon2-cffi-bindings              

In [5]:

# !pip -q install  "datasets[audio]" "evaluate==0.4.3" "jiwer==3.0.4"
# !pip -q install   "evaluate==0.4.3" "jiwer==3.0.4"

# import transformers, datasets, evaluate, jiwer, tokenizers, soundfile, torch, torchaudio, huggingface_hub
# print("Transformers  :", transformers.__version__)
# print("Datasets      :", datasets.__version__)
# print("Evaluate      :", evaluate.__version__)
# print("JiWER         :", jiwer.__version__)
# print("Tokenizers    :", tokenizers.__version__)
# print("HF Hub        :", huggingface_hub.__version__)
# print("Torch         :", torch.__version__)
# print("Torchaudio    :", torchaudio.__version__)

from importlib.metadata import version, PackageNotFoundError
import transformers, datasets, evaluate, tokenizers, huggingface_hub, torch, torchaudio, jiwer

def safe_ver(pkg_name):
    try:
        return version(pkg_name)
    except PackageNotFoundError:
        return "(missing)"

print("Transformers  :", safe_ver("transformers"))
print("Datasets      :", safe_ver("datasets"))
print("Evaluate      :", safe_ver("evaluate"))
print("JiWER         :", safe_ver("jiwer"))          # ✅ 改用 metadata 查询
print("Tokenizers    :", safe_ver("tokenizers"))
print("HF Hub        :", safe_ver("huggingface_hub"))
print("Torch         :", torch.__version__)
print("Torchaudio    :", torchaudio.__version__)



Transformers  : 4.57.1
Datasets      : 4.0.0
Evaluate      : 0.4.3
JiWER         : 3.0.4
Tokenizers    : 0.22.1
HF Hub        : 0.35.3
Torch         : 2.8.0+cu126
Torchaudio    : 2.8.0+cu126


In [15]:
import os, re, tarfile, json
from pathlib import Path
from collections import Counter
from typing import Dict, List
import numpy as np, torch
from datasets import Dataset, DatasetDict, Audio, Features, Value
from huggingface_hub import hf_hub_download
from transformers import AutoFeatureExtractor, AutoModelForCTC, TrainingArguments, Trainer, Wav2Vec2Processor, Wav2Vec2CTCTokenizer
import evaluate

SEED=42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)

BASE_DIR = Path('/content') if Path('/content').exists() else Path.cwd()
DATA_ROOT = BASE_DIR / 'cs_dialogue'
AUDIO_DIR = DATA_ROOT / 'short_wav'
INDEX_DIR = DATA_ROOT / 'index' / 'short_wav'
VOCAB_DIR = DATA_ROOT / 'custom_vocab'
for p in [DATA_ROOT, AUDIO_DIR, INDEX_DIR, VOCAB_DIR]: p.mkdir(parents=True, exist_ok=True)

CKPT='facebook/wav2vec2-xls-r-300m'
NUM_SHARDS = int(os.environ.get('CS_NUM_SHARDS', 2))  # increase to 19 for full short_wav


## 1) Download index & audio shards (short_wav)

In [16]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [18]:
from huggingface_hub import hf_hub_download
from pathlib import Path
import os, shutil, tarfile, glob

REPO_ID = "BAAI/CS-Dialogue"        # 数据集仓库
# 你已有：DATA_ROOT = Path('/content/cs_dialogue'); AUDIO_DIR = DATA_ROOT/'short_wav'; INDEX_DIR = DATA_ROOT/'index'/'short_wav'

def download_index():
    files = [
        "data/index/short_wav/train/text",
        "data/index/short_wav/train/wav.scp",
        "data/index/short_wav/dev/text",
        "data/index/short_wav/dev/wav.scp",
        "data/index/short_wav/test/text",
        "data/index/short_wav/test/wav.scp",
    ]
    local = []
    for fp in files:
        # ★ 关键：直接落到 INDEX_DIR/ 相对目录；禁用 symlink
        dst_dir = INDEX_DIR / Path(fp).parent.relative_to("data/index/short_wav")
        dst_dir.mkdir(parents=True, exist_ok=True)
        src = hf_hub_download(
            repo_id=REPO_ID,
            filename=fp,
            repo_type="dataset",
            local_dir=str(dst_dir),
            local_dir_use_symlinks=False,   # ★ 避免产生符号链接
        )
        local.append(Path(src))
    print("Index ready:", local)
    return local


def download_shards(n=2):
    got = []
    for i in range(n):
        rel = f"data/short_wav/short_wav.tar.gz{i:02d}"
        # ★ 关键：直接落到 DATA_ROOT；禁用 symlink
        src = hf_hub_download(
            repo_id=REPO_ID,
            filename=rel,
            repo_type="dataset",
            local_dir=str(DATA_ROOT),
            local_dir_use_symlinks=False,   # ★ 不要 symlink
        )
        dst = Path(src)
        print("Downloaded:", dst)
        assert dst.exists() and dst.stat().st_size > 0, f"missing or empty part: {dst}"
        got.append(dst)
    print("Now in DATA_ROOT:", sorted(p.name for p in DATA_ROOT.glob("short_wav.tar.gz*")))
    return got


def concat_parts(parts, out_file: Path):
    """把 *.tar.gz00.. 拼接成一个完整 tar.gz"""
    parts = sorted(parts, key=lambda p: p.name)
    if not parts:
        raise RuntimeError("no parts provided to concat_parts")
    # 逐个断言存在
    for p in parts:
        if not Path(p).exists():
            raise FileNotFoundError(f"part not found on disk: {p}")
    out_file.parent.mkdir(parents=True, exist_ok=True)
    total = 0
    with open(out_file, "wb") as w:
        for p in parts:
            with open(p, "rb") as r:
                shutil.copyfileobj(r, w)
            sz = Path(p).stat().st_size
            total += sz
            print(f"  appended {Path(p).name} ({sz/1e6:.1f} MB)")
    print(f"==> concatenated -> {out_file} (~{total/1e9:.2f} GB)")

def extract_concatenated_tar_gz(out_dir: Path, parts=None):
    """
    解压流程：
    1) 如未给 parts，则自动从 DATA_ROOT 扫描 short_wav.tar.gz[0-9][0-9]
    2) 拼成 short_wav.tar.gz
    3) 正常 tar 解压到 out_dir
    """
    out_dir.mkdir(parents=True, exist_ok=True)

    # ★ 自动发现分片（避免用了旧的 tars 变量）
    if parts is None:
        candidates = sorted(DATA_ROOT.glob("short_wav.tar.gz[0-9][0-9]"))
    else:
        candidates = list(map(Path, parts))

    if not candidates:
        raise RuntimeError(f"No split parts found in {DATA_ROOT}. Expected files like short_wav.tar.gz00")

    merged = DATA_ROOT / "short_wav.tar.gz"
    concat_parts(candidates, merged)

    # 解压
    with tarfile.open(merged, "r:gz") as tf:
        tf.extractall(out_dir)
    print("Extracted ->", out_dir)

    # 可选：节省空间
    try:
        merged.unlink()
    except Exception:
        pass

# === 调用顺序 ===
_ = download_index()
_ = download_shards(NUM_SHARDS)                 # e.g. NUM_SHARDS=2
extract_concatenated_tar_gz(AUDIO_DIR, None)   # 让它自己扫描分片再拼接


For more details, check out https://huggingface.co/docs/huggingface_hub/main/en/guides/download#download-files-to-local-folder.


text: 0.00B [00:00, ?B/s]

wav.scp: 0.00B [00:00, ?B/s]

text: 0.00B [00:00, ?B/s]

wav.scp: 0.00B [00:00, ?B/s]

text: 0.00B [00:00, ?B/s]

wav.scp: 0.00B [00:00, ?B/s]

Index ready: [PosixPath('/content/cs_dialogue/index/short_wav/train/data/index/short_wav/train/text'), PosixPath('/content/cs_dialogue/index/short_wav/train/data/index/short_wav/train/wav.scp'), PosixPath('/content/cs_dialogue/index/short_wav/dev/data/index/short_wav/dev/text'), PosixPath('/content/cs_dialogue/index/short_wav/dev/data/index/short_wav/dev/wav.scp'), PosixPath('/content/cs_dialogue/index/short_wav/test/data/index/short_wav/test/text'), PosixPath('/content/cs_dialogue/index/short_wav/test/data/index/short_wav/test/wav.scp')]


data/short_wav/short_wav.tar.gz00:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz00


data/short_wav/short_wav.tar.gz01:   0%|          | 0.00/524M [00:00<?, ?B/s]

Downloaded: /content/cs_dialogue/data/short_wav/short_wav.tar.gz01
Now in DATA_ROOT: ['short_wav.tar.gz', 'short_wav.tar.gz00', 'short_wav.tar.gz01']


FileNotFoundError: part not found on disk: /content/cs_dialogue/short_wav.tar.gz00

In [19]:
print("INDEX_DIR tree:", list((INDEX_DIR/'train').glob('*')))
print("DATA_ROOT parts:", sorted(p.name for p in DATA_ROOT.glob('short_wav.tar.gz*')))

INDEX_DIR tree: [PosixPath('/content/cs_dialogue/index/short_wav/train/.cache'), PosixPath('/content/cs_dialogue/index/short_wav/train/data'), PosixPath('/content/cs_dialogue/index/short_wav/train/wav.scp'), PosixPath('/content/cs_dialogue/index/short_wav/train/text')]
DATA_ROOT parts: ['short_wav.tar.gz', 'short_wav.tar.gz01']


## 2) Build DatasetDict from wav.scp & text

In [None]:
def read_kv(fp: Path):
    d={}
    with open(fp, 'r', encoding='utf-8') as f:
        for line in f:
            line=line.strip()
            if not line: continue
            k,v=line.split(' ',1)
            d[k]=v
    return d

def make_split(split: str):
    wavscp = read_kv(INDEX_DIR/split/'wav.scp')
    text   = read_kv(INDEX_DIR/split/'text')
    ids, paths, trans = [], [], []
    for uid, wavpath in wavscp.items():
        # expect .../short_wav/{shard}/{file}.wav
        shard = Path(wavpath).parts[-2]
        fname = Path(wavpath).name
        local = AUDIO_DIR / shard / fname
        if local.exists() and uid in text:
            ids.append(uid); paths.append(str(local)); trans.append(text[uid])
    feats = Features({'id': Value('string'), 'audio': Audio(sampling_rate=16000), 'transcription': Value('string')})
    return Dataset.from_dict({'id': ids, 'audio': paths, 'transcription': trans}, features=feats)

train_ds = make_split('train')
val_ds   = make_split('dev')
test_ds  = make_split('test')
minds = DatasetDict(train=train_ds, validation=val_ds, test=test_ds)
minds = minds.cast_column('audio', Audio(sampling_rate=16000))
minds


## 3) Normalize transcripts (EN upper + Chinese)

In [None]:
CN_RANGE = r"\u4E00-\u9FFF"
def normalize_text(ex):
    t = ex['transcription'].strip().upper()
    t = re.sub(fr"[^{CN_RANGE}A-Z' ]+", "", t)
    t = re.sub(r"\s+", " ", t)
    return {'transcription': t}

minds = minds.map(normalize_text)
minds['train'][0]['transcription'][:120]


## 4) Build CTC vocab (space→`|`)

In [None]:
from collections import Counter
def collect_chars(ds, key='transcription'):
    c=Counter()
    for s in ds[key]: c.update(list(s))
    return c
cnt=Counter()
for sp in ['train','validation','test']:
    cnt.update(collect_chars(minds[sp]))
chars=sorted([ch for ch in cnt if ch!=' '])
vocab={ch:i for i,ch in enumerate(chars)}
vocab['|']=len(vocab); vocab['<unk>']=len(vocab); vocab['<pad>']=len(vocab)
VOCAB_DIR.mkdir(parents=True, exist_ok=True)
with open(VOCAB_DIR/'vocab.json','w',encoding='utf-8') as f:
    json.dump(vocab, f, ensure_ascii=False, indent=2)
len(vocab)


## 5) Init tokenizer/processor & XLS-R-300M model

In [None]:
feature_extractor = AutoFeatureExtractor.from_pretrained(CKPT, return_attention_mask=True)
tokenizer = Wav2Vec2CTCTokenizer(str(VOCAB_DIR/'vocab.json'), unk_token='<unk>', pad_token='<pad>', word_delimiter_token='|')
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

model = AutoModelForCTC.from_pretrained(
    CKPT,
    ctc_loss_reduction='mean',
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer)
)
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device); model.train(); device


## 6) Encode → input_values / attention_mask / labels

In [None]:
def prepare_batch(batch):
    audio = batch['audio']
    ins = processor(audio['array'], sampling_rate=audio['sampling_rate'], return_attention_mask=True)
    batch['input_values']   = ins['input_values'][0]
    batch['attention_mask'] = ins['attention_mask'][0]
    with processor.as_target_processor():
        batch['labels'] = processor(batch['transcription']).input_ids
    return batch
encoded = minds.map(prepare_batch, remove_columns=minds['train'].column_names, num_proc=1)
encoded


## 7) Collator & Metrics (WER/CER) + Sanity check

In [None]:
from dataclasses import dataclass
from typing import Union
@dataclass
class DataCollatorCTCWithPadding:
    processor: Wav2Vec2Processor
    padding: Union[bool,str]='longest'
    def __call__(self, features: List[Dict]):
        inf = [{'input_values': f['input_values']} for f in features]
        lab = [{'input_ids': f['labels']} for f in features]
        batch = self.processor.pad(inf, padding=self.padding, return_tensors='pt')
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(lab, padding=self.padding, return_tensors='pt')
        labels = labels_batch['input_ids'].masked_fill(labels_batch['attention_mask'].ne(1), -100)
        batch['labels']=labels
        return batch
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

wer_metric = evaluate.load('wer')
cer_metric = evaluate.load('cer')
def compute_metrics(pred):
    pred_ids = np.argmax(pred.predictions, axis=-1)
    label_ids = pred.label_ids.copy()
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    pred_str  = processor.batch_decode(pred_ids,  skip_special_tokens=True)
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)
    return {'wer': wer_metric.compute(predictions=pred_str, references=label_str),
            'cer': cer_metric.compute(predictions=pred_str, references=label_str)}

# sanity check
sample = encoded['validation'].select(range(min(3, len(encoded['validation']))))
if len(sample):
    ins = processor.pad({'input_values': sample['input_values']}, padding=True, return_tensors='pt')
    with torch.no_grad():
        lg = model(input_values=ins['input_values'].to(device), attention_mask=ins['attention_mask'].to(device)).logits
    hyp_ids = lg.argmax(dim=-1).cpu().numpy()
    lbl_ids=[]
    for seq in sample['labels']:
        arr=np.array(seq, dtype=np.int64); arr[arr==-100]=processor.tokenizer.pad_token_id; lbl_ids.append(arr.tolist())
    hyp = processor.batch_decode(hyp_ids, skip_special_tokens=True)
    ref = processor.batch_decode(lbl_ids, skip_special_tokens=True)
    for i,(r,h) in enumerate(zip(ref,hyp),1):
        print(f'[{i}] REF: {r[:80]}')
        print(f'[{i}] HYP: {h[:80]}')


## 8) Train

In [None]:
from transformers import TrainingArguments, Trainer
args = TrainingArguments(
    output_dir=str(DATA_ROOT/'outputs'),
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=2,
    learning_rate=1e-4,
    warmup_steps=500,
    max_steps=2000,  # increase to 8000-10000 later
    gradient_checkpointing=True,
    fp16=torch.cuda.is_available(),
    group_by_length=True,
    evaluation_strategy='steps',
    eval_steps=200,
    save_steps=1000,
    logging_steps=25,
    load_best_model_at_end=True,
    metric_for_best_model='cer',
    greater_is_better=False,
    report_to='none',
)
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=encoded['train'],
    eval_dataset=encoded['validation'],
    tokenizer=processor,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)
trainer.train()
