In [1]:
from pathlib import Path
import pickle

from datasets import Audio, Dataset
import pandas as pd
import IPython.display
import numpy as np

from asr_eval.models.base import ASREvalWrapper
from asr_eval.models.gigaam import GigaAMWrapper
from asr_eval.datasets.datasets import load_multivariant_v1_200, load_golos_farfield, load_common_voice_17_0, load_podlodka
from asr_eval.datasets.recording import Recording
from asr_eval.models.whisper_wrapper import WhisperLongformWrapper
from asr_eval.models.pisets_wrapper import PisetsWrapper
from asr_eval.models.legacy_pisets_wrapper import LegacyPisetsWrapper
from asr_eval.align.recursive import align, select_shortest_multi_variants
from asr_eval.align.parsing import split_text_into_tokens
from asr_eval.align.data import MatchesList

In [2]:
# dataset = (
#     load_multivariant_v1_200()
#     .cast_column('audio', Audio(sampling_rate=16000, decode=True)) # type: ignore
#     .take(30)
# )

# waveform = dataset[0]['audio']['array']

# model = GigaAMWrapper()
# model([waveform])

In [3]:
# for path in [
#     'openai/whisper-large-v3',
#     'openai/whisper-medium',
#     'tmp/whisper-podlodka-turbo',
# ]:
#     wrapper = WhisperLongformWrapper(path)
#     wrapper._maybe_instantiate()

#     dim = wrapper.model.model.encoder.layers[0].fc1.in_features
#     n_params = sum(p.numel() for p in wrapper.model.parameters())
#     print(f'{path}: {dim=}, n_params={n_params/10**9:.2}B')

In [4]:
def get_preds(
    model: ASREvalWrapper,
    model_name: str,
    dataset_name: str,
    dataset_index: int,
    recording: Recording,
) -> MatchesList:
    # predicting
    cache_file = Path(f'tmp/predictions/{dataset_name}/{model_name}_{dataset_index}.txt')
    cache_file.parent.mkdir(exist_ok=True, parents=True)
    if not cache_file.is_file():
        print(f'Predicting {cache_file}')
        pred = model([recording.waveform])[0] # type: ignore
        cache_file.write_text(pred)
    pred = cache_file.read_text()
    
    # aligning
    cache_file_alignment = cache_file.with_suffix('.pkl')
    if not cache_file_alignment.is_file():
        with open(cache_file_alignment, 'wb') as f:
            alignment = align(recording.transcription_words, split_text_into_tokens(pred))
            pickle.dump(alignment, f)
    
    with open(cache_file_alignment, 'rb') as f:
        return pickle.load(f)

models: dict[str, ASREvalWrapper] = {
    'pisets-podlodka-old': LegacyPisetsWrapper(repo_dir='/home/oleg/pisets_legacy'),
    'pisets-podlodka': PisetsWrapper(diarization=None),
    'pisets-whisper-large-v3': PisetsWrapper(diarization=None, recognizer='openai/whisper-large-v3'),
    'whisper-large-v3-podlodka': WhisperLongformWrapper('bond005/whisper-large-v3-ru-podlodka'),
    'whisper-large-v3': WhisperLongformWrapper('openai/whisper-large-v3'),
    'whisper-medium': WhisperLongformWrapper('openai/whisper-medium'),
    'whisper-podlodka-turbo': WhisperLongformWrapper('bond005/whisper-podlodka-turbo'),
    'gigaam2': GigaAMWrapper(),
}

datasets: dict[str, Dataset] = {
    'multivariant_v1_200': (
        load_multivariant_v1_200()
        .cast_column('audio', Audio(sampling_rate=16000, decode=True)) # type: ignore
        .take(30)
    ),
    'golos_farfield': (
        load_golos_farfield()
        .cast_column('audio', Audio(sampling_rate=16000, decode=True)) # type: ignore
        .take(200)
    ),
    'common_voice': (
        load_common_voice_17_0()
        .cast_column('audio', Audio(sampling_rate=16000, decode=True)) # type: ignore
        .take(100)
    ),
}

report_df = pd.DataFrame(dtype=pd.Int64Dtype())

for dataset_name, dataset in datasets.items():
    for model_name, model in models.items():
        print(f'{dataset_name:<30} {model_name:<30} predicting and aligning...')
        
        wers: list[float] = []
        for i, sample in enumerate(dataset): # type: ignore
            recording = Recording.from_sample(sample) # type: ignore
            alignment = get_preds(model, model_name, dataset_name, i, recording) # type: ignore
            true_length = len(select_shortest_multi_variants(recording.transcription_words))
            sample_wer = alignment.n_errors_with_insertions_tolerance() / true_length
            wers.append(sample_wer)
        
        wer = np.mean(wers)
        print(f'{dataset_name:<30} {model_name:<30} WER: {wer * 100:.1f}%')
        report_df.at[model_name, dataset_name] = round(wer * 100, 2) # type: ignore
        
        IPython.display.clear_output()
        display(report_df)

Unnamed: 0,multivariant_v1_200,golos_farfield,common_voice
pisets-podlodka-old,7.93,14.82,13.55
pisets-podlodka,7.46,14.82,13.55
pisets-whisper-large-v3,7.98,19.6,12.31
whisper-large-v3-podlodka,12.06,17.86,20.93
whisper-large-v3,6.1,21.19,13.69
whisper-medium,9.63,32.79,18.94
whisper-podlodka-turbo,10.58,19.68,13.03
gigaam2,6.74,4.8,7.37


1