# XLS-R (Wav2Vec2) CTC Fine-tuning on BAAI/CS-Dialogue (Mandarin–English Code-Switching)

**Goal:** Fine-tune `facebook/wav2vec2-xls-r-300m` with a custom CTC vocabulary built from CS-Dialogue.

**Key points**
- Uses `datasets.Audio` + `cast_column(..., Audio(16000))`.
- Builds mixed EN (A–Z) + Chinese vocab; maps space→`|`.
- Correct padding & metrics (WER/CER).

References:
- CS-Dialogue dataset card (16kHz, structure): https://huggingface.co/datasets/BAAI/CS-Dialogue
- XLS-R-300M model card (16kHz input): https://huggingface.co/facebook/wav2vec2-xls-r-300m
- Datasets audio processing: https://huggingface.co/docs/datasets/en/audio_process
- WER / CER metrics: https://huggingface.co/spaces/evaluate-metric/wer , https://huggingface.co/spaces/evaluate-metric/cer


In [None]:
%%bash
pip -q install "transformers==4.57.1" "datasets[audio]==2.21.0" "evaluate==0.4.2" "jiwer==3.0.4" \
                 "huggingface_hub>=0.24.0" soundfile torchaudio --upgrade
python - << 'PY'
import transformers, datasets, evaluate
print('Transformers:', transformers.__version__)
print('Datasets    :', datasets.__version__)
print('Evaluate    :', evaluate.__version__)
PY


In [None]:
import os, re, tarfile, json
from pathlib import Path
from collections import Counter
from typing import Dict, List
import numpy as np, torch
from datasets import Dataset, DatasetDict, Audio, Features, Value
from huggingface_hub import hf_hub_download
from transformers import AutoFeatureExtractor, AutoModelForCTC, TrainingArguments, Trainer, Wav2Vec2Processor, Wav2Vec2CTCTokenizer
import evaluate

SEED=42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)

BASE_DIR = Path('/content') if Path('/content').exists() else Path.cwd()
DATA_ROOT = BASE_DIR / 'cs_dialogue'
AUDIO_DIR = DATA_ROOT / 'short_wav'
INDEX_DIR = DATA_ROOT / 'index' / 'short_wav'
VOCAB_DIR = DATA_ROOT / 'custom_vocab'
for p in [DATA_ROOT, AUDIO_DIR, INDEX_DIR, VOCAB_DIR]: p.mkdir(parents=True, exist_ok=True)

CKPT='facebook/wav2vec2-xls-r-300m'
NUM_SHARDS = int(os.environ.get('CS_NUM_SHARDS', 2))  # increase to 19 for full short_wav


## 1) Download index & audio shards (short_wav)

In [None]:
REPO_ID='BAAI/CS-Dialogue'
def download_index():
    pairs = [('train','data/index/short_wav/train/text'),
             ('train','data/index/short_wav/train/wav.scp'),
             ('dev','data/index/short_wav/dev/text'),
             ('dev','data/index/short_wav/dev/wav.scp'),
             ('test','data/index/short_wav/test/text'),
             ('test','data/index/short_wav/test/wav.scp')]
    for split, rel in pairs:
        p = hf_hub_download(REPO_ID, rel)
        dst = INDEX_DIR/ split / Path(rel).name
        dst.parent.mkdir(parents=True, exist_ok=True)
        os.replace(p, dst)
    print('Index ready at', INDEX_DIR)

def download_shards(n=2):
    paths=[]
    for i in range(n):
        rel=f'data/short_wav/{i:02d}.tar.gz'
        try:
            p=hf_hub_download(REPO_ID, rel)
            dst=DATA_ROOT/Path(rel).name
            os.replace(p, dst)
            paths.append(dst)
            print('Downloaded', dst)
        except Exception as e:
            print('Skip', rel, e)
    return paths

def extract_all(tar_paths: List[Path], out_dir: Path):
    out_dir.mkdir(parents=True, exist_ok=True)
    for tp in tar_paths:
        with tarfile.open(tp, 'r:gz') as tf:
            tf.extractall(out_dir)
        print('Extracted', tp, '->', out_dir)

download_index()
tars=download_shards(NUM_SHARDS)
extract_all(tars, AUDIO_DIR)


## 2) Build DatasetDict from wav.scp & text

In [None]:
def read_kv(fp: Path):
    d={}
    with open(fp, 'r', encoding='utf-8') as f:
        for line in f:
            line=line.strip()
            if not line: continue
            k,v=line.split(' ',1)
            d[k]=v
    return d

def make_split(split: str):
    wavscp = read_kv(INDEX_DIR/split/'wav.scp')
    text   = read_kv(INDEX_DIR/split/'text')
    ids, paths, trans = [], [], []
    for uid, wavpath in wavscp.items():
        # expect .../short_wav/{shard}/{file}.wav
        shard = Path(wavpath).parts[-2]
        fname = Path(wavpath).name
        local = AUDIO_DIR / shard / fname
        if local.exists() and uid in text:
            ids.append(uid); paths.append(str(local)); trans.append(text[uid])
    feats = Features({'id': Value('string'), 'audio': Audio(sampling_rate=16000), 'transcription': Value('string')})
    return Dataset.from_dict({'id': ids, 'audio': paths, 'transcription': trans}, features=feats)

train_ds = make_split('train')
val_ds   = make_split('dev')
test_ds  = make_split('test')
minds = DatasetDict(train=train_ds, validation=val_ds, test=test_ds)
minds = minds.cast_column('audio', Audio(sampling_rate=16000))
minds


