# Whisper (Seq2Seq) LoRA Fine-tuning on BAAI/CS-Dialogue (Mandarin–English Code-Switching)

**Goal:** Fine-tune `openai/whisper-small` with a LoRA adapter for the CS-Dialogue short_wav splits. The notebook downloads the dataset locally, prepares Whisper-compatible inputs, and trains a sequence-to-sequence model with generation-based evaluation (WER/CER).


In [None]:
# installations
!pip -q install "evaluate==0.4.3" "jiwer==3.0.4" "soundfile"
!pip install -U "transformers==4.40.1" "datasets>=2.19.0" "accelerate>=0.29.0" "peft==0.17.1"
!pip install -U "bitsandbytes==0.48.1"
#===============================================================================
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:512"
# os.environ["HF_HUB_DISABLE_XET"] = "1"
# os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

from importlib.metadata import version, PackageNotFoundError
import transformers, datasets, evaluate, tokenizers, huggingface_hub, torch, torchaudio, jiwer, soundfile
import peft
from peft import LoraConfig


def safe_ver(pkg_name):
    try:
        return version(pkg_name)
    except PackageNotFoundError:
        return "(missing)"


print("Transformers :", safe_ver("transformers"))
print("Datasets     :", safe_ver("datasets"))
print("Evaluate     :", safe_ver("evaluate"))
print("JiWER        :", safe_ver("jiwer"))
print("Tokenizers   :", safe_ver("tokenizers"))
print("HF Hub       :", safe_ver("huggingface_hub"))
print("SoundFile    :", safe_ver("soundfile"))
print("Torch        :", torch.__version__)
print("Torchaudio   :", torchaudio.__version__)
print("PEFT         :", safe_ver("peft"))
print("Accelerate   :", safe_ver("accelerate"))
print("Bitsandbytes :", safe_ver("bitsandbytes"))
print("PYTORCH_CUDA_ALLOC_CONF =", os.environ.get("PYTORCH_CUDA_ALLOC_CONF"))

test_cfg = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
print(test_cfg)


In [None]:
!python -m bitsandbytes


In [None]:
import os
import re
import tarfile
import json
from pathlib import Path
from typing import Dict, List

import numpy as np
import torch
from datasets import Dataset, DatasetDict, Audio, Features, Value
from huggingface_hub import hf_hub_download
from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorSpeechSeq2SeqWithPadding,
)
import evaluate

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

BASE_DIR = Path('/content') if Path('/content').exists() else Path.cwd()
DATA_ROOT = BASE_DIR / 'cs_dialogue'
AUDIO_DIR = DATA_ROOT / 'short_wav'
INDEX_DIR = DATA_ROOT / 'data' / 'index' / 'short_wav'

for p in [DATA_ROOT, AUDIO_DIR, INDEX_DIR]:
    p.mkdir(parents=True, exist_ok=True)

# increase to 19 for full short_wav
NUM_SHARDS = int(os.environ.get('CS_NUM_SHARDS', 19))


## 1) Download index & audio shards (short_wav)

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
from pathlib import Path
import os
import shutil
import tarfile
import glob

REPO_ID = "BAAI/CS-Dialogue"        # 数据集仓库
# 你已有：DATA_ROOT = Path('/content/cs_dialogue'); AUDIO_DIR = DATA_ROOT/'short_wav'; INDEX_DIR = DATA_ROOT/'index'/'short_wav'


def download_index():
    files = [
        "data/index/short_wav/train/text",
        "data/index/short_wav/train/wav.scp",
        "data/index/short_wav/dev/text",
        "data/index/short_wav/dev/wav.scp",
        "data/index/short_wav/test/text",
        "data/index/short_wav/test/wav.scp",
    ]
    local = []
    for fp in files:
        # ★ 关键：直接落到 INDEX_DIR/ 相对目录；禁用 symlink
        dst_dir = INDEX_DIR / \
            Path(fp).parent.relative_to("data/index/short_wav")
        dst_dir.mkdir(parents=True, exist_ok=True)
        src = hf_hub_download(
            repo_id=REPO_ID,
            filename=fp,
            repo_type="dataset",
            local_dir=str(DATA_ROOT),   # 关键：交给 local_dir 来还原相对层级
        )
        local.append(Path(src))
    print("Index ready:", local)
    return local


