-
Notifications
You must be signed in to change notification settings - Fork 0
fix: normalize faster-whisper CUDA device args #18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ce9812d
861ea4f
4c5e688
2061e9c
ba06f44
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -10,6 +10,10 @@ | |||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| import logging | ||||||||||||||||||||||||||||||||||||||||
| import hashlib | ||||||||||||||||||||||||||||||||||||||||
| import json | ||||||||||||||||||||||||||||||||||||||||
| import re | ||||||||||||||||||||||||||||||||||||||||
| import tempfile | ||||||||||||||||||||||||||||||||||||||||
| from contextlib import nullcontext | ||||||||||||||||||||||||||||||||||||||||
| from pathlib import Path | ||||||||||||||||||||||||||||||||||||||||
| from typing import Any | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -37,6 +41,16 @@ | |||||||||||||||||||||||||||||||||||||||
| "Specifications", | ||||||||||||||||||||||||||||||||||||||||
| "Resolution", | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| _LOCAL_PYANNOTE_CONFIG_MODELS = { | ||||||||||||||||||||||||||||||||||||||||
| "segmentation": ("pytorch_model.bin", "pyannote segmentation"), | ||||||||||||||||||||||||||||||||||||||||
| "embedding": ("pytorch_model.bin", "pyannote embedding"), | ||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||
| _PYANNOTE_PARAMS_RE = re.compile(r"^(\s*)params\s*:\s*(?:#.*)?$") | ||||||||||||||||||||||||||||||||||||||||
| _PYANNOTE_COMPONENT_RE = re.compile(r"^(\s*)(segmentation|embedding)\s*:\s*(.*?)\s*$") | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| class LocalPyannoteModelArtifactError(RuntimeError): | ||||||||||||||||||||||||||||||||||||||||
| """Public-safe error for incomplete local pyannote model snapshots.""" | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def _trusted_pyannote_checkpoint_globals() -> list[type]: | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -85,6 +99,160 @@ def _load_trusted_pyannote_model( | |||||||||||||||||||||||||||||||||||||||
| return from_pretrained(model_ref, **auth_kwargs) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def _is_local_model_ref(model_ref: str | Path) -> bool: | ||||||||||||||||||||||||||||||||||||||||
| if isinstance(model_ref, Path): | ||||||||||||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||||||||||||
| path = Path(model_ref).expanduser() | ||||||||||||||||||||||||||||||||||||||||
| return path.is_absolute() or model_ref.startswith((".", "~")) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def _resolve_local_pyannote_file(model_ref: str | Path, snapshot_filename: str) -> str: | ||||||||||||||||||||||||||||||||||||||||
| """Convert a local HF snapshot directory into pyannote's expected file path.""" | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| if not _is_local_model_ref(model_ref): | ||||||||||||||||||||||||||||||||||||||||
| return str(model_ref) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| local_path = Path(model_ref).expanduser() | ||||||||||||||||||||||||||||||||||||||||
| if local_path.is_dir(): | ||||||||||||||||||||||||||||||||||||||||
| local_path = local_path / snapshot_filename | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| if not local_path.is_file(): | ||||||||||||||||||||||||||||||||||||||||
| raise FileNotFoundError(f"Local pyannote model file not found: {local_path}") | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| return str(local_path) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def _public_safe_missing_pyannote_artifact(component: str) -> str: | ||||||||||||||||||||||||||||||||||||||||
| return ( | ||||||||||||||||||||||||||||||||||||||||
| f"Local pyannote diarization config requires a cached {component} model " | ||||||||||||||||||||||||||||||||||||||||
| "artifact. Preload the required Hugging Face snapshot or allow Hub loading." | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def _split_yaml_scalar_and_comment(raw_value: str) -> tuple[str, str]: | ||||||||||||||||||||||||||||||||||||||||
| value, separator, comment = raw_value.partition(" #") | ||||||||||||||||||||||||||||||||||||||||
| if not separator: | ||||||||||||||||||||||||||||||||||||||||
| return raw_value.strip(), "" | ||||||||||||||||||||||||||||||||||||||||
| return value.strip(), f" #{comment}" | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def _unquote_yaml_scalar(raw_value: str) -> str: | ||||||||||||||||||||||||||||||||||||||||
| value = raw_value.strip() | ||||||||||||||||||||||||||||||||||||||||
| if len(value) >= 2 and value[0] == value[-1] and value[0] in {"'", '"'}: | ||||||||||||||||||||||||||||||||||||||||
| return value[1:-1] | ||||||||||||||||||||||||||||||||||||||||
| return value | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def _resolve_local_config_component_file( | ||||||||||||||||||||||||||||||||||||||||
| component: str, | ||||||||||||||||||||||||||||||||||||||||
| model_ref: str, | ||||||||||||||||||||||||||||||||||||||||
| *, | ||||||||||||||||||||||||||||||||||||||||
| token: str | None, | ||||||||||||||||||||||||||||||||||||||||
| ) -> str: | ||||||||||||||||||||||||||||||||||||||||
| snapshot_filename, purpose = _LOCAL_PYANNOTE_CONFIG_MODELS[component] | ||||||||||||||||||||||||||||||||||||||||
| resolved_ref = ( | ||||||||||||||||||||||||||||||||||||||||
| model_ref | ||||||||||||||||||||||||||||||||||||||||
| if _is_local_model_ref(model_ref) | ||||||||||||||||||||||||||||||||||||||||
| else resolve_hf_model_ref(model_ref, token=token, purpose=purpose) | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| if not _is_local_model_ref(resolved_ref): | ||||||||||||||||||||||||||||||||||||||||
| raise LocalPyannoteModelArtifactError( | ||||||||||||||||||||||||||||||||||||||||
| _public_safe_missing_pyannote_artifact(component) | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||
| return _resolve_local_pyannote_file(resolved_ref, snapshot_filename) | ||||||||||||||||||||||||||||||||||||||||
| except FileNotFoundError as exc: | ||||||||||||||||||||||||||||||||||||||||
| raise LocalPyannoteModelArtifactError( | ||||||||||||||||||||||||||||||||||||||||
| _public_safe_missing_pyannote_artifact(component) | ||||||||||||||||||||||||||||||||||||||||
| ) from exc | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def _localized_pyannote_config_path( | ||||||||||||||||||||||||||||||||||||||||
| source_config: Path, | ||||||||||||||||||||||||||||||||||||||||
| localized_content: str, | ||||||||||||||||||||||||||||||||||||||||
| ) -> Path: | ||||||||||||||||||||||||||||||||||||||||
| digest = hashlib.sha256() | ||||||||||||||||||||||||||||||||||||||||
| digest.update(str(source_config).encode("utf-8")) | ||||||||||||||||||||||||||||||||||||||||
| digest.update(b"\0") | ||||||||||||||||||||||||||||||||||||||||
| digest.update(localized_content.encode("utf-8")) | ||||||||||||||||||||||||||||||||||||||||
| cache_dir = ( | ||||||||||||||||||||||||||||||||||||||||
| Path(tempfile.gettempdir()) | ||||||||||||||||||||||||||||||||||||||||
| / "voscript-pyannote-localized" | ||||||||||||||||||||||||||||||||||||||||
| / digest.hexdigest()[:16] | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| cache_dir.mkdir(parents=True, exist_ok=True) | ||||||||||||||||||||||||||||||||||||||||
| localized_config = cache_dir / "config.yaml" | ||||||||||||||||||||||||||||||||||||||||
| localized_config.write_text(localized_content, encoding="utf-8") | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
| localized_config.write_text(localized_content, encoding="utf-8") | |
| temp_config_path: Path | None = None | |
| try: | |
| with tempfile.NamedTemporaryFile( | |
| mode="w", | |
| encoding="utf-8", | |
| dir=cache_dir, | |
| prefix="config.", | |
| suffix=".tmp", | |
| delete=False, | |
| ) as temp_config: | |
| temp_config.write(localized_content) | |
| temp_config.flush() | |
| temp_config_path = Path(temp_config.name) | |
| temp_config_path.replace(localized_config) | |
| except Exception: | |
| if temp_config_path is not None: | |
| temp_config_path.unlink(missing_ok=True) | |
| raise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Raising FileNotFoundError here includes the full local filesystem path in the exception message. That string is later surfaced to clients via the job status
errorfield (see run_transcription storingstr(e)), so this leaks host paths. Consider raising a public-safe error message (or a custom exception) that does not embedlocal_path, while still logging the full path server-side for debugging.