In [36]:
import os
import numpy as np


import torch
import json
import torchaudio
import sys

from torch.utils.data import DataLoader
from evaluate import load
from datasets import Audio

from model import Whisper


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
from transformers import WhisperProcessor

from operator import attrgetter

wer_metric = load("wer")

In [2]:
from types import SimpleNamespace

with open('whisper-small-config.json', 'r') as f:
    dims = json.load(f)
dims = SimpleNamespace(**{
    "n_mels": dims['num_mel_bins'],
    "n_audio_ctx": dims['max_source_positions'],
    "n_audio_state": dims['d_model'],
    "n_audio_head": dims['encoder_attention_heads'],
    "n_audio_layer": dims['encoder_layers'],
    "n_vocab": dims['vocab_size'],
    "n_text_ctx": dims['max_target_positions'],
    "n_text_state": dims['d_model'],
    "n_text_head": dims['decoder_attention_heads'],
    "n_text_layer": dims['decoder_layers']
})
model = Whisper(dims)
model.load_state_dict(torch.load("../whisper-small.pt"))


<All keys matched successfully>

In [37]:
from datasets import load_dataset

dataset_hf =load_dataset("ylacombe/english_dialects", "scottish_male")
print(dataset_hf)
dataset_hf['train'] = dataset_hf['train'].cast_column("audio", Audio(sampling_rate=16000))
print(dataset_hf['train'][0])

DatasetDict({
    train: Dataset({
        features: ['line_id', 'audio', 'text', 'speaker_id'],
        num_rows: 1649
    })
})
{'line_id': 'EN0003', 'audio': {'path': 'scm_04310_00656117725.wav', 'array': array([-8.02338036e-06,  6.05384776e-05,  4.38706120e-05, ...,
       -3.86700849e-04, -3.47789202e-04, -2.66398594e-04]), 'sampling_rate': 16000}, 'text': 'These take the shape of a long round arch with its path high above and its two ends apparently beyond the horizon', 'speaker_id': 4310}


In [62]:
class Dataset(torch.utils.data.Dataset):
    """
    A simple class to wrap LibriSpeech and trim/pad the audio to 30 seconds.
    It will drop the last few seconds of a very small portion of the utterances.
    """
    def __init__(self, dataset, device=DEVICE, padding_token_id=-100):
        self.dataset = dataset
        self.processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="English", task="transcribe")
        self.device = device
        self.feature_extractor = self.processor.feature_extractor
        self.tokenizer = self.processor.tokenizer
        self.padding_token_id = padding_token_id
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        data = self.dataset[idx]
        audio = data['audio']
        audio_array = np.array(audio['array'])
        print(audio_array.shape)
        text = data['text']
        sample_rate = audio['sampling_rate']
        assert sample_rate == 16000
        max_length = sample_rate * 30
        if audio_array.shape[0] < max_length:
            # Pad with zeros
            padded_audio = np.zeros(max_length)
            padded_audio[:audio_array.shape[0]] = audio_array
            audio_array = padded_audio
        else:
            # Truncate ones over 30 seconds - need to fix this
            audio_array = audio_array[:, :max_length]
        print(audio_array.shape)
        
        mel = self.feature_extractor(audio_array.flatten(), sampling_rate=sample_rate)
        
        
        text = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=400)
        labels = text.input_ids
        labels.masked_fill_(text.attention_mask.eq(0), self.padding_token_id)
            
        
        return (mel.input_features.squeeze(0), labels.squeeze(0), data['text'])

batch_size=4
dataset = Dataset(dataset_hf['train'])
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


subset = torch.utils.data.Subset(dataset, range(100))
loaderSubset = DataLoader(subset, batch_size=batch_size, shuffle=True)

In [66]:

model.eval()
all_predictions = []
all_references = []
wer = 0
with torch.no_grad():
    for _, (mel, text, original_text) in enumerate(loader):
        print(original_text)
        mel = mel.to(DEVICE)
        text = text.to(DEVICE)
        # Get model predictions
        print(mel.shape)
        # mel = mel.to(device)
        outputs = model.decode(mel)
        # first, lang, probs, tokens = outputs
        
        batch_predictions = list(map(attrgetter('text'), outputs))  # More efficient than list comprehension
        # print(first.text, lang.text, probs, tokens)
        # Decode predictions and reference text
        # pred_text = dataset.processor.batch_decode(outputs, skip_special_tokens=True)
        # ref_text = dataset.processor.batch_decode(text, skip_special_tokens=True)
        print(original_text)
        print(batch_predictions)
        batch_wer = wer_metric.compute(predictions=batch_predictions, references=original_text)
        print(batch_wer)
        wer += batch_wer
        # Collect predictions and references
        # all_predictions.extend(batch_predictions)
        # all_references.extend(original_text)
        
        # Print batch results for debugging
        # print("Batch Predictions:", pred_text)
        # print("Batch References:", ref_text)

# Calculate total WER
print(f"wer TOTAL rate: {wer}")
print(f"Word Error Rate: {wer/len(loader):.4f}")

(69632,)
(480000,)
(154283,)
(480000,)
(77824,)
(480000,)
(96939,)
(480000,)
('Peugeot is a French automotive manufacturer', 'The best goose feathers for quill pens were the primary feathers of the left wing whose curvature bent away from the eyes of right-handed writers', "This is the cinematic superhero showdown you've dreamt of since childhood", 'Nonstop flights from Pretoria to Medellin are about ten hours and four minutes long')
torch.Size([4, 80, 3000])
('Peugeot is a French automotive manufacturer', 'The best goose feathers for quill pens were the primary feathers of the left wing whose curvature bent away from the eyes of right-handed writers', "This is the cinematic superhero showdown you've dreamt of since childhood", 'Nonstop flights from Pretoria to Medellin are about ten hours and four minutes long')
['Peugeot is a French automotive manufacturer.', 'The best goose feathers for quill pens were the primary feathers of the left wing, whose curvature bent away from the eyes of

KeyboardInterrupt: 