# **Fine-tuning Wav2Vec2 for English ASR with Huggingface Transformers**

In [1]:
!pip install -q datasets==1.18.3
!pip install -q transformers==4.17.0
!pip install -q jiwer

In [2]:
# from huggingface_hub import notebook_login

# notebook_login()

In [3]:
# %%capture
# !apt install git-lfs

## Prepare Data, Tokenizer, Feature Extractor

### Create Wav2Vec2CTCTokenizer

In [4]:
from datasets import load_dataset, load_metric

timit = load_dataset("timit_asr")



  0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
timit

DatasetDict({
    train: Dataset({
        features: ['file', 'audio', 'text', 'phonetic_detail', 'word_detail', 'dialect_region', 'sentence_type', 'speaker_id', 'id'],
        num_rows: 4620
    })
    test: Dataset({
        features: ['file', 'audio', 'text', 'phonetic_detail', 'word_detail', 'dialect_region', 'sentence_type', 'speaker_id', 'id'],
        num_rows: 1680
    })
})

In [6]:
timit = timit.remove_columns(["phonetic_detail", "word_detail", "dialect_region", "speaker_id", "sentence_type", "id"])

Let's write a short function to display some random samples of the dataset and run it a couple of times to get a feeling for the transcriptions.

In [7]:
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)

    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))

In [13]:
show_random_elements(timit["train"].remove_columns(["audio", "file"]), num_examples=10)

Unnamed: 0,text
0,She had your dark suit in greasy wash water all year.
1,Gently place Jim's foam sculpture in the box.
2,She had your dark suit in greasy wash water all year.
3,You're so preoccupied that you've let your faith grow dim.
4,She had your dark suit in greasy wash water all year.
5,A lawyer was appointed to execute her will.
6,What outfit does she drive for?
7,Forget we ever knew what?
8,The easygoing zoologist relaxed throughout the voyage.
9,She had your dark suit in greasy wash water all year.


In [14]:
import re
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"]'

def remove_special_characters(batch):
    batch["text"] = re.sub(chars_to_ignore_regex, '', batch["text"]).lower() + " "
    return batch

In [15]:
timit = timit.map(remove_special_characters)



0ex [00:00, ?ex/s]

0ex [00:00, ?ex/s]

In [16]:
show_random_elements(timit["train"].remove_columns(["audio", "file"]))

Unnamed: 0,text
0,this girl soon drops the bourgeoisie psychiatrist who disapproves of her life
1,she had your dark suit in greasy wash water all year
2,with her sharp tongue she'd have cut his pompousness to ribbons
3,some make beautiful chairs cabinets chests doll houses etc
4,a complete plan we have made limited application of the parallel ladder plan
5,the ward was a small one four beds kept reserved for female alcoholics
6,she had your dark suit in greasy wash water all year
7,lots of foreign movies have subtitles
8,the figure five is important in insurance
9,don't ask me to carry an oily rag like that


In [17]:
def extract_all_chars(batch):
  all_text = " ".join(batch["text"])
  vocab = list(set(all_text))
  return {"vocab": [vocab], "all_text": [all_text]}

In [18]:
vocabs = timit.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=timit.column_names["train"])

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

Now, we create the union of all distinct letters in the training dataset and test dataset and convert the resulting list into an enumerated dictionary.

In [19]:
vocab_list = list(set(vocabs["train"]["vocab"][0]) | set(vocabs["test"]["vocab"][0]))

In [20]:
vocab_dict = {v: k for k, v in enumerate(vocab_list)}
vocab_dict

{'m': 0,
 "'": 1,
 'q': 2,
 't': 3,
 'r': 4,
 'i': 5,
 'p': 6,
 'y': 7,
 'u': 8,
 'd': 9,
 'k': 10,
 's': 11,
 'o': 12,
 'x': 13,
 'n': 14,
 'b': 15,
 'g': 16,
 'j': 17,
 'e': 18,
 'f': 19,
 'w': 20,
 ' ': 21,
 'a': 22,
 'c': 23,
 'z': 24,
 'v': 25,
 'h': 26,
 'l': 27}

In [21]:
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

In [22]:
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
len(vocab_dict)

30