def download_shards(n=19):
    got = []
    for i in range(n):
        rel = f"data/short_wav/short_wav.tar.gz{i:02d}"
        src = hf_hub_download(
            repo_id=REPO_ID,
            filename=rel,
            repo_type="dataset",
            local_dir=str(DATA_ROOT),   # 落到 DATA_ROOT/rel
        )
        dst = Path(src)
        print("Downloaded:", dst)
        assert dst.exists() and dst.stat(
        ).st_size > 0, f"missing or empty part: {dst}"
        got.append(dst)
    # 关键：分片实际在 DATA_ROOT/data/short_wav 下
    print("In data/short_wav:", sorted(p.name for p in (DATA_ROOT /
          'data'/'short_wav').glob('short_wav.tar.gz*')))
    return got


def concat_parts(parts, out_file: Path):
    """把 *.tar.gz00.. 拼接成一个完整 tar.gz"""
    parts = sorted(parts, key=lambda p: p.name)
    if not parts:
        raise RuntimeError("no parts provided to concat_parts")
    # 逐个断言存在
    for p in parts:
        if not Path(p).exists():
            raise FileNotFoundError(f"part not found on disk: {p}")
    out_file.parent.mkdir(parents=True, exist_ok=True)
    total = 0
    with open(out_file, "wb") as w:
        for p in parts:
            with open(p, "rb") as r:
                shutil.copyfileobj(r, w)
            sz = Path(p).stat().st_size
            total += sz
            print(f"  appended {Path(p).name} ({sz/1e6:.1f} MB)")
    print(f"==> concatenated -> {out_file} (~{total/1e9:.2f} GB)")


def extract_concatenated_tar_gz(out_dir: Path, parts=None):
    """
    解压流程：
    1) 如未给 parts，则自动从 DATA_ROOT 扫描 short_wav.tar.gz[0-9][0-9]
    2) 拼成 short_wav.tar.gz
    3) 正常 tar 解压到 out_dir
    """
    out_dir.mkdir(parents=True, exist_ok=True)

    # ★ 自动发现分片（避免用了旧的 tars 变量）
    if parts is None:
        # 同时在 DATA_ROOT 根 & DATA_ROOT/data/short_wav 下搜索
        candidates = sorted((DATA_ROOT).glob("short_wav.tar.gz[0-9][0-9]")) \
            + sorted((DATA_ROOT / "data" /
                     "short_wav").glob("short_wav.tar.gz[0-9][0-9]"))
    else:
        candidates = list(map(Path, parts))

    if not candidates:
        raise RuntimeError(
            f"No split parts found in {DATA_ROOT}. Expected files like short_wav.tar.gz00")

    merged = DATA_ROOT / "short_wav.tar.gz"

    if merged.exists():
        merged.unlink()  # 清除此前拼接失败的半成品
    concat_parts(candidates, merged)

    # 解压
    with tarfile.open(merged, "r:gz") as tf:
        tf.extractall(out_dir, filter="data")  # 3.12+ 推荐
    print("Extracted ->", out_dir)

    # 可选：节省空间
    try:
        merged.unlink()
    except Exception:
        pass


# === 调用顺序 ===
_ = download_index()
_ = download_shards(NUM_SHARDS)                 # e.g. NUM_SHARDS=2

# ★ 完整性检查：必须 19 片全部到位
expected = [f"short_wav.tar.gz{i:02d}" for i in range(19)]
have = sorted(p.name for p in (DATA_ROOT/'data' /
              'short_wav').glob('short_wav.tar.gz*'))
missing = [x for x in expected if x not in have]
assert not missing, f"缺少分片：{missing}。请把 NUM_SHARDS 设为 19 或补齐后再解压。"

extract_concatenated_tar_gz(AUDIO_DIR, None)   # 让它自己扫描分片再拼接

In [None]:
print("INDEX_DIR =", INDEX_DIR)
print("Has train/text?   ->", (INDEX_DIR/'train'/'text').exists())
print("Has train/wav.scp?->", (INDEX_DIR/'train'/'wav.scp').exists())

print("Split parts in DATA_ROOT/data/short_wav:",
      sorted(p.name for p in (DATA_ROOT/'data'/'short_wav').glob('short_wav.tar.gz*'))[:19])

In [None]:
from pathlib import Path


def print_tree(root: Path, max_depth: int = 3):
    root = Path(root)

    def walk(p: Path, depth: int, prefix: str = ""):
        if depth < 0:
            return
        items = sorted(p.iterdir(), key=lambda x: (
            x.is_file(), x.name.lower()))
        for i, it in enumerate(items):
            is_last = (i == len(items) - 1)
            connector = "└── " if is_last else "├── "
            name = it.name + ("/" if it.is_dir() else "")
            print(prefix + connector + name)
            if it.is_dir():
                next_prefix = prefix + ("    " if is_last else "│   ")
                walk(it, depth - 1, next_prefix)

    root = root.resolve()
    print(root.as_posix())
    walk(root, max_depth)


