# YouTube 多说话人字幕系统（Kaggle GPU）

> 本 Notebook 为端到端版本，按 13 个阶段组织。
> 每个阶段前有 Markdown 说明，便于逐格调试与定位问题。


## 1. 参数区

- 定义输入参数与输出路径
- 统一缓存与产出到 `/kaggle/working/`
- 读取 `HF_TOKEN`（建议来自 Kaggle Secrets）


In [None]:
print("[Cell 1] 参数区初始化")
from pathlib import Path
import os
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
YOUTUBE_URL = ""
DOWNLOAD_MODE = "single"  # 可选: "single" 或 "all"
PLAYLIST_INDEX = 1
TARGET_REF_AUDIO_PATH = "/kaggle/working/biao.mp3"
TARGET_REF_AUDIO_URL = "https://raw.githubusercontent.com/Hana19951208/youtube-speaker-diarization/master/biao.mp3"
HF_TOKEN = user_secrets.get_secret("HF_TOKEN")
ENABLE_DEMUCS = True
WHISPER_MODEL = "large-v3"
LANGUAGE = None
NUM_SPEAKERS = None
TARGET_SIM_THRESHOLD = 0.3

BASE_WORKDIR = Path("/kaggle/working")
RAW_AUDIO_DIR = BASE_WORKDIR / "raw_audio"
PROCESSED_AUDIO_DIR = BASE_WORKDIR / "processed_audio"
DEMUCS_OUTPUT_DIR = BASE_WORKDIR / "demucs_output"
VOCALS_DIR = BASE_WORKDIR / "vocals_audio"
OUTPUT_DIR = BASE_WORKDIR / "outputs"
MODEL_CACHE_DIR = BASE_WORKDIR / "model_cache"
TARGET_REF_DIR = BASE_WORKDIR / "target_ref_processed"

for p in [
    RAW_AUDIO_DIR,
    PROCESSED_AUDIO_DIR,
    DEMUCS_OUTPUT_DIR,
    VOCALS_DIR,
    OUTPUT_DIR,
    MODEL_CACHE_DIR,
    TARGET_REF_DIR,
]:
    p.mkdir(parents=True, exist_ok=True)

os.environ["HF_HOME"] = str(MODEL_CACHE_DIR / "hf_home")
os.environ["TRANSFORMERS_CACHE"] = str(MODEL_CACHE_DIR / "transformers")
os.environ["TORCH_HOME"] = str(MODEL_CACHE_DIR / "torch")
os.environ["XDG_CACHE_HOME"] = str(MODEL_CACHE_DIR / "xdg")

print("参数初始化完成")
print(f"HF_TOKEN 已设置: {bool(HF_TOKEN)}")
if not HF_TOKEN:
    print("⚠️ HF_TOKEN 为空：pyannote 将回退为单说话人模式")
print(f"TARGET_REF_AUDIO_PATH: {TARGET_REF_AUDIO_PATH}")
print(f"TARGET_REF_AUDIO_URL: {TARGET_REF_AUDIO_URL}")


## 2. GPU 与网络检测

- 检查 CUDA 可用性
- 检查 Kaggle Internet 开关
- 提前提示高风险运行条件


In [None]:
print("[Cell 2] GPU/网络环境检测")
import socket
import platform
import torch

def check_internet(host="huggingface.co", port=443, timeout=5):
    try:
        socket.create_connection((host, port), timeout=timeout)
        return True
    except Exception:
        return False

INTERNET_AVAILABLE = check_internet()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print("=== 环境检测 ===")
print(f"Platform: {platform.platform()}")
print(f"Torch CUDA Available: {torch.cuda.is_available()}")
print(f"DEVICE: {DEVICE}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA (torch): {torch.version.cuda}")

if not INTERNET_AVAILABLE:
    print("⚠️ 检测到 Kaggle Internet 可能未开启；请到 Notebook Settings 打开 Internet")


## 3. 依赖安装（锁版本 + 兼容修复）

- **先强制重装科学计算栈**，修复 `numpy/scipy/sklearn` ABI 混装问题
- 再安装 ASR/diarization 栈
- 优先保留 Kaggle 预置 torch，不强制重装


In [None]:
print("[Cell 3] 依赖安装（含 Python 3.12 兼容修复）")
import sys
import subprocess
import traceback
from pathlib import Path

# 关键防护：如果你在同一会话中反复重装 numpy/scipy/sklearn，容易出现 ABI 混态
preloaded = [m for m in ("numpy", "scipy", "sklearn") if m in sys.modules]
if preloaded:
    raise RuntimeError(
        f"检测到已加载模块 {preloaded}。请先 Restart Session，再从 Cell 1 -> Cell 3 顺序重跑。"
    )


def run_cmd_capture(cmd):
    print("$", " ".join(cmd))
    return subprocess.run(cmd, text=True, capture_output=True)


