In [None]:
%cd ..

In [None]:
import typing
from pathlib import Path

from datasets import load_dataset, Audio
import gigaam
from gigaam.model import GigaAMASR
import soundfile as sf
from tqdm.auto import tqdm

from asr_eval.models.gigaam import encode
from asr_eval.streaming.evaluation import get_word_timings
from asr_eval.srt_utils import utterances_to_srt

In [None]:
model = typing.cast(GigaAMASR, gigaam.load_model('ctc', device='cuda'))

In [None]:
# type: ignore

datasets = {
    'podlodka': (
        load_dataset('bond005/podlodka_speech')['test']
        .cast_column("audio", Audio(sampling_rate=16_000))
        .take(20)
    ),
    'golos_farfield': (
        load_dataset('bond005/sberdevices_golos_100h_farfield')['test']
        .cast_column("audio", Audio(sampling_rate=16_000))
        .take(20)
    ),
}

for dataset_name, dataset in datasets.items():
    dir = Path(f'srt/{dataset_name}')
    dir.mkdir(exist_ok=True, parents=True)
    for sample_idx, sample in tqdm(enumerate(dataset)):
        waveform = sample['audio']['array']
        text = sample['transcription']
        text = text.lower().replace('ё', 'е').replace('-', ' ')
        for char in ('.', ',', '!', '?', ';', ':', '"', '(', ')'):
            text = text.replace(char, '')
        try:
            encode(model, text)
        except ValueError:
            print(f'Cannot encode: {text}')
            continue
        timings = get_word_timings(model, waveform, text)
        (dir / f'{sample_idx}.srt').write_text(utterances_to_srt(timings))
        sf.write(dir / f'{sample_idx}.wav', waveform, samplerate=16_000)