In [1]:
from datasets import load_dataset
from models import get_supported_models
from asr_model_evaluator import ASRModelEvaluator

In [None]:
# dataset = load_dataset("VDK/hse_lectures_dataset_private", split="test")

dataset = load_dataset("mozilla-foundation/common_voice_17_0", "ru", split="test")
dataset = dataset.select_columns(["audio", "sentence"]).rename_column("sentence", "transcription")

In [None]:
evaluator = ASRModelEvaluator()

models = get_supported_models()

wer_list = evaluator.evaluate(
    metric="wer",
    models=models,
    data=dataset,
    use_text_normalization=True,
    use_device="cuda",
    verbose=True,
)

In [None]:
max_width = max(len(model_name) for model_name in models)

for model_name, wer in sorted(zip(models, wer_list), key=lambda pair: pair[1]):
    print(f"{model_name.ljust(max_width)} -> {wer:.5f}")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

color_map = {
    "SeamlessM4T": "#FF5733",
    "MMS": "#FF3322",
    "NVIDIA": "#00FF00",
    "Whisper": "#F0A500",
    "GigaAM": "#88EEDD",
}

bar_colors = []
for model in models:
    if "SeamlessM4T" in model:
        bar_colors.append(color_map["SeamlessM4T"])
    elif "MMS" in model:
        bar_colors.append(color_map["MMS"])
    elif "NVIDIA" in model:
        bar_colors.append(color_map["NVIDIA"])
    elif "Whisper" in model:
        bar_colors.append(color_map["Whisper"])
    elif "GigaAM" in model:
        bar_colors.append(color_map["GigaAM"])
    else:
        bar_colors.append("#999999")

fig, ax = plt.subplots(figsize=(13, 6))
y_pos = np.arange(len(models))
ax.barh(y_pos, wer, color=bar_colors, alpha=0.7)

ax.set_title("Word Error Rate (WER) on custom dataset")
ax.set_xlabel("WER (%)")
ax.set_xlim(0, 46)
ax.set_xticks(np.arange(0, 50, 5))
ax.set_yticks(y_pos)
ax.set_yticklabels(models)

for i, v in enumerate(wer_list):
    ax.text(v + 1, i, f"{v:.2f}%", va="center")

plt.gca().invert_yaxis()
plt.show();