In [1]:
# --- Setup ---
import time, re, torch
from datasets import load_dataset, Audio
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from jiwer import wer, cer

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TARGET_SR = 16_000
SPLIT = "test[:10%]"      # change to "test" for full set, or "test[:1%]" for a quick run
print("Device:", DEVICE)

# --- Load LibriSpeech and resample on-the-fly to 16 kHz ---
ds = load_dataset("librispeech_asr", "clean", split=SPLIT)
ds = ds.cast_column("audio", Audio(sampling_rate=TARGET_SR))

# --- Normalizer (same as used in many papers for fair WER) ---
def normalize_text(s: str) -> str:
    s = s.lower().strip()
    s = re.sub(r"[^a-z' ]+", " ", s)      # keep letters + apostrophes
    s = re.sub(r"\s+", " ", s).strip()
    return s

references = [normalize_text(x["text"]) for x in ds]
print(f"Items: {len(ds)}")

Device: cpu


Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

Items: 262


In [2]:
def collate_raw(batch):
    # batch: list of dicts with "audio" etc.
    wavs = [torch.tensor(ex["audio"]["array"]) for ex in batch]
    lens = [len(w) for w in wavs]
    # Keep raw list (each 1D); each model will pad in its own processor
    return wavs, lens

BATCH_SIZE = 8
loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_raw)

In [3]:
import torch, numpy as np, re, time
from datasets import load_dataset, Audio
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_ID = "facebook/wav2vec2-large-960h-lv60-self"
SPLIT = "test[:10%]"  # adjust

# 1) Load data and resample on the fly to 16k for CTC
TARGET_SR = 16_000
ds = load_dataset("librispeech_asr", "clean", split=SPLIT)
ds = ds.cast_column("audio", Audio(sampling_rate=TARGET_SR))

# 2) Simple normalizer to match WER setups
def normalize_text(s: str) -> str:
    s = s.lower().strip()
    s = re.sub(r"[^a-z' ]+", " ", s)   # keep letters and apostrophes
    s = re.sub(r"\s+", " ", s).strip()
    return s

references = [normalize_text(x["text"]) for x in ds]

# 3) Load processor + model
processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID).to(DEVICE).eval()

# 4) Batched evaluation
BATCH_SIZE = 8
hypotheses = []
audio_seconds_total = 0.0

t0 = time.time()
with torch.no_grad():
    for i in range(0, len(ds), BATCH_SIZE):
        # make a “batch” by slicing the dataset
        batch = ds.select(range(i, min(i + BATCH_SIZE, len(ds))))

        # *** extract a LIST of 1-D float arrays (not 2-D) ***
        wavs = [ex["audio"]["array"] for ex in batch]
        audio_seconds_total += sum(len(w) for w in wavs) / TARGET_SR

        # processor pads and builds tensors
        inputs = processor(
            wavs,
            sampling_rate=TARGET_SR,
            return_tensors="pt",
            padding=True  # pad-to-longest
        )

        # move to device, run model
        inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
        logits = model(**inputs).logits                       # [B, T, vocab]
        pred_ids = torch.argmax(logits, dim=-1)               # [B, T]
        texts = processor.batch_decode(pred_ids, skip_special_tokens=True)
        hypotheses.extend(normalize_text(t) for t in texts)

wall = time.time() - t0
rtf = wall / audio_seconds_total

print(f"Items: {len(ds)}  audio: {audio_seconds_total/60:.2f} min  wall: {wall:.2f}s  RTF: {rtf:.3f}")

Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h-lv60-self and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Items: 262  audio: 32.41 min  wall: 194.46s  RTF: 0.100


In [4]:
MODEL_ID = "facebook/wav2vec2-large-960h-lv60-self"  # English CTC
processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
model     = Wav2Vec2ForCTC.from_pretrained(MODEL_ID).to(DEVICE).eval()

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h-lv60-self and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
BATCH_SIZE = 8
hypotheses = []
audio_seconds_total = 0.0

