In [1]:
from datasets import load_dataset
from transformers import pipeline, WhisperForConditionalGeneration, WhisperProcessor
from evaluate import load
import torch

In [2]:
dataset = load_dataset("LIUM/tedlium", "release3", split="validation", streaming=True)
dataset = dataset.take(32)

whisper_asr = pipeline(
    "automatic-speech-recognition", model="openai/whisper-tiny.en", device=0
)

whisper_asr.model.config.suppress_tokens.remove(6)
whisper_asr.model.config.suppress_tokens.remove(12)

wer_metric = load("wer")

In [3]:
def normalise(text):
    return whisper_asr.tokenizer._normalize(text)

In [4]:
# helper function: get the column names for the datasets
def get_text(sample):
    if "text" in sample:
        return sample["text"]
    elif "sentence" in sample:
        return sample["sentence"]
    elif "normalized_text" in sample:
        return sample["normalized_text"]
    elif "transcript" in sample:
        return sample["transcript"]
    else:
        raise ValueError(f"Sample: {sample.keys()} has no transcript.")

## Method 1: with pipeline

In [5]:
def predict_and_normalise(batch):
    references = get_text(batch)
    predictions = whisper_asr(batch["audio"])

    batch["ref"] = [normalise(ref) for ref in references]
    batch["pred"] = [normalise(pred["text"]) for pred in predictions]

    return batch

In [6]:
# batch size for extracting references and predictions
batch_size = 8

result_set = dataset.map(
    predict_and_normalise,
    batched=True,
    batch_size=batch_size,
    remove_columns=dataset.features.keys(),
)

In [7]:
def is_target_text_in_range(ref):
    if ref.strip() == "ignore time segment in scoring":
        return False
    else:
        return ref.strip() != ""

In [8]:
result_set = result_set.filter(is_target_text_in_range, input_columns=["ref"])

In [9]:
%%time

refs = []
preds = []

for i, sample in enumerate(result_set):
    refs.append(sample["ref"])
    preds.append(sample["pred"])



CPU times: user 51.2 s, sys: 1.17 s, total: 52.4 s
Wall time: 16.3 s


In [10]:
wer = wer_metric.compute(references=refs, predictions=preds)
wer = round(100 * wer, 2)

print("WER: ", wer)

WER:  3.95


#### Re-run pipeline again to remove 'warm-up' effects from datasets and cuda init:

In [11]:
%%time

refs = []
preds = []

for i, sample in enumerate(result_set):
    refs.append(sample["ref"])
    preds.append(sample["pred"])

CPU times: user 41.3 s, sys: 690 ms, total: 42 s
Wall time: 6.89 s


In [12]:
wer = wer_metric.compute(references=refs, predictions=preds)
wer = round(100 * wer, 2)

print("WER: ", wer)

WER:  3.95


## Method 2: with processor + model

In [13]:
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en").to("cuda:1")

model.config.suppress_tokens.remove(6)
model.config.suppress_tokens.remove(12)

In [14]:
def predict_and_normalise_2(batch):    
    audios = [audio["array"] for audio in batch["audio"]]
    references = get_text(batch)
    input_features = processor(audios, sampling_rate=16000, return_tensors="pt").input_features

    with torch.no_grad():
        predicted_ids = model.generate(input_features.to("cuda:1"))
    predictions = processor.batch_decode(predicted_ids, skip_special_tokens=True)

    batch["ref"] = [normalise(ref) for ref in references]
    batch["pred"] = [normalise(pred) for pred in predictions]
    return batch

In [15]:
# batch size for extracting references and predictions
batch_size = 8

result_set_2 = dataset.map(
    predict_and_normalise_2,
    batched=True,
    batch_size=batch_size,
    remove_columns=dataset.features.keys(),
)

result_set_2 = result_set_2.filter(is_target_text_in_range, input_columns=["ref"])

In [16]:
%%time

refs = []
preds = []

for i, sample in enumerate(result_set_2):
    refs.append(sample["ref"])
    preds.append(sample["pred"])

CPU times: user 18.5 s, sys: 364 ms, total: 18.8 s
Wall time: 4.55 s


In [17]:
wer = wer_metric.compute(references=refs, predictions=preds)
wer = round(100 * wer, 2)

print("WER: ", wer)

WER:  3.95
