diff --git a/src/diarizer.py b/src/diarizer.py index af28c8a..a5baf2d 100644 --- a/src/diarizer.py +++ b/src/diarizer.py @@ -1,6 +1,6 @@ """Speaker diarization using PyAnnote.audio""" from pathlib import Path -from typing import List, Dict, Optional, Tuple, Any, Callable +from typing import List, Dict, Optional, Tuple, Any, Callable, Union from dataclasses import dataclass import os import sys @@ -364,26 +364,20 @@ def _load_component(factory: Callable, model_name: str, **factory_kwargs): self.logger.info("4. Set HF_TOKEN in your .env file") self.pipeline = None # Ensure it's None on failure - def diarize(self, audio_path: Path) -> Tuple[List[SpeakerSegment], Dict[str, np.ndarray]]: + def _load_audio_for_diarization(self, audio_path: Path) -> Union[Dict, str]: """ - Perform speaker diarization on audio file. + Load audio file for diarization, preferring in-memory loading. + + Attempts to load audio using torchaudio for in-memory processing. + Falls back to file path if in-memory loading fails. Args: - audio_path: Path to WAV file + audio_path: Path to audio file Returns: - A tuple containing: - - A list of SpeakerSegment objects. - - A dictionary mapping speaker IDs to their embeddings. + Either a dict with 'waveform' and 'sample_rate' keys (in-memory), + or a string path (fallback for file-based loading) """ - self._load_pipeline_if_needed() - - if self.pipeline is None: - # Fallback: create dummy single-speaker segments - segments = self._create_fallback_diarization(audio_path) - return segments, {} - - # Run diarization diarization_input = str(audio_path) try: import torchaudio # type: ignore @@ -392,12 +386,28 @@ def diarize(self, audio_path: Path) -> Tuple[List[SpeakerSegment], Dict[str, np. "waveform": waveform, "sample_rate": sample_rate } + self.logger.debug("Loaded audio in-memory for diarization") except Exception as exc: self.logger.debug( "Falling back to on-disk audio loading for diarization: %s", exc ) + return diarization_input + + def _perform_diarization(self, diarization_input: Union[Dict, str]) -> Tuple[Any, List[SpeakerSegment]]: + """ + Execute diarization pipeline and convert results to segments. + + Args: + diarization_input: Either a dict with audio data or a file path string + + Returns: + A tuple of (diarization_result, segments_list) where: + - diarization_result: Raw result from pyannote pipeline (needed for embeddings) + - segments_list: List of SpeakerSegment objects + """ + self.logger.debug("Running diarization pipeline...") diarization = self.pipeline(diarization_input) # Convert to our format @@ -409,58 +419,130 @@ def diarize(self, audio_path: Path) -> Tuple[List[SpeakerSegment], Dict[str, np. end_time=turn.end )) - # Extract speaker embeddings using the loaded model + self.logger.info( + "Diarization complete: %d segments, %d speakers", + len(segments), + len(set(seg.speaker_id for seg in segments)) + ) + + return diarization, segments + + def _extract_speaker_embeddings( + self, + audio_path: Path, + diarization: Any + ) -> Dict[str, np.ndarray]: + """ + Extract speaker embeddings for each diarized speaker. + + Uses the embedding model to extract voice embeddings for each speaker + identified in the diarization result. Embeddings are averaged across + all segments for each speaker. + + Args: + audio_path: Path to audio file + diarization: Raw diarization result from pyannote pipeline + + Returns: + Dictionary mapping speaker IDs to their embedding arrays + """ speaker_embeddings: Dict[str, np.ndarray] = {} - if self.embedding_model is not None: + + if self.embedding_model is None: + self.logger.debug("No embedding model available, skipping embedding extraction") + return speaker_embeddings + + # Import pydub for audio segment extraction + try: + from pydub import AudioSegment + except ImportError as exc: + self.logger.warning( + "Unable to import pydub for embedding extraction: %s", + exc + ) + return speaker_embeddings + + # Load audio file + try: + audio = AudioSegment.from_wav(str(audio_path)) + except Exception as exc: + self.logger.warning( + "Unable to load %s for speaker embeddings: %s", + audio_path, + exc + ) + return speaker_embeddings + + # Extract embeddings for each speaker + for speaker_id in diarization.labels(): try: - from pydub import AudioSegment + speaker_segments = diarization.label_timeline(speaker_id) + speaker_audio = AudioSegment.empty() + + # Concatenate all segments for this speaker + for segment in speaker_segments: + speaker_audio += audio[segment.start * 1000:segment.end * 1000] + + if len(speaker_audio) <= 0: + self.logger.debug("Skipping %s: no audio data", speaker_id) + continue + + # Convert to numpy array and normalize + samples = np.array( + speaker_audio.get_array_of_samples(), + dtype=np.float32 + ) / 32768.0 + + # Prepare tensor and run inference + samples_tensor = self._prepare_waveform_tensor(samples) + embedding = self._run_embedding_inference(samples_tensor, audio.frame_rate) + speaker_embeddings[speaker_id] = self._embedding_to_numpy(embedding) + + self.logger.debug("Extracted embedding for %s", speaker_id) + except Exception as exc: self.logger.warning( - "Unable to import pydub for embedding extraction: %s", + "Failed to extract embedding for %s: %s", + speaker_id, exc ) - AudioSegment = None # type: ignore - audio = None - if AudioSegment is not None: - try: - audio = AudioSegment.from_wav(str(audio_path)) - except Exception as exc: - self.logger.warning( - "Unable to load %s for speaker embeddings: %s", - audio_path, - exc - ) + self.logger.info("Extracted embeddings for %d speakers", len(speaker_embeddings)) + return speaker_embeddings - if audio is None: - self.logger.warning( - "Skipping speaker embedding extraction because audio could not be loaded." - ) - else: - for speaker_id in diarization.labels(): - speaker_segments = diarization.label_timeline(speaker_id) - speaker_audio = AudioSegment.empty() - for segment in speaker_segments: - speaker_audio += audio[segment.start * 1000:segment.end * 1000] - - if len(speaker_audio) <= 0: - continue - - samples = np.array( - speaker_audio.get_array_of_samples(), - dtype=np.float32 - ) / 32768.0 - - samples_tensor = self._prepare_waveform_tensor(samples) - try: - embedding = self._run_embedding_inference(samples_tensor, audio.frame_rate) - speaker_embeddings[speaker_id] = self._embedding_to_numpy(embedding) - except Exception as exc: - self.logger.warning( - "Failed to extract embedding for %s: %s", - speaker_id, - exc - ) + def diarize(self, audio_path: Path) -> Tuple[List[SpeakerSegment], Dict[str, np.ndarray]]: + """ + Perform speaker diarization on audio file. + + This method orchestrates the complete diarization pipeline: + 1. Load and initialize the diarization pipeline + 2. Load audio file for processing + 3. Execute speaker diarization + 4. Extract speaker embeddings + + Args: + audio_path: Path to WAV file + + Returns: + A tuple containing: + - A list of SpeakerSegment objects + - A dictionary mapping speaker IDs to their embeddings + """ + self._load_pipeline_if_needed() + + if self.pipeline is None: + # Fallback: create dummy single-speaker segments + segments = self._create_fallback_diarization(audio_path) + return segments, {} + + # Step 1: Load audio for diarization + diarization_input = self._load_audio_for_diarization(audio_path) + + # Step 2: Perform diarization + diarization, segments = self._perform_diarization(diarization_input) + + # Step 3: Extract speaker embeddings + speaker_embeddings = self._extract_speaker_embeddings(audio_path, diarization) return segments, speaker_embeddings diff --git a/tests/test_diarizer.py b/tests/test_diarizer.py index 148e808..d2fbef1 100644 --- a/tests/test_diarizer.py +++ b/tests/test_diarizer.py @@ -334,6 +334,263 @@ def test_create_unknown_backend_raises_error(self, monkeypatch): with pytest.raises(ValueError): DiarizerFactory.create() +class TestExtractedMethods: + """Test suite for extracted diarization methods.""" + + @pytest.fixture + def diarizer(self): + """Provides a SpeakerDiarizer instance.""" + return SpeakerDiarizer() + + @pytest.fixture + def mock_audio_segment_setup(self): + """Provides common mock setup for pydub.AudioSegment.""" + # Mock audio chunk + mock_audio_chunk = MagicMock() + mock_audio_chunk.frame_rate = 16000 + mock_audio_chunk.__len__.return_value = 16000 + mock_audio_chunk.get_array_of_samples.return_value = np.zeros(16000, dtype=np.int16) + + # Mock full audio + mock_full_audio = MagicMock() + mock_full_audio.frame_rate = 16000 + mock_full_audio.__getitem__ = lambda self, key: mock_audio_chunk + + # Mock empty audio accumulation + accumulated_audio = MagicMock() + accumulated_audio.frame_rate = 16000 + accumulated_audio.__len__.return_value = 16000 + accumulated_audio.get_array_of_samples.return_value = np.zeros(16000, dtype=np.int16) + accumulated_audio.__add__ = lambda self, other: accumulated_audio + accumulated_audio.__iadd__ = lambda self, other: accumulated_audio + + # Mock empty audio + mock_empty = MagicMock() + mock_empty.frame_rate = 16000 + mock_empty.__len__.return_value = 0 + mock_empty.__add__ = lambda self, other: accumulated_audio + mock_empty.__iadd__ = lambda self, other: accumulated_audio + + return { + 'audio_chunk': mock_audio_chunk, + 'full_audio': mock_full_audio, + 'accumulated_audio': accumulated_audio, + 'empty': mock_empty + } + + def test_load_audio_for_diarization_with_torchaudio(self, diarizer, tmp_path, monkeypatch): + """Test audio loading with torchaudio succeeds.""" + # Create a dummy audio file + from scipy.io.wavfile import write + dummy_audio_path = tmp_path / "test.wav" + write(dummy_audio_path, 16000, np.zeros(16000, dtype=np.int16)) + + # Mock torchaudio to ensure it's used + mock_torchaudio = MagicMock() + mock_waveform = torch.zeros((1, 16000)) + mock_torchaudio.load.return_value = (mock_waveform, 16000) + + with patch.dict('sys.modules', {'torchaudio': mock_torchaudio}): + result = diarizer._load_audio_for_diarization(dummy_audio_path) + + # Should return dict with waveform and sample_rate + assert isinstance(result, dict) + assert "waveform" in result + assert "sample_rate" in result + assert result["sample_rate"] == 16000 + + def test_load_audio_for_diarization_fallback(self, diarizer, tmp_path): + """Test audio loading falls back to file path when torchaudio fails.""" + dummy_audio_path = tmp_path / "test.wav" + dummy_audio_path.touch() + + # Mock torchaudio to raise an exception + with patch.dict('sys.modules', {'torchaudio': None}): + result = diarizer._load_audio_for_diarization(dummy_audio_path) + + # Should return string path as fallback + assert isinstance(result, str) + assert result == str(dummy_audio_path) + + def test_perform_diarization_returns_segments(self, diarizer): + """Test that _perform_diarization converts pipeline results to segments.""" + # Mock pipeline result + mock_diarization = MagicMock() + mock_diarization.itertracks.return_value = [ + (MagicMock(start=0.0, end=2.5), None, "SPEAKER_00"), + (MagicMock(start=2.5, end=5.0), None, "SPEAKER_01"), + (MagicMock(start=5.0, end=7.0), None, "SPEAKER_00"), + ] + + # Mock the pipeline + diarizer.pipeline = MagicMock(return_value=mock_diarization) + + # Execute + diarization_result, segments = diarizer._perform_diarization("dummy_input.wav") + + # Verify results + assert diarization_result == mock_diarization + assert len(segments) == 3 + assert segments[0].speaker_id == "SPEAKER_00" + assert segments[0].start_time == 0.0 + assert segments[0].end_time == 2.5 + assert segments[1].speaker_id == "SPEAKER_01" + assert segments[2].speaker_id == "SPEAKER_00" + + def test_perform_diarization_with_dict_input(self, diarizer): + """Test that _perform_diarization works with dict input.""" + mock_diarization = MagicMock() + mock_diarization.itertracks.return_value = [ + (MagicMock(start=1.0, end=3.0), None, "SPEAKER_00"), + ] + + diarizer.pipeline = MagicMock(return_value=mock_diarization) + + dict_input = { + "waveform": torch.zeros((1, 16000)), + "sample_rate": 16000 + } + + diarization_result, segments = diarizer._perform_diarization(dict_input) + + assert len(segments) == 1 + diarizer.pipeline.assert_called_once_with(dict_input) + + def test_extract_speaker_embeddings_no_model(self, diarizer, tmp_path): + """Test that embedding extraction returns empty dict when model is None.""" + diarizer.embedding_model = None + mock_diarization = MagicMock() + + dummy_audio_path = tmp_path / "test.wav" + dummy_audio_path.touch() + + result = diarizer._extract_speaker_embeddings(dummy_audio_path, mock_diarization) + + assert result == {} + + @patch('pydub.AudioSegment') + def test_extract_speaker_embeddings_successful(self, MockAudioSegment, diarizer, tmp_path, mock_audio_segment_setup): + """Test successful embedding extraction for multiple speakers.""" + # Setup + mock_embedding_model = MagicMock() + mock_embedding_model.return_value = torch.tensor([[0.1, 0.2, 0.3]]) + diarizer.embedding_model = mock_embedding_model + + # Use fixture for audio segment mocking + mocks = mock_audio_segment_setup + MockAudioSegment.from_wav.return_value = mocks['full_audio'] + MockAudioSegment.empty.return_value = mocks['empty'] + + # Mock diarization result + mock_diarization = MagicMock() + mock_diarization.labels.return_value = ["SPEAKER_00", "SPEAKER_01"] + + def mock_label_timeline(speaker_id): + return [MagicMock(start=0.0, end=2.0)] + + mock_diarization.label_timeline = mock_label_timeline + + dummy_audio_path = tmp_path / "test.wav" + dummy_audio_path.touch() + + # Execute + embeddings = diarizer._extract_speaker_embeddings(dummy_audio_path, mock_diarization) + + # Verify + assert "SPEAKER_00" in embeddings + assert "SPEAKER_01" in embeddings + assert isinstance(embeddings["SPEAKER_00"], np.ndarray) + assert isinstance(embeddings["SPEAKER_01"], np.ndarray) + + @patch('pydub.AudioSegment') + def test_extract_speaker_embeddings_skips_empty_segments(self, MockAudioSegment, diarizer, tmp_path): + """Test that embedding extraction skips speakers with no audio.""" + mock_embedding_model = MagicMock() + diarizer.embedding_model = mock_embedding_model + + # Mock empty audio segment + mock_empty = MagicMock() + mock_empty.frame_rate = 16000 + mock_empty.__len__.return_value = 0 + + mock_full_audio = MagicMock() + mock_full_audio.frame_rate = 16000 + + MockAudioSegment.from_wav.return_value = mock_full_audio + MockAudioSegment.empty.return_value = mock_empty + + mock_diarization = MagicMock() + mock_diarization.labels.return_value = ["SPEAKER_00"] + mock_diarization.label_timeline.return_value = [MagicMock(start=0.0, end=0.0)] + + dummy_audio_path = tmp_path / "test.wav" + dummy_audio_path.touch() + + embeddings = diarizer._extract_speaker_embeddings(dummy_audio_path, mock_diarization) + + # Should be empty since audio segment was empty + assert embeddings == {} + # Embedding model should not be called + mock_embedding_model.assert_not_called() + + def test_extract_speaker_embeddings_handles_pydub_import_error(self, diarizer, tmp_path): + """Test that embedding extraction handles pydub import failure gracefully.""" + diarizer.embedding_model = MagicMock() + + with patch.dict('sys.modules', {'pydub': None}): + with patch('builtins.__import__', side_effect=ImportError("pydub not found")): + mock_diarization = MagicMock() + dummy_audio_path = tmp_path / "test.wav" + + embeddings = diarizer._extract_speaker_embeddings(dummy_audio_path, mock_diarization) + + assert embeddings == {} + + @patch('pydub.AudioSegment') + def test_extract_speaker_embeddings_handles_audio_load_error(self, MockAudioSegment, diarizer, tmp_path): + """Test that embedding extraction handles audio loading errors.""" + diarizer.embedding_model = MagicMock() + MockAudioSegment.from_wav.side_effect = Exception("Audio file corrupt") + + mock_diarization = MagicMock() + dummy_audio_path = tmp_path / "test.wav" + dummy_audio_path.touch() + + embeddings = diarizer._extract_speaker_embeddings(dummy_audio_path, mock_diarization) + + assert embeddings == {} + + @patch('pydub.AudioSegment') + def test_extract_speaker_embeddings_handles_inference_error(self, MockAudioSegment, diarizer, tmp_path, caplog, mock_audio_segment_setup): + """Test that embedding extraction continues when inference fails for one speaker.""" + # Setup: first speaker succeeds, second fails + mock_embedding_model = MagicMock() + mock_embedding_model.side_effect = [ + torch.tensor([[0.1, 0.2]]), # Success for SPEAKER_00 + RuntimeError("Inference failed"), # Failure for SPEAKER_01 + ] + diarizer.embedding_model = mock_embedding_model + + # Use fixture for audio segment mocking + mocks = mock_audio_segment_setup + MockAudioSegment.from_wav.return_value = mocks['full_audio'] + MockAudioSegment.empty.return_value = mocks['empty'] + + mock_diarization = MagicMock() + mock_diarization.labels.return_value = ["SPEAKER_00", "SPEAKER_01"] + mock_diarization.label_timeline.return_value = [MagicMock(start=0.0, end=1.0)] + + dummy_audio_path = tmp_path / "test.wav" + dummy_audio_path.touch() + + with caplog.at_level("WARNING"): + embeddings = diarizer._extract_speaker_embeddings(dummy_audio_path, mock_diarization) + + # Should have SPEAKER_00 but not SPEAKER_01 + assert "SPEAKER_00" in embeddings + assert "SPEAKER_01" not in embeddings + assert "Failed to extract embedding for SPEAKER_01" in caplog.text + class TestHuggingFaceApiDiarizer: """Test the HuggingFaceApiDiarizer."""