t0 = time.time()
with torch.inference_mode():
    for i in range(0, len(ds), BATCH_SIZE):
        # Slice the dataset instead of using a DataLoader
        batch = ds.select(range(i, min(i + BATCH_SIZE, len(ds))))

        # IMPORTANT: pass a LIST of raw 1-D arrays to the processor
        wavs = [ex["audio"]["array"] for ex in batch]          # list of 1-D float arrays
        audio_seconds_total += sum(len(w) for w in wavs) / TARGET_SR

        inputs = processor(
            wavs,
            sampling_rate=TARGET_SR,
            return_tensors="pt",
            padding=True,            # let processor pad
        )

        logits = model(
            input_values=inputs.input_values.to(DEVICE),
            attention_mask=inputs.attention_mask.to(DEVICE),
        ).logits                                   # [B, T, vocab]

        pred_ids = torch.argmax(logits, dim=-1)
        texts = processor.batch_decode(pred_ids, skip_special_tokens=True)
        hypotheses.extend(normalize_text(t) for t in texts)

t1 = time.time()
wer_score = 100.0 * (
    sum(1 for _ in references)  # dummy, replace with your WER function if needed
)

wall = t1 - t0
rtf  = wall / audio_seconds_total  # Real-Time Factor

print(f"Items: {len(ds)}  audio: {audio_seconds_total/60:.2f} min  wall: {wall:.2f}s  RTF: {rtf:.3f}")
# print(f"WER: {wer_score:.2f}%")

Items: 262  audio: 32.41 min  wall: 196.95s  RTF: 0.101


In [6]:
import time, torch, numpy as np, re
from datasets import load_dataset, Audio
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
# from jiwer import wer 

DEVICE   = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_ID = "facebook/wav2vec2-large-960h-lv60-self"
TARGET_SR = 16_000

# Load LibriSpeech (resampled on the fly)
ds = load_dataset("librispeech_asr", "clean", split="test[:10%]")  # adjust split if you like
ds = ds.cast_column("audio", Audio(sampling_rate=TARGET_SR))

# simple normalizer (same as you used above)
def normalize_text(s: str) -> str:
    s = s.lower().strip()
    s = re.sub(r"[^a-z' ]+", " ", s)   # keep letters + apostrophes
    s = re.sub(r"\s+", " ", s).strip()
    return s

references = [normalize_text(x["text"]) for x in ds]

# Load processor + model
processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
model     = Wav2Vec2ForCTC.from_pretrained(MODEL_ID).to(DEVICE).eval()

Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h-lv60-self and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
BATCH_SIZE = 8

def transcribe_batch(batch_examples):
    # a list of 1-D float32 arrays at 16 kHz
    wavs = [ex["audio"]["array"] for ex in batch_examples]

    #the processor pad and build tensors
    inputs = processor(
        wavs,
        sampling_rate=TARGET_SR,
        return_tensors="pt",
        padding=True,              # pad to longest
        # pad_to_multiple_of=8,    # optional, a tiny speedup on some CPUs/GPUs
    ).to(DEVICE)

    # forward
    with torch.inference_mode():
        logits = model(
            input_values   = inputs.input_values,
            attention_mask = inputs.attention_mask
        ).logits                       # [B, T, vocab]
        pred_ids = torch.argmax(logits, dim=-1)   # [B, T] -> ids
        texts = processor.batch_decode(pred_ids, skip_special_tokens=True)

    # normalize transcripts to match refs
    return [normalize_text(t) for t in texts]

# --- run the loop ---
hyps = []
audio_seconds_total = 0.0
t0 = time.time()

for i in range(0, len(ds), BATCH_SIZE):
    batch = ds.select(range(i, min(i + BATCH_SIZE, len(ds))))
    # keep speed/RTF accounting
    wavs = [ex["audio"]["array"] for ex in batch]
    audio_seconds_total += sum(len(w) for w in wavs) / TARGET_SR
    hyps.extend(transcribe_batch(batch))

wall = time.time() - t0
rtf  = wall / audio_seconds_total

print(f"Items: {len(ds)}  audio: {audio_seconds_total/60:.2f}m  wall: {wall:.2f}s  RTF: {rtf:.3f}")

Items: 262  audio: 32.41m  wall: 195.92s  RTF: 0.101


