In [None]:
# %cd ..

In [None]:
# type: ignore

import sys
import os
import typing
from typing import cast

import gigaam
from gigaam.model import GigaAMASR
from datasets import Dataset, load_dataset, Audio
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

from asr_eval.align.timings import CannotFillTimings
from asr_eval.datasets.recording import Recording
from asr_eval.models.gigaam import GigaAMEncodeError, encode
from asr_eval.streaming.models.vosk import VoskStreaming
from asr_eval.streaming.evaluation import default_evaluation_pipeline, RecordingStreamingEvaluation
from asr_eval.streaming.plots import (
    partial_alignments_plot,
    visualize_history,
    streaming_error_vs_latency_histogram,
    latency_plot,
    show_last_alignments,
)
from asr_eval.utils.serializing import save_to_json, load_from_json

In [None]:
%config InlineBackend.figure_formats = ['svg']

In [None]:
gigaam_model = typing.cast(GigaAMASR, gigaam.load_model('ctc', device='cuda'))

In [None]:
# type: ignore

samples: list[Recording] = []

# name, split = 'mozilla-foundation/common_voice_17_0', 'test'  #, 'ru'
name = 'bond005/podlodka_speech'
for split in ['train', 'validation', 'test']:
    dataset: Dataset = (
        load_dataset(name)[split]
        .cast_column("audio", Audio(sampling_rate=16_000))
        # .rename_column('sentence', 'transcription')
    )

    for i in tqdm(range(len(dataset))):
        try:
            samples.append(Recording.from_sample(
                sample=dataset[i],
                name=name,
                split=split,
                index=i,
                use_gigaam=gigaam_model,
            ))
        except CannotFillTimings:
            pass
        if len(samples) >= 100:
            break

    print(len(samples))

In [None]:
asr = VoskStreaming(model_name='vosk-model-ru-0.42', chunk_length_sec=1)
asr.start_thread()

evals: list[RecordingStreamingEvaluation] = []
for recording in tqdm(samples):
    evals.append(default_evaluation_pipeline(recording, asr))
    recording.waveform = None

asr.stop_thread()

In [None]:
save_to_json(evals, 'tmp/evals.json')

In [None]:
evals: list[RecordingStreamingEvaluation] = load_from_json('tmp/evals.json')

In [None]:
eval = evals[-1]

In [None]:
plt.figure(figsize=(10, 8)) # type: ignore
partial_alignments_plot(eval)

In [None]:
visualize_history(eval.input_chunks, eval.output_chunks)

In [None]:
streaming_error_vs_latency_histogram(evals)

In [None]:
latency_plot(evals)

In [None]:
plt.figure(figsize=(10, 8)) # type: ignore
show_last_alignments(evals)