ASR implementation using pretrained wav2vec2


# Fine-tuning Wav2Vec2 with an LM head

In this notebook, we will load the pre-trained wav2vec2 model from [TFHub](https://tfhub.dev) and will fine-tune it on LJSpeech by appending Language Modeling head (LM) over the top of our pre-trained model. 

## Setting Up


In [None]:
!pip3 install -q git+https://github.com/vasudevgupta7/gsoc-wav2vec2@main

In [1]:
import os

import tensorflow as tf
import tensorflow_hub as hub
from wav2vec2 import Wav2Vec2Config

config = Wav2Vec2Config()

print("TF version:", tf.__version__)

TF version: 2.6.1


Download the model


In [2]:
pretrained_layer = hub.KerasLayer("https://tfhub.dev/vasudevgupta7/wav2vec2/1", trainable=True)

In [28]:
# Setting hyperparameters

AUDIO_MAXLEN = 246000

LABEL_MAXLEN = 256

BATCH_SIZE = 128

In the following cell, we will wrap `pretrained_layer` & a dense layer (LM head) with the [Keras's Functional API](https://www.tensorflow.org/guide/keras/functional).

In [4]:
inputs = tf.keras.Input(shape=(AUDIO_MAXLEN,))
hidden_states = pretrained_layer(inputs)
outputs = tf.keras.layers.Dense(config.vocab_size)(hidden_states)

model = tf.keras.Model(inputs=inputs, outputs=outputs)

The dense layer (defined above) is having an output dimension of `vocab_size` as we want to predict probabilities of each token in the vocabulary at each time step.

## Setting up training state

In TensorFlow, model weights are built only when `model.call` or `model.build` is called for the first time, so the following cell will build the model weights for us. Further, we will be running `model.summary()` to check the total number of trainable parameters.

In [5]:
model(tf.random.uniform(shape=(BATCH_SIZE, AUDIO_MAXLEN)))
model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 246000)]          0         
_________________________________________________________________
keras_layer (KerasLayer)     (None, 768, 768)          94371712  
_________________________________________________________________
dense (Dense)                (None, 768, 32)           24608     
Total params: 94,396,320
Trainable params: 94,396,320
Non-trainable params: 0
_________________________________________________________________


Now, we need to define the `loss_fn` and optimizer to be able to train the model. The following cell will do that for us. We will be using the `Adam` optimizer for simplicity. `CTCLoss` is a common loss type that is used for tasks (like `ASR`) where input sub-parts can't be easily aligned with output sub-parts. 

In [6]:
from wav2vec2 import CTCLoss

LEARNING_RATE = 5e-5

loss_fn = CTCLoss(config, (BATCH_SIZE, AUDIO_MAXLEN), division_factor=BATCH_SIZE)
optimizer = tf.keras.optimizers.Adam(LEARNING_RATE)

## Loading & Pre-processing data

Let's now download the LJSpeech dataset from the [official website](https://keithito.com/LJ-Speech-Dataset/) and set it up.

In [7]:
from tensorflow import keras
import os
from glob import glob

In [None]:
keras.utils.get_file(
    os.path.join(os.getcwd(), "data.tar.gz"),
    "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2",
    extract=True,
    archive_format="tar",
    cache_dir=".",
    )

In [8]:
saveto = "./datasets/LJSpeech-1.1"
wavs = glob(f"{saveto}/**/*.wav", recursive=True)
id_to_text = {line.strip().split("|")[0]:line.strip().split("|")[2] for line in open(os.path.join(saveto, "metadata.csv"), encoding="utf-8") }


In [9]:
len(id_to_text)

13100

In [11]:
!pip install librosa

Collecting librosa
  Downloading librosa-0.8.1-py3-none-any.whl (203 kB)
Collecting resampy>=0.2.2
  Downloading resampy-0.2.2.tar.gz (323 kB)
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Collecting audioread>=2.0.0
  Downloading audioread-2.1.9.tar.gz (377 kB)
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Collecting scipy>=1.0.0
  Downloading scipy-1.7.2-cp37-cp37m-win_amd64.whl (34.1 MB)
Collecting scikit-learn!=0.19.0,>=0.14.0
  Using cached scikit_learn-1.0.1-cp37-cp37m-win_amd64.whl (7.2 MB)
Collecting pooch>=1.0
  Downloading pooch-1.5.2-py3-none-any.whl (57 kB)
Collecting joblib>=0.14
  Using cached joblib-1.1.0-py2.py3-none-any.whl (306 kB)
Collecting numba>=0.43.0
  Downloading numba-0.54.1-cp37-cp37m-win_amd64.whl (2.3 MB)