def pip_install(pkgs, force_reinstall=False, constraints_file=None, no_deps=False):
    cmd = [
        sys.executable,
        "-m",
        "pip",
        "install",
        "--no-cache-dir",
        "--upgrade-strategy",
        "only-if-needed",
    ]
    if force_reinstall:
        cmd.append("--force-reinstall")
    if constraints_file:
        cmd.extend(["-c", str(constraints_file)])
    if no_deps:
        cmd.append("--no-deps")
    cmd.extend(pkgs)
    res = run_cmd_capture(cmd)
    if res.returncode != 0:
        print("stderr tail:", res.stderr[-8000:])
        raise RuntimeError("pip install failed")
    print(res.stdout[-1800:])


def soft_pip_check():
    res = run_cmd_capture([sys.executable, "-m", "pip", "check"])
    if res.returncode != 0:
        print("⚠️ pip check 有冲突（Kaggle/基础镜像预装包常见），仅告警不失败。")
        print((res.stdout or "")[-3000:])
        print((res.stderr or "")[-3000:])
    else:
        print("✅ pip check 通过")


torch_before = None
torchaudio_before = None
try:
    import torch
    torch_before = torch.__version__
except Exception:
    pass
try:
    import torchaudio
    torchaudio_before = torchaudio.__version__
except Exception:
    pass

# 约束文件：锁住关键二进制栈，防止 resolver 升级到不兼容版本
constraints_path = Path('/kaggle/working/constraints_whisperx_py312.txt')
constraints_text = """numpy==2.1.3
scipy==1.14.1
scikit-learn==1.5.2
transformers==4.46.3
tokenizers==0.20.3
accelerate==0.34.2
huggingface-hub==0.36.2
faster-whisper==1.0.0
ctranslate2==4.4.0
"""
constraints_path.write_text(constraints_text, encoding='utf-8')
print(f"写入 constraints: {constraints_path}")

# 先统一科学栈（NumPy 2.1.x，避免 NumPy 2.4.x 链路问题）
numeric_stack = ["numpy==2.1.3", "scipy==1.14.1", "scikit-learn==1.5.2"]
pip_install(numeric_stack, force_reinstall=True)

# 方案 A：在 constraints 下正常解析安装
plan_a = [
    "yt-dlp==2025.2.19",
    "ffmpeg-python==0.2.0",
    "demucs==4.0.1",
    "faster-whisper==1.0.0",
    "ctranslate2==4.4.0",
    "whisperx==3.2.0",
    "pyannote.audio==3.1.1",
    "speechbrain==0.5.16",
    "soundfile==0.12.1",
    "transformers==4.46.3",
    "accelerate==0.34.2",
]

# 方案 B：先装依赖，再 no-deps 装 whisperx
plan_b_deps = [
    "yt-dlp==2025.2.19",
    "ffmpeg-python==0.2.0",
    "demucs==4.0.1",
    "faster-whisper==1.0.0",
    "ctranslate2==4.4.0",
    "pyannote.audio==3.1.1",
    "speechbrain==0.5.16",
    "soundfile==0.12.1",
    "transformers==4.46.3",
    "accelerate==0.34.2",
]

installed = False

try:
    print("\n=== 尝试方案 A（constraints + 正常解析）===")
    pip_install(plan_a, constraints_file=constraints_path)

    import numpy  # noqa: F401
    import scipy  # noqa: F401
    import sklearn  # noqa: F401
    import transformers  # noqa: F401
    import whisperx  # noqa: F401
    import pyannote.audio  # noqa: F401

    print("✅ 方案 A 安装并导入验证成功")
    installed = True
except Exception as e:
    print(f"❌ 方案 A 失败: {e}")
    traceback.print_exc()

if not installed:
    try:
        print("\n=== 尝试方案 B（constraints + whisperx no-deps）===")
        pip_install(plan_b_deps, constraints_file=constraints_path)
        pip_install(["whisperx==3.2.0"], constraints_file=constraints_path, no_deps=True)

        import numpy  # noqa: F401
        import scipy  # noqa: F401
        import sklearn  # noqa: F401
        import transformers  # noqa: F401
        import whisperx  # noqa: F401
        import pyannote.audio  # noqa: F401

        print("✅ 方案 B 安装并导入验证成功")
        installed = True
    except Exception as e:
        print(f"❌ 方案 B 失败: {e}")
        traceback.print_exc()

if not installed:
    raise RuntimeError("所有依赖方案都失败，请检查 Kaggle Internet 与 pip 日志")

soft_pip_check()

try:
    import torch
    print(f"torch(before): {torch_before}, torch(after): {torch.__version__}")
except Exception:
    print("⚠️ torch 导入失败")

try:
    import torchaudio
    print(f"torchaudio(before): {torchaudio_before}, torchaudio(after): {torchaudio.__version__}")
except Exception:
    print("⚠️ torchaudio 导入失败")



## 4. 依赖验证（含 whisperx smoke test）

- 打印关键库版本
- 检查 ffmpeg
- 立即验证 `import whisperx`，尽早失败、尽早定位