# 使用：打印 DATA_ROOT 下三层
print_tree(DATA_ROOT, max_depth=4)

## 2) Build DatasetDict from wav.scp & text

In [None]:
# def read_kv(fp: Path):
#     d={}
#     with open(fp, 'r', encoding='utf-8') as f:
#         for line in f:
#             line=line.strip()
#             if not line: continue
#             k,v=line.split(' ',1)
#             d[k]=v
#     return d

# def make_split(split: str):
#     wavscp = read_kv(INDEX_DIR/split/'wav.scp')
#     text   = read_kv(INDEX_DIR/split/'text')
#     ids, paths, trans = [], [], []
#     for uid, wavpath in wavscp.items():

#         shard = Path(wavpath).parts[-2]
#         fname = Path(wavpath).name
#         local = AUDIO_DIR / shard / fname
#         if local.exists() and uid in text:
#             ids.append(uid); paths.append(str(local)); trans.append(text[uid])
#     feats = Features({'id': Value('string'), 'audio': Audio(sampling_rate=16000), 'transcription': Value('string')})
#     return Dataset.from_dict({'id': ids, 'audio': paths, 'transcription': trans}, features=feats)
from pathlib import Path


def resolve_local_audio_path(wavpath: str) -> Path | None:
    """
    把 wav.scp 的路径映射到本地真实存在的文件。
    兼容如下情况：
    - 路径前缀带不带 'data/'
    - 是否出现双重 'short_wav/short_wav'
    - 绝对路径/相对路径
    """
    p = Path(wavpath.strip())

    # 1) 若本身就是绝对路径且存在，直接返回
    if p.is_absolute() and p.exists():
        return p

    candidates: list[Path] = []

    # 2) wav.scp 一般包含 ".../short_wav/..." 这段；抽取从第一处 'short_wav' 之后的尾部
    parts = p.parts
    if "short_wav" in parts:
        i = parts.index("short_wav")
        tail = Path(*parts[i+1:])  # 去掉第一个 'short_wav' 及之前的前缀
        # 情况 A：你的磁盘上是 short_wav/short_wav/WAVE/...
        if (AUDIO_DIR / "short_wav").exists():
            # /.../short_wav/short_wav/...
            candidates.append(AUDIO_DIR / "short_wav" / tail)
        # 情况 B：只有一层 short_wav/WAVE/...
        # /.../short_wav/...
        candidates.append(AUDIO_DIR / tail)
    else:
        # 没出现 short_wav 关键词时，尝试几种常见组合
        # /content/cs_dialogue/<wav.scp里的相对路径>
        candidates.append(DATA_ROOT / p)
        # /content/cs_dialogue/data/<...>
        candidates.append(DATA_ROOT / "data" / p)

    # 3) 再补充几种保守候选
    candidates.append(DATA_ROOT / p)
    if str(p).startswith("data/"):
        # /content/cs_dialogue/data/short_wav/...
        candidates.append(DATA_ROOT / str(p))
        # /content/cs_dialogue/short_wav/...
        candidates.append(DATA_ROOT / str(p).replace("data/", "", 1))

    for c in candidates:
        if c.exists():
            return c.resolve()
    return None


