Skip to content

Commit

Permalink
Move transcriptions to individual cache files (#519)
Browse files Browse the repository at this point in the history
  • Loading branch information
chidiwilliams committed Jul 4, 2023
1 parent 2042947 commit f83d2d6
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 20 deletions.
50 changes: 42 additions & 8 deletions buzz/cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import logging
import json
import os
import pickle
from typing import List
Expand All @@ -11,22 +11,56 @@
class TasksCache:
def __init__(self, cache_dir=user_cache_dir('Buzz')):
os.makedirs(cache_dir, exist_ok=True)
self.file_path = os.path.join(cache_dir, 'tasks')
self.cache_dir = cache_dir
self.pickle_cache_file_path = os.path.join(cache_dir, 'tasks')
self.tasks_list_file_path = os.path.join(cache_dir, 'tasks.json')

def save(self, tasks: List[FileTranscriptionTask]):
with open(self.file_path, 'wb') as file:
pickle.dump(tasks, file)
self.save_json_tasks(tasks=tasks)

def load(self) -> List[FileTranscriptionTask]:
if os.path.exists(self.tasks_list_file_path):
return self.load_json_tasks()

try:
with open(self.file_path, 'rb') as file:
with open(self.pickle_cache_file_path, 'rb') as file:
return pickle.load(file)
except FileNotFoundError:
return []
except (pickle.UnpicklingError, AttributeError, ValueError): # delete corrupted cache
os.remove(self.file_path)
os.remove(self.pickle_cache_file_path)
return []

def load_json_tasks(self) -> List[FileTranscriptionTask]:
with open(self.tasks_list_file_path, 'r') as file:
task_ids = json.load(file)

tasks = []
for task_id in task_ids:
try:
with open(self.get_task_path(task_id=task_id)) as file:
tasks.append(FileTranscriptionTask.from_json(file.read()))
except FileNotFoundError:
pass

return tasks

def save_json_tasks(self, tasks: List[FileTranscriptionTask]):
json_str = json.dumps([task.id for task in tasks])
with open(self.tasks_list_file_path, "w") as file:
file.write(json_str)

for task in tasks:
file_path = self.get_task_path(task_id=task.id)
json_str = task.to_json()
with open(file_path, "w") as file:
file.write(json_str)

def get_task_path(self, task_id: int):
path = os.path.join(self.cache_dir, 'transcriptions', f'{task_id}.json')
os.makedirs(os.path.dirname(path), exist_ok=True)
return path

def clear(self):
if os.path.exists(self.file_path):
os.remove(self.file_path)
if os.path.exists(self.pickle_cache_file_path):
os.remove(self.pickle_cache_file_path)
8 changes: 6 additions & 2 deletions buzz/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,9 +866,12 @@ def on_file_transcriber_triggered(self, options: Tuple[TranscriptionOptions, Fil
file_path, transcription_options, file_transcription_options, model_path)
self.add_task(task)

def update_task_table_row(self, task: FileTranscriptionTask):
def load_task(self, task: FileTranscriptionTask):
self.table_widget.upsert_task(task)
self.tasks[task.id] = task

def update_task_table_row(self, task: FileTranscriptionTask):
self.load_task(task=task)
self.tasks_changed.emit()

@staticmethod
Expand Down Expand Up @@ -965,6 +968,7 @@ def open_transcription_viewer(self, task_id: int):

transcription_viewer_widget = TranscriptionViewerWidget(
transcription_task=task, parent=self, flags=Qt.WindowType.Window)
transcription_viewer_widget.task_changed.connect(self.on_tasks_changed)
transcription_viewer_widget.show()

def add_task(self, task: FileTranscriptionTask):
Expand All @@ -978,7 +982,7 @@ def load_tasks_from_cache(self):
task.status = None
self.transcriber_worker.add_task(task)
else:
self.update_task_table_row(task)
self.load_task(task=task)

def save_tasks_to_cache(self):
self.tasks_cache.save(list(self.tasks.values()))
Expand Down
8 changes: 4 additions & 4 deletions buzz/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ class TranscriptionModel:
}


def get_hugging_face_dataset_file_url(author: str, repository_name: str, filename: str):
return f'https://huggingface.co/datasets/{author}/{repository_name}/resolve/main/{filename}'
def get_hugging_face_file_url(author: str, repository_name: str, filename: str):
return f'https://huggingface.co/{author}/{repository_name}/resolve/main/{filename}'


def get_whisper_cpp_file_path(size: WhisperModelSize) -> str:
Expand Down Expand Up @@ -132,8 +132,8 @@ def __init__(self, model: TranscriptionModel):
def run(self) -> None:
if self.model.model_type == ModelType.WHISPER_CPP:
model_name = self.model.whisper_model_size.value
url = get_hugging_face_dataset_file_url(author='ggerganov', repository_name='whisper.cpp',
filename=f'ggml-{model_name}.bin')
url = get_hugging_face_file_url(author='ggerganov', repository_name='whisper.cpp',
filename=f'ggml-{model_name}.bin')
file_path = get_whisper_cpp_file_path(
size=self.model.whisper_model_size)
expected_sha256 = WHISPER_CPP_MODELS_SHA256[model_name]
Expand Down
6 changes: 4 additions & 2 deletions buzz/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import tqdm
import whisper
from PyQt6.QtCore import QObject, pyqtSignal, pyqtSlot
from dataclasses_json import dataclass_json, config, Exclude
from whisper import tokenizer

from . import transformers_whisper
Expand Down Expand Up @@ -65,7 +66,7 @@ class TranscriptionOptions:
word_level_timings: bool = False
temperature: Tuple[float, ...] = DEFAULT_WHISPER_TEMPERATURE
initial_prompt: str = ''
openai_access_token: str = ''
openai_access_token: str = field(default='', metadata=config(exclude=Exclude.ALWAYS))


@dataclass()
Expand All @@ -74,6 +75,7 @@ class FileTranscriptionOptions:
output_formats: Set['OutputFormat'] = field(default_factory=set)


@dataclass_json
@dataclass
class FileTranscriptionTask:
class Status(enum.Enum):
Expand All @@ -87,7 +89,7 @@ class Status(enum.Enum):
transcription_options: TranscriptionOptions
file_transcription_options: FileTranscriptionOptions
model_path: str
id: int = field(default_factory=lambda: randint(0, 1_000_000))
id: int = field(default_factory=lambda: randint(0, 100_000_000))
segments: List[Segment] = field(default_factory=list)
status: Optional[Status] = None
fraction_completed = 0.0
Expand Down
10 changes: 7 additions & 3 deletions buzz/widgets/transcription_viewer_widget.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Optional

from PyQt6.QtCore import Qt
from PyQt6.QtCore import Qt, pyqtSignal
from PyQt6.QtGui import QUndoCommand, QUndoStack, QKeySequence, QAction
from PyQt6.QtWidgets import QWidget, QHBoxLayout, QMenu, QPushButton, QVBoxLayout, QFileDialog

Expand All @@ -16,16 +16,18 @@

class TranscriptionViewerWidget(QWidget):
transcription_task: FileTranscriptionTask
task_changed = pyqtSignal()

class ChangeSegmentTextCommand(QUndoCommand):
def __init__(self, table_widget: TranscriptionSegmentsEditorWidget, segments: List[Segment],
segment_index: int, segment_text: str):
segment_index: int, segment_text: str, task_changed: pyqtSignal):
super().__init__()

