Skip to content

Commit

Permalink
Adding fix for multi-byte segments in whisper.cpp (#734)
Browse files Browse the repository at this point in the history
Co-authored-by: Chidi Williams <williamschidi1@gmail.com>
  • Loading branch information
raivisdejus and chidiwilliams committed May 14, 2024
1 parent ca49b8e commit 38f5d26
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 17 deletions.
75 changes: 60 additions & 15 deletions buzz/transcriber/whisper_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,26 @@
class WhisperCpp:
def __init__(self, model: str) -> None:
self.ctx = whisper_cpp.whisper_init_from_file(model.encode())
self.segments: List[Segment] = []

def append_segment(self, txt: bytes, start: int, end: int):
if txt == b'':
return True

# try-catch will guard against multi-byte utf-8 characters
# https://github.com/ggerganov/whisper.cpp/issues/1798
try:
self.segments.append(
Segment(
start=start * 10, # centisecond to ms
end=end * 10, # centisecond to ms
text=txt.decode("utf-8"),
)
)

return True
except UnicodeDecodeError:
return False

def transcribe(self, audio: Union[np.ndarray, str], params: Any):
if isinstance(audio, str):
Expand All @@ -29,25 +49,50 @@ def transcribe(self, audio: Union[np.ndarray, str], params: Any):
if result != 0:
raise Exception(f"Error from whisper.cpp: {result}")

segments: List[Segment] = []
n_segments = whisper_cpp.whisper_full_n_segments(self.ctx)

n_segments = whisper_cpp.whisper_full_n_segments((self.ctx))
for i in range(n_segments):
txt = whisper_cpp.whisper_full_get_segment_text((self.ctx), i)
t0 = whisper_cpp.whisper_full_get_segment_t0((self.ctx), i)
t1 = whisper_cpp.whisper_full_get_segment_t1((self.ctx), i)
if params.token_timestamps:
# Will process word timestamps
txt_buffer = b''
txt_start = 0
txt_end = 0

segments.append(
Segment(
start=t0 * 10, # centisecond to ms
end=t1 * 10, # centisecond to ms
text=txt.decode("utf-8"),
)
)
for i in range(n_segments):
txt = whisper_cpp.whisper_full_get_segment_text(self.ctx, i)
start = whisper_cpp.whisper_full_get_segment_t0(self.ctx, i)
end = whisper_cpp.whisper_full_get_segment_t1(self.ctx, i)

if txt.startswith(b' ') and self.append_segment(txt_buffer, txt_start, txt_end):
txt_buffer = txt
txt_start = start
txt_end = end
continue

if txt.startswith(b', '):
txt_buffer += b','
self.append_segment(txt_buffer, txt_start, txt_end)
txt_buffer = txt.lstrip(b',')
txt_start = start
txt_end = end
continue

txt_buffer += txt
txt_end = end

# Append the last segment
self.append_segment(txt_buffer, txt_start, txt_end)

else:
for i in range(n_segments):
txt = whisper_cpp.whisper_full_get_segment_text(self.ctx, i)
start = whisper_cpp.whisper_full_get_segment_t0(self.ctx, i)
end = whisper_cpp.whisper_full_get_segment_t1(self.ctx, i)

self.append_segment(txt, start, end)

return {
"segments": segments,
"text": "".join([segment.text for segment in segments]),
"segments": self.segments,
"text": "".join([segment.text for segment in self.segments]),
}

def __del__(self):
Expand Down
Binary file added testdata/whisper-latvian.wav
Binary file not shown.
4 changes: 4 additions & 0 deletions tests/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@
test_audio_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../testdata/whisper-french.mp3")
)

test_multibyte_utf8_audio_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../testdata/whisper-latvian.wav")
)
64 changes: 62 additions & 2 deletions tests/transcriber/whisper_cpp_file_transcriber_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
FileTranscriptionTask,
)
from buzz.transcriber.whisper_cpp_file_transcriber import WhisperCppFileTranscriber
from tests.audio import test_audio_path
from tests.audio import test_audio_path, test_multibyte_utf8_audio_path
from tests.model_loader import get_model_path


Expand All @@ -25,7 +25,7 @@ class TestWhisperCppFileTranscriber:
False,
[Segment(0, 6560, "Bienvenue dans Passe-Relle. Un podcast pensé pour")],
),
(True, [Segment(30, 330, "Bien"), Segment(330, 740, "venue")]),
(True, [Segment(30, 740, "Bienvenue"), Segment(740, 1070, " dans")]),
],
)
def test_transcribe(
Expand Down Expand Up @@ -75,3 +75,63 @@ def test_transcribe(
assert expected_segment.start == segments[i].start
assert expected_segment.end == segments[i].end
assert expected_segment.text in segments[i].text

@pytest.mark.parametrize(
"word_level_timings,expected_segments",
[
(
False,
[Segment(0, 7000, " Mani uzstrauts, laikabstākļi, tapēc uz jūru, es diezvajī braukša.")],
),
(True, [Segment(380, 500, " Mani"), Segment(500, 1880, " uzstrauts,"), Segment(1880, 3920, " laikabstākļi")]),
],
)
# Problematic part is in "laikabstākļi" where "ļ" gets returned from whisper.cpp in two segments
# First segment has first byte b'\xc4' and the second has second byte b'\xbc'.
def test_transcribe_latvian(
self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment]
):
file_transcription_options = FileTranscriptionOptions(
file_paths=[test_multibyte_utf8_audio_path]
)
transcription_options = TranscriptionOptions(
language="lv",
task=Task.TRANSCRIBE,
word_level_timings=word_level_timings,
model=TranscriptionModel(
model_type=ModelType.WHISPER_CPP,
whisper_model_size=WhisperModelSize.TINY,
),
)

model_path = get_model_path(transcription_options.model)
transcriber = WhisperCppFileTranscriber(
task=FileTranscriptionTask(
file_path=test_multibyte_utf8_audio_path,
transcription_options=transcription_options,
file_transcription_options=file_transcription_options,
model_path=model_path,
)
)
mock_progress = Mock(side_effect=lambda value: print("progress: ", value))
mock_completed = Mock()
mock_error = Mock()
transcriber.progress.connect(mock_progress)
transcriber.completed.connect(mock_completed)
transcriber.error.connect(mock_error)

with qtbot.wait_signal(transcriber.completed, timeout=10 * 60 * 1000):
transcriber.run()

mock_error.assert_not_called()

mock_progress.assert_called()
segments = [
segment
for segment in mock_completed.call_args[0][0]
if len(segment.text) > 0
]
for i, expected_segment in enumerate(expected_segments):
assert expected_segment.start == segments[i].start
assert expected_segment.end == segments[i].end
assert expected_segment.text in segments[i].text

0 comments on commit 38f5d26

Please sign in to comment.