In [None]:
print("[Cell 4] 依赖验证与 whisperx smoke test")
import importlib
import importlib.metadata as md
import subprocess

def pkg_ver(name):
    try:
        return md.version(name)
    except Exception:
        return "NOT_FOUND"

for name in [
    "torch", "torchaudio", "numpy", "scipy", "scikit-learn",
    "transformers", "yt-dlp", "demucs", "whisperx", "pyannote.audio", "speechbrain", "soundfile"
]:
    print(f"{name}: {pkg_ver(name)}")

ffmpeg_check = subprocess.run(["ffmpeg", "-version"], text=True, capture_output=True)
if ffmpeg_check.returncode == 0:
    print((ffmpeg_check.stdout.splitlines() or ["ffmpeg ok"])[0])
else:
    raise RuntimeError("ffmpeg 不可用")

# 关键：提前 smoke test，避免在后续功能 cell 才爆栈
try:
    import numpy  # noqa: F401
    import scipy  # noqa: F401
    import sklearn  # noqa: F401
    import transformers  # noqa: F401
    import whisperx  # noqa: F401
    print("✅ whisperx 及依赖导入成功")
except Exception as e:
    print("❌ whisperx 导入失败，错误如下：")
    raise


## 5. 下载模块

- 用 `yt-dlp` 下载 YouTube 音频
- 自动识别单视频 / playlist
- 支持 `single` 指定集数与 `all` 全量模式


In [None]:
print("[Cell 5] 定义下载模块")
import os
import re
import time
import yt_dlp
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple


def sanitize_filename(name: str) -> str:
    name = re.sub(r"[^\w\-.]+", "_", name.strip(), flags=re.UNICODE)
    return name[:120] if len(name) > 120 else name


def detect_playlist(url: str) -> bool:
    probe_opts = {
        "quiet": True,
        "skip_download": True,
        "extract_flat": "in_playlist",
        "noplaylist": False,
    }
    with yt_dlp.YoutubeDL(probe_opts) as ydl:
        info = ydl.extract_info(url, download=False)
    return bool(info.get("_type") == "playlist" or "entries" in info)


def download_youtube_audio(url: str, mode: str = "single", playlist_index: int = 1) -> List[str]:
    if not url.strip():
        raise ValueError("YOUTUBE_URL 为空，请先填写")

    mode = mode.lower().strip()
    if mode not in {"single", "all"}:
        raise ValueError("DOWNLOAD_MODE 只支持 'single' 或 'all'")
    if playlist_index < 1:
        raise ValueError("PLAYLIST_INDEX 必须 >= 1")

    print("=== YouTube 下载开始 ===")
    print(f"URL: {url}")
    print(f"MODE: {mode}, PLAYLIST_INDEX: {playlist_index}")

    is_playlist = detect_playlist(url)
    print(f"检测类型: {'Playlist/合集' if is_playlist else '单视频'}")

    before_files = set(RAW_AUDIO_DIR.glob("*.wav"))
    start_ts = time.time()

    def progress_hook(d):
        status = d.get("status", "")
        if status == "downloading":
            pct = d.get("_percent_str", "").strip()
            spd = d.get("_speed_str", "").strip()
            eta = d.get("_eta_str", "").strip()
            fname = os.path.basename(d.get("filename", "unknown"))
            print(f"\r下载中: {fname} | {pct} | {spd} | ETA {eta}", end="")
        elif status == "finished":
            fname = os.path.basename(d.get("filename", "unknown"))
            print(f"\n✅ 下载完成(待转码): {fname}")

    outtmpl = (
        str(RAW_AUDIO_DIR / "%(playlist_index)s_%(id)s_%(title).120B.%(ext)s")
        if is_playlist
        else str(RAW_AUDIO_DIR / "%(id)s_%(title).120B.%(ext)s")
    )

    ydl_opts = {
        "format": "bestaudio/best",
        "outtmpl": outtmpl,
        "noplaylist": (mode == "single" and not is_playlist),
        "playlist_items": str(playlist_index) if (mode == "single" and is_playlist) else None,
        "ignoreerrors": False,
        "continuedl": True,
        "restrictfilenames": True,
        "progress_hooks": [progress_hook],
        "postprocessors": [{"key": "FFmpegExtractAudio", "preferredcodec": "wav"}],
        "quiet": False,
    }

    with yt_dlp.YoutubeDL(ydl_opts) as ydl:
        ydl.download([url])

    after_files = set(RAW_AUDIO_DIR.glob("*.wav"))
    new_files = sorted(after_files - before_files, key=lambda p: p.stat().st_mtime)

    if not new_files:
        new_files = sorted(
            [p for p in RAW_AUDIO_DIR.glob("*.wav") if p.stat().st_mtime >= start_ts - 3],
            key=lambda p: p.stat().st_mtime,
        )

    if mode == "single":
        if not new_files:
            raise RuntimeError("single 模式未找到下载结果")
        selected = [str(new_files[-1])]
    else:
        selected = [str(p) for p in new_files]

    if not selected:
        raise RuntimeError("未找到下载后的 wav 文件")

    print(f"\n下载完成，共 {len(selected)} 个音频")
    for p in selected[:10]:
        print(" -", p)
    if len(selected) > 10:
        print(f" ...其余 {len(selected) - 10} 个已省略")
    return selected


