In [1]:
from datasets import load_dataset
from transformers import WhisperForConditionalGeneration,WhisperProcessor,WhisperTokenizer
import soundfile as sf
import torchaudio
import torch
from jiwer import wer
import os
import sys



In [2]:
models = ["openai/whisper-large-v2","openai/whisper-medium.en","openai/whisper-small.en"]
current_model = models[0]

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
model = WhisperForConditionalGeneration.from_pretrained(current_model)
processor = WhisperProcessor.from_pretrained(current_model)

In [5]:
tokenizer = WhisperTokenizer.from_pretrained(current_model)

In [6]:
dataset = load_dataset("librispeech_asr","clean")

Found cached dataset librispeech_asr (D:/AI/HugginFace/datasets/librispeech_asr/clean/2.1.0/cff5df6e7955c80a67f80e27e7e655de71c689e2d2364bece785b972acb37fe7)


  0%|          | 0/4 [00:00<?, ?it/s]

In [39]:
dataset['test']["text"][0]

'CONCORD RETURNED TO ITS PLACE AMIDST THE TENTS'

In [7]:
x = processor(dataset["test"]['audio'][0]['array'], sampling_rate=16_000, return_tensors="pt").input_features

In [8]:
x.shape

torch.Size([1, 80, 3000])

In [16]:
out = model(x,decoder_input_ids = torch.tensor([[50258]]))

In [41]:
processor.batch_decode(out)[0]

'<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Concord returned to its place amidst the tents.<|endoftext|>'

In [7]:
decoder_input_ids = tokenizer.encode("<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",return_tensors="pt").to(device)

In [30]:
generated = model.generate(x,decoder_input_ids=decoder_input_ids,max_length=100,do_sample=True,top_k=50,top_p=0.95,temperature=0.9)

In [34]:
tokenizer.decode(generated[0],skip_special_tokens=True)

' Concord returned to its place amidst the tents.'

In [43]:
score = wer(tokenizer.decode(generated[0],skip_special_tokens=True).upper()[:-1],dataset['test']["text"][0])
print(tokenizer.decode(generated[0],skip_special_tokens=True).upper()[:-1],dataset['test']["text"][0],score)

 CONCORD RETURNED TO ITS PLACE AMIDST THE TENTS CONCORD RETURNED TO ITS PLACE AMIDST THE TENTS 0.0


In [11]:
test_dataset = dataset["test"]

In [22]:
test_dataset

Dataset({
    features: ['file', 'audio', 'text', 'speaker_id', 'chapter_id', 'id'],
    num_rows: 2620
})

In [9]:
from tqdm import tqdm

In [8]:
def get_transcription(array):
    x = processor(array, sampling_rate=16_000, return_tensors="pt").input_features.to(device)
    with torch.no_grad():
        generated = model.generate(x,decoder_input_ids=decoder_input_ids,max_length=100)
    generated = processor.tokenizer.batch_decode(generated,skip_special_tokens=True)
    
    return generated



In [25]:
test_dataset_subset = test_dataset.select(range(500))


In [13]:
model.to(device)

WhisperForConditionalGeneration(
  (model): WhisperModel(
    (encoder): WhisperEncoder(
      (conv1): Conv1d(80, 1280, kernel_size=(3,), stride=(1,), padding=(1,))
      (conv2): Conv1d(1280, 1280, kernel_size=(3,), stride=(2,), padding=(1,))
      (embed_positions): Embedding(1500, 1280)
      (layers): ModuleList(
        (0): WhisperEncoderLayer(
          (self_attn): WhisperAttention(
            (k_proj): Linear(in_features=1280, out_features=1280, bias=False)
            (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1280, out_features=5120, bias=True)
          (fc2): Linear(in_features=5120, out_features=1280, bias=True)
     

In [26]:
results = []
for i in tqdm(range(len(test_dataset_subset))):
    batch = test_dataset['audio'][i]['array']
    text = test_dataset['text'][i]
    transcription = get_transcription(batch)

    results.append((text,transcription))

100%|██████████| 500/500 [1:13:41<00:00,  8.84s/it]


In [27]:
def normalize(x):
    x = list(x)
    x[1] = x[1][0].upper()
    x[1] = x[1][:-1]
    return x

results_copy = results
results = list(map(normalize,results))

In [30]:
wer([r[0] for r in results],[r[1].replace(',',"").replace('?',"") for r in results])

0.05652496751438649

In [37]:
punctuation = [',','.','?','!',':',';']
un_punctuated = [r[1].replace(',',"").replace('?',"").replace('.',"").replace('!',"") for r in results]

In [38]:
wer([r[0] for r in results],un_punctuated)

0.04241692964544273