#### To install KenLM:

`pip install https://github.com/kpu/kenlm/archive/master.zip`

In [36]:
import pandas as pd
import numpy as np
import os
import kenlm
from transformers import Wav2Vec2ProcessorWithLM, Wav2Vec2ForCTC, Wav2Vec2Tokenizer, Wav2Vec2Processor, AutoModelForCTC, AutoProcessor
from datasets import load_dataset
from jiwer import wer
import librosa
import nltk
import tarfile
import torch
import urllib.request
import soundfile as sf

nltk.download('punkt')

[nltk_data] Downloading package punkt to /Users/juliawang/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [21]:
# set paths
datasets_path = os.path.join(os.getcwd(), 'datasets') 
# create folders if they do not already exist
if not os.path.exists(datasets_path): os.makedirs(datasets_path)
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [22]:
def download_and_extract_dataset_from_url(url: str, datasets_path: str = datasets_path):
    """
    downloads and extracts dataset from url into datasets_path/
    """
    if not os.path.exists('./datasets/LibriSpeech'):
        print('downloading test dataset (300mb)')
        urllib.request.urlretrieve("https://www.openslr.org/resources/12/test-clean.tar.gz", "test.tar.gz")

        print('extracting data')
        file = tarfile.open('test.tar.gz')
        file.extractall('./data')
        file.close()

In [51]:
# load extracted lr data as dataset
librispeech_eval = load_dataset("datasets/LibriSpeech", "clean", split='test[:5%]')

Resolving data files:   0%|          | 0/2707 [00:00<?, ?it/s]

Using custom data configuration LibriSpeech-ba5746f5e657f0cb
Found cached dataset audiofolder (/Users/juliawang/.cache/huggingface/datasets/audiofolder/LibriSpeech-ba5746f5e657f0cb/0.0.0/d21214990bdb6d1fa3e71fd8d9083f8303b0c0ca6911ad366f2e988039ac58c5)


In [52]:
def map_to_ground_truth(batch):
    """
    inserts ground truth in dataset
    """
    transcription_file_path = batch['audio']['path'][:-10] + '.trans.txt'
    f = open(transcription_file_path, 'r')
    lines= str.splitlines(f.read())
    txt=lines[int(batch['audio']['path'][-7:-5])].split(' ', 1)[1]
    batch['txt'] = txt
    return batch

In [53]:
librispeech_eval = librispeech_eval.map(map_to_ground_truth)

  0%|          | 0/131 [00:00<?, ?ex/s]

In [54]:
def load_wav2vec_model(process_path: str):
    """
    load and return wav2vec tokenizer and model from huggingface
    """
    model = AutoModelForCTC.from_pretrained(process_path)
    processor = AutoProcessor.from_pretrained(process_path)
    return processor, model

In [41]:
processor, model = load_wav2vec_model("patrickvonplaten/wav2vec2-base-960h-4-gram")


Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at patrickvonplaten/wav2vec2-base-960h-4-gram and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

In [48]:
def map_to_pred(batch):
    """
    predicts transcription
    """
    #tokenize
    input_values = processor(batch["audio"]["array"], return_tensors="pt")
    #take logits
    with torch.no_grad():
        logits = model(**input_values).logits

    transcription = processor.batch_decode(logits.cpu().numpy()).text[0]
    batch["transcription"] = transcription
    return batch
    

In [55]:
result = librispeech_eval.map(map_to_pred)

  0%|          | 0/131 [00:00<?, ?ex/s]

It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_ra

It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_ra

It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_ra

In [56]:
print('WER with wav2vec2-base-960h-4-gram on lr-test-clean:', round(100 * wer(result["txt"], result["transcription"]), 1), '%.')

WER with wav2vec2-base-960h-4-gram on lr-test-clean: 2.4 %.