## 6. 预处理模块

- 使用 ffmpeg 统一音频格式
- 输出 `16kHz + mono + PCM s16le`
- 输出路径固定到 `/kaggle/working/processed_audio/`


In [None]:
print("[Cell 6] 定义预处理模块")
import subprocess

def run_cmd(cmd, check=True):
    res = subprocess.run(cmd, text=True, capture_output=True)
    if check and res.returncode != 0:
        print("命令失败:", " ".join(cmd))
        print("stdout:", res.stdout[-3000:])
        print("stderr:", res.stderr[-3000:])
        raise RuntimeError(f"Command failed: {' '.join(cmd)}")
    return res


def preprocess_audio_ffmpeg(input_audio_path: str, output_dir: Path = PROCESSED_AUDIO_DIR) -> str:
    input_path = Path(input_audio_path)
    if not input_path.exists():
        raise FileNotFoundError(f"输入音频不存在: {input_audio_path}")

    output_dir.mkdir(parents=True, exist_ok=True)
    stem = sanitize_filename(input_path.stem)
    output_path = output_dir / f"{stem}_16k_mono.wav"

    cmd = [
        "ffmpeg", "-y",
        "-i", str(input_path),
        "-ar", "16000",
        "-ac", "1",
        "-c:a", "pcm_s16le",
        str(output_path),
    ]
    run_cmd(cmd, check=True)
    return str(output_path)


## 7. Demucs 模块（可开关）

- `ENABLE_DEMUCS=True` 时执行 `htdemucs --two-stems=vocals`
- 若失败自动 fallback 到原音频
- 固定输出到 `/kaggle/working/vocals_audio/`


In [None]:
print("[Cell 7] 定义 Demucs 模块")
import sys
import shutil

def separate_vocals_demucs(input_audio_path: str, enable_demucs: bool = True) -> str:
    if not enable_demucs:
        print("Demucs 已关闭，跳过")
        return input_audio_path

    input_path = Path(input_audio_path)
    if not input_path.exists():
        raise FileNotFoundError(f"输入音频不存在: {input_audio_path}")

    print(f"=== Demucs 分离: {input_audio_path} ===")
    start_ts = time.time()

    cmd = [
        sys.executable, "-m", "demucs.separate",
        "-n", "htdemucs",
        "--two-stems=vocals",
        "-o", str(DEMUCS_OUTPUT_DIR),
        str(input_path),
    ]

    try:
        run_cmd(cmd, check=True)
        candidates = sorted(DEMUCS_OUTPUT_DIR.glob("**/vocals.wav"), key=lambda p: p.stat().st_mtime, reverse=True)
        if not candidates:
            raise RuntimeError("未找到 vocals.wav")

        selected = None
        for c in candidates:
            if c.stat().st_mtime >= start_ts - 5:
                selected = c
                break
        if selected is None:
            selected = candidates[0]

        out_path = VOCALS_DIR / f"{sanitize_filename(input_path.stem)}_vocals.wav"
        shutil.copy2(selected, out_path)
        print(f"✅ Demucs 成功: {out_path}")
        return str(out_path)
    except Exception as e:
        print(f"⚠️ Demucs 失败，fallback 原音频: {e}")
        return input_audio_path


## 8. WhisperX 转写模块

- GPU 优先，自动降级
- 自动选择 compute_type
- 执行 alignment，失败时回退到未对齐段落
- 输出句子级 segment 结构


In [None]:
print("[Cell 8] 定义 WhisperX 模块")
import gc
import torch
import whisperx
from typing import Any, Dict, Optional, List, Tuple


def _normalize_whisper_segments(segments: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    out = []
    for s in segments:
        text = (s.get("text") or "").strip()
        if not text:
            continue
        start = float(s.get("start", 0.0))
        end = float(s.get("end", start))
        if end < start:
            end = start
        out.append({"start": start, "end": end, "text": text})
    return out


def transcribe_with_whisperx(
    audio_path: str,
    model_name: str = "large-v3",
    language: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], str]:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    batch_size = 16 if device == "cuda" else 4
    compute_candidates = ["float16", "int8"] if device == "cuda" else ["int8"]

    model = None
    used_compute = None
    for ct in compute_candidates:
        try:
            print(f"加载 WhisperX: model={model_name}, device={device}, compute_type={ct}")
            model = whisperx.load_model(
                model_name,
                device=device,
                compute_type=ct,
                download_root=str(MODEL_CACHE_DIR / "whisperx"),
            )
            used_compute = ct
            break
        except Exception as e:
            print(f"加载失败 compute_type={ct}: {e}")

    if model is None:
        raise RuntimeError("WhisperX 模型加载失败")

    print(f"WhisperX 已加载，compute_type={used_compute}")
    result = model.transcribe(audio_path, batch_size=batch_size, language=language)
    detected_lang = result.get("language", language if language else "unknown")
    segments = _normalize_whisper_segments(result.get("segments", []))

    try:
        print("执行 alignment...")
        align_model, metadata = whisperx.load_align_model(
            language_code=detected_lang,
            device=device,
            model_dir=str(MODEL_CACHE_DIR / "whisperx_align"),
        )
        aligned = whisperx.align(
            result["segments"],
            align_model,
            metadata,
            audio_path,
            device,
            return_char_alignments=False,
        )
        aligned_segments = _normalize_whisper_segments(aligned.get("segments", []))
        if aligned_segments:
            segments = aligned_segments
            print("✅ alignment 成功")
        else:
            print("⚠️ alignment 为空，使用原始段落")
    except Exception as e:
        print(f"⚠️ alignment 失败，fallback 原始段落: {e}")

    try:
        del model
    except Exception:
        pass
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return segments, detected_lang


