Skip to content

Commit

Permalink
Fix recording transcriber (#286)
Browse files Browse the repository at this point in the history
  • Loading branch information
chidiwilliams committed Jan 2, 2023
1 parent 611e623 commit 380e975
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 59 deletions.
33 changes: 24 additions & 9 deletions buzz/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,11 @@ def __init__(self, parent: Optional[QWidget]) -> None:
self.setDefault(True)
self.setSizePolicy(QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed))

def set_to_record(self):
def set_stopped(self):
self.setText('Record')
self.setDefault(True)

def set_to_stop(self):
def set_recording(self):
self.setText('Stop')
self.setDefault(False)

Expand Down Expand Up @@ -529,11 +529,10 @@ def on_record_button_clicked(self):
if self.current_status == self.RecordingStatus.STOPPED:
self.start_recording()
self.current_status = self.RecordingStatus.RECORDING
self.record_button.set_to_stop()
self.record_button.set_recording()
else: # RecordingStatus.RECORDING
self.stop_recording()
self.record_button.set_to_record()
self.current_status = self.RecordingStatus.STOPPED
self.set_recording_status_stopped()

def start_recording(self):
self.record_button.setDisabled(True)
Expand Down Expand Up @@ -567,6 +566,10 @@ def start_recording(self):
self.transcriber.finished.connect(self.transcription_thread.quit)
self.transcriber.finished.connect(self.transcriber.deleteLater)

self.transcriber.error.connect(self.on_transcriber_error)
self.transcriber.error.connect(self.transcription_thread.quit)
self.transcriber.error.connect(self.transcriber.deleteLater)

self.transcription_thread.start()

def on_download_model_progress(self, progress: Tuple[float, float]):
Expand All @@ -580,11 +583,15 @@ def on_download_model_progress(self, progress: Tuple[float, float]):
if self.model_download_progress_dialog is not None:
self.model_download_progress_dialog.set_fraction_completed(fraction_completed=current_size / total_size)

def set_recording_status_stopped(self):
self.record_button.set_stopped()
self.current_status = self.RecordingStatus.STOPPED

def on_download_model_error(self, error: str):
self.reset_model_download()
show_model_download_error_dialog(self, error)
self.stop_recording()
self.record_button.set_to_stop()
self.set_recording_status_stopped()
self.record_button.setDisabled(False)

def on_next_transcription(self, text: str):
Expand All @@ -603,13 +610,18 @@ def stop_recording(self):
self.record_button.setDisabled(True)

def on_transcriber_finished(self):
self.record_button.setEnabled(True)
self.reset_record_button()

def on_transcriber_error(self, error: str):
self.reset_record_button()
self.set_recording_status_stopped()
QMessageBox.critical(self, '', f'An error occurred while starting a new recording: {error}. Please check your audio devices or check the application logs for more information.')

def on_cancel_model_progress_dialog(self):
if self.model_loader is not None:
self.model_loader.stop()
self.reset_model_download()
self.record_button.set_to_stop()
self.set_recording_status_stopped()
self.record_button.setDisabled(False)

def reset_model_download(self):
Expand All @@ -620,11 +632,14 @@ def reset_model_download(self):
def reset_recording_controls(self):
# Clear text box placeholder because the first chunk takes a while to process
self.text_box.setPlaceholderText('')
self.record_button.setDisabled(False)
self.reset_record_button()
if self.model_download_progress_dialog is not None:
self.model_download_progress_dialog.close()
self.model_download_progress_dialog = None

def reset_record_button(self):
self.record_button.setEnabled(True)

def on_recording_amplitude_changed(self, amplitude: float):
self.audio_meter_widget.update_amplitude(amplitude)

Expand Down
15 changes: 10 additions & 5 deletions buzz/recording.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional

