In [1]:
pip install torchcodec datasets transformers librosa evaluate jiwer gradio accelerate bitsandbytes peft accelerate

Collecting torchcodec
  Downloading torchcodec-0.9.1-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (11 kB)
Collecting evaluate
  Downloading evaluate-0.4.6-py3-none-any.whl.metadata (9.5 kB)
Collecting jiwer
  Downloading jiwer-4.0.0-py3-none-any.whl.metadata (3.3 kB)
Collecting bitsandbytes
  Downloading bitsandbytes-0.49.0-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Collecting rapidfuzz>=3.9.7 (from jiwer)
  Downloading rapidfuzz-3.14.3-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (12 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py

# WER Prediction Dataset Pipeline
This notebook contains the full pipeline for loading the test set, running Whisper inference, extracting features, and building the WER prediction dataset.

In [2]:
from datasets import load_dataset
from transformers import (WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor,
                          WhisperForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer)
import numpy as np
import pandas as pd
from jiwer import wer
import librosa
import torch
from dataclasses import dataclass
from scipy.stats import entropy
from typing import Any, Dict, List, Union

MODEL = 'openai/whisper-tiny'

2025-12-13 21:34:29.214280: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1765661669.398277      20 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1765661669.449314      20 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
from kaggle_secrets import UserSecretsClient
from huggingface_hub import login
user_secrets = UserSecretsClient()
#from google.colab import userdata
import os
login(token=user_secrets.get_secret('HUGGINGFACE_API_KEY'))#userdata.get("HUGGINGFACE_API_KEY"))

os.environ["WANDB_API_KEY"]= user_secrets.get_secret("WANDB_API_KEY")#userdata.get("WANDB_API_KEY")
os.environ["WANDB_PROJECT"] = "Fine-tuning Whisper Large"
os.environ["WANDB_NOTES"] = "Fine tune model whisper"
os.environ["WANDB_NAME"] = "ft-whisper-med-asr"

In [4]:
# Load processors
feature_extractor = WhisperFeatureExtractor.from_pretrained(MODEL)
tokenizer = WhisperTokenizer.from_pretrained(MODEL, language='English', task='transcribe')
processor = WhisperProcessor.from_pretrained(MODEL, language='English', task='transcribe')
model = WhisperForConditionalGeneration.from_pretrained(MODEL)

model.config.forced_decoder_ids = processor.tokenizer.get_decoder_prompt_ids(language='English', task='transcribe')
model.config.suppress_tokens = []

preprocessor_config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

normalizer.json: 0.00B [00:00, ?B/s]

added_tokens.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/151M [00:00<?, ?B/s]

generation_config.json: 0.00B [00:00, ?B/s]

In [5]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{'input_features': f['input_features']} for f in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors='pt')
        label_features = [{'input_ids': f['labels']} for f in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors='pt')
        labels = labels_batch['input_ids'].masked_fill(labels_batch.attention_mask.ne(1), -100)
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]
        batch['labels'] = labels
        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

def prepare_med_dataset(batch):
    audio_array = batch["audio"]["array"]
    sampling_rate = batch["audio"]["sampling_rate"]

    batch["input_features"] = feature_extractor(
        audio_array,
        sampling_rate=sampling_rate
    ).input_features[0]

    batch["labels"] = tokenizer(batch["text"]).input_ids
    return batch


# Load only test split
ds = load_dataset('leduckhai/MultiMed-ST','English', split='train', streaming=True)

processed_test = ds.map(prepare_med_dataset, remove_columns=ds.column_names)

README.md: 0.00B [00:00, ?B/s]

In [6]:
# Inference-only trainer
training_args = Seq2SeqTrainingArguments(
    output_dir='.',
    per_device_eval_batch_size=32,
    predict_with_generate=True,
    generation_max_length=448,
    fp16=True,
    do_train=False,
    do_eval=False,
    logging_strategy='no'
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator, 
    processing_class=processor,
)

In [7]:
# --------------------------
#  WER + Decoding
# --------------------------

pred_results = trainer.predict(processed_test)

