In [None]:
# type: ignore

import typing
from typing import Any
from pathlib import Path

import librosa
from datasets import load_dataset, Audio
import gigaam
from gigaam.model import GigaAMASR
import numpy as np
import numpy.typing as npt
import matplotlib.pyplot as plt

from asr_eval.datasets.recording import Recording
from asr_eval.align.data import MatchesList
from asr_eval.align.timings import fill_word_timings_inplace, CannotFillTimings
from asr_eval.align.parsing import parse_multivariant_string, colorize_parsed_string
from asr_eval.align.plots import draw_timed_transcription
from asr_eval.streaming.models.vosk import VoskStreaming
from asr_eval.streaming.evaluation import default_evaluation_pipeline
from asr_eval.streaming.model import TranscriptionChunk
from asr_eval.streaming.evaluation import RecordingStreamingEvaluation
from asr_eval.align.recursive import align
from asr_eval.streaming.plots import (
    partial_alignments_plot,
    visualize_history,
    streaming_error_vs_latency_histogram,
    latency_plot,
    show_last_alignments,
)

In [None]:
text = (
    '(7-8 мая) в Пуэрто-Рико прошёл {шестнадцатый|16-й|16}'
    ' этап "Формулы-1" с фондом 100,000$!'
)

for method in 'space', 'razdel', 'wordpunct_tokenize', 'asr_eval':
    tokens = parse_multivariant_string(text, method=method)
    colored_str, colors = colorize_parsed_string(text, tokens)
    print(f'{method: <20}', colored_str)

In [None]:
# type: ignore

waveform: npt.NDArray[np.floating[Any]] = (
    librosa.load('tests/testdata/formula1.mp3', sr=16000)[0])
waveform += waveform[::-1] / 4  # add some speech-like noise

text = Path('tests/testdata/formula1.txt').read_text()
tokens = parse_multivariant_string(text)

model = typing.cast(GigaAMASR, gigaam.load_model('ctc', device='cuda'))
fill_word_timings_inplace(model, waveform, tokens, verbose=True)

In [None]:
 # type: ignore

plt.figure(figsize=(15, 4))
plt.plot(np.arange(len(waveform)) / 16000,
    3 * waveform / waveform.max(), alpha=0.3, zorder=-1)
draw_timed_transcription(tokens, y_delta=-3)
plt.ylim(-3.5, 3.5)
plt.show()

print(colorize_parsed_string(text, tokens)[0])

In [None]:
asr = VoskStreaming(model_name='vosk-model-ru-0.42', chunk_length_sec=0.5)
asr.start_thread()

recording = Recording(
    transcription=text,
    transcription_words=tokens,
    waveform=waveform,
)
eval = default_evaluation_pipeline(recording, asr)

asr.stop_thread()

In [None]:
print(TranscriptionChunk.join(eval.output_chunks))
print(eval.partial_alignments[-1].pred)

In [None]:
eval.partial_alignments[-1].alignment.matches

In [None]:
eval.partial_alignments[-1].get_error_positions()

In [None]:
# type: ignore

plt.figure(figsize=(15, 6))
partial_alignments_plot(eval)
plt.show()

In [None]:
# type: ignore

dataset = (
    load_dataset('bond005/podlodka_speech')['test']
    .cast_column('audio', Audio(sampling_rate=16_000))
)

gigaam_model = typing.cast(GigaAMASR, gigaam.load_model('ctc', device='cuda'))
asr = VoskStreaming(model_name='vosk-model-ru-0.42', chunk_length_sec=0.5)
asr.start_thread()

evals: list[RecordingStreamingEvaluation] = []
for sample in dataset:
    try:
        recording = Recording.from_sample(sample, use_gigaam=gigaam_model)
    except CannotFillTimings:
        continue
    evals.append(default_evaluation_pipeline(
        recording, asr, partial_alignment_interval=0.5
    ))

asr.stop_thread()

In [None]:
for i, eval in enumerate(evals):
    alignment: MatchesList = eval.partial_alignments[-1].alignment
    print(
        f'sample {i},', f'total_true_len={alignment.total_true_len},', alignment.score
    )

In [None]:
alignment.score.n_word_errors / max(1, alignment.total_true_len)

In [None]:
# type: ignore

matches_list = align(
    parse_multivariant_string('nothing'),
    parse_multivariant_string('no thing'),
)
print(matches_list.matches)
print(matches_list.score)

In [None]:
# type: ignore

plt.figure(figsize=(15, 6))
partial_alignments_plot(evals[10])
plt.show()

In [None]:
# type: ignore

plt.figure(figsize=(15, 3))
visualize_history(eval.input_chunks, eval.output_chunks)
plt.show()

In [None]:
# type: ignore

fig, (ax1, ax2) = plt.subplots(figsize=(12, 4), ncols=2, width_ratios=[2, 1])
streaming_error_vs_latency_histogram(evals, ax=ax1)
latency_plot(evals, ax=ax2)
plt.show()

In [None]:
# type: ignore

plt.figure(figsize=(15, 3))
show_last_alignments(evals)
plt.show()