Collecting llvmlite<0.38,>=0.37.0rc1
  Using cached llvmlite-0.37.0-cp37-cp37m-win_amd64.whl (17.0 MB)
Collecting appdirs
  Using cached appdirs

In [17]:
import librosa

def read_audio(file_path):
    y,_=librosa.load(file_path,sr=16000)
    return y

In [15]:
def get_data(wavs, id_to_text, maxlen=50):
    """ returns mapping of audio paths and transcription texts """
    data = []
    for w in wavs:
        id = w.split("\\")[-1].split(".")[0]
        if len(id_to_text[id]) < maxlen:
            data.append((read_audio(w),id_to_text[id]))
    return data

In [18]:
samples = get_data(wavs,id_to_text)

In [None]:
samples[:5]

Alright, so each sub-directory has many `.flac` files and a `.txt` file. The `.txt` file contains text transcriptions for all the speech samples (i.e. `.flac` files) present in that sub-directory.

We can load this text data as follows:

In [19]:
from wav2vec2 import Wav2Vec2Processor
tokenizer = Wav2Vec2Processor(is_tokenizer=True)
processor = Wav2Vec2Processor(is_tokenizer=False)

def preprocess_text(text):
    
    label = tokenizer(text)
    return tf.constant(label, dtype=tf.int32)

def preprocess_speech(audio):
    audio = tf.constant(audio, dtype=tf.float32)
    return processor(tf.transpose(audio))

Downloading `vocab.json` from https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/vocab.json ... DONE


Now, we will define the python generator to call the preprocessing functions we defined in above cells.

In [20]:
def inputs_generator():
    for speech, text in samples:
        yield preprocess_speech(speech), preprocess_text(text)

## Setting up `tf.data.Dataset`

Following cell will setup `tf.data.Dataset` object using its `.from_generator(...)` method. We will be using the `generator` object, we defined in the above cell.

**Note:** For distributed training (especially on TPUs), `.from_generator(...)` doesn't work currently and it is recommended to train on data stored in `.tfrecord` format (Note: The TFRecords should ideally be stored inside a GCS Bucket in order for the TPUs to work to the fullest extent).