pred_ids = pred_results.predictions
label_ids = pred_results.label_ids

# pad -100 to tokenizer pad token
label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)

decoded_preds = processor.batch_decode(pred_ids, skip_special_tokens=True)
decoded_labels = processor.batch_decode(label_ids, skip_special_tokens=True)

# WER per sample
wers = [wer(t.strip(), p.strip()) for p, t in zip(decoded_preds, decoded_labels)]

Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English.This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
[34m[1mwandb[0m: Currently logged in as: [33maminq[0m ([33maminq-lums[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34

In [8]:
# --------------------------
#  FEATURE EXTRACTION (extended, no logits)
# --------------------------

durations = []
word_lengths = []
char_lengths = []
avg_word_lens = []
contexts = []
speaking_rates = []

energies = []
zcrs = []
centroids = []
silence_ratios = []
snrs = []

for row in ds:
    audio = np.array(row['audio']["array"], dtype=float)
    sr = row['audio']["sampling_rate"]

    # --- Duration ---
    duration = len(audio) / sr
    durations.append(duration)

    # --- Text features ---
    text = row["text"]
    wc = len(text.split())
    cc = len(text)

    word_lengths.append(wc)
    char_lengths.append(cc)
    avg_word_lens.append(cc / wc if wc > 0 else 0)

    # speaking rate (w/s)
    speaking_rates.append(wc / duration if duration > 0 else 0)

    # --- Audio Features ---
    # Energy
    energies.append(float(np.mean(audio**2)))

    # Zero Crossing Rate
    try:
        zcrs.append(float(librosa.feature.zero_crossing_rate(audio)[0].mean()))
    except:
        zcrs.append(0.0)

    # Spectral Centroid
    try:
        centroids.append(float(librosa.feature.spectral_centroid(y=audio, sr=sr)[0].mean()))
    except:
        centroids.append(0.0)

    # Silence Ratio
    silence_ratios.append(float(np.mean(np.abs(audio) < 0.01)))

    # SNR estimate
    signal = np.percentile(np.abs(audio), 95)
    noise = np.percentile(np.abs(audio), 5) + 1e-6
    snrs.append(float(signal / noise))

In [9]:
# --------------------------
#  Build Final Dataset
# --------------------------

df = pd.DataFrame({
    "duration_sec": durations,
    "word_count": word_lengths,
    "char_count": char_lengths,
    "avg_word_len": avg_word_lens,
    "speaking_rate": speaking_rates,

    "energy": energies,
    "zcr": zcrs,
    "spectral_centroid": centroids,
    "silence_ratio": silence_ratios,
    "snr": snrs,

    "wer": wers,
    "pred_text": decoded_preds,
    "gt_text": decoded_labels,
})

df.to_csv("wer_prediction_dataset_extended.csv", index=False)
df.head()

Unnamed: 0,duration_sec,word_count,char_count,avg_word_len,speaking_rate,energy,zcr,spectral_centroid,silence_ratio,snr,wer,pred_text,gt_text
0,13.724,31,174,5.612903,2.258817,0.001302,0.121971,1164.707789,0.551921,8770.855076,0.419355,"As I already said, I'm a wife and mother and ...","As already said, I'm a wife, a mother, and a f..."
1,10.995,21,124,5.904762,1.909959,0.001752,0.127213,1227.217683,0.554565,14935.06884,0.238095,I was diagnosed in 1995 after recovering from...,I was diagnosed in 1995 after recovering from ...
2,9.525,21,94,4.47619,2.204724,0.001793,0.158352,1491.089096,0.630846,20233.191995,17.904762,I was a little bit late. I was a little bit l...,advil shovel and I was okay and then one day I...
3,10.996,30,167,5.566667,2.728265,0.001594,0.126224,1204.152998,0.599207,14788.428382,0.333333,More people go from failure with the term Yup...,More people were familiar with the term yuppie...
4,11.475,24,135,5.625,2.091503,0.001771,0.118303,1122.158219,0.575272,9482.378436,0.25,"In 1995, during the late night hours, I was r...","In 1995, during the late night hours that I wa..."