def read_kv(fp: Path):
    d = {}
    with open(fp, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            # Kaldi风格：<key><space><value...>
            k, v = line.split(' ', 1)
            d[k] = v
    return d


def make_split(split: str):
    wavscp = read_kv(INDEX_DIR/split/'wav.scp')
    text = read_kv(INDEX_DIR/split/'text')

    ids, paths, trans = [], [], []
    miss_audio, miss_text = 0, 0

    for uid, wavpath in wavscp.items():
        local = resolve_local_audio_path(wavpath)
        if local is None or not local.exists():
            miss_audio += 1
            continue
        if uid not in text:
            miss_text += 1
            continue

        ids.append(uid)
        paths.append(str(local))
        trans.append(text[uid])

    print(f"[{split}] matched {len(ids)} items "
          f"(missing audio: {miss_audio}, missing text: {miss_text})")

    feats = Features({
        'id': Value('string'),
        'audio': Audio(sampling_rate=16000),
        'transcription': Value('string')
    })
    return Dataset.from_dict(
        {'id': ids, 'audio': paths, 'transcription': trans},
        features=feats
    )


train_ds = make_split('train')
val_ds = make_split('dev')
test_ds = make_split('test')
minds = DatasetDict(train=train_ds, validation=val_ds, test=test_ds)
minds = minds.cast_column('audio', Audio(sampling_rate=16000))
minds

In [None]:
MIN_SEC, MAX_SEC = 0.3, 12.0


def _keep_ok(ex):
    sec = ex["audio"]["array"].shape[0] / ex["audio"]["sampling_rate"]
    return (sec >= MIN_SEC) and (sec <= MAX_SEC)


minds = minds.filter(_keep_ok, num_proc=4)

## 3) Normalize transcripts (EN upper + Chinese)

In [None]:
CN_RANGE = r"\u4E00-\u9FFF"


def normalize_text(ex):
    t = ex['transcription'].strip().upper()
    t = re.sub(fr"[^{CN_RANGE}A-Z' ]+", "", t)
    t = re.sub(r"\s+", " ", t)
    return {'transcription': t}


minds = minds.map(normalize_text)
minds['train'][0]['transcription'][:120]

## 4) Load Whisper processor & base model


In [None]:
CKPT = "openai/whisper-small"

processor = WhisperProcessor.from_pretrained(
    CKPT,
    language=None,
    task="transcribe",
)
processor.feature_extractor.return_attention_mask = True

model = WhisperForConditionalGeneration.from_pretrained(CKPT)
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
model.generation_config.forced_decoder_ids = None
model.generation_config.suppress_tokens = []
processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.generation_config.pad_token_id = processor.tokenizer.pad_token_id

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.config.use_cache = False
model.train()
device


## 5) Encode Whisper input features & decoder labels


In [None]:
def prepare_batch(batch):
    audio = batch['audio']
    inputs = processor(
        audio['array'],
        sampling_rate=audio['sampling_rate'],
        return_attention_mask=True,
    )
    batch['input_features'] = inputs['input_features'][0]
    if 'attention_mask' in inputs:
        batch['attention_mask'] = inputs['attention_mask'][0]
    with processor.as_target_processor():
        labels = processor(batch['transcription'])
    batch['labels'] = labels['input_ids']
    return batch


encoded = minds.map(
    prepare_batch,
    remove_columns=minds['train'].column_names,
    num_proc=4,
    desc='Preparing Whisper inputs',
)
encoded


## 6) Collator & metrics (WER/CER) + sanity checks


In [None]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

wer_metric = evaluate.load('wer')
cer_metric = evaluate.load('cer')


def _norm_text(s: str) -> str:
    s = s.strip()
    return ' '.join(s.split())


def compute_metrics(pred):
    pred_ids = pred.predictions
    if isinstance(pred_ids, tuple):
        pred_ids = pred_ids[0]

    label_ids = np.array(pred.label_ids, copy=True)
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    pred_str = [_norm_text(s) for s in pred_str]
    label_str = [_norm_text(s) for s in label_str]

    return {
        'wer': wer_metric.compute(predictions=pred_str, references=label_str),
        'cer': cer_metric.compute(predictions=pred_str, references=label_str),
    }


### Quick sanity checks


In [None]:
TARGET_SR = processor.feature_extractor.sampling_rate
sample_audio = minds['train'][0]['audio']
assert sample_audio['sampling_rate'] == TARGET_SR and sample_audio['array'].ndim == 1
print(f"[OK] audio resampled to {TARGET_SR} Hz mono")

sample_feats = encoded['validation'].select(range(min(2, len(encoded['validation']))))
batch = data_collator(sample_feats)
print('[OK] collator output shapes:', {k: tuple(v.shape) for k, v in batch.items()})

with processor.as_target_processor():
    ids = processor(
        [minds['train'][i]['transcription'] for i in range(min(2, len(minds['train'])))],
        padding=True,
        return_tensors='np',
    ).input_ids
print('[OK] target ids shape:', ids.shape)

model.eval()
with torch.no_grad():
    gen_kwargs = {'max_length': 225}
    attn_mask = batch.get('attention_mask')
    generated_ids = model.generate(
        input_features=batch['input_features'].to(device),
        attention_mask=attn_mask.to(device) if attn_mask is not None else None,
        **gen_kwargs,
    )

label_ids = batch['labels'].cpu().numpy()
label_ids = np.where(label_ids == -100, processor.tokenizer.pad_token_id, label_ids)

pred_str = processor.batch_decode(generated_ids, skip_special_tokens=True)
ref_str = processor.batch_decode(label_ids, skip_special_tokens=True)

for i, (ref, hyp) in enumerate(zip(ref_str, pred_str), 1):
    print(f"[{i}] REF: {ref[:80]}")
    print(f"[{i}] HYP: {hyp[:80]}")


In [None]:
# (cell intentionally left blank after migrating to Whisper)


In [None]:
# (cell intentionally left blank after migrating to Whisper)


In [None]:
# (cell intentionally left blank after migrating to Whisper)


In [None]:
# (cell intentionally left blank after migrating to Whisper)


In [None]:
# (cell intentionally left blank after migrating to Whisper)


In [None]:
# (cell intentionally left blank after migrating to Whisper)


In [None]:
# (cell intentionally left blank after migrating to Whisper)


In [None]:
# (cell intentionally left blank after migrating to Whisper)


In [None]:
# (cell intentionally left blank after migrating to Whisper)


In [None]:
tiny = encoded["train"].select(range(12))
args = TrainingArguments(output_dir="tmp_overfit", max_steps=300,
                         per_device_train_batch_size=3, learning_rate=3e-4,
                         logging_steps=20, save_steps=10_000, fp16=torch.cuda.is_available(), report_to='none',)
trainer = Trainer(model=model, args=args, train_dataset=tiny, data_collator=data_collator,
                  tokenizer=processor.feature_extractor)
trainer.train()
print("[OK] tiny overfit finished; loss should drop")

## 7) LoRA fine-tuning setup


In [None]:
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.1,
    bias='none',
    task_type='SEQ_2_SEQ_LM',
    target_modules=['q_proj', 'k_proj', 'v_proj', 'out_proj'],
)