In [23]:
import json
with open('/content/vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

In [24]:
from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer("/content/vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

### Create Wav2Vec2 Feature Extractor

In [25]:
from transformers import Wav2Vec2FeatureExtractor

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)

In [26]:
from transformers import Wav2Vec2Processor

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

### Preprocess Data



In [27]:
timit["train"][0]["file"]

'/root/.cache/huggingface/datasets/downloads/extracted/404950a46da14eac65eb4e2a8317b1372fb3971d980d91d5d5b221275b1fd7e0/data/TRAIN/DR4/MMDM0/SI681.WAV'

`Wav2Vec2` expects the input in the format of a 1-dimensional array of 16 kHz. This means that the audio file has to be loaded and resampled.

 Thankfully, `datasets` does this automatically when calling the column `audio`. Let try it out.

In [28]:
timit["train"][0]["audio"]

{'path': '/root/.cache/huggingface/datasets/downloads/extracted/404950a46da14eac65eb4e2a8317b1372fb3971d980d91d5d5b221275b1fd7e0/data/TRAIN/DR4/MMDM0/SI681.WAV',
 'array': array([-2.1362305e-04,  6.1035156e-05,  3.0517578e-05, ...,
        -3.0517578e-05, -9.1552734e-05, -6.1035156e-05], dtype=float32),
 'sampling_rate': 16000}

In [29]:
import IPython.display as ipd
import numpy as np
import random

rand_int = random.randint(0, len(timit["train"]))

print(timit["train"][rand_int]["text"])
ipd.Audio(data=np.asarray(timit["train"][rand_int]["audio"]["array"]), autoplay=True, rate=16000)

she had your dark suit in greasy wash water all year 


In [31]:
rand_int = random.randint(0, len(timit["train"]))

print("Target text:", timit["train"][rand_int]["text"])
print("Input array shape:", np.asarray(timit["train"][rand_int]["audio"]["array"]).shape)
print("Sampling rate:", timit["train"][rand_int]["audio"]["sampling_rate"])

Target text: those who teach values first abolish cheating 
Input array shape: (45671,)
Sampling rate: 16000


In [32]:
def prepare_dataset(batch):
    audio = batch["audio"]

    # batched output is "un-batched" to ensure mapping is correct
    batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    batch["input_length"] = len(batch["input_values"])

    with processor.as_target_processor():
        batch["labels"] = processor(batch["text"]).input_ids
    return batch

In [33]:
import numpy as np

np.object = object

In [34]:
timit = timit.map(prepare_dataset, remove_columns=timit.column_names["train"], num_proc=4)

  table = cls._concat_blocks(blocks, axis=0)


Long input sequences require a lot of memory. Since `Wav2Vec2` is based on `self-attention` the memory requirement scales quadratically with the input length for long input sequences (*cf.* with [this](https://www.reddit.com/r/MachineLearning/comments/genjvb/d_why_is_the_maximum_input_sequence_length_of/) reddit post). For this demo, let's filter all sequences that are longer than 4 seconds out of the training dataset.

In [35]:
max_input_length_in_sec = 4.0
timit["train"] = timit["train"].filter(lambda x: x < max_input_length_in_sec * processor.feature_extractor.sampling_rate, input_columns=["input_length"])

  0%|          | 0/5 [00:00<?, ?ba/s]

## Training & Evaluation

### Set-up Trainer


In [36]:
import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

In [37]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

In [38]:
wer_metric = load_metric("wer")

In [39]:
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [40]:
from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-base",
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
)

Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2ForCTC: ['quantizer.weight_proj.bias', 'quantizer.weight_proj.weight', 'project_q.weight', 'project_hid.bias', 'project_hid.weight', 'quantizer.codevectors', 'project_q.bias']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['lm_head.weight', 'lm_head.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predicti

In [41]:
model.freeze_feature_encoder()

In [42]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="nothing",
  group_by_length=True,
  per_device_train_batch_size=8,
  evaluation_strategy="steps",
  num_train_epochs=1, # 30
  fp16=True,
  gradient_checkpointing=True,
  save_steps=500,
  eval_steps=500,
  logging_steps=500,
  learning_rate=1e-4,
  weight_decay=0.005,
  warmup_steps=1000,
  push_to_hub=False,
  save_total_limit=2,
)

In [43]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=timit["train"],
    eval_dataset=timit["test"],
    tokenizer=processor.feature_extractor,
)

Using amp half precision backend


### Training

Depending on what GPU was allocated to your google colab it might be possible that you are seeing an `"out-of-memory"` error here. In this case, it's probably best to reduce `per_device_train_batch_size` to 16 or even less and eventually make use of [`gradient_accumulation`](https://huggingface.co/transformers/master/main_classes/trainer.html#trainingarguments).

In [44]:
import numpy as np

np.bool = np.bool_

In [45]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length. If input_length are not expected by `Wav2Vec2ForCTC.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 3978
  Num Epochs = 1
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 498
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Step,Training Loss,Validation Loss


KeyboardInterrupt: 

In [None]:
#trainer.push_to_hub()

In [46]:
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC

processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")


loading feature extractor configuration file https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/preprocessor_config.json from cache at /root/.cache/huggingface/transformers/07e398f6c4f4eb4f676c75befc5ace223491c79cea1109fb4029751892d380a1.bc3155ca0bae3a39fc37fc6d64829c6a765f46480894658bb21c08db6155358d
Feature extractor Wav2Vec2FeatureExtractor {
  "do_normalize": true,
  "feature_extractor_type": "Wav2Vec2FeatureExtractor",
  "feature_size": 1,
  "padding_side": "right",
  "padding_value": 0.0,
  "return_attention_mask": false,
  "sampling_rate": 16000
}

loading configuration file https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/cbb3014bb9f03ead9b94f4a791ff8e777465307670e85079d35e28cbc5d88727.0e2d739358c9b58747bd19db5f9f4320dacabbeb1e6282f5cc1069c5c55a82d2
Model config Wav2Vec2Config {
  "_name_or_path": "facebook/wav2vec2-base-960h",
  "activation_dropout": 0.1,
  "adapter_kernel_size": 3,
  "

In [47]:
from datasets import load_dataset

timit = load_dataset("timit_asr")




  0%|          | 0/2 [00:00<?, ?it/s]

In [48]:
import torchaudio

def prepare_dataset(batch):
    # Resample the audio to 16000 Hz (as expected by Wav2Vec2)
    speech_array, sampling_rate = torchaudio.load(batch["file"])
    resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
    batch["speech"] = resampler(speech_array).squeeze().numpy()
    # Tokenize the audio
    batch["input_values"] = processor(batch["speech"], sampling_rate=16000, return_tensors="pt", padding=True).input_values.squeeze()
    return batch

# Apply the function to preprocess the audio data
timit = timit.map(prepare_dataset)


0ex [00:00, ?ex/s]

0ex [00:00, ?ex/s]

In [49]:
# Move the model to the correct device (GPU if available)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

def predict(batch):
    # Move input to the correct device
    input_values = torch.tensor(batch["input_values"]).unsqueeze(0).to(device)
    with torch.no_grad():
        logits = model(input_values).logits
    predicted_ids = torch.argmax(logits, dim=-1)
    batch["transcription"] = processor.decode(predicted_ids[0])
    return batch


# Apply prediction to one example or more
result = timit["test"].select(range(5)).map(predict)
for res in result:
    print("Predicted transcription:", res["transcription"])
    print("Actual transcription:", res["text"])


0ex [00:00, ?ex/s]

Predicted transcription: THE BUNGALOW WAS PLEASANTLY SITUATED NEAR THE SHORE
Actual transcription: The bungalow was pleasantly situated near the shore.
Predicted transcription: DON'T ASK ME TO CARRY AN OILY RAG LIKE THAT
Actual transcription: Don't ask me to carry an oily rag like that.
Predicted transcription: ARE YOU LOOKING FOR EMPLOYMENT
Actual transcription: Are you looking for employment?
Predicted transcription: SHE HAD YOUR DARK SUIT AND GREASY WASHWATER ALL YEAR
Actual transcription: She had your dark suit in greasy wash water all year.
Predicted transcription: AT TWILIGHT ON THE TWELFTH DAY WE'LL HAVE CHABLI
Actual transcription: At twilight on the twelfth day we'll have Chablis.