In [8]:
from torch.utils.data import DataLoader

def collate_keep_lists(examples):
    return {
        "wavs": [ex["audio"]["array"] for ex in examples],
        "text": [ex["text"] for ex in examples]
    }

loader = DataLoader(ds, batch_size=8, shuffle=False, collate_fn=collate_keep_lists)

hyps, audio_seconds_total = [], 0.0
t0 = time.time()
for batch in loader:
    wavs = batch["wavs"]
    audio_seconds_total += sum(len(w) for w in wavs) / TARGET_SR

    inputs = processor(wavs, sampling_rate=TARGET_SR, return_tensors="pt", padding=True).to(DEVICE)
    with torch.inference_mode():
        logits = model(**inputs).logits
        pred_ids = torch.argmax(logits, dim=-1)
        texts = processor.batch_decode(pred_ids, skip_special_tokens=True)
    hyps.extend(normalize_text(t) for t in texts)

wall = time.time() - t0
rtf  = wall / audio_seconds_total
print(f"Items: {len(ds)}  audio: {audio_seconds_total/60:.2f}m  wall: {wall:.2f}s  RTF: {rtf:.3f}")

Items: 262  audio: 32.41m  wall: 197.46s  RTF: 0.102


In [9]:
from jiwer import wer
score = wer(references, hyps) * 100.0
print(f"WER: {score:.2f}%")

WER: 10.64%


In [10]:
BATCH_SIZE = 8
hyps = []
audio_seconds_total = 0.0

t0 = time.time()
with torch.inference_mode():
    for i in range(0, len(ds), BATCH_SIZE):
        batch = ds.select(range(i, min(i + BATCH_SIZE, len(ds))))

        wavs = [ex["audio"]["array"] for ex in batch]        # list of 1-D arrays
        audio_seconds_total += sum(len(w) for w in wavs) / TARGET_SR

        inputs = processor(
            wavs,
            sampling_rate=TARGET_SR,
            return_tensors="pt",
            padding=True
        ).to(DEVICE)

        logits = model(
            input_values   = inputs.input_values,
            attention_mask = inputs.attention_mask
        ).logits
        pred_ids = torch.argmax(logits, dim=-1)
        texts = processor.batch_decode(pred_ids, skip_special_tokens=True)
        hyps.extend(normalize_text(t) for t in texts)

wall = time.time() - t0
rtf  = wall / audio_seconds_total
print(f"Items: {len(ds)}  audio: {audio_seconds_total/60:.2f}m  wall: {wall:.2f}s  RTF: {rtf:.3f}")

Items: 262  audio: 32.41m  wall: 198.17s  RTF: 0.102


import torch, numpy as np, re, time
from datasets import load_dataset, Audio
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from jiwer import wer

DEVICE   = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_ID = "facebook/wav2vec2-large-960h-lv60-self"
TARGET_SR = 16_000
SPLIT = "test[:10%]"   # adjust as you like

# dataset + resample-on-the-fly
ds = load_dataset("librispeech_asr", "clean", split=SPLIT)
ds = ds.cast_column("audio", Audio(sampling_rate=TARGET_SR))

# normalizer used for WER
def normalize_text(s): 
    s = s.lower()
    s = re.sub("[^a-z' ]+", " ", s)
    s = re.sub("\s+", " ", s).strip()
    return s
references = [normalize_text(x["text"]) for x in ds]

# models
processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID).to(DEVICE).eval()

BATCH_SIZE = 8
hyps, audio_seconds_total = [], 0.0
t0 = time.time()

with torch.inference_mode():
    for i in range(0, len(ds), BATCH_SIZE):
        batch = ds.select(range(i, min(i+BATCH_SIZE, len(ds))))
        waves = [ex["audio"]["array"] for ex in batch]           # list of 1-D float arrays
        audio_seconds_total += sum(len(w) for w in waves) / TARGET_SR

        inputs = processor(waves, sampling_rate=TARGET_SR,
                           return_tensors="pt", padding=True).to(DEVICE)
        logits = model(**inputs).logits                           # [B, T, vocab]
        pred_ids = torch.argmax(logits, dim=-1)
        texts = processor.batch_decode(pred_ids, skip_special_tokens=True)
        hyps.extend(normalize_text(t) for t in texts)