model.enable_input_require_grads()
model = get_peft_model(model, lora_config)
model.gradient_checkpointing_enable()
model.print_trainable_parameters()


In [None]:
OUTPUT_DIR = (DATA_ROOT / 'outputs_whisper_small_lora').as_posix()
training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=2,
    learning_rate=1e-4,
    warmup_steps=500,
    num_train_epochs=3,
    logging_steps=50,
    evaluation_strategy='steps',
    eval_steps=500,
    save_steps=500,
    save_total_limit=2,
    predict_with_generate=True,
    generation_max_length=225,
    fp16=torch.cuda.is_available(),
    gradient_checkpointing=True,
    label_smoothing_factor=0.1,
    load_best_model_at_end=True,
    metric_for_best_model='wer',
    greater_is_better=False,
    report_to=['none'],
)
training_args


In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=encoded['train'],
    eval_dataset=encoded['validation'],
    data_collator=data_collator,
    tokenizer=processor,
    compute_metrics=compute_metrics,
)
trainer


In [None]:
train_result = trainer.train()
trainer.save_state()
train_result


In [None]:
val_metrics = trainer.evaluate()
print('Validation metrics:', val_metrics)


In [None]:
test_metrics = trainer.evaluate(eval_dataset=encoded['test'])
print('Test metrics:', test_metrics)


In [None]:
adapter_dir = Path(training_args.output_dir) / 'lora_adapter'
adapter_dir.mkdir(parents=True, exist_ok=True)
trainer.model.save_pretrained(adapter_dir)
processor.save_pretrained(Path(training_args.output_dir) / 'processor')
print('Saved LoRA adapter to', adapter_dir)


In [None]:
pred_samples = encoded['validation'].select(range(min(3, len(encoded['validation']))))
preds = trainer.predict(pred_samples, max_length=225)
pred_ids = preds.predictions
if isinstance(pred_ids, tuple):
    pred_ids = pred_ids[0]
label_ids = np.array(preds.label_ids, copy=True)
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

pred_texts = processor.batch_decode(pred_ids, skip_special_tokens=True)
ref_texts = processor.batch_decode(label_ids, skip_special_tokens=True)

for i, (ref, hyp) in enumerate(zip(ref_texts, pred_texts), 1):
    print(f'[{i}] REF: {ref}')
    print(f'[{i}] HYP: {hyp}')


In [None]:
# Optional: merge LoRA adapter into a standalone Whisper checkpoint
from peft import PeftModel

merge_dir = Path(training_args.output_dir) / 'whisper_small_lora_merged'
merge_dir.mkdir(parents=True, exist_ok=True)

base_model = WhisperForConditionalGeneration.from_pretrained(CKPT)
peft_model = PeftModel.from_pretrained(base_model, adapter_dir)
merged = peft_model.merge_and_unload()
merged.save_pretrained(merge_dir)
processor.save_pretrained(merge_dir)
print('Merged model saved to', merge_dir)


In [None]:
# (cell intentionally left blank after migrating to Whisper)


In [None]:
# (cell intentionally left blank after migrating to Whisper)


In [None]:
# (cell intentionally left blank after migrating to Whisper)


In [None]:
# (cell intentionally left blank after migrating to Whisper)


In [None]:
# (cell intentionally left blank after migrating to Whisper)


In [None]:
# (cell intentionally left blank after migrating to Whisper)
