Skip to content

Commit

Permalink
Add support for OpenAI Whisper API (#371)
Browse files Browse the repository at this point in the history
  • Loading branch information
chidiwilliams committed Mar 9, 2023
1 parent 3a59172 commit 62ed04d
Show file tree
Hide file tree
Showing 7 changed files with 1,259 additions and 747 deletions.
4 changes: 3 additions & 1 deletion README.md
Expand Up @@ -19,7 +19,9 @@ OpenAI's [Whisper](https://github.com/openai/whisper).
VTT ([Demo](https://www.loom.com/share/cf263b099ac3481082bb56d19b7c87fe))
- Supports [Whisper](https://github.com/openai/whisper#available-models-and-languages),
[Whisper.cpp](https://github.com/ggerganov/whisper.cpp),
and [Whisper-compatible Hugging Face models](https://huggingface.co/models?other=whisper)
[Whisper-compatible Hugging Face models](https://huggingface.co/models?other=whisper), and
the [OpenAI Whisper API](https://platform.openai.com/docs/api-reference/introduction)
- Available on Mac, Windows, and Linux

## Installation

Expand Down
64 changes: 46 additions & 18 deletions buzz/gui.py
Expand Up @@ -4,11 +4,10 @@
import logging
import os
import platform
import random
import sys
from datetime import datetime
from enum import auto
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple

import humanize
import sounddevice
Expand All @@ -33,7 +32,6 @@
from .recording import RecordingAmplitudeListener
from .transcriber import (SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, OutputFormat,
Task,
WhisperCppFileTranscriber, WhisperFileTranscriber,
get_default_output_file_path, segments_to_text, write_output, TranscriptionOptions,
FileTranscriberQueueWorker, FileTranscriptionTask, RecordingTranscriber, LOADED_WHISPER_DLL)

Expand Down Expand Up @@ -217,24 +215,23 @@ def show_model_download_error_dialog(parent: QWidget, error: str):

class FileTranscriberWidget(QWidget):
model_download_progress_dialog: Optional[DownloadModelProgressDialog] = None
file_transcriber: Optional[Union[WhisperFileTranscriber,
WhisperCppFileTranscriber]] = None
model_loader: Optional[ModelLoader] = None
transcriber_thread: Optional[QThread] = None
file_transcription_options: FileTranscriptionOptions
transcription_options: TranscriptionOptions
is_transcribing = False
# (TranscriptionOptions, FileTranscriptionOptions, str)
triggered = pyqtSignal(tuple)
openai_access_token_changed = pyqtSignal(str)

def __init__(self, file_paths: List[str], parent: Optional[QWidget] = None,
flags: Qt.WindowType = Qt.WindowType.Widget) -> None:
def __init__(self, file_paths: List[str], openai_access_token: Optional[str] = None,
parent: Optional[QWidget] = None, flags: Qt.WindowType = Qt.WindowType.Widget) -> None:
super().__init__(parent, flags)

self.setWindowTitle(file_paths_as_title(file_paths))

self.file_paths = file_paths
self.transcription_options = TranscriptionOptions()
self.transcription_options = TranscriptionOptions(openai_access_token=openai_access_token)
self.file_transcription_options = FileTranscriptionOptions(
file_paths=self.file_paths)

Expand Down Expand Up @@ -266,7 +263,9 @@ def __init__(self, file_paths: List[str], parent: Optional[QWidget] = None,
def on_transcription_options_changed(self, transcription_options: TranscriptionOptions):
self.transcription_options = transcription_options
self.word_level_timings_checkbox.setDisabled(
self.transcription_options.model.model_type == ModelType.HUGGING_FACE)
self.transcription_options.model.model_type == ModelType.HUGGING_FACE or self.transcription_options.model.model_type == ModelType.OPEN_AI_WHISPER_API)
if self.transcription_options.openai_access_token is not None:
self.openai_access_token_changed.emit(self.transcription_options.openai_access_token)

def on_click_run(self):
self.run_button.setDisabled(True)
Expand Down Expand Up @@ -503,7 +502,10 @@ def __init__(self, parent: Optional[QWidget] = None, flags: Optional[Qt.WindowTy
self.text_box.setPlaceholderText(_('Click Record to begin...'))

transcription_options_group_box = TranscriptionOptionsGroupBox(
default_transcription_options=self.transcription_options, parent=self)
default_transcription_options=self.transcription_options,
# Live transcription with OpenAI Whisper API not implemented
model_types=[model_type for model_type in ModelType if model_type is not ModelType.OPEN_AI_WHISPER_API],
parent=self)
transcription_options_group_box.transcription_options_changed.connect(
self.on_transcription_options_changed)

Expand Down Expand Up @@ -820,7 +822,7 @@ def upsert_task(self, task: FileTranscriptionTask):
elif task.status == FileTranscriptionTask.Status.COMPLETED:
status_widget.setText(_('Completed'))
elif task.status == FileTranscriptionTask.Status.FAILED:
status_widget.setText(_('Failed'))
status_widget.setText(f'{_("Failed")} ({task.error})')
elif task.status == FileTranscriptionTask.Status.CANCELED:
status_widget.setText(_('Canceled'))

Expand Down Expand Up @@ -925,6 +927,7 @@ class MainWindow(QMainWindow):
table_widget: TranscriptionTasksTableWidget
tasks: Dict[int, 'FileTranscriptionTask']
tasks_changed = pyqtSignal()
openai_access_token: Optional[str] = None

def __init__(self, tasks_cache=TasksCache()):
super().__init__(flags=Qt.WindowType.Window)
Expand Down Expand Up @@ -1026,11 +1029,17 @@ def on_new_transcription_action_triggered(self):
return

file_transcriber_window = FileTranscriberWidget(
file_paths, self, flags=Qt.WindowType.Window)
file_paths, self.openai_access_token, self, flags=Qt.WindowType.Window)
file_transcriber_window.triggered.connect(
self.on_file_transcriber_triggered)
file_transcriber_window.openai_access_token_changed.connect(self.on_openai_access_token_changed)
file_transcriber_window.show()

# Save the access token on the main window so the user doesn't need to re-enter (at least, not while the app is
# still open)
def on_openai_access_token_changed(self, access_token: str):
self.openai_access_token = access_token

def on_open_transcript_action_triggered(self):
selected_rows = self.table_widget.selectionModel().selectedRows()
for selected_row in selected_rows:
Expand Down Expand Up @@ -1092,6 +1101,7 @@ def on_tasks_changed(self):
self.toolbar.set_open_transcript_action_enabled(self.should_enable_open_transcript_action())
self.toolbar.set_stop_transcription_action_enabled(self.should_enable_stop_transcription_action())
self.toolbar.set_clear_history_action_enabled(self.should_enable_clear_history_action())
self.save_tasks_to_cache()

def closeEvent(self, event: QtGui.QCloseEvent) -> None:
self.transcriber_worker.stop()
Expand Down Expand Up @@ -1236,6 +1246,7 @@ class TranscriptionOptionsGroupBox(QGroupBox):
transcription_options_changed = pyqtSignal(TranscriptionOptions)

def __init__(self, default_transcription_options: TranscriptionOptions = TranscriptionOptions(),
model_types: Optional[List[ModelType]] = None,
parent: Optional[QWidget] = None):
super().__init__(title='', parent=parent)
self.transcription_options = default_transcription_options
Expand All @@ -1261,7 +1272,9 @@ def __init__(self, default_transcription_options: TranscriptionOptions = Transcr
self.hugging_face_search_line_edit.model_selected.connect(self.on_hugging_face_model_changed)

self.model_type_combo_box = QComboBox(self)
for model_type in ModelType:
if model_types is None:
model_types = [model_type for model_type in ModelType]
for model_type in model_types:
# Hide Whisper.cpp option is whisper.dll did not load correctly.
# See: https://github.com/chidiwilliams/buzz/issues/274, https://github.com/chidiwilliams/buzz/issues/197
if model_type == ModelType.WHISPER_CPP and LOADED_WHISPER_DLL is False:
Expand All @@ -1277,18 +1290,28 @@ def __init__(self, default_transcription_options: TranscriptionOptions = Transcr
default_transcription_options.model.whisper_model_size.value.title())
self.whisper_model_size_combo_box.currentTextChanged.connect(self.on_whisper_model_size_changed)

self.form_layout.addRow(_('Task:'), self.tasks_combo_box)
self.form_layout.addRow(_('Language:'), self.languages_combo_box)
self.openai_access_token_edit = QLineEdit(self)
self.openai_access_token_edit.setText(default_transcription_options.openai_access_token)
self.openai_access_token_edit.setEchoMode(QLineEdit.EchoMode.Password)
self.openai_access_token_edit.textChanged.connect(self.on_openai_access_token_edit_changed)

self.form_layout.addRow(_('Model:'), self.model_type_combo_box)
self.form_layout.addRow('', self.whisper_model_size_combo_box)
self.form_layout.addRow('', self.hugging_face_search_line_edit)
self.form_layout.addRow('Access Token:', self.openai_access_token_edit)
self.form_layout.addRow(_('Task:'), self.tasks_combo_box)
self.form_layout.addRow(_('Language:'), self.languages_combo_box)

self.form_layout.setRowVisible(self.hugging_face_search_line_edit, False)
self.reset_visible_rows()

self.form_layout.addRow('', self.advanced_settings_button)

self.setLayout(self.form_layout)

def on_openai_access_token_edit_changed(self, access_token: str):
self.transcription_options.openai_access_token = access_token
self.transcription_options_changed.emit(self.transcription_options)

def on_language_changed(self, language: str):
self.transcription_options.language = language
self.transcription_options_changed.emit(self.transcription_options)
Expand Down Expand Up @@ -1316,12 +1339,17 @@ def on_transcription_options_changed(self, transcription_options: TranscriptionO
self.transcription_options = transcription_options
self.transcription_options_changed.emit(transcription_options)

def on_model_type_changed(self, text: str):
model_type = ModelType(text)
def reset_visible_rows(self):
model_type = self.transcription_options.model.model_type
self.form_layout.setRowVisible(self.hugging_face_search_line_edit, model_type == ModelType.HUGGING_FACE)
self.form_layout.setRowVisible(self.whisper_model_size_combo_box,
(model_type == ModelType.WHISPER) or (model_type == ModelType.WHISPER_CPP))
self.form_layout.setRowVisible(self.openai_access_token_edit, model_type == ModelType.OPEN_AI_WHISPER_API)

def on_model_type_changed(self, text: str):
model_type = ModelType(text)
self.transcription_options.model.model_type = model_type
self.reset_visible_rows()
self.transcription_options_changed.emit(self.transcription_options)

def on_whisper_model_size_changed(self, text: str):
Expand Down
9 changes: 8 additions & 1 deletion buzz/model_loader.py
Expand Up @@ -26,6 +26,7 @@ class ModelType(enum.Enum):
WHISPER = 'Whisper'
WHISPER_CPP = 'Whisper.cpp'
HUGGING_FACE = 'Hugging Face'
OPEN_AI_WHISPER_API = 'OpenAI Whisper API'


@dataclass()
Expand Down Expand Up @@ -82,7 +83,7 @@ def run(self):
expected_sha256 = url.split('/')[-2]
self.download_model(url, file_path, expected_sha256)

else: # ModelType.HUGGING_FACE:
elif self.model_type == ModelType.HUGGING_FACE:
self.progress.emit((0, 100))

try:
Expand All @@ -95,6 +96,12 @@ def run(self):
self.progress.emit((100, 100))
file_path = self.hugging_face_model_id

elif self.model_type == ModelType.OPEN_AI_WHISPER_API:
file_path = ""

else:
raise Exception("Invalid model type: " + self.model_type.value)

self.finished.emit(file_path)

def download_model(self, url: str, file_path: str, expected_sha256: Optional[str]):
Expand Down
84 changes: 72 additions & 12 deletions buzz/transcriber.py
Expand Up @@ -12,11 +12,13 @@
import sys
import tempfile
import threading
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from multiprocessing.connection import Connection
from random import randint
from threading import Thread
from typing import Any, List, Optional, Tuple, Union
import openai

import ffmpeg
import numpy as np
Expand Down Expand Up @@ -61,10 +63,11 @@ class Segment:
class TranscriptionOptions:
language: Optional[str] = None
task: Task = Task.TRANSCRIBE
model: TranscriptionModel = TranscriptionModel()
model: TranscriptionModel = field(default_factory=TranscriptionModel)
word_level_timings: bool = False
temperature: Tuple[float, ...] = DEFAULT_WHISPER_TEMPERATURE
initial_prompt: str = ''
openai_access_token: Optional[str] = None


@dataclass()
Expand Down Expand Up @@ -219,17 +222,34 @@ class OutputFormat(enum.Enum):
VTT = 'vtt'


class WhisperCppFileTranscriber(QObject):
class FileTranscriber(QObject):
transcription_task: FileTranscriptionTask
progress = pyqtSignal(tuple) # (current, total)
completed = pyqtSignal(list) # List[Segment]
error = pyqtSignal(str)

def __init__(self, task: FileTranscriptionTask,
parent: Optional['QObject'] = None):
super().__init__(parent)
self.transcription_task = task

@abstractmethod
def run(self):
...

@abstractmethod
def stop(self):
...


class WhisperCppFileTranscriber(FileTranscriber):
duration_audio_ms = sys.maxsize # max int
segments: List[Segment]
running = False

def __init__(self, task: FileTranscriptionTask,
parent: Optional['QObject'] = None) -> None:
super().__init__(parent)
super().__init__(task, parent)

self.file_path = task.file_path
self.language = task.transcription_options.language
Expand Down Expand Up @@ -332,22 +352,60 @@ def read_std_err(self):
pass


class WhisperFileTranscriber(QObject):
class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
def __init__(self, task: FileTranscriptionTask, parent: Optional['QObject'] = None):
super().__init__(task=task, parent=parent)
self.file_path = task.file_path
self.task = task.transcription_options.task

@pyqtSlot()
def run(self):
try:
logging.debug('Starting OpenAI Whisper API file transcription, file path = %s, task = %s', self.file_path,
self.task)

wav_file = tempfile.mktemp() + '.wav'
(
ffmpeg.input(self.file_path)
.output(wav_file, acodec="pcm_s16le", ac=1, ar=whisper.audio.SAMPLE_RATE)
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
)

# TODO: Check if file size is more than 25MB (2.5 minutes), then chunk
audio_file = open(wav_file, "rb")
openai.api_key = self.transcription_task.transcription_options.openai_access_token
language = self.transcription_task.transcription_options.language
response_format = "verbose_json"
if self.transcription_task.transcription_options.task == Task.TRANSLATE:
transcript = openai.Audio.translate("whisper-1", audio_file, response_format=response_format,
language=language)
else:
transcript = openai.Audio.transcribe("whisper-1", audio_file, response_format=response_format,
language=language)

segments = [Segment(segment["start"] * 1000, segment["end"] * 1000, segment["text"]) for segment in
transcript["segments"]]
self.completed.emit(segments)
except Exception as exc:
self.error.emit(str(exc))
logging.exception('')

def stop(self):
pass


class WhisperFileTranscriber(FileTranscriber):
"""WhisperFileTranscriber transcribes an audio file to text, writes the text to a file, and then opens the file
using the default program for opening txt files. """

current_process: multiprocessing.Process
progress = pyqtSignal(tuple) # (current, total)
completed = pyqtSignal(list) # List[Segment]
error = pyqtSignal(str)
running = False
read_line_thread: Optional[Thread] = None
READ_LINE_THREAD_STOP_TOKEN = '--STOP--'

def __init__(self, task: FileTranscriptionTask,
parent: Optional['QObject'] = None) -> None:
super().__init__(parent)
self.transcription_task = task
super().__init__(task, parent)
self.segments = []
self.started_process = False
self.stopped = False
Expand Down Expand Up @@ -570,8 +628,7 @@ def __del__(self):
class FileTranscriberQueueWorker(QObject):
tasks_queue: multiprocessing.Queue
current_task: Optional[FileTranscriptionTask] = None
current_transcriber: Optional[WhisperFileTranscriber |
WhisperCppFileTranscriber] = None
current_transcriber: Optional[FileTranscriber] = None
current_transcriber_thread: Optional[QThread] = None
task_updated = pyqtSignal(FileTranscriptionTask)
completed = pyqtSignal()
Expand Down Expand Up @@ -605,9 +662,12 @@ def run(self):

logging.debug('Starting next transcription task')

if self.current_task.transcription_options.model.model_type == ModelType.WHISPER_CPP:
model_type = self.current_task.transcription_options.model.model_type
if model_type == ModelType.WHISPER_CPP:
self.current_transcriber = WhisperCppFileTranscriber(
task=self.current_task)
elif model_type == ModelType.OPEN_AI_WHISPER_API:
self.current_transcriber = OpenAIWhisperAPIFileTranscriber(task=self.current_task)
else:
self.current_transcriber = WhisperFileTranscriber(
task=self.current_task)
Expand Down

0 comments on commit 62ed04d

Please sign in to comment.