wall = time.time() - t0
rtf = wall / audio_seconds_total
score = wer(references, hyps) * 100.0

print(f"Items: {len(ds)}  audio: {audio_seconds_total/60:.2f}m  wall: {wall:.2f}s  RTF: {rtf:.3f}")
print(f"WER: {score:.2f}%")


In [12]:
# Picking the SR this dataset was resampled towards
WHISPER_SR = getattr(wh_proc.feature_extractor, "sampling_rate", 16000)
assert isinstance(WHISPER_SR, (int, float)), f"Bad sampling rate: {WHISPER_SR!r}"

NameError: name 'wh_proc' is not defined

In [13]:
print("WHISPER_SR =", WHISPER_SR, type(WHISPER_SR))

NameError: name 'WHISPER_SR' is not defined

In [14]:
from transformers import WhisperProcessor, WhisperForConditionalGeneration

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

WHISPER_ID  = "openai/whisper-medium"   # or "small"/"base" for speed
wh_proc     = WhisperProcessor.from_pretrained(WHISPER_ID)
wh_model    = WhisperForConditionalGeneration.from_pretrained(WHISPER_ID).to(DEVICE).eval()

# Whisper’s expected SR (should be 16000)
WHISPER_SR = getattr(wh_proc.feature_extractor, "sampling_rate", 16000)
assert isinstance(WHISPER_SR, int), f"Bad sampling rate: {WHISPER_SR}"
print("Whisper SR:", WHISPER_SR)

Whisper SR: 16000


In [15]:
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import numpy as np, torch, time

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

WHISPER_ID = "openai/whisper-medium"   # or "small"/"base" for speed
wh_proc  = WhisperProcessor.from_pretrained(WHISPER_ID)
wh_model = WhisperForConditionalGeneration.from_pretrained(WHISPER_ID).to(DEVICE).eval()

WHISPER_SR = int(wh_proc.feature_extractor.sampling_rate)
print("Whisper SR:", WHISPER_SR)  # should print 16000

Whisper SR: 16000


In [16]:
from torch.utils.data import DataLoader

def collate_wavs(examples):
    # examples is a list of dataset rows (dicts)
    wavs = [ex["audio"]["array"].astype("float32") for ex in examples]
    lens = [len(w) for w in wavs]
    return wavs, lens

BATCH_SIZE = 8
loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_wavs)

In [17]:
wavs, lens = next(iter(loader))
assert isinstance(wavs, list) and len(wavs) > 0
assert isinstance(wavs[0], np.ndarray) and wavs[0].ndim == 1
assert len(wavs) == len(lens)

In [18]:
import time, torch, numpy as np
from torch.utils.data import DataLoader
from transformers import WhisperProcessor, WhisperForConditionalGeneration

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Using a small/base model for CPU; switch back to "openai/whisper-medium".
WHISPER_ID = "openai/whisper-small"
wh_proc = WhisperProcessor.from_pretrained(WHISPER_ID)
wh_model = WhisperForConditionalGeneration.from_pretrained(WHISPER_ID).to(DEVICE).eval()

# Whisper expects 16 kHz features internally
WHISPER_SR = int(getattr(wh_proc.feature_extractor, "sampling_rate", 16000))

# ---- Dataloader that returns lists of 1-D float arrays + lengths ----
def collate_keep_lists(examples):
    return {
        "wavs":  [ex["audio"]["array"] for ex in examples],  # list[np.ndarray], dtype float32, 1-D
        "lens":  [ex["audio"]["array"].shape[0] for ex in examples],  # sample counts
        "text":  [ex["text"] for ex in examples],
    }

BATCH_SIZE = 16  # 8–32 is fine on CPU for small/base; lower if you see OOM
loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_keep_lists)

# Beam search config – set num_beams=1 for greedy (faster, fewer stalls)
GEN_KW = dict(
    task="transcribe",          # or "translate"
    language="en",              # set explicitly if needed
    num_beams=1,                # <-- greedy = fastest
    max_new_tokens=225,         # cap generated length; adjust if needed
    length_penalty=1.0,
    no_repeat_ngram_size=3,
)