## 9. pyannote 说话人分离模块

- 使用 `pyannote/speaker-diarization-3.1`
- 支持 `NUM_SPEAKERS` 指定
- 若失败，回退为单 speaker


In [None]:
print("[Cell 9] 定义 Diarization 模块")
from typing import Any, Dict, List, Optional

def get_audio_duration_sec(audio_path: str) -> float:
    cmd = [
        "ffprobe", "-v", "error",
        "-show_entries", "format=duration",
        "-of", "default=nokey=1:noprint_wrappers=1",
        audio_path,
    ]
    res = run_cmd(cmd, check=False)
    if res.returncode != 0:
        return 0.0
    try:
        return float(res.stdout.strip())
    except Exception:
        return 0.0


def fallback_single_speaker_turns(audio_path: str) -> List[Dict[str, Any]]:
    dur = get_audio_duration_sec(audio_path)
    return [{"start": 0.0, "end": max(dur, 0.01), "speaker": "SPEAKER_00"}]


def diarize_with_pyannote(
    audio_path: str,
    hf_token: str,
    num_speakers: Optional[int] = None,
) -> List[Dict[str, Any]]:
    if not hf_token.strip():
        print("⚠️ HF_TOKEN 为空，回退单说话人")
        return fallback_single_speaker_turns(audio_path)

    try:
        from pyannote.audio import Pipeline
        import torch as _torch

        pipeline = Pipeline.from_pretrained(
            "pyannote/speaker-diarization-3.1",
            use_auth_token=hf_token,
            cache_dir=str(MODEL_CACHE_DIR / "pyannote"),
        )

        if _torch.cuda.is_available():
            pipeline.to(_torch.device("cuda"))

        kwargs = {}
        if num_speakers is not None:
            kwargs["num_speakers"] = int(num_speakers)

        diarization = pipeline(audio_path, **kwargs)

        turns = []
        for turn, _, speaker in diarization.itertracks(yield_label=True):
            turns.append({
                "start": float(turn.start),
                "end": float(turn.end),
                "speaker": str(speaker),
            })

        turns.sort(key=lambda x: x["start"])
        if not turns:
            print("⚠️ pyannote 返回空，回退单说话人")
            return fallback_single_speaker_turns(audio_path)

        print(f"✅ diarization 完成，turn 数: {len(turns)}")
        return turns

    except Exception as e:
        print(f"⚠️ diarization 失败，回退单说话人: {e}")
        return fallback_single_speaker_turns(audio_path)


## 10. Target 识别模块

- 自动确保参考音频可用（本地不存在则从 GitHub 下载）
- 每个 speaker 累计 30~90 秒采样计算 embedding
- 输出 `target_id / speaker_scores / confidence_flag`


In [None]:
print("[Cell 10] 定义 Target 识别模块")
import torch
import torchaudio
import torch.nn.functional as F
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple


def ensure_target_reference_audio(local_path: str, source_url: str) -> str:
    lp = Path(local_path)
    if lp.exists():
        return str(lp)
    if not source_url.strip():
        return str(lp)

    print(f"参考音频不存在，尝试下载: {source_url}")
    lp.parent.mkdir(parents=True, exist_ok=True)
    cmd = ["curl", "-L", "--fail", source_url, "-o", str(lp)]
    res = run_cmd(cmd, check=False)
    if res.returncode != 0:
        print("参考音频下载失败 stderr:")
        print(res.stderr[-2000:])
        return str(lp)

    print(f"✅ 参考音频已下载: {lp}")
    return str(lp)


def load_wave_slice(audio_path: str, start_sec: float, end_sec: float, target_sr: int = 16000):
    if end_sec <= start_sec:
        return None

    info = torchaudio.info(audio_path)
    src_sr = info.sample_rate

    frame_offset = max(0, int(start_sec * src_sr))
    num_frames = max(1, int((end_sec - start_sec) * src_sr))

    wav, sr = torchaudio.load(audio_path, frame_offset=frame_offset, num_frames=num_frames)
    if wav.numel() == 0:
        return None

    if wav.shape[0] > 1:
        wav = wav.mean(dim=0, keepdim=True)

    if sr != target_sr:
        wav = torchaudio.functional.resample(wav, sr, target_sr)

    return wav


