In [None]:
!pip install -q transformers datasets jiwer evaluate
!pip install -q seacrowd
!pip install -U datasets transformers

In [None]:
import torch
import pandas as pd
from tqdm import tqdm
from datasets import load_dataset, Audio
from transformers import pipeline, AutoProcessor
from jiwer import wer, cer

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32

In [None]:
from kaggle_secrets import UserSecretsClient
import os
user_secrets = UserSecretsClient()
os.environ["HF_TOKEN"] = user_secrets.get_secret("HF_TOKEN")

In [None]:
import os
import warnings
from transformers.utils import logging

os.environ["TRANSFORMERS_NO_PROGRESS_BAR"] = "1"

warnings.filterwarnings("ignore", message="`generation_config` default values have been modified")
warnings.filterwarnings("ignore", message="A custom logits processor of type.*SuppressTokensLogitsProcessor")
warnings.filterwarnings("ignore", message="A custom logits processor of type.*SuppressTokensAtBeginLogitsProcessor")

logging.set_verbosity_error()

In [None]:
pipe = pipeline(
    task="automatic-speech-recognition",
    model="PogusTheWhisper/Pathumma-whisper-th-large-v3-natural-noise-finetuned",
    device=device,
    torch_dtype=dtype,
    chunk_length_s=30,
    return_timestamps=False,
)

pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(
    language="th",
    task="transcribe"
)

processor = AutoProcessor.from_pretrained(pipe.model.name_or_path)
model = pipe.model

In [None]:
def evaluate_dataset(dataset_id, config, split, audio_col, text_col,
                     processor, model, device, dtype,
                     limit=100, batch_size=8, seed=134):
    print(f"Evaluating: {dataset_id} ({config})")

    try:
        ds = load_dataset(dataset_id, config, split=split, trust_remote_code=True)
        ds = ds.cast_column(audio_col, Audio(sampling_rate=16000))

        subset = ds.shuffle(seed=seed).select(range(min(limit, len(ds))))

        audios = [s[audio_col]["array"] for s in subset]
        refs = [s[text_col].strip() for s in subset]
        preds = []

        forced_ids = processor.get_decoder_prompt_ids(language="th", task="transcribe")
        is_parallel = isinstance(model, torch.nn.DataParallel)

        for i in tqdm(range(0, len(audios), batch_size), desc="Transcribing"):
            batch_audio = audios[i:i + batch_size]

            inputs = processor(
                batch_audio,
                sampling_rate=16000,
                return_tensors="pt",
                padding=True
            )
            inputs = {k: v.to(device) for k, v in inputs.items()}
            if device == "cuda":
                inputs["input_features"] = inputs["input_features"].to(dtype=dtype)

            if "attention_mask" not in inputs:
                inputs["attention_mask"] = (inputs["input_features"] != 0).long()

            with torch.no_grad():
                generate_fn = model.module.generate if is_parallel else model.generate
                output_ids = generate_fn(
                    input_features=inputs["input_features"],
                    attention_mask=inputs["attention_mask"],
                    forced_decoder_ids=forced_ids
                )

            batch_preds = processor.batch_decode(output_ids, skip_special_tokens=True)
            preds.extend([p.strip() for p in batch_preds])

        cer_score = cer(refs, preds)
        wer_score = wer(refs, preds)

        print("CER:", cer_score)
        print("WER:", wer_score)
        print("-" * 50)

        detailed = pd.DataFrame({
            "dataset": dataset_id,
            "config": config,
            "reference": refs,
            "prediction": preds
        })

        summary = {
            "dataset": dataset_id,
            "config": config,
            "samples": len(refs),
            "CER": cer_score,
            "WER": wer_score
        }

        return summary, detailed

    except Exception as e:
        print(f"Error evaluating {dataset_id}: {e}")
        return {
            "dataset": dataset_id,
            "config": config,
            "samples": 0,
            "CER": None,
            "WER": None,
            "error": str(e)
        }, pd.DataFrame()

In [None]:
datasets_to_eval = [
    ("tingwry/asr-augmented", "default", "test", "wav", "transcript"),
    ("google/fleurs", "th_th", "test", "audio", "raw_transcription"),
    ("fsicoli/common_voice_18_0", "th", "test", "audio", "sentence"),
    ("SEACrowd/gowajee", "gowajee_source", "train", "audio", "transcription"),
    ("SEACrowd/thai_elderly_speech", "thai_elderly_speech_healthcare_source", "train", "audio", "transcription"),
    # ("speechcolab/gigaspeech2", "default", "test", "audio", "text"),  # fixed config
]

summary_report = []
detailed_report = pd.DataFrame()

for args in datasets_to_eval:
    summary, detailed = evaluate_dataset(
        *args,
        processor=processor,
        model=model,
        device=device,
        dtype=dtype,
        limit=388,
        batch_size=8
    )
    summary_report.append(summary)
    detailed_report = pd.concat([detailed_report, detailed], ignore_index=True)

pd.DataFrame(summary_report).to_csv("asr_evaluation_summary.csv", index=False)
detailed_report.to_csv("asr_evaluation_details.csv", index=False)

pd.DataFrame(summary_report)

In [None]:
pd.read_csv('/kaggle/working/asr_evaluation_summary.csv').head()

In [None]:
detailed_report