hypotheses = []
audio_seconds_total = 0.0
t0 = time.time()

with torch.inference_mode():
    for bi, batch in enumerate(loader, 1):
        wavs = batch["wavs"]             # list of 1-D float arrays
        lens = batch["lens"]             # lengths in input sampling rate
        refs = batch["text"]

        # Updating RTF denominator using *model* SR
        audio_seconds_total += sum(l for l in lens) / WHISPER_SR

        # Processor computes log-mels and attention_mask; pad to longest in the batch
        inputs = wh_proc(
            wavs,
            sampling_rate=WHISPER_SR,
            return_tensors="pt",
            padding=True
        )

        input_features = inputs["input_features"].to(DEVICE)
        attention_mask = inputs.get("attention_mask")  # present in recent transformers

        # Generating ids (pass attention_mask to avoid ‘pad==eos’ pathology)
        gen_ids = wh_model.generate(
            input_features,
            attention_mask=attention_mask,   # <-- IMPORTANT on padded batches
            **GEN_KW
        )

        texts = wh_proc.batch_decode(gen_ids, skip_special_tokens=True)
        hypotheses.extend(texts)

        # progress heartbeat
        if bi % 5 == 0 or bi == len(loader):
            wall = time.time() - t0
            rtf = wall / max(audio_seconds_total, 1e-6)
            print(f"[Whisper] batch {bi}/{len(loader)} | items: {len(hypotheses)} | wall: {wall:,.1f}s | audio: {audio_seconds_total/60:,.1f}m | RTF: {rtf:.3f}")

wall = time.time() - t0
rtf = wall / max(audio_seconds_total, 1e-6)
print(f"\n[Whisper DONE] items: {len(hypotheses)} | wall: {wall:,.1f}s | audio: {audio_seconds_total/60:,.1f}m | RTF: {rtf:.3f}")


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


[Whisper] batch 5/17 | items: 80 | wall: 65.7s | audio: 8.4m | RTF: 0.131
[Whisper] batch 10/17 | items: 160 | wall: 138.4s | audio: 19.7m | RTF: 0.117
[Whisper] batch 15/17 | items: 240 | wall: 203.7s | audio: 29.9m | RTF: 0.114
[Whisper] batch 17/17 | items: 262 | wall: 225.7s | audio: 32.4m | RTF: 0.116

[Whisper DONE] items: 262 | wall: 225.7s | audio: 32.4m | RTF: 0.116


In [23]:
# 'references' should be your ground-truth list, already normalized the same way
from jiwer import wer, cer

wh_wer = wer(references, hypotheses) * 100
wh_cer = cer(references, hypotheses) * 100
print(f"[Whisper]  WER: {wh_wer:.2f}%  CER: {wh_cer:.2f}%")

[Whisper]  WER: 22.75%  CER: 5.63%


In [25]:
import pandas as pd, json, time, os

# 1) Make sure we expose the hyps with the expected name
wh_hyps = hypotheses  # <-- 'hypotheses' came from your Whisper loop

# 2) If your timing variables have different names, map them here:
# wall = wh_wall
# rtf  = wh_rtf

# 3) Build a compact summary of the run
run = {
    "model": "whisper-medium",        # adjust if you used a different size
    "beam": 5,                        # match your GEN_KW settings
    "length_penalty": 1.0,
    "no_repeat_ngram_size": 3,
    "dataset": "LibriSpeech test-clean (slice)",  # or your exact split label
    "items": len(references),
    "wall_s": wall,                   # total wall time (seconds)
    "audio_s": audio_seconds_total,   # total audio (seconds)
    "rtf": rtf,                       # real-time factor
    "wer": wh_wer / 100.0,            # store as fraction (e.g., 0.2275)
    "cer": wh_cer / 100.0             # store as fraction (e.g., 0.0563)
}

# 4) Building a pairwise ref/hyp table (useful for error analysis)
df = pd.DataFrame({
    "ref": references,
    "hyp": wh_hyps
})