import logging
import numpy as np
import sounddevice
from PyQt6.QtCore import QObject, pyqtSignal
Expand All @@ -16,13 +17,17 @@ def __init__(self, input_device_index: Optional[int] = None,
self.input_device_index = input_device_index

def start_recording(self):
self.stream = sounddevice.InputStream(device=self.input_device_index, dtype='float32',
channels=1, callback=self.stream_callback)
self.stream.start()
try:
self.stream = sounddevice.InputStream(device=self.input_device_index, dtype='float32',
channels=1, callback=self.stream_callback)
self.stream.start()
except sounddevice.PortAudioError:
logging.exception('')

def stop_recording(self):
self.stream.stop()
self.stream.close()
if self.stream is not None:
self.stream.stop()
self.stream.close()

def stream_callback(self, in_data: np.ndarray, frame_count, time_info, status):
chunk = in_data.ravel()
Expand Down
96 changes: 51 additions & 45 deletions buzz/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class Status(enum.Enum):
class RecordingTranscriber(QObject):
transcription = pyqtSignal(str)
finished = pyqtSignal()
error = pyqtSignal(str)
is_running = False
MAX_QUEUE_SIZE = 10

Expand Down Expand Up @@ -123,52 +124,57 @@ def start(self, model_path: str):
self.transcription_options, model_path, self.sample_rate, self.input_device_index)

self.is_running = True
with sounddevice.InputStream(samplerate=self.sample_rate,
device=self.input_device_index, dtype="float32",
channels=1, callback=self.stream_callback):
while self.is_running:
self.mutex.acquire()
if self.queue.size >= self.n_batch_samples:
samples = self.queue[:self.n_batch_samples]
self.queue = self.queue[self.n_batch_samples:]
self.mutex.release()

logging.debug('Processing next frame, sample size = %s, queue size = %s, amplitude = %s',
samples.size, self.queue.size, self.amplitude(samples))
time_started = datetime.datetime.now()

if self.transcription_options.model.model_type == ModelType.WHISPER:
assert isinstance(model, whisper.Whisper)
result = model.transcribe(
audio=samples, language=self.transcription_options.language,
task=self.transcription_options.task.value,
initial_prompt=initial_prompt,
temperature=self.transcription_options.temperature)
elif self.transcription_options.model.model_type == ModelType.WHISPER_CPP:
assert isinstance(model, WhisperCpp)
result = model.transcribe(
audio=samples,
params=whisper_cpp_params(
language=self.transcription_options.language
if self.transcription_options.language is not None else 'en',
task=self.transcription_options.task.value, word_level_timings=False))
try:
with sounddevice.InputStream(samplerate=self.sample_rate,
device=self.input_device_index, dtype="float32",
channels=1, callback=self.stream_callback):
while self.is_running:
self.mutex.acquire()
if self.queue.size >= self.n_batch_samples:
samples = self.queue[:self.n_batch_samples]
self.queue = self.queue[self.n_batch_samples:]
self.mutex.release()

logging.debug('Processing next frame, sample size = %s, queue size = %s, amplitude = %s',
samples.size, self.queue.size, self.amplitude(samples))
time_started = datetime.datetime.now()

if self.transcription_options.model.model_type == ModelType.WHISPER:
assert isinstance(model, whisper.Whisper)
result = model.transcribe(
audio=samples, language=self.transcription_options.language,
task=self.transcription_options.task.value,
initial_prompt=initial_prompt,
temperature=self.transcription_options.temperature)
elif self.transcription_options.model.model_type == ModelType.WHISPER_CPP:
assert isinstance(model, WhisperCpp)
result = model.transcribe(
audio=samples,
params=whisper_cpp_params(
language=self.transcription_options.language
if self.transcription_options.language is not None else 'en',
task=self.transcription_options.task.value, word_level_timings=False))
else:
assert isinstance(model, TransformersWhisper)
result = model.transcribe(audio=samples,
language=self.transcription_options.language
if self.transcription_options.language is not None else 'en',
task=self.transcription_options.task.value)

next_text: str = result.get('text')

# Update initial prompt between successive recording chunks
initial_prompt += next_text

logging.debug('Received next result, length = %s, time taken = %s',
len(next_text), datetime.datetime.now() - time_started)
self.transcription.emit(next_text)
else:
assert isinstance(model, TransformersWhisper)
result = model.transcribe(audio=samples,
language=self.transcription_options.language
if self.transcription_options.language is not None else 'en',
task=self.transcription_options.task.value)

next_text: str = result.get('text')

# Update initial prompt between successive recording chunks
initial_prompt += next_text

logging.debug('Received next result, length = %s, time taken = %s',
len(next_text), datetime.datetime.now() - time_started)
self.transcription.emit(next_text)
else:
self.mutex.release()
self.mutex.release()
except PortAudioError as exc:
self.error.emit(str(exc))
logging.exception('')
return

self.finished.emit()

Expand Down

0 comments on commit 380e975

Please sign in to comment.