def compute_speaker_embedding(classifier, waveform: torch.Tensor, device: str = "cpu") -> torch.Tensor:
    if waveform.dim() == 1:
        waveform = waveform.unsqueeze(0)
    elif waveform.dim() == 2 and waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)

    with torch.no_grad():
        emb = classifier.encode_batch(waveform.to(device)).squeeze().detach().cpu()
    return emb


def identify_target_speaker(
    audio_path: str,
    turns: List[Dict[str, Any]],
    target_ref_audio_path: str,
    threshold: float = 0.3,
) -> Tuple[Optional[str], Dict[str, float], bool]:
    speaker_scores: Dict[str, float] = {}

    if not target_ref_audio_path.strip():
        print("未提供 TARGET_REF_AUDIO_PATH，跳过 target 识别")
        return None, speaker_scores, False

    ref_path = Path(target_ref_audio_path)
    if not ref_path.exists():
        print(f"⚠️ 参考音频不存在: {target_ref_audio_path}")
        return None, speaker_scores, False

    if not turns:
        print("⚠️ turns 为空，跳过 target 识别")
        return None, speaker_scores, False

    try:
        from speechbrain.inference.speaker import EncoderClassifier
    except Exception as e:
        print(f"⚠️ speechbrain 导入失败: {e}")
        return None, speaker_scores, False

    device = "cuda" if torch.cuda.is_available() else "cpu"

    ref_processed = preprocess_audio_ffmpeg(str(ref_path), TARGET_REF_DIR)
    ref_wav, ref_sr = torchaudio.load(ref_processed)
    if ref_wav.shape[0] > 1:
        ref_wav = ref_wav.mean(dim=0, keepdim=True)
    if ref_sr != 16000:
        ref_wav = torchaudio.functional.resample(ref_wav, ref_sr, 16000)

    ref_sec = ref_wav.shape[1] / 16000.0
    if ref_sec < 3.0:
        print(f"⚠️ 参考音频过短: {ref_sec:.2f}s（建议 >= 3s）")

    classifier = EncoderClassifier.from_hparams(
        source="speechbrain/spkrec-ecapa-voxceleb",
        savedir=str(MODEL_CACHE_DIR / "speechbrain"),
        run_opts={"device": device},
    )

    ref_emb = compute_speaker_embedding(classifier, ref_wav, device=device)

    speaker_turns = defaultdict(list)
    for t in sorted(turns, key=lambda x: x["start"]):
        speaker_turns[t["speaker"]].append(t)

    for spk, tlist in speaker_turns.items():
        chunks = []
        total_sec = 0.0
        for t in tlist:
            if total_sec >= 90.0:
                break
            s, e = float(t["start"]), float(t["end"])
            if e <= s:
                continue

            remain = 90.0 - total_sec
            e = min(e, s + remain)

            w = load_wave_slice(audio_path, s, e, target_sr=16000)
            if w is None or w.numel() == 0:
                continue

            chunks.append(w)
            total_sec += w.shape[1] / 16000.0

        print(f"[采样时长] {spk}: {total_sec:.2f}s")
        if total_sec <= 0.2:
            continue
        if total_sec < 30.0:
            print(f"⚠️ {spk} 采样不足 30s，结果可能不稳定")

        spk_wav = torch.cat(chunks, dim=1)
        spk_emb = compute_speaker_embedding(classifier, spk_wav, device=device)
        sim = float(F.cosine_similarity(ref_emb.unsqueeze(0), spk_emb.unsqueeze(0)).item())
        speaker_scores[spk] = sim
        print(f"[相似度] {spk}: {sim:.4f}")

    if not speaker_scores:
        print("⚠️ 未得到有效 speaker embedding")
        return None, speaker_scores, False

    target_id = max(speaker_scores, key=speaker_scores.get)
    best_score = speaker_scores[target_id]
    confidence_flag = best_score >= threshold

    print(f"[最终选择] target_id={target_id}, score={best_score:.4f}, threshold={threshold}")
    if not confidence_flag:
        print("⚠️ 低于阈值，标记为不确定")

    return target_id, speaker_scores, confidence_flag


## 11. 对齐与输出模块

- segment 与 diarization turn 按重叠时间对齐
- speaker 映射为 `speaker1/speaker2...`
- target speaker 显示为 `TARGET`
- 输出 `.srt` 与 `.json`


In [None]:
print("[Cell 11] 定义对齐与输出模块")
import json
from collections import OrderedDict


def overlap_duration(a_start: float, a_end: float, b_start: float, b_end: float) -> float:
    return max(0.0, min(a_end, b_end) - max(a_start, b_start))