You can refer to [this script](https://github.com/vasudevgupta7/gsoc-wav2vec2/blob/main/src/make_tfrecords.py) for more details on how to convert LibriSpeech data into tfrecords.

In [21]:
output_signature = (
    tf.TensorSpec(shape=(None),  dtype=tf.float32),
    tf.TensorSpec(shape=(None), dtype=tf.int32),
)

dataset = tf.data.Dataset.from_generator(inputs_generator, output_signature=output_signature)

In [22]:
BUFFER_SIZE = len(wavs)
SEED = 42

dataset = dataset.shuffle(BUFFER_SIZE, seed=SEED)

We will pass the dataset into multiple batches, so let's prepare batches in the following cell. Now, all the sequences in a batch should be padded to a constant length. We will use the`.padded_batch(...)` method for that purpose.

In [23]:
dataset = dataset.padded_batch(BATCH_SIZE, padded_shapes=(AUDIO_MAXLEN, LABEL_MAXLEN), padding_values=(0.0, 0))

Accelerators (like GPUs/TPUs) are very fast and often data-loading (& pre-processing) becomes the bottleneck during training as the data-loading part happens on CPUs. This can increase the training time significantly especially when there is a lot of online pre-processing involved or data is streamed online from GCS buckets. To handle those issues, `tf.data.Dataset` offers the `.prefetch(...)` method. This method helps in preparing the next few batches in parallel (on CPUs) while the model is making predictions (on GPUs/TPUs) on the current batch.

In [24]:
dataset = dataset.prefetch(tf.data.AUTOTUNE)

Since this notebook is made for demonstration purposes, we will be taking first `num_train_batches` and will perform training over only that. You are encouraged to train on the whole dataset though. Similarly, we will evaluate only `num_val_batches`.

In [25]:
split = 0.9

num_train_batches = 5895
num_val_batches = 655

train_dataset = dataset.take(num_train_batches)
val_dataset = dataset.skip(num_train_batches).take(num_val_batches)

## Model training

For training our model, we will be directly calling `.fit(...)` method after compiling our model with `.compile(...)`.

In [26]:
model.compile(optimizer, loss=loss_fn)

The above cell will set up our training state. Now we can initiate training with the `.fit(...)` method.

In [27]:
history = model.fit(train_dataset, validation_data=val_dataset, epochs=1)
history.history

Instructions for updating:
Prefer tf.tensor_scatter_nd_add, which offers the same functionality with well-defined read-write semantics.


Instructions for updating:
Prefer tf.tensor_scatter_nd_add, which offers the same functionality with well-defined read-write semantics.


Instructions for updating:
Prefer tf.tensor_scatter_nd_update, which offers the same functionality with well-defined read-write semantics.


Instructions for updating:
Prefer tf.tensor_scatter_nd_update, which offers the same functionality with well-defined read-write semantics.










    199/Unknown - 336s 2s/step - loss: 281.4029

KeyboardInterrupt: 

Let's save our model with `.save(...)` method to be able to perform inference later. You can also export this SavedModel to TFHub by following [TFHub documentation](https://www.tensorflow.org/hub/publish).

In [None]:
save_dir = "../Models/finetuned-wav2vec2"
model.save(save_dir, include_optimizer=False)

Note: We are setting `include_optimizer=False` as we want to use this model for inference only.

## Evaluation

Now we will be computing Word Error Rate over the validation dataset

**Word error rate** (WER) is a common metric for measuring the performance of an automatic speech recognition system. The WER is derived from the Levenshtein distance, working at the word level. Word error rate can then be computed as: WER = (S + D + I) / N = (S + D + I) / (S + D + C) where S is the number of substitutions, D is the number of deletions, I is the number of insertions, C is the number of correct words, N is the number of words in the reference (N=S+D+C). This value indicates the percentage of words that were incorrectly predicted. 

You can refer to [this paper](https://www.isca-speech.org/archive_v0/interspeech_2004/i04_2765.html) to learn more about WER.

We will use `load_metric(...)` function from [HuggingFace datasets](https://huggingface.co/docs/datasets/) library. Let's first install the `datasets` library using `pip` and then define the `metric` object.

In [None]:
!pip3 install -q datasets

from datasets import load_metric
metric = load_metric("wer")

In [None]:
@tf.function(jit_compile=True)
def eval_fwd(batch):
  logits = model(batch, training=False)
  return tf.argmax(logits, axis=-1)

It's time to run the evaluation on validation data now.

In [None]:
from tqdm.auto import tqdm

for speech, labels in tqdm(val_dataset, total=num_val_batches):
    predictions  = eval_fwd(speech)
    predictions = [tokenizer.decode(pred) for pred in predictions.numpy().tolist()]
    references = [tokenizer.decode(label, group_tokens=False) for label in labels.numpy().tolist()]
    metric.add_batch(references=references, predictions=predictions)

We are using the `tokenizer.decode(...)` method for decoding our predictions and labels back into the text and will add them to the metric for `WER` computation later.

Now, let's calculate the metric value in following cell:

In [None]:
metric.compute()

**Note:** Here metric value doesn't make any sense as the model is trained on very small data and ASR-like tasks often require a large amount of data to learn a mapping from speech to text. You should probably train on large data to get some good results. This notebook gives you a template to fine-tune a pre-trained speech model.

## Inference

Now that we are satisfied with the training process & have saved the model in `save_dir`, we will see how this model can be used for inference.

First, we will load our model using `tf.keras.models.load_model(...)`.

In [None]:
finetuned_model = tf.keras.models.load_model(save_dir)

Let's download some speech samples for performing inference. You can replace the following sample with your speech sample also.

In [None]:
!wget https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/SA2.wav

Now, we will read the speech sample using `soundfile.read(...)` and pad it to `AUDIO_MAXLEN` to satisfy the model signature. Then we will normalize that speech sample using the `Wav2Vec2Processor` instance & will feed it into the model.

In [1]:
import numpy as np

speech, _ = sf.read("SA2.wav")
speech = np.pad(speech, (0, AUDIO_MAXLEN - len(speech)))
speech = tf.expand_dims(processor(tf.constant(speech)), 0)

outputs = finetuned_model(speech)
outputs

NameError: name 'sf' is not defined

Let's decode numbers back into text sequence using the `Wav2Vec2tokenizer` instance, we defined above.

In [None]:
predictions = tf.argmax(outputs, axis=-1)
predictions = [tokenizer.decode(pred) for pred in predictions.numpy().tolist()]
predictions

This prediction is quite random as the model was never trained on large data in this notebook (as this notebook is not meant for doing complete training). You will get good predictions if you train this model on complete LibriSpeech dataset.

Finally, we have reached an end to this notebook. But it's not an end of learning TensorFlow for speech-related tasks, this [repository](https://github.com/tulasiram58827/TTS_TFLite) contains some more amazing tutorials. In case you encountered any bug in this notebook, please create an issue [here](https://github.com/vasudevgupta7/gsoc-wav2vec2/issues).