From ce9812dd83c2b9a13064371fe3de8609746547e1 Mon Sep 17 00:00:00 2001 From: Maple Gao Date: Wed, 29 Apr 2026 11:12:31 +0800 Subject: [PATCH 1/5] fix: normalize faster-whisper cuda devices --- app/pipeline/orchestrator.py | 14 ++- doc/changelog.en.md | 7 ++ doc/changelog.zh.md | 6 ++ tests/unit/test_pipeline_model_lifecycle.py | 97 ++++++++++++++++----- 4 files changed, 101 insertions(+), 23 deletions(-) diff --git a/app/pipeline/orchestrator.py b/app/pipeline/orchestrator.py index 532b75d..355268d 100644 --- a/app/pipeline/orchestrator.py +++ b/app/pipeline/orchestrator.py @@ -85,6 +85,18 @@ def _load_trusted_pyannote_model( return from_pretrained(model_ref, **auth_kwargs) +def _faster_whisper_device_kwargs(device: str) -> dict[str, Any]: + """Translate torch-style CUDA device strings to faster-whisper kwargs.""" + + if not device.startswith("cuda:"): + return {"device": device} + + device_kind, _, raw_index = device.partition(":") + if raw_index.isdigit(): + return {"device": device_kind, "device_index": int(raw_index)} + return {"device": device} + + class TranscriptionPipeline: def __init__( self, @@ -152,7 +164,7 @@ def whisper(self): ) self._whisper = WhisperModel( model_ref, - device=self.device, + **_faster_whisper_device_kwargs(self.device), compute_type=compute_type, ) return self._whisper diff --git a/doc/changelog.en.md b/doc/changelog.en.md index 44016f6..812548e 100644 --- a/doc/changelog.en.md +++ b/doc/changelog.en.md @@ -8,6 +8,13 @@ _No changes yet._ ## 0.7.5 — Idle GPU model unload and CI quality gates (2026-04-29) +### Bug Fixes + +- Fixed faster-whisper CUDA device arguments: internal torch-facing device + state can still use `cuda:0` / `cuda:1`, but faster-whisper loads now receive + `device="cuda"` with the matching `device_index`, avoiding + `unsupported device cuda:0`. + ### Features - Added optional `MODEL_IDLE_TIMEOUT_SEC` support. The default `0` keeps the diff --git a/doc/changelog.zh.md b/doc/changelog.zh.md index c9a655d..2ad9523 100644 --- a/doc/changelog.zh.md +++ b/doc/changelog.zh.md @@ -8,6 +8,12 @@ _暂无变更。_ ## 0.7.5 — GPU 模型空闲卸载与 CI 质量门禁 (2026-04-29) +### Bug 修复 + +- 修复 faster-whisper CUDA 设备传参:内部仍可用 `cuda:0` / `cuda:1` + 表示 torch 设备,但加载 faster-whisper 时会转换为 `device="cuda"` 与 + 对应 `device_index`,避免 `unsupported device cuda:0`。 + ### 功能 - 新增可选 `MODEL_IDLE_TIMEOUT_SEC`。默认 `0` 保持当前模型常驻行为;设为正数后, diff --git a/tests/unit/test_pipeline_model_lifecycle.py b/tests/unit/test_pipeline_model_lifecycle.py index 79c8928..2aa82a7 100644 --- a/tests/unit/test_pipeline_model_lifecycle.py +++ b/tests/unit/test_pipeline_model_lifecycle.py @@ -27,6 +27,17 @@ def _new_pipeline(*, device="cuda"): return pipeline +def _install_fake_faster_whisper(monkeypatch, loaded_models): + class FakeWhisperModel: + def __init__(self, model_ref, **kwargs): + loaded_models.append((model_ref, kwargs)) + + faster_whisper = ModuleType("faster_whisper") + faster_whisper.WhisperModel = FakeWhisperModel + monkeypatch.setitem(sys.modules, "faster_whisper", faster_whisper) + return FakeWhisperModel + + def test_unload_models_drops_loaded_references_without_selecting_device(monkeypatch): pipeline = _new_pipeline(device="cuda") pipeline._whisper = object() @@ -54,16 +65,9 @@ def test_unload_models_drops_loaded_references_without_selecting_device(monkeypa def test_whisper_lazy_reload_selects_best_cuda_device(monkeypatch): pipeline = _new_pipeline(device="cuda") calls = [] - loaded_devices = [] - - class FakeWhisperModel: - def __init__(self, model_ref, *, device, compute_type): - loaded_devices.append((model_ref, device, compute_type)) - - faster_whisper = ModuleType("faster_whisper") - faster_whisper.WhisperModel = FakeWhisperModel + loaded_models = [] + fake_model = _install_fake_faster_whisper(monkeypatch, loaded_models) - monkeypatch.setitem(__import__("sys").modules, "faster_whisper", faster_whisper) monkeypatch.setattr(orchestrator.Path, "exists", lambda self: False) monkeypatch.setattr( orchestrator, @@ -71,32 +75,81 @@ def __init__(self, model_ref, *, device, compute_type): lambda configured: calls.append(configured) or "cuda:1", ) - assert pipeline.whisper.__class__ is FakeWhisperModel + assert pipeline.whisper.__class__ is fake_model assert calls == ["cuda"] - assert loaded_devices == [("tiny", "cuda:1", "float16")] + assert loaded_models == [ + ("tiny", {"device": "cuda", "device_index": 1, "compute_type": "float16"}) + ] assert pipeline.device == "cuda:1" def test_cpu_lazy_load_does_not_probe_cuda(monkeypatch): pipeline = _new_pipeline(device="cpu") - loaded_devices = [] - - class FakeWhisperModel: - def __init__(self, model_ref, *, device, compute_type): - loaded_devices.append((device, compute_type)) - - faster_whisper = ModuleType("faster_whisper") - faster_whisper.WhisperModel = FakeWhisperModel + loaded_models = [] + fake_model = _install_fake_faster_whisper(monkeypatch, loaded_models) def fail_if_called(configured): raise AssertionError("CPU-only loads must not probe CUDA") - monkeypatch.setitem(__import__("sys").modules, "faster_whisper", faster_whisper) monkeypatch.setattr(orchestrator.Path, "exists", lambda self: False) monkeypatch.setattr(orchestrator, "select_best_cuda_device", fail_if_called) - assert pipeline.whisper.__class__ is FakeWhisperModel + assert pipeline.whisper.__class__ is fake_model - assert loaded_devices == [("cpu", "int8")] + assert loaded_models == [("tiny", {"device": "cpu", "compute_type": "int8"})] assert pipeline.device == "cpu" + + +def test_whisper_lazy_load_keeps_unindexed_cuda_supported(monkeypatch): + pipeline = _new_pipeline(device="cuda") + loaded_models = [] + fake_model = _install_fake_faster_whisper(monkeypatch, loaded_models) + + monkeypatch.setattr(orchestrator.Path, "exists", lambda self: False) + monkeypatch.setattr( + orchestrator, "select_best_cuda_device", lambda configured: configured + ) + + assert pipeline.whisper.__class__ is fake_model + + assert loaded_models == [("tiny", {"device": "cuda", "compute_type": "float16"})] + assert pipeline.device == "cuda" + + +def test_whisper_lazy_load_normalizes_cuda_zero_for_faster_whisper(monkeypatch): + pipeline = _new_pipeline(device="cuda:0") + loaded_models = [] + fake_model = _install_fake_faster_whisper(monkeypatch, loaded_models) + + monkeypatch.setattr(orchestrator.Path, "exists", lambda self: False) + monkeypatch.setattr( + orchestrator, "select_best_cuda_device", lambda configured: configured + ) + + assert pipeline.whisper.__class__ is fake_model + + assert loaded_models == [ + ("tiny", {"device": "cuda", "device_index": 0, "compute_type": "float16"}) + ] + assert pipeline.device == "cuda:0" + + +def test_whisper_lazy_load_normalizes_fallback_cuda_index_for_faster_whisper( + monkeypatch, +): + pipeline = _new_pipeline(device="cuda:1") + loaded_models = [] + fake_model = _install_fake_faster_whisper(monkeypatch, loaded_models) + + monkeypatch.setattr(orchestrator.Path, "exists", lambda self: False) + monkeypatch.setattr( + orchestrator, "select_best_cuda_device", lambda configured: configured + ) + + assert pipeline.whisper.__class__ is fake_model + + assert loaded_models == [ + ("tiny", {"device": "cuda", "device_index": 1, "compute_type": "float16"}) + ] + assert pipeline.device == "cuda:1" From 861ea4fafdd9247dd7c59242e94ac3842eaaf760 Mon Sep 17 00:00:00 2001 From: Maple Gao Date: Wed, 29 Apr 2026 12:03:36 +0800 Subject: [PATCH 2/5] fix: load pyannote snapshots from local files --- app/pipeline/orchestrator.py | 25 ++++ doc/changelog.en.md | 5 + doc/changelog.zh.md | 3 + tests/unit/test_huggingface_models.py | 161 ++++++++++++++++++++++++-- 4 files changed, 185 insertions(+), 9 deletions(-) diff --git a/app/pipeline/orchestrator.py b/app/pipeline/orchestrator.py index 355268d..111e882 100644 --- a/app/pipeline/orchestrator.py +++ b/app/pipeline/orchestrator.py @@ -85,6 +85,29 @@ def _load_trusted_pyannote_model( return from_pretrained(model_ref, **auth_kwargs) +def _is_local_model_ref(model_ref: str | Path) -> bool: + if isinstance(model_ref, Path): + return True + path = Path(model_ref).expanduser() + return path.is_absolute() or model_ref.startswith((".", "~")) + + +def _resolve_local_pyannote_file(model_ref: str | Path, snapshot_filename: str) -> str: + """Convert a local HF snapshot directory into pyannote's expected file path.""" + + if not _is_local_model_ref(model_ref): + return str(model_ref) + + local_path = Path(model_ref).expanduser() + if local_path.is_dir(): + local_path = local_path / snapshot_filename + + if not local_path.is_file(): + raise FileNotFoundError(f"Local pyannote model file not found: {local_path}") + + return str(local_path) + + def _faster_whisper_device_kwargs(device: str) -> dict[str, Any]: """Translate torch-style CUDA device strings to faster-whisper kwargs.""" @@ -180,6 +203,7 @@ def diarization(self): token=self.hf_token, purpose="pyannote diarization", ) + model_ref = _resolve_local_pyannote_file(model_ref, "config.yaml") logger.info("Loading pyannote diarization model") self._diarization = _load_trusted_pyannote_model( PyannotePipeline.from_pretrained, @@ -216,6 +240,7 @@ def embedding_model(self): token=self.hf_token, purpose="WeSpeaker speaker encoder", ) + model_ref = _resolve_local_pyannote_file(model_ref, "pytorch_model.bin") logger.info("Loading WeSpeaker speaker encoder") model = _load_trusted_pyannote_model( Model.from_pretrained, diff --git a/doc/changelog.en.md b/doc/changelog.en.md index 812548e..94a27ce 100644 --- a/doc/changelog.en.md +++ b/doc/changelog.en.md @@ -14,6 +14,11 @@ _No changes yet._ state can still use `cuda:0` / `cuda:1`, but faster-whisper loads now receive `device="cuda"` with the matching `device_index`, avoiding `unsupported device cuda:0`. +- Fixed pyannote local-cache loading: when a complete Hugging Face snapshot is + already present, diarization and speaker-embedding models now receive local + file paths that pyannote can load directly. Missing caches still fall back to + Hub repo ids, and missing local files fail explicitly instead of being + swallowed. ### Features diff --git a/doc/changelog.zh.md b/doc/changelog.zh.md index 2ad9523..462a750 100644 --- a/doc/changelog.zh.md +++ b/doc/changelog.zh.md @@ -13,6 +13,9 @@ _暂无变更。_ - 修复 faster-whisper CUDA 设备传参:内部仍可用 `cuda:0` / `cuda:1` 表示 torch 设备,但加载 faster-whisper 时会转换为 `device="cuda"` 与 对应 `device_index`,避免 `unsupported device cuda:0`。 +- 修复 pyannote 本地缓存加载:当已有完整 Hugging Face snapshot 时,说话人分离和 + 声纹模型会传入 pyannote 可直接接受的本地文件路径;缓存缺失仍回退到 Hub + repo id,缺失本地文件会明确失败而不是被吞掉。 ### 功能 diff --git a/tests/unit/test_huggingface_models.py b/tests/unit/test_huggingface_models.py index 0c536c0..c7c31ab 100644 --- a/tests/unit/test_huggingface_models.py +++ b/tests/unit/test_huggingface_models.py @@ -4,8 +4,11 @@ import sys import os +from pathlib import Path from types import ModuleType +import pytest + def _stub_numpy(monkeypatch) -> None: numpy_stub = ModuleType("numpy") @@ -95,13 +98,16 @@ def test_diarization_loader_uses_cache_resolved_model_reference( from pipeline import TranscriptionPipeline import pipeline.orchestrator as orchestrator - cached_snapshot = str(tmp_path / "diarization-snapshot") + cached_snapshot = tmp_path / "diarization-snapshot" + cached_snapshot.mkdir() + config_yml = cached_snapshot / "config.yaml" + config_yml.write_text("pipeline:\n name: fake.Pipeline\n", encoding="utf-8") calls = [] monkeypatch.setattr( orchestrator, "resolve_hf_model_ref", - lambda repo_id, *, token, purpose: cached_snapshot, + lambda repo_id, *, token, purpose: str(cached_snapshot), ) class FakeLoadedPipeline: @@ -110,6 +116,8 @@ class FakeLoadedPipeline: class FakePyannotePipeline: @classmethod def from_pretrained(cls, model_ref, use_auth_token=None): + if Path(model_ref).is_dir(): + raise ValueError("HFValidationError: repo id cannot be a directory") calls.append((model_ref, use_auth_token)) return FakeLoadedPipeline() @@ -126,7 +134,83 @@ def from_pretrained(cls, model_ref, use_auth_token=None): pipeline._diarization = None assert pipeline.diarization.__class__ is FakeLoadedPipeline - assert calls == [(cached_snapshot, "test-token")] + assert calls == [(str(config_yml), "test-token")] + + +def test_diarization_loader_keeps_hub_repo_id_reference(monkeypatch): + _stub_numpy(monkeypatch) + from pipeline import TranscriptionPipeline + import pipeline.orchestrator as orchestrator + + repo_id = "pyannote/speaker-diarization-3.1" + calls = [] + + monkeypatch.setattr( + orchestrator, + "resolve_hf_model_ref", + lambda repo_id, *, token, purpose: repo_id, + ) + + class FakeLoadedPipeline: + pass + + class FakePyannotePipeline: + @classmethod + def from_pretrained(cls, model_ref, use_auth_token=None): + calls.append((model_ref, use_auth_token)) + return FakeLoadedPipeline() + + monkeypatch.setattr( + sys.modules["pyannote.audio"], + "Pipeline", + FakePyannotePipeline, + raising=False, + ) + + pipeline = TranscriptionPipeline.__new__(TranscriptionPipeline) + pipeline.device = "cpu" + pipeline.hf_token = "test-token" + pipeline._diarization = None + + assert pipeline.diarization.__class__ is FakeLoadedPipeline + assert calls == [(repo_id, "test-token")] + + +def test_diarization_loader_rejects_missing_local_snapshot(monkeypatch, tmp_path): + _stub_numpy(monkeypatch) + from pipeline import TranscriptionPipeline + import pipeline.orchestrator as orchestrator + + missing_snapshot = tmp_path / "missing-diarization-snapshot" + calls = [] + + monkeypatch.setattr( + orchestrator, + "resolve_hf_model_ref", + lambda repo_id, *, token, purpose: str(missing_snapshot), + ) + + class FakePyannotePipeline: + @classmethod + def from_pretrained(cls, model_ref, use_auth_token=None): + calls.append((model_ref, use_auth_token)) + return object() + + monkeypatch.setattr( + sys.modules["pyannote.audio"], + "Pipeline", + FakePyannotePipeline, + raising=False, + ) + + pipeline = TranscriptionPipeline.__new__(TranscriptionPipeline) + pipeline.device = "cpu" + pipeline.hf_token = "test-token" + pipeline._diarization = None + + with pytest.raises(FileNotFoundError): + _ = pipeline.diarization + assert calls == [] def test_diarization_loader_scopes_torch26_safe_globals(monkeypatch, tmp_path): @@ -134,7 +218,10 @@ def test_diarization_loader_scopes_torch26_safe_globals(monkeypatch, tmp_path): from pipeline import TranscriptionPipeline import pipeline.orchestrator as orchestrator - cached_snapshot = str(tmp_path / "diarization-snapshot") + cached_snapshot = tmp_path / "diarization-snapshot" + cached_snapshot.mkdir() + config_yml = cached_snapshot / "config.yaml" + config_yml.write_text("pipeline:\n name: fake.Pipeline\n", encoding="utf-8") events = [] class TorchVersion: @@ -180,7 +267,7 @@ class Resolution: monkeypatch.setattr( orchestrator, "resolve_hf_model_ref", - lambda repo_id, *, token, purpose: cached_snapshot, + lambda repo_id, *, token, purpose: str(cached_snapshot), ) class FakeSafeGlobals: @@ -225,7 +312,7 @@ def from_pretrained(cls, model_ref, use_auth_token=None): assert events == [ ("globals", (TorchVersion, Problem, Specifications, Resolution)), ("enter",), - ("load", cached_snapshot, "test-token"), + ("load", str(config_yml), "test-token"), ("exit", None), ] @@ -258,13 +345,69 @@ def test_embedding_loader_uses_cache_resolved_model_reference(monkeypatch, tmp_p from pipeline import TranscriptionPipeline import pipeline.orchestrator as orchestrator - cached_snapshot = str(tmp_path / "embedding-snapshot") + cached_snapshot = tmp_path / "embedding-snapshot" + cached_snapshot.mkdir() + weights_file = cached_snapshot / "pytorch_model.bin" + weights_file.write_bytes(b"fake weights") + calls = [] + + monkeypatch.setattr( + orchestrator, + "resolve_hf_model_ref", + lambda repo_id, *, token, purpose: str(cached_snapshot), + ) + + class FakeModel: + @classmethod + def from_pretrained(cls, model_ref, use_auth_token=None): + if Path(model_ref).is_dir(): + raise ValueError("HFValidationError: repo id cannot be a directory") + calls.append(("from_pretrained", model_ref, use_auth_token)) + return cls() + + def to(self, device): + calls.append(("to", device)) + return self + + class FakeInference: + def __init__(self, model, window): + calls.append(("inference", model.__class__.__name__, window)) + + monkeypatch.setattr( + sys.modules["pyannote.audio"], "Model", FakeModel, raising=False + ) + monkeypatch.setattr( + sys.modules["pyannote.audio"], + "Inference", + FakeInference, + raising=False, + ) + + pipeline = TranscriptionPipeline.__new__(TranscriptionPipeline) + pipeline.device = "cpu" + pipeline.hf_token = "test-token" + pipeline._embedding_model = None + + assert pipeline.embedding_model.__class__ is FakeInference + assert calls == [ + ("from_pretrained", str(weights_file), "test-token"), + ("to", "cpu"), + ("inference", "FakeModel", "whole"), + ] + + +def test_embedding_loader_keeps_hub_repo_id_reference(monkeypatch): + _stub_numpy(monkeypatch) + from pipeline import TranscriptionPipeline + import pipeline.orchestrator as orchestrator + + repo_id = "pyannote/wespeaker-voxceleb-resnet34-LM" calls = [] monkeypatch.setattr( orchestrator, "resolve_hf_model_ref", - lambda repo_id, *, token, purpose: cached_snapshot, + lambda repo_id, *, token, purpose: repo_id, ) class FakeModel: @@ -298,7 +441,7 @@ def __init__(self, model, window): assert pipeline.embedding_model.__class__ is FakeInference assert calls == [ - ("from_pretrained", cached_snapshot, "test-token"), + ("from_pretrained", repo_id, "test-token"), ("to", "cpu"), ("inference", "FakeModel", "whole"), ] From 4c5e6887f2357b52bc95c967a5706e694052db51 Mon Sep 17 00:00:00 2001 From: Maple Gao Date: Wed, 29 Apr 2026 12:07:13 +0800 Subject: [PATCH 3/5] ci: run FOSSA baseline on main --- .github/workflows/fossa.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/fossa.yml b/.github/workflows/fossa.yml index fe1438f..98f33fb 100644 --- a/.github/workflows/fossa.yml +++ b/.github/workflows/fossa.yml @@ -1,6 +1,8 @@ name: FOSSA on: + push: + branches: [main] pull_request: branches: [main] workflow_dispatch: From 2061e9c2b2aef149d131570375447f411e1a7dff Mon Sep 17 00:00:00 2001 From: Maple Gao Date: Wed, 29 Apr 2026 12:09:16 +0800 Subject: [PATCH 4/5] ci: scan before FOSSA diff test --- .github/workflows/fossa.yml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/fossa.yml b/.github/workflows/fossa.yml index 98f33fb..792ea97 100644 --- a/.github/workflows/fossa.yml +++ b/.github/workflows/fossa.yml @@ -27,5 +27,10 @@ jobs: uses: fossas/fossa-action@v1.7.0 with: api-key: ${{ secrets.FOSSA_API_KEY }} - run-tests: ${{ github.event_name == 'pull_request' }} + - name: Run FOSSA diff test + if: ${{ env.FOSSA_API_KEY != '' && github.event_name == 'pull_request' }} + uses: fossas/fossa-action@v1.7.0 + with: + api-key: ${{ secrets.FOSSA_API_KEY }} + run-tests: true test-diff-revision: ${{ github.event_name == 'pull_request' && github.event.pull_request.base.sha || '' }} From ba06f44409c3398a30655d857c0ee38e96443f5a Mon Sep 17 00:00:00 2001 From: Maple Gao Date: Wed, 29 Apr 2026 12:37:02 +0800 Subject: [PATCH 5/5] fix: localize nested pyannote model refs --- app/pipeline/orchestrator.py | 138 +++++++++++++++++++ doc/changelog.en.md | 10 +- doc/changelog.zh.md | 6 +- tests/unit/test_huggingface_models.py | 182 +++++++++++++++++++++++++- 4 files changed, 322 insertions(+), 14 deletions(-) diff --git a/app/pipeline/orchestrator.py b/app/pipeline/orchestrator.py index 111e882..3bf5fbb 100644 --- a/app/pipeline/orchestrator.py +++ b/app/pipeline/orchestrator.py @@ -10,6 +10,10 @@ """ import logging +import hashlib +import json +import re +import tempfile from contextlib import nullcontext from pathlib import Path from typing import Any @@ -37,6 +41,16 @@ "Specifications", "Resolution", ) +_LOCAL_PYANNOTE_CONFIG_MODELS = { + "segmentation": ("pytorch_model.bin", "pyannote segmentation"), + "embedding": ("pytorch_model.bin", "pyannote embedding"), +} +_PYANNOTE_PARAMS_RE = re.compile(r"^(\s*)params\s*:\s*(?:#.*)?$") +_PYANNOTE_COMPONENT_RE = re.compile(r"^(\s*)(segmentation|embedding)\s*:\s*(.*?)\s*$") + + +class LocalPyannoteModelArtifactError(RuntimeError): + """Public-safe error for incomplete local pyannote model snapshots.""" def _trusted_pyannote_checkpoint_globals() -> list[type]: @@ -108,6 +122,125 @@ def _resolve_local_pyannote_file(model_ref: str | Path, snapshot_filename: str) return str(local_path) +def _public_safe_missing_pyannote_artifact(component: str) -> str: + return ( + f"Local pyannote diarization config requires a cached {component} model " + "artifact. Preload the required Hugging Face snapshot or allow Hub loading." + ) + + +def _split_yaml_scalar_and_comment(raw_value: str) -> tuple[str, str]: + value, separator, comment = raw_value.partition(" #") + if not separator: + return raw_value.strip(), "" + return value.strip(), f" #{comment}" + + +def _unquote_yaml_scalar(raw_value: str) -> str: + value = raw_value.strip() + if len(value) >= 2 and value[0] == value[-1] and value[0] in {"'", '"'}: + return value[1:-1] + return value + + +def _resolve_local_config_component_file( + component: str, + model_ref: str, + *, + token: str | None, +) -> str: + snapshot_filename, purpose = _LOCAL_PYANNOTE_CONFIG_MODELS[component] + resolved_ref = ( + model_ref + if _is_local_model_ref(model_ref) + else resolve_hf_model_ref(model_ref, token=token, purpose=purpose) + ) + if not _is_local_model_ref(resolved_ref): + raise LocalPyannoteModelArtifactError( + _public_safe_missing_pyannote_artifact(component) + ) + try: + return _resolve_local_pyannote_file(resolved_ref, snapshot_filename) + except FileNotFoundError as exc: + raise LocalPyannoteModelArtifactError( + _public_safe_missing_pyannote_artifact(component) + ) from exc + + +def _localized_pyannote_config_path( + source_config: Path, + localized_content: str, +) -> Path: + digest = hashlib.sha256() + digest.update(str(source_config).encode("utf-8")) + digest.update(b"\0") + digest.update(localized_content.encode("utf-8")) + cache_dir = ( + Path(tempfile.gettempdir()) + / "voscript-pyannote-localized" + / digest.hexdigest()[:16] + ) + cache_dir.mkdir(parents=True, exist_ok=True) + localized_config = cache_dir / "config.yaml" + localized_config.write_text(localized_content, encoding="utf-8") + return localized_config + + +def _localize_pyannote_diarization_config( + config_path: str | Path, + *, + token: str | None, +) -> str: + """Rewrite nested pyannote config model refs to local snapshot weight files.""" + + source_config = Path(config_path) + lines = source_config.read_text(encoding="utf-8").splitlines() + rewritten: list[str] = [] + params_indent: int | None = None + changed = False + + for line in lines: + if params_indent is None: + match = _PYANNOTE_PARAMS_RE.match(line) + if match: + params_indent = len(match.group(1)) + rewritten.append(line) + continue + + if line.strip() and not line.lstrip().startswith("#"): + current_indent = len(line) - len(line.lstrip()) + if current_indent <= params_indent: + params_indent = None + rewritten.append(line) + continue + + component_match = _PYANNOTE_COMPONENT_RE.match(line) + if not component_match: + rewritten.append(line) + continue + + indent, component, raw_value = component_match.groups() + raw_scalar, comment = _split_yaml_scalar_and_comment(raw_value) + model_ref = _unquote_yaml_scalar(raw_scalar) + if not model_ref: + rewritten.append(line) + continue + + local_file = _resolve_local_config_component_file( + component, + model_ref, + token=token, + ) + rewritten.append(f"{indent}{component}: {json.dumps(local_file)}{comment}") + changed = True + + if not changed: + return str(source_config) + + localized_content = "\n".join(rewritten) + "\n" + return str(_localized_pyannote_config_path(source_config, localized_content)) + + def _faster_whisper_device_kwargs(device: str) -> dict[str, Any]: """Translate torch-style CUDA device strings to faster-whisper kwargs.""" @@ -204,6 +337,11 @@ def diarization(self): purpose="pyannote diarization", ) model_ref = _resolve_local_pyannote_file(model_ref, "config.yaml") + if _is_local_model_ref(model_ref): + model_ref = _localize_pyannote_diarization_config( + model_ref, + token=self.hf_token, + ) logger.info("Loading pyannote diarization model") self._diarization = _load_trusted_pyannote_model( PyannotePipeline.from_pretrained, diff --git a/doc/changelog.en.md b/doc/changelog.en.md index 94a27ce..5774520 100644 --- a/doc/changelog.en.md +++ b/doc/changelog.en.md @@ -14,11 +14,11 @@ _No changes yet._ state can still use `cuda:0` / `cuda:1`, but faster-whisper loads now receive `device="cuda"` with the matching `device_index`, avoiding `unsupported device cuda:0`. -- Fixed pyannote local-cache loading: when a complete Hugging Face snapshot is - already present, diarization and speaker-embedding models now receive local - file paths that pyannote can load directly. Missing caches still fall back to - Hub repo ids, and missing local files fail explicitly instead of being - swallowed. +- Fixed pyannote local-cache loading: when complete Hugging Face snapshots are + already present, diarization now receives a runtime-localized config whose + nested segmentation / embedding models also point at local weight files. + Missing caches still fall back to Hub repo ids, and incomplete local artifacts + fail explicitly before model loading. ### Features diff --git a/doc/changelog.zh.md b/doc/changelog.zh.md index 462a750..6ba9639 100644 --- a/doc/changelog.zh.md +++ b/doc/changelog.zh.md @@ -13,9 +13,9 @@ _暂无变更。_ - 修复 faster-whisper CUDA 设备传参:内部仍可用 `cuda:0` / `cuda:1` 表示 torch 设备,但加载 faster-whisper 时会转换为 `device="cuda"` 与 对应 `device_index`,避免 `unsupported device cuda:0`。 -- 修复 pyannote 本地缓存加载:当已有完整 Hugging Face snapshot 时,说话人分离和 - 声纹模型会传入 pyannote 可直接接受的本地文件路径;缓存缺失仍回退到 Hub - repo id,缺失本地文件会明确失败而不是被吞掉。 +- 修复 pyannote 本地缓存加载:当已有完整 Hugging Face snapshot 时,说话人分离 + 会生成 runtime-localized config,把内嵌 segmentation / embedding 子模型也指向 + 本地权重文件;缓存缺失仍回退到 Hub repo id,缺失本地工件会在加载前明确失败。 ### 功能 diff --git a/tests/unit/test_huggingface_models.py b/tests/unit/test_huggingface_models.py index c7c31ab..e6cc9f4 100644 --- a/tests/unit/test_huggingface_models.py +++ b/tests/unit/test_huggingface_models.py @@ -137,20 +137,187 @@ def from_pretrained(cls, model_ref, use_auth_token=None): assert calls == [(str(config_yml), "test-token")] -def test_diarization_loader_keeps_hub_repo_id_reference(monkeypatch): +def test_diarization_loader_localizes_nested_model_refs(monkeypatch, tmp_path): _stub_numpy(monkeypatch) from pipeline import TranscriptionPipeline import pipeline.orchestrator as orchestrator - repo_id = "pyannote/speaker-diarization-3.1" + diarization_snapshot = tmp_path / "diarization-snapshot" + segmentation_snapshot = tmp_path / "segmentation-snapshot" + embedding_snapshot = tmp_path / "embedding-snapshot" + for snapshot in (diarization_snapshot, segmentation_snapshot, embedding_snapshot): + snapshot.mkdir() + + config_yml = diarization_snapshot / "config.yaml" + config_yml.write_text( + "\n".join( + [ + "pipeline:", + " name: pyannote.audio.pipelines.SpeakerDiarization", + "params:", + " segmentation: pyannote/segmentation-3.0", + " embedding: pyannote/wespeaker-voxceleb-resnet34-LM", + " embedding_exclude_overlap: true", + ] + ), + encoding="utf-8", + ) + segmentation_weights = segmentation_snapshot / "pytorch_model.bin" + embedding_weights = embedding_snapshot / "pytorch_model.bin" + segmentation_weights.write_bytes(b"segmentation") + embedding_weights.write_bytes(b"embedding") calls = [] + def fake_resolve(repo_id, *, token, purpose): + calls.append((repo_id, token, purpose)) + return { + "pyannote/speaker-diarization-3.1": str(diarization_snapshot), + "pyannote/segmentation-3.0": str(segmentation_snapshot), + "pyannote/wespeaker-voxceleb-resnet34-LM": str(embedding_snapshot), + }[repo_id] + + monkeypatch.setattr(orchestrator, "resolve_hf_model_ref", fake_resolve) + + class FakeLoadedPipeline: + pass + + class FakePyannotePipeline: + @classmethod + def from_pretrained(cls, model_ref, use_auth_token=None): + loaded_config = Path(model_ref) + assert loaded_config != config_yml + content = loaded_config.read_text(encoding="utf-8") + assert "pyannote/segmentation-3.0" not in content + assert "pyannote/wespeaker-voxceleb-resnet34-LM" not in content + assert str(segmentation_weights) in content + assert str(embedding_weights) in content + calls.append(("from_pretrained", model_ref, use_auth_token)) + return FakeLoadedPipeline() + monkeypatch.setattr( - orchestrator, - "resolve_hf_model_ref", - lambda repo_id, *, token, purpose: repo_id, + sys.modules["pyannote.audio"], + "Pipeline", + FakePyannotePipeline, + raising=False, + ) + + pipeline = TranscriptionPipeline.__new__(TranscriptionPipeline) + pipeline.device = "cpu" + pipeline.hf_token = "test-token" + pipeline._diarization = None + + assert pipeline.diarization.__class__ is FakeLoadedPipeline + assert calls[:3] == [ + ( + "pyannote/speaker-diarization-3.1", + "test-token", + "pyannote diarization", + ), + ("pyannote/segmentation-3.0", "test-token", "pyannote segmentation"), + ( + "pyannote/wespeaker-voxceleb-resnet34-LM", + "test-token", + "pyannote embedding", + ), + ] + assert calls[3][0] == "from_pretrained" + assert Path(calls[3][1]).is_file() + assert calls[3][2] == "test-token" + + +@pytest.mark.parametrize( + "missing_name", + [ + "segmentation", + "embedding", + ], +) +def test_diarization_loader_rejects_missing_nested_local_artifact_without_loading( + monkeypatch, + tmp_path, + missing_name, +): + _stub_numpy(monkeypatch) + from pipeline import TranscriptionPipeline + import pipeline.orchestrator as orchestrator + + diarization_snapshot = tmp_path / "diarization-snapshot" + segmentation_snapshot = tmp_path / "segmentation-snapshot" + embedding_snapshot = tmp_path / "embedding-snapshot" + for snapshot in (diarization_snapshot, segmentation_snapshot, embedding_snapshot): + snapshot.mkdir() + + config_yml = diarization_snapshot / "config.yaml" + config_yml.write_text( + "\n".join( + [ + "pipeline:", + " name: pyannote.audio.pipelines.SpeakerDiarization", + "params:", + " segmentation: pyannote/segmentation-3.0", + " embedding: pyannote/wespeaker-voxceleb-resnet34-LM", + ] + ), + encoding="utf-8", + ) + if missing_name != "segmentation": + (segmentation_snapshot / "pytorch_model.bin").write_bytes(b"segmentation") + if missing_name != "embedding": + (embedding_snapshot / "pytorch_model.bin").write_bytes(b"embedding") + loader_calls = [] + + def fake_resolve(repo_id, *, token, purpose): + return { + "pyannote/speaker-diarization-3.1": str(diarization_snapshot), + "pyannote/segmentation-3.0": str(segmentation_snapshot), + "pyannote/wespeaker-voxceleb-resnet34-LM": str(embedding_snapshot), + }[repo_id] + + monkeypatch.setattr(orchestrator, "resolve_hf_model_ref", fake_resolve) + + class FakePyannotePipeline: + @classmethod + def from_pretrained(cls, model_ref, use_auth_token=None): + loader_calls.append((model_ref, use_auth_token)) + return object() + + monkeypatch.setattr( + sys.modules["pyannote.audio"], + "Pipeline", + FakePyannotePipeline, + raising=False, ) + pipeline = TranscriptionPipeline.__new__(TranscriptionPipeline) + pipeline.device = "cpu" + pipeline.hf_token = "test-token" + pipeline._diarization = None + + with pytest.raises(RuntimeError) as exc_info: + _ = pipeline.diarization + + message = str(exc_info.value) + assert missing_name in message + assert str(tmp_path) not in message + assert "test-token" not in message + assert "https://" not in message + assert loader_calls == [] + + +def test_diarization_loader_keeps_hub_repo_id_reference(monkeypatch): + _stub_numpy(monkeypatch) + from pipeline import TranscriptionPipeline + import pipeline.orchestrator as orchestrator + + repo_id = "pyannote/speaker-diarization-3.1" + calls = [] + + def fake_resolve(repo_id, *, token, purpose): + calls.append(("resolve", repo_id, token, purpose)) + return repo_id + + monkeypatch.setattr(orchestrator, "resolve_hf_model_ref", fake_resolve) + class FakeLoadedPipeline: pass @@ -173,7 +340,10 @@ def from_pretrained(cls, model_ref, use_auth_token=None): pipeline._diarization = None assert pipeline.diarization.__class__ is FakeLoadedPipeline - assert calls == [(repo_id, "test-token")] + assert calls == [ + ("resolve", repo_id, "test-token", "pyannote diarization"), + (repo_id, "test-token"), + ] def test_diarization_loader_rejects_missing_local_snapshot(monkeypatch, tmp_path):