# 5) Write artifacts
os.makedirs("runs", exist_ok=True)
ts = time.strftime("%Y%m%d-%H%M%S")
df.to_csv(f"runs/whisper_medium_librispeech_{ts}.csv", index=False)

with open(f"runs/whisper_medium_librispeech_{ts}.json", "w") as f:
    json.dump(run, f, indent=2)

print(f"[Saved] CSV + JSON under runs/ with timestamp {ts}")
# ---- end block ----

[Saved] CSV + JSON under runs/ with timestamp 20250904-132637


In [32]:
import numpy as np
import librosa  # only needed if you ever get filepaths

TARGET_SR = 16000  # make sure this matches your pipeline

def extract_wavs_and_lens(batch, target_sr=TARGET_SR):
    """
    Return: wavs (list of 1-D float arrays), lens (list of ints)
    Handles several common DataLoader shapes:
      - dict with "wavs"
      - dict with "audio" (list of dicts or arrays)
      - list of examples where each ex["audio"]["array"] exists
      - (wavs, lens) tuple prepared by a collate_fn
      - paths (rare) -> will load with librosa
    """
    # 1) dict batch (HuggingFace default)
    if isinstance(batch, dict):
        if "wavs" in batch:                               # custom collate_keep_lists
            wavs = batch["wavs"]
        elif "audio" in batch:                            # HF Datasets default
            aud = batch["audio"]
            if isinstance(aud, list):
                # list of dicts or arrays
                wavs = [
                    (a["array"] if isinstance(a, dict) and "array" in a else np.asarray(a, dtype=np.float32))
                    for a in aud
                ]
            elif isinstance(aud, dict) and "array" in aud:
                wavs = [aud["array"]]
            else:
                raise TypeError(f"Unsupported 'audio' field type: {type(aud)}")
        else:
            raise KeyError("Batch dict has neither 'wavs' nor 'audio'")
        lens = [len(w) for w in wavs]
        return wavs, lens

    # 2) list of examples, each a dict
    if isinstance(batch, list) and len(batch) > 0 and isinstance(batch[0], dict):
        wavs = [
            (ex["audio"]["array"] if isinstance(ex["audio"], dict) else np.asarray(ex["audio"], dtype=np.float32))
            for ex in batch
        ]
        lens = [len(w) for w in wavs]
        return wavs, lens

    # 3) collate_fn already produced (wavs, lens)
    if isinstance(batch, (tuple, list)) and len(batch) == 2:
        wavs, lens = batch
        return wavs, lens

    # 4) paths (rare)
    if isinstance(batch, list) and len(batch) > 0 and isinstance(batch[0], str):
        wavs = [librosa.load(p, sr=target_sr)[0] for p in batch]
        lens = [len(w) for w in wavs]
        return wavs, lens

    raise TypeError(f"Don't know how to read batch of type {type(batch)}")

In [33]:
# Lists to store predictions
w2v_hyps = []
audio_seconds_total = 0.0
t0 = time.time()

with torch.no_grad():
    for batch in loader:
        wavs, lens = extract_wavs_and_lens(batch)  # <-- robust extraction
        audio_seconds_total += sum(lens) / TARGET_SR

        inputs = processor(
            wavs, sampling_rate=TARGET_SR, return_tensors="pt", padding=True
        )
        inputs = {k: v.to(DEVICE) for k, v in inputs.items()}

        logits = model(**inputs).logits
        pred_ids = torch.argmax(logits, dim=-1)

        texts = processor.batch_decode(pred_ids, skip_special_tokens=True)
        w2v_hyps.extend([t.lower().strip() for t in texts])  # normalize

w2v_wall = time.time() - t0
w2v_rtf  = w2v_wall / audio_seconds_total

from jiwer import wer, cer
w2v_wer = wer(references, w2v_hyps) * 100
w2v_cer = cer(references, w2v_hyps) * 100

print(f"[Wav2Vec2] WER: {w2v_wer:.2f}% CER: {w2v_cer:.2f}% RTF: {w2v_rtf:.3f}")


[Wav2Vec2] WER: 10.64% CER: 2.33% RTF: 0.169
