Skip to content

Commit

Permalink
Add segments debug logging (#294)
Browse files Browse the repository at this point in the history
  • Loading branch information
chidiwilliams committed Jan 3, 2023
1 parent 40b0236 commit 614b395
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 19 deletions.
29 changes: 12 additions & 17 deletions buzz/transcriber.py
Expand Up @@ -126,8 +126,8 @@ def start(self, model_path: str):
self.is_running = True
try:
with sounddevice.InputStream(samplerate=self.sample_rate,
device=self.input_device_index, dtype="float32",
channels=1, callback=self.stream_callback):
device=self.input_device_index, dtype="float32",
channels=1, callback=self.stream_callback):
while self.is_running:
self.mutex.acquire()
if self.queue.size >= self.n_batch_samples:
Expand All @@ -136,7 +136,7 @@ def start(self, model_path: str):
self.mutex.release()

logging.debug('Processing next frame, sample size = %s, queue size = %s, amplitude = %s',
samples.size, self.queue.size, self.amplitude(samples))
samples.size, self.queue.size, self.amplitude(samples))
time_started = datetime.datetime.now()

if self.transcription_options.model.model_type == ModelType.WHISPER:
Expand All @@ -157,17 +157,17 @@ def start(self, model_path: str):
else:
assert isinstance(model, TransformersWhisper)
result = model.transcribe(audio=samples,
language=self.transcription_options.language
if self.transcription_options.language is not None else 'en',
task=self.transcription_options.task.value)
language=self.transcription_options.language
if self.transcription_options.language is not None else 'en',
task=self.transcription_options.task.value)

next_text: str = result.get('text')

# Update initial prompt between successive recording chunks
initial_prompt += next_text

logging.debug('Received next result, length = %s, time taken = %s',
len(next_text), datetime.datetime.now() - time_started)
len(next_text), datetime.datetime.now() - time_started)
self.transcription.emit(next_text)
else:
self.mutex.release()
Expand Down Expand Up @@ -367,19 +367,16 @@ def run(self):

self.current_process.join()

logging.debug(
'whisper process completed with code = %s, time taken = %s',
self.current_process.exitcode, datetime.datetime.now() - time_started)

if self.current_process.exitcode != 0:
send_pipe.close()

self.read_line_thread.join()

# TODO: fix error handling when process crashes
if self.current_process.exitcode != 0 and self.current_process.exitcode is not None:
self.completed.emit([])
logging.debug(
'whisper process completed with code = %s, time taken = %s, number of segments = %s',
self.current_process.exitcode, datetime.datetime.now() - time_started, len(self.segments))

self.completed.emit(self.segments)
self.running = False

def stop(self):
Expand All @@ -403,9 +400,7 @@ def read_line(self, pipe: Connection):
end=segment.get('end'),
text=segment.get('text'),
) for segment in segments_dict]
self.current_process.join()
# TODO: move this back to the parent thread
self.completed.emit(segments)
self.segments = segments
else:
try:
progress = int(line.split('|')[0].strip().strip('%'))
Expand Down
7 changes: 5 additions & 2 deletions tests/transcriber_test.py
@@ -1,3 +1,4 @@
import logging
import os
import pathlib
import platform
Expand All @@ -7,7 +8,7 @@
from unittest.mock import Mock, patch

import pytest
from PyQt6.QtCore import QThread
from PyQt6.QtCore import QThread, QCoreApplication
from pytestqt.qtbot import QtBot

from buzz.model_loader import WhisperModelSize, ModelType, TranscriptionModel, ModelLoader
Expand Down Expand Up @@ -135,7 +136,8 @@ def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segme
file_path='testdata/whisper-french.mp3', model_path=model_path))
transcriber.progress.connect(mock_progress)
transcriber.completed.connect(mock_completed)
with qtbot.wait_signal(transcriber.completed, timeout=10 * 6000):
with qtbot.wait_signal(transcriber.progress, timeout=10 * 6000), qtbot.wait_signal(transcriber.completed,
timeout=10 * 6000):
transcriber.run()

if check_progress:
Expand All @@ -150,6 +152,7 @@ def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segme

mock_completed.assert_called()
segments = mock_completed.call_args[0][0]
assert len(segments) >= len(expected_segments)
for (i, expected_segment) in enumerate(expected_segments):
assert segments[i] == expected_segment

Expand Down

0 comments on commit 614b395

Please sign in to comment.