Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 141 additions & 59 deletions src/diarizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Speaker diarization using PyAnnote.audio"""
from pathlib import Path
from typing import List, Dict, Optional, Tuple, Any, Callable
from typing import List, Dict, Optional, Tuple, Any, Callable, Union
from dataclasses import dataclass
import os
import sys
Expand Down Expand Up @@ -364,26 +364,20 @@ def _load_component(factory: Callable, model_name: str, **factory_kwargs):
self.logger.info("4. Set HF_TOKEN in your .env file")
self.pipeline = None # Ensure it's None on failure

def diarize(self, audio_path: Path) -> Tuple[List[SpeakerSegment], Dict[str, np.ndarray]]:
def _load_audio_for_diarization(self, audio_path: Path) -> Union[Dict, str]:
"""
Perform speaker diarization on audio file.
Load audio file for diarization, preferring in-memory loading.

Attempts to load audio using torchaudio for in-memory processing.
Falls back to file path if in-memory loading fails.

Args:
audio_path: Path to WAV file
audio_path: Path to audio file

Returns:
A tuple containing:
- A list of SpeakerSegment objects.
- A dictionary mapping speaker IDs to their embeddings.
Either a dict with 'waveform' and 'sample_rate' keys (in-memory),
or a string path (fallback for file-based loading)
"""
self._load_pipeline_if_needed()

if self.pipeline is None:
# Fallback: create dummy single-speaker segments
segments = self._create_fallback_diarization(audio_path)
return segments, {}

# Run diarization
diarization_input = str(audio_path)
try:
import torchaudio # type: ignore
Expand All @@ -392,12 +386,28 @@ def diarize(self, audio_path: Path) -> Tuple[List[SpeakerSegment], Dict[str, np.
"waveform": waveform,
"sample_rate": sample_rate
}
self.logger.debug("Loaded audio in-memory for diarization")
except Exception as exc:
self.logger.debug(
"Falling back to on-disk audio loading for diarization: %s",
exc
)

return diarization_input

def _perform_diarization(self, diarization_input: Union[Dict, str]) -> Tuple[Any, List[SpeakerSegment]]:
"""
Execute diarization pipeline and convert results to segments.

Args:
diarization_input: Either a dict with audio data or a file path string

Returns:
A tuple of (diarization_result, segments_list) where:
- diarization_result: Raw result from pyannote pipeline (needed for embeddings)
- segments_list: List of SpeakerSegment objects
"""
self.logger.debug("Running diarization pipeline...")
diarization = self.pipeline(diarization_input)

# Convert to our format
Expand All @@ -409,58 +419,130 @@ def diarize(self, audio_path: Path) -> Tuple[List[SpeakerSegment], Dict[str, np.
end_time=turn.end
))

# Extract speaker embeddings using the loaded model
self.logger.info(
"Diarization complete: %d segments, %d speakers",
len(segments),
len(set(seg.speaker_id for seg in segments))
)

return diarization, segments

def _extract_speaker_embeddings(
self,
audio_path: Path,
diarization: Any
) -> Dict[str, np.ndarray]:
"""
Extract speaker embeddings for each diarized speaker.

Uses the embedding model to extract voice embeddings for each speaker
identified in the diarization result. Embeddings are averaged across
all segments for each speaker.

Args:
audio_path: Path to audio file
diarization: Raw diarization result from pyannote pipeline

Returns:
Dictionary mapping speaker IDs to their embedding arrays
"""
speaker_embeddings: Dict[str, np.ndarray] = {}
if self.embedding_model is not None:

if self.embedding_model is None:
self.logger.debug("No embedding model available, skipping embedding extraction")
return speaker_embeddings

# Import pydub for audio segment extraction
try:
from pydub import AudioSegment
except ImportError as exc:
self.logger.warning(
"Unable to import pydub for embedding extraction: %s",
exc
)
return speaker_embeddings

# Load audio file
try:
audio = AudioSegment.from_wav(str(audio_path))
except Exception as exc:
self.logger.warning(
"Unable to load %s for speaker embeddings: %s",
audio_path,
exc
)
return speaker_embeddings

# Extract embeddings for each speaker
for speaker_id in diarization.labels():
try:
from pydub import AudioSegment
speaker_segments = diarization.label_timeline(speaker_id)
speaker_audio = AudioSegment.empty()

# Concatenate all segments for this speaker
for segment in speaker_segments:
speaker_audio += audio[segment.start * 1000:segment.end * 1000]

if len(speaker_audio) <= 0:
self.logger.debug("Skipping %s: no audio data", speaker_id)
continue

# Convert to numpy array and normalize
samples = np.array(
speaker_audio.get_array_of_samples(),
dtype=np.float32
) / 32768.0

# Prepare tensor and run inference
samples_tensor = self._prepare_waveform_tensor(samples)
embedding = self._run_embedding_inference(samples_tensor, audio.frame_rate)
speaker_embeddings[speaker_id] = self._embedding_to_numpy(embedding)

self.logger.debug("Extracted embedding for %s", speaker_id)

except Exception as exc:
self.logger.warning(
"Unable to import pydub for embedding extraction: %s",
"Failed to extract embedding for %s: %s",
speaker_id,
exc
)
AudioSegment = None # type: ignore

audio = None
if AudioSegment is not None:
try:
audio = AudioSegment.from_wav(str(audio_path))
except Exception as exc:
self.logger.warning(
"Unable to load %s for speaker embeddings: %s",
audio_path,
exc
)
self.logger.info("Extracted embeddings for %d speakers", len(speaker_embeddings))
return speaker_embeddings

if audio is None:
self.logger.warning(
"Skipping speaker embedding extraction because audio could not be loaded."
)
else:
for speaker_id in diarization.labels():
speaker_segments = diarization.label_timeline(speaker_id)
speaker_audio = AudioSegment.empty()
for segment in speaker_segments:
speaker_audio += audio[segment.start * 1000:segment.end * 1000]

if len(speaker_audio) <= 0:
continue

samples = np.array(
speaker_audio.get_array_of_samples(),
dtype=np.float32
) / 32768.0

samples_tensor = self._prepare_waveform_tensor(samples)
try:
embedding = self._run_embedding_inference(samples_tensor, audio.frame_rate)
speaker_embeddings[speaker_id] = self._embedding_to_numpy(embedding)
except Exception as exc:
self.logger.warning(
"Failed to extract embedding for %s: %s",
speaker_id,
exc
)
def diarize(self, audio_path: Path) -> Tuple[List[SpeakerSegment], Dict[str, np.ndarray]]:
"""
Perform speaker diarization on audio file.

This method orchestrates the complete diarization pipeline:
1. Load and initialize the diarization pipeline
2. Load audio file for processing
3. Execute speaker diarization
4. Extract speaker embeddings

Args:
audio_path: Path to WAV file

Returns:
A tuple containing:
- A list of SpeakerSegment objects
- A dictionary mapping speaker IDs to their embeddings
"""
self._load_pipeline_if_needed()

if self.pipeline is None:
# Fallback: create dummy single-speaker segments
segments = self._create_fallback_diarization(audio_path)
return segments, {}

# Step 1: Load audio for diarization
diarization_input = self._load_audio_for_diarization(audio_path)

# Step 2: Perform diarization
diarization, segments = self._perform_diarization(diarization_input)

# Step 3: Extract speaker embeddings
speaker_embeddings = self._extract_speaker_embeddings(audio_path, diarization)

return segments, speaker_embeddings

Expand Down
Loading