def pick_speaker_for_segment(seg, turns):
    s0, s1 = float(seg["start"]), float(seg["end"])
    best_spk = None
    best_ov = -1.0

    for t in turns:
        ov = overlap_duration(s0, s1, float(t["start"]), float(t["end"]))
        if ov > best_ov:
            best_ov = ov
            best_spk = t["speaker"]

    if best_spk is not None and best_ov > 0:
        return best_spk

    seg_mid = (s0 + s1) / 2.0
    nearest_spk = turns[0]["speaker"] if turns else "SPEAKER_00"
    min_dist = float("inf")
    for t in turns:
        mid = (float(t["start"]) + float(t["end"])) / 2.0
        d = abs(mid - seg_mid)
        if d < min_dist:
            min_dist = d
            nearest_spk = t["speaker"]
    return nearest_spk


def align_segments_with_diarization(segments, turns, target_id, speaker_scores):
    turns = sorted(turns, key=lambda x: x["start"]) if turns else [{"start": 0.0, "end": 1e9, "speaker": "SPEAKER_00"}]

    alias_map = OrderedDict()
    next_id = 1
    aligned = []

    for seg in sorted(segments, key=lambda x: x["start"]):
        raw_spk = pick_speaker_for_segment(seg, turns)

        if raw_spk not in alias_map:
            alias_map[raw_spk] = f"speaker{next_id}"
            next_id += 1

        is_target = (target_id is not None and raw_spk == target_id)
        display_spk = "TARGET" if is_target else alias_map[raw_spk]
        sim_score = float(speaker_scores[raw_spk]) if raw_spk in speaker_scores else None

        aligned.append({
            "start": float(seg["start"]),
            "end": float(seg["end"]),
            "text": (seg["text"] or "").strip(),
            "speaker": display_spk,
            "is_target": bool(is_target),
            "similarity_score": sim_score,
        })

    return aligned


def sec_to_srt_time(sec: float) -> str:
    ms = int(round(sec * 1000))
    h = ms // 3600000
    ms %= 3600000
    m = ms // 60000
    ms %= 60000
    s = ms // 1000
    ms %= 1000
    return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"


def save_srt_json(aligned_segments, audio_path):
    base = sanitize_filename(Path(audio_path).stem)
    srt_path = OUTPUT_DIR / f"output_{base}.srt"
    json_path = OUTPUT_DIR / f"output_{base}.json"

    with open(srt_path, "w", encoding="utf-8") as f:
        for i, seg in enumerate(aligned_segments, 1):
            f.write(
                f"{i}\n"
                f"{sec_to_srt_time(seg['start'])} --> {sec_to_srt_time(seg['end'])}\n"
                f"{seg['speaker']}: {seg['text']}\n\n"
            )

    json_data = [{
        "start": seg["start"],
        "end": seg["end"],
        "text": seg["text"],
        "speaker": seg["speaker"],
        "is_target": seg["is_target"],
        "similarity_score": seg["similarity_score"],
    } for seg in aligned_segments]

    with open(json_path, "w", encoding="utf-8") as f:
        json.dump(json_data, f, ensure_ascii=False, indent=2)

    print(f"✅ 输出: {srt_path}")
    print(f"✅ 输出: {json_path}")
    return str(srt_path), str(json_path)


## 12. 主流程运行

- 逐视频顺序处理（支持 `single/all`）
- 每阶段计时、捕获异常、不中断整体批处理
- 汇总每个视频的处理结果


In [None]:
print("[Cell 12] 执行主流程")
import traceback
import gc
from collections import Counter


