In [None]:
# %cd ..

In [None]:
# type: ignore

import typing
from typing import cast
from itertools import chain

import gigaam
from gigaam.model import GigaAMASR
import numpy as np
from datasets import Dataset, load_dataset, Audio
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

from asr_eval.data import Recording
from asr_eval.align.data import Token
from asr_eval.models.gigaam import EncodeError
from asr_eval.streaming.models.vosk import VoskStreaming
from asr_eval.streaming.evaluation import default_evaluation_pipeline
from asr_eval.streaming.plots import partial_alignment_plot, streaming_error_vs_latency_histogram, latency_plot
from asr_eval.serializing import save_to_json, load_from_json
from asr_eval.utils import N

In [None]:
gigaam_model = typing.cast(GigaAMASR, gigaam.load_model('ctc', device='cuda'))

In [None]:
# type: ignore

name, split = 'bond005/podlodka_speech', 'test'
dataset: Dataset = (
    load_dataset(name)[split]
    .cast_column("audio", Audio(sampling_rate=16_000))
)

samples: list[Recording] = []

for i in tqdm(range(len(dataset))):
    try:
        samples.append(Recording.from_sample(
            sample=dataset[i],
            name=name,
            split=split,
            index=i,
            use_gigaam=gigaam_model,
        ))
    except EncodeError:
        pass

In [None]:
asr = VoskStreaming(model_name='vosk-model-ru-0.42', chunk_length_sec=1)
asr.start_thread()
for recording in tqdm(samples):
    recording.evals = default_evaluation_pipeline(recording, asr)
    recording.waveform = None
asr.stop_thread()

In [None]:
save_to_json(samples, 'tmp/samples.json')

In [None]:
samples: list[Recording] = load_from_json('tmp/samples.json')

In [None]:
recording = samples[2]

partial_alignment_plot(
    N(N(recording.evals).partial_alignments),
    cast(list[Token], recording.transcription_words),
    start_real_time=N(N(N(recording.evals).input_chunks)[0].put_timestamp),
    end_real_time=N(N(N(recording.evals).output_chunks)[-1].put_timestamp),
    figsize=(12, 12),
)

In [None]:
streaming_error_vs_latency_histogram(chain(
    partial_alignment.get_error_positions()
    for recording in samples
    for partial_alignment in recording.evals.partial_alignments # type: ignore
))

In [None]:
latency_plot(samples)