diff --git a/buzz/transcriber.py b/buzz/transcriber.py index 8f75f6f91..7c42912eb 100644 --- a/buzz/transcriber.py +++ b/buzz/transcriber.py @@ -126,8 +126,8 @@ def start(self, model_path: str): self.is_running = True try: with sounddevice.InputStream(samplerate=self.sample_rate, - device=self.input_device_index, dtype="float32", - channels=1, callback=self.stream_callback): + device=self.input_device_index, dtype="float32", + channels=1, callback=self.stream_callback): while self.is_running: self.mutex.acquire() if self.queue.size >= self.n_batch_samples: @@ -136,7 +136,7 @@ def start(self, model_path: str): self.mutex.release() logging.debug('Processing next frame, sample size = %s, queue size = %s, amplitude = %s', - samples.size, self.queue.size, self.amplitude(samples)) + samples.size, self.queue.size, self.amplitude(samples)) time_started = datetime.datetime.now() if self.transcription_options.model.model_type == ModelType.WHISPER: @@ -157,9 +157,9 @@ def start(self, model_path: str): else: assert isinstance(model, TransformersWhisper) result = model.transcribe(audio=samples, - language=self.transcription_options.language - if self.transcription_options.language is not None else 'en', - task=self.transcription_options.task.value) + language=self.transcription_options.language + if self.transcription_options.language is not None else 'en', + task=self.transcription_options.task.value) next_text: str = result.get('text') @@ -167,7 +167,7 @@ def start(self, model_path: str): initial_prompt += next_text logging.debug('Received next result, length = %s, time taken = %s', - len(next_text), datetime.datetime.now() - time_started) + len(next_text), datetime.datetime.now() - time_started) self.transcription.emit(next_text) else: self.mutex.release() @@ -367,19 +367,16 @@ def run(self): self.current_process.join() - logging.debug( - 'whisper process completed with code = %s, time taken = %s', - self.current_process.exitcode, datetime.datetime.now() - time_started) - if self.current_process.exitcode != 0: send_pipe.close() self.read_line_thread.join() - # TODO: fix error handling when process crashes - if self.current_process.exitcode != 0 and self.current_process.exitcode is not None: - self.completed.emit([]) + logging.debug( + 'whisper process completed with code = %s, time taken = %s, number of segments = %s', + self.current_process.exitcode, datetime.datetime.now() - time_started, len(self.segments)) + self.completed.emit(self.segments) self.running = False def stop(self): @@ -403,9 +400,7 @@ def read_line(self, pipe: Connection): end=segment.get('end'), text=segment.get('text'), ) for segment in segments_dict] - self.current_process.join() - # TODO: move this back to the parent thread - self.completed.emit(segments) + self.segments = segments else: try: progress = int(line.split('|')[0].strip().strip('%')) diff --git a/tests/transcriber_test.py b/tests/transcriber_test.py index b4c49f260..21f29b3bb 100644 --- a/tests/transcriber_test.py +++ b/tests/transcriber_test.py @@ -1,3 +1,4 @@ +import logging import os import pathlib import platform @@ -7,7 +8,7 @@ from unittest.mock import Mock, patch import pytest -from PyQt6.QtCore import QThread +from PyQt6.QtCore import QThread, QCoreApplication from pytestqt.qtbot import QtBot from buzz.model_loader import WhisperModelSize, ModelType, TranscriptionModel, ModelLoader @@ -135,7 +136,8 @@ def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segme file_path='testdata/whisper-french.mp3', model_path=model_path)) transcriber.progress.connect(mock_progress) transcriber.completed.connect(mock_completed) - with qtbot.wait_signal(transcriber.completed, timeout=10 * 6000): + with qtbot.wait_signal(transcriber.progress, timeout=10 * 6000), qtbot.wait_signal(transcriber.completed, + timeout=10 * 6000): transcriber.run() if check_progress: @@ -150,6 +152,7 @@ def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segme mock_completed.assert_called() segments = mock_completed.call_args[0][0] + assert len(segments) >= len(expected_segments) for (i, expected_segment) in enumerate(expected_segments): assert segments[i] == expected_segment