def process_one_audio(raw_audio_path: str, idx: int, total: int):
    print("\n" + "=" * 80)
    print(f"[{idx}/{total}] 开始处理: {raw_audio_path}")
    print("=" * 80)

    result = {
        "input_audio": raw_audio_path,
        "status": "success",
        "error": None,
        "stage_time_sec": {},
        "srt_path": None,
        "json_path": None,
        "target_id": None,
        "target_confidence": False,
        "speaker_scores": {},
        "speaker_distribution": {},
    }

    total_start = time.time()

    try:
        t0 = time.time()
        processed_audio = preprocess_audio_ffmpeg(raw_audio_path, PROCESSED_AUDIO_DIR)
        result["stage_time_sec"]["preprocess"] = round(time.time() - t0, 3)

        t0 = time.time()
        work_audio = separate_vocals_demucs(processed_audio, enable_demucs=ENABLE_DEMUCS)
        result["stage_time_sec"]["demucs"] = round(time.time() - t0, 3)

        t0 = time.time()
        segments, detected_lang = transcribe_with_whisperx(
            work_audio,
            model_name=WHISPER_MODEL,
            language=LANGUAGE,
        )
        result["stage_time_sec"]["whisperx"] = round(time.time() - t0, 3)
        result["detected_language"] = detected_lang
        print(f"[whisperx] segments={len(segments)}, language={detected_lang}")

        t0 = time.time()
        turns = diarize_with_pyannote(
            work_audio,
            hf_token=HF_TOKEN,
            num_speakers=NUM_SPEAKERS,
        )
        result["stage_time_sec"]["diarization"] = round(time.time() - t0, 3)
        print(f"[diarization] turns={len(turns)}")

        t0 = time.time()
        target_id, speaker_scores, confidence_flag = identify_target_speaker(
            work_audio,
            turns,
            TARGET_REF_AUDIO_PATH,
            threshold=TARGET_SIM_THRESHOLD,
        )
        result["stage_time_sec"]["target_identification"] = round(time.time() - t0, 3)
        result["target_id"] = target_id
        result["target_confidence"] = bool(confidence_flag)
        result["speaker_scores"] = speaker_scores

        t0 = time.time()
        aligned_segments = align_segments_with_diarization(
            segments,
            turns,
            target_id=target_id if confidence_flag else None,
            speaker_scores=speaker_scores,
        )
        srt_path, json_path = save_srt_json(aligned_segments, raw_audio_path)
        result["stage_time_sec"]["align_and_export"] = round(time.time() - t0, 3)
        result["srt_path"] = srt_path
        result["json_path"] = json_path

        dist = Counter([x["speaker"] for x in aligned_segments])
        result["speaker_distribution"] = dict(dist)

    except Exception as e:
        result["status"] = "failed"
        result["error"] = str(e)
        traceback.print_exc()

    result["stage_time_sec"]["total"] = round(time.time() - total_start, 3)
    print(f"[完成] status={result['status']}, total={result['stage_time_sec']['total']}s")

    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return result


def run_pipeline():
    if not YOUTUBE_URL.strip():
        raise ValueError("YOUTUBE_URL 不能为空")

    if not INTERNET_AVAILABLE:
        print("⚠️ 可能无网络，YouTube 与模型下载可能失败。请开启 Kaggle Internet")

    if not HF_TOKEN.strip():
        print("⚠️ HF_TOKEN 为空，Diarization 将回退到单 speaker")

    globals()["TARGET_REF_AUDIO_PATH"] = ensure_target_reference_audio(
        TARGET_REF_AUDIO_PATH,
        TARGET_REF_AUDIO_URL,
    )

    start_all = time.time()
    wav_paths = download_youtube_audio(
        YOUTUBE_URL,
        mode=DOWNLOAD_MODE,
        playlist_index=PLAYLIST_INDEX,
    )

    results = []
    total = len(wav_paths)
    for i, wav_path in enumerate(wav_paths, 1):
        results.append(process_one_audio(wav_path, i, total))

    print("\n" + "#" * 80)
    print(f"全部处理完成，视频数={len(results)}，总耗时={round(time.time() - start_all, 2)}s")
    print("#" * 80)
    return results


RUN_RESULTS = run_pipeline()


## 13. 结果统计

- 汇总每个视频耗时
- 统计 speaker 分布
- 打印 target 相似度结果


In [None]:
print("[Cell 13] 输出统计结果")
import pandas as pd
from collections import Counter
from pathlib import Path
from IPython.display import display

if "RUN_RESULTS" not in globals() or not RUN_RESULTS:
    print("暂无结果，请先运行主流程")
else:
    rows = []
    agg_speaker = Counter()

    for r in RUN_RESULTS:
        agg_speaker.update(r.get("speaker_distribution", {}))

        best_score = None
        if r.get("speaker_scores"):
            best_score = max(r["speaker_scores"].values())

        rows.append({
            "input_audio": r.get("input_audio"),
            "status": r.get("status"),
            "total_sec": r.get("stage_time_sec", {}).get("total"),
            "preprocess_sec": r.get("stage_time_sec", {}).get("preprocess"),
            "demucs_sec": r.get("stage_time_sec", {}).get("demucs"),
            "whisperx_sec": r.get("stage_time_sec", {}).get("whisperx"),
            "diarization_sec": r.get("stage_time_sec", {}).get("diarization"),
            "target_identification_sec": r.get("stage_time_sec", {}).get("target_identification"),
            "align_export_sec": r.get("stage_time_sec", {}).get("align_and_export"),
            "target_id": r.get("target_id"),
            "target_confidence": r.get("target_confidence"),
            "best_similarity_score": best_score,
            "srt_path": r.get("srt_path"),
            "json_path": r.get("json_path"),
            "error": r.get("error"),
        })

    df = pd.DataFrame(rows)
    display(df)

    print("\n=== Speaker 分布（按字幕段数）===")
    if agg_speaker:
        for k, v in agg_speaker.items():
            print(f"{k}: {v}")
    else:
        print("无统计数据")

    print("\n=== Target 分数 ===")
    for r in RUN_RESULTS:
        name = Path(r.get("input_audio", "")).name
        scores = r.get("speaker_scores", {})
        if scores:
            items = ", ".join([f"{k}={v:.4f}" for k, v in sorted(scores.items(), key=lambda x: x[1], reverse=True)])
            print(f"{name}: {items}")
        else:
            print(f"{name}: 无 target score")