self.table_widget = table_widget
self.segments = segments
self.segment_index = segment_index
self.segment_text = segment_text
self.task_changed = task_changed

self.previous_segment_text = self.segments[self.segment_index].text

Expand All @@ -41,6 +43,7 @@ def set_segment_text(self, text: str):
self.table_widget.set_segment_text(self.segment_index, text)
self.table_widget.blockSignals(False)
self.segments[self.segment_index].text = text
self.task_changed.emit()

def __init__(
self, transcription_task: FileTranscriptionTask,
Expand Down Expand Up @@ -102,7 +105,8 @@ def on_segment_text_changed(self, event: tuple):
segment_index, segment_text = event
self.undo_stack.push(
self.ChangeSegmentTextCommand(table_widget=self.table_widget, segments=self.transcription_task.segments,
segment_index=segment_index, segment_text=segment_text))
segment_index=segment_index, segment_text=segment_text,
task_changed=self.task_changed))

def on_menu_triggered(self, action: QAction):
output_format = OutputFormat[action.text()]
Expand Down
86 changes: 85 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ faster-whisper = "^0.4.1"
keyring = "^23.13.1"
openai-whisper = "v20230124"
platformdirs = "^3.5.3"
dataclasses-json = "^0.5.9"

[tool.poetry.group.dev.dependencies]
autopep8 = "^1.7.0"
Expand Down

0 comments on commit f83d2d6

Please sign in to comment.