## 3) Normalize transcripts (EN upper + Chinese)

In [None]:
CN_RANGE = r"\u4E00-\u9FFF"
def normalize_text(ex):
    t = ex['transcription'].strip().upper()
    t = re.sub(fr"[^{CN_RANGE}A-Z' ]+", "", t)
    t = re.sub(r"\s+", " ", t)
    return {'transcription': t}

minds = minds.map(normalize_text)
minds['train'][0]['transcription'][:120]


## 4) Build CTC vocab (space→`|`)

In [None]:
from collections import Counter
def collect_chars(ds, key='transcription'):
    c=Counter()
    for s in ds[key]: c.update(list(s))
    return c
cnt=Counter()
for sp in ['train','validation','test']:
    cnt.update(collect_chars(minds[sp]))
chars=sorted([ch for ch in cnt if ch!=' '])
vocab={ch:i for i,ch in enumerate(chars)}
vocab['|']=len(vocab); vocab['<unk>']=len(vocab); vocab['<pad>']=len(vocab)
VOCAB_DIR.mkdir(parents=True, exist_ok=True)
with open(VOCAB_DIR/'vocab.json','w',encoding='utf-8') as f:
    json.dump(vocab, f, ensure_ascii=False, indent=2)
len(vocab)


## 5) Init tokenizer/processor & XLS-R-300M model

In [None]:
feature_extractor = AutoFeatureExtractor.from_pretrained(CKPT, return_attention_mask=True)
tokenizer = Wav2Vec2CTCTokenizer(str(VOCAB_DIR/'vocab.json'), unk_token='<unk>', pad_token='<pad>', word_delimiter_token='|')
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

model = AutoModelForCTC.from_pretrained(
    CKPT,
    ctc_loss_reduction='mean',
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer)
)
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device); model.train(); device


## 6) Encode → input_values / attention_mask / labels

In [None]:
def prepare_batch(batch):
    audio = batch['audio']
    ins = processor(audio['array'], sampling_rate=audio['sampling_rate'], return_attention_mask=True)
    batch['input_values']   = ins['input_values'][0]
    batch['attention_mask'] = ins['attention_mask'][0]
    with processor.as_target_processor():
        batch['labels'] = processor(batch['transcription']).input_ids
    return batch
encoded = minds.map(prepare_batch, remove_columns=minds['train'].column_names, num_proc=1)
encoded


## 7) Collator & Metrics (WER/CER) + Sanity check

In [None]:
from dataclasses import dataclass
from typing import Union
@dataclass
class DataCollatorCTCWithPadding:
    processor: Wav2Vec2Processor
    padding: Union[bool,str]='longest'
    def __call__(self, features: List[Dict]):
        inf = [{'input_values': f['input_values']} for f in features]
        lab = [{'input_ids': f['labels']} for f in features]
        batch = self.processor.pad(inf, padding=self.padding, return_tensors='pt')
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(lab, padding=self.padding, return_tensors='pt')
        labels = labels_batch['input_ids'].masked_fill(labels_batch['attention_mask'].ne(1), -100)
        batch['labels']=labels
        return batch
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

wer_metric = evaluate.load('wer')
cer_metric = evaluate.load('cer')
def compute_metrics(pred):
    pred_ids = np.argmax(pred.predictions, axis=-1)
    label_ids = pred.label_ids.copy()
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    pred_str  = processor.batch_decode(pred_ids,  skip_special_tokens=True)
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)
    return {'wer': wer_metric.compute(predictions=pred_str, references=label_str),
            'cer': cer_metric.compute(predictions=pred_str, references=label_str)}

# sanity check
sample = encoded['validation'].select(range(min(3, len(encoded['validation']))))
if len(sample):
    ins = processor.pad({'input_values': sample['input_values']}, padding=True, return_tensors='pt')
    with torch.no_grad():
        lg = model(input_values=ins['input_values'].to(device), attention_mask=ins['attention_mask'].to(device)).logits
    hyp_ids = lg.argmax(dim=-1).cpu().numpy()
    lbl_ids=[]
    for seq in sample['labels']:
        arr=np.array(seq, dtype=np.int64); arr[arr==-100]=processor.tokenizer.pad_token_id; lbl_ids.append(arr.tolist())
    hyp = processor.batch_decode(hyp_ids, skip_special_tokens=True)
    ref = processor.batch_decode(lbl_ids, skip_special_tokens=True)
    for i,(r,h) in enumerate(zip(ref,hyp),1):
        print(f'[{i}] REF: {r[:80]}')
        print(f'[{i}] HYP: {h[:80]}')


## 8) Train

In [None]:
from transformers import TrainingArguments, Trainer
args = TrainingArguments(
    output_dir=str(DATA_ROOT/'outputs'),
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=2,
    learning_rate=1e-4,
    warmup_steps=500,
    max_steps=2000,  # increase to 8000-10000 later
    gradient_checkpointing=True,
    fp16=torch.cuda.is_available(),
    group_by_length=True,
    evaluation_strategy='steps',
    eval_steps=200,
    save_steps=1000,
    logging_steps=25,
    load_best_model_at_end=True,
    metric_for_best_model='cer',
    greater_is_better=False,
    report_to='none',
)
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=encoded['train'],
    eval_dataset=encoded['validation'],
    tokenizer=processor,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)
trainer.train()
