In this notebook, we're going to compare the transcription quality of the Wav2vec2 speech-to-text model to two others, Speech2Text2 and HuBERT.

# Setup

In [1]:
!pip install transformers==4.11.3 datasets==1.13.3 librosa==0.8.1 jiwer==2.2.0 torchaudio==0.9.1 sentencepiece==0.1.96

Collecting transformers==4.11.3
  Downloading transformers-4.11.3-py3-none-any.whl (2.9 MB)
[K     |████████████████████████████████| 2.9 MB 5.2 MB/s 
[?25hCollecting datasets==1.13.3
  Downloading datasets-1.13.3-py3-none-any.whl (287 kB)
[K     |████████████████████████████████| 287 kB 48.1 MB/s 
Collecting jiwer==2.2.0
  Downloading jiwer-2.2.0-py3-none-any.whl (13 kB)
Collecting torchaudio==0.9.1
  Downloading torchaudio-0.9.1-cp37-cp37m-manylinux1_x86_64.whl (1.9 MB)
[K     |████████████████████████████████| 1.9 MB 31.6 MB/s 
[?25hCollecting sentencepiece==0.1.96
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 6.1 MB/s 
Collecting huggingface-hub>=0.0.17
  Downloading huggingface_hub-0.4.0-py3-none-any.whl (67 kB)
[K     |████████████████████████████████| 67 kB 2.0 MB/s 
Collecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.man

In [2]:
from datasets import load_dataset, load_metric
from transformers import (Wav2Vec2Processor, 
                          Wav2Vec2ForCTC, 
                          Speech2TextProcessor, 
                          Speech2TextForConditionalGeneration, 
                          HubertForCTC)
import librosa
import torch

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data

We'll evaluate each model on TIMIT.



In [4]:
timit = load_dataset("timit_asr", split="test")
print("TIMIT test set size: ", len(timit))

Downloading:   0%|          | 0.00/2.40k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

Downloading and preparing dataset timit_asr/clean (download: 828.75 MiB, generated: 7.90 MiB, post-processed: Unknown size, total: 836.65 MiB) to /root/.cache/huggingface/datasets/timit_asr/clean/2.0.1/5bebea6cd9df0fc2c8c871250de23293a94c1dc49324182b330b6759ae6718f8...


Downloading:   0%|          | 0.00/869M [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset timit_asr downloaded and prepared to /root/.cache/huggingface/datasets/timit_asr/clean/2.0.1/5bebea6cd9df0fc2c8c871250de23293a94c1dc49324182b330b6759ae6718f8. Subsequent calls will reuse this data.
TIMIT test set size:  1680


We're using Huggingface Datasets to handle data loading, so we need a simple function load audio files in batches.

In [5]:
def process_data(batch):
    batch["speech"], batch["sampling_rate"] = librosa.load(batch["file"], sr=16000)    
    return batch

Each model needs the data in the same format, so we can run the preprocessing first.

In [6]:
timit = timit.map(process_data, remove_columns=['file', 'audio', 'phonetic_detail', 'word_detail', 'dialect_region', 'sentence_type', 'speaker_id', 'id'])

  0%|          | 0/1680 [00:00<?, ?ex/s]

In [7]:
timit

Dataset({
    features: ['text', 'speech', 'sampling_rate'],
    num_rows: 1680
})

In [8]:
timit["text"][:5]

['The bungalow was pleasantly situated near the shore.',
 "Don't ask me to carry an oily rag like that.",
 'Are you looking for employment?',
 'She had your dark suit in greasy wash water all year.',
 "At twilight on the twelfth day we'll have Chablis."]

# Evaluation

Automatic speech recognition models are generally evaluated using the word error rate (WER), and sometimes with the character error rate (CER), which is a character-level analog of WER

$WER = \frac{S + D + I}{N}$

$CER = \frac{S + D + I}{N}$

- S: number of substitutions
- D: number of deletions
- I: number of insertions,
- N: number of words in the reference text

Note that for CER, $N$ is the number of characters in the reference text.

Huggingface Datasets has both metrics built-in.


In [9]:
wer = load_metric("wer")
cer = load_metric("cer")

Downloading:   0%|          | 0.00/1.95k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.91k [00:00<?, ?B/s]

# Models

## Wav2vec2

In [10]:
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h").to(device)

Downloading:   0%|          | 0.00/159 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/291 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/163 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/843 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.18G [00:00<?, ?B/s]

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
import re

def remove_punc(text):
    text = str(text)
#     text = text.lower()
    text = text.replace('.', '')
    text = text.replace(',', '')
    text = text.replace('?', '')
    text = text.replace(';', '')
    text = text.replace('!', '')    
    return text

def wav2vec2_predict(batch):
    features = processor(
        batch["speech"],
        sampling_rate=batch["sampling_rate"][0],
        padding=True,
        return_tensors="pt")

    input_values = features["input_values"].to(device)

    with torch.no_grad():
        logits = model(input_values).logits

    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)
    batch["transcription"] = transcription    
    # Wav2vec 2's base model doesn't produce punctuation and uppercases text
    batch["target"] = [remove_punc(x.upper()) for x in batch["text"]]
    return batch

In [12]:
BATCH_SIZE = 16
result = timit.map(wav2vec2_predict, 
                   batched=True, 
                   batch_size=BATCH_SIZE, 
                   remove_columns=["speech", "sampling_rate"])

print("WER: ", wer.compute(predictions=result["transcription"], 
                           references=result["target"]))
print("CER: ", cer.compute(predictions=result["transcription"], 
                           references=result["target"]))

  0%|          | 0/105 [00:00<?, ?ba/s]

WER:  0.09291913486706158
CER:  0.022128702861186938


## Speech2Text2

Speech2Text2 is a transformer decoder model that can be used with _any_ speech encoder, such as Wav2Vec2 or HuBERT.

We'll use a Wav2vec2 variant, so that we can compare it to Wav2vec2 itself.



In [13]:
s2t_processor = Speech2TextProcessor.from_pretrained("facebook/s2t-large-librispeech-asr")
s2t_model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-large-librispeech-asr").to(device)

Downloading:   0%|          | 0.00/242 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/230k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/407k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/457 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.11k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.00G [00:00<?, ?B/s]

In [14]:
def remove_punc(text):
    text = str(text)
    text = text.lower()
    text = text.replace('.', '')
    text = text.replace(',', '')
    text = text.replace('?', '')
    text = text.replace(';', '')
    text = text.replace('!', '')    
    return text


def s2t_predict(batch):
    features = s2t_processor(
        batch["speech"],
        sampling_rate=batch["sampling_rate"][0],
        padding=True,
        return_tensors="pt")

    input_features = features["input_features"].to(device)
    # including the attention mask is important for this model
    # if it is omitted, then the model may generate transcription 
    # that is noticably longer than the target
    attention_mask = features["attention_mask"].to(device)
    
    with torch.no_grad():
        generated_ids = s2t_model.generate(input_ids=input_features,
                                           attention_mask=attention_mask)

    batch["transcription"] = s2t_processor.batch_decode(generated_ids, 
                                                        skip_special_tokens=True)    
    # Speech2Text2 model doesn't produce punctuation and lowercases text
    batch["target"] = [remove_punc(x) for x in batch["text"]]
    return batch

In [15]:
# s2t_predict(timit[:2])["transcription"]

In [16]:
BATCH_SIZE = 16
result = timit.map(s2t_predict, 
                   batched=True, 
                   batch_size=BATCH_SIZE, 
                   remove_columns=["speech", "sampling_rate"])

print("WER: ", wer.compute(predictions=result["transcription"], 
                           references=result["target"]))
print("CER: ", cer.compute(predictions=result["transcription"], 
                           references=result["target"]))

  0%|          | 0/105 [00:00<?, ?ba/s]

To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  ../aten/src/ATen/native/BinaryOps.cpp:467.)
  return torch.floor_divide(self, other)


WER:  0.10283785645405703
CER:  0.03624507127691841


## HuBERT



In [17]:
hb_processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
hb_model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").to(device)

Downloading:   0%|          | 0.00/212 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/291 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/138 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.34k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.18G [00:00<?, ?B/s]

In [18]:
def remove_punc(text):    
    text = text.upper()
    text = text.replace('.', '')
    text = text.replace(',', '')
    text = text.replace('?', '')
    text = text.replace(';', '')
    text = text.replace('!', '')    
    return text


def hb_predict(batch):
    features = hb_processor(
        batch["speech"],
        sampling_rate=batch["sampling_rate"][0],
        padding=True,
        return_tensors="pt")

    input_values = features["input_values"].to(device)
    with torch.no_grad():
        logits = model(input_values).logits

    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)
    batch["transcription"] = transcription    
    # HuBERT doesn't produce punctuation and uppercases text
    batch["target"] = [remove_punc(x) for x in batch["text"]]    
    return batch

In [19]:
hb_predict(timit[:2])["transcription"]

['THE BUNGALOW WAS PLEASANTLY SITUATED NEAR THE SHORE',
 "DON'T ASK ME TO CARRY AN OILY RAG LIKE THAT"]

In [20]:
BATCH_SIZE = 16
result = timit.map(hb_predict, 
                   batched=True, 
                   batch_size=BATCH_SIZE, 
                   remove_columns=["speech", "sampling_rate"])

print("WER: ", wer.compute(predictions=result["transcription"], 
                           references=result["target"]))
print("CER: ", cer.compute(predictions=result["transcription"], 
                           references=result["target"]))

  0%|          | 0/105 [00:00<?, ?ba/s]

WER:  0.09298801487808238
CER:  0.022128702861186938
