In this notebook, we will load the pre-trained wav2vec2 model from TFHub and will fine-tune it on LibriSpeech dataset by appending Language Modeling head (LM) over the top of our pre-trained model.

https://www.openslr.org/12

In [1]:
!pip3 install -q git+https://github.com/vasudevgupta7/gsoc-wav2vec2@main
!sudo apt-get install -y libsndfile1-dev
!pip3 install -q SoundFile

[K     |████████████████████████████████| 1.8 MB 32.8 MB/s 
[K     |████████████████████████████████| 101 kB 11.3 MB/s 
[K     |████████████████████████████████| 596 kB 72.7 MB/s 
[K     |████████████████████████████████| 50 kB 6.1 MB/s 
[K     |████████████████████████████████| 145 kB 73.5 MB/s 
[K     |████████████████████████████████| 181 kB 73.9 MB/s 
[K     |████████████████████████████████| 63 kB 1.7 MB/s 
[?25h  Building wheel for wav2vec2 (setup.py) ... [?25l[?25hdone
  Building wheel for python-Levenshtein (setup.py) ... [?25l[?25hdone
  Building wheel for pathtools (setup.py) ... [?25l[?25hdone
Reading package lists... Done
Building dependency tree       
Reading state information... Done
libsndfile1-dev is already the newest version (1.0.28-4ubuntu0.18.04.2).
The following package was automatically installed and is no longer required:
  libnvidia-common-460
Use 'sudo apt autoremove' to remove it.
0 upgraded, 0 newly installed, 0 to remove and 49 not upgraded.


In [2]:
import os

import tensorflow as tf
import tensorflow_hub as hub
from wav2vec2 import Wav2Vec2Config

config = Wav2Vec2Config()

print("TF version:", tf.__version__)

TF version: 2.8.2


In [4]:
pretrained_layer = hub.KerasLayer("https://tfhub.dev/vasudevgupta7/wav2vec2/1", trainable=True)

In [5]:
AUDIO_MAXLEN = 246000
LABEL_MAXLEN = 256
BATCH_SIZE = 2

In [6]:
# In the following cell, we will wrap pretrained_layer & a dense layer (LM head) with the Keras's Functional API.

inputs = tf.keras.Input(shape=(AUDIO_MAXLEN,))
hidden_states = pretrained_layer(inputs)
outputs = tf.keras.layers.Dense(config.vocab_size)(hidden_states)

model = tf.keras.Model(inputs=inputs, outputs=outputs)


In [7]:
model(tf.random.uniform(shape=(BATCH_SIZE, AUDIO_MAXLEN)))
model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 246000)]          0         
                                                                 
 keras_layer (KerasLayer)    (None, 768, 768)          94371712  
                                                                 
 dense (Dense)               (None, 768, 32)           24608     
                                                                 
Total params: 94,396,320
Trainable params: 94,396,320
Non-trainable params: 0
_________________________________________________________________


Now, we need to define the loss_fn and optimizer to be able to train the model. The following cell will do that for us. We will be using the Adam optimizer for simplicity. CTCLoss is a common loss type that is used for tasks (like ASR) where input sub-parts can't be easily aligned with output sub-parts.

CTCLoss (from gsoc-wav2vec2 package) accepts 3 arguments: config, model_input_shape & division_factor. If division_factor=1, then loss will simply get summed, so pass division_factor accordingly to get mean over batch.

In [8]:
from wav2vec2 import CTCLoss

LEARNING_RATE = 5e-5

loss_fn = CTCLoss(config, (BATCH_SIZE, AUDIO_MAXLEN), division_factor=BATCH_SIZE)
optimizer = tf.keras.optimizers.Adam(LEARNING_RATE)

Loading & Pre-processing data
Let's now download the LibriSpeech dataset from the official website and set it up.

In [10]:
!wget https://www.openslr.org/resources/12/dev-clean.tar.gz -P ./data/train/
!tar -xf ./data/train/dev-clean.tar.gz -C ./data/train/

--2022-07-07 07:37:59--  https://www.openslr.org/resources/12/dev-clean.tar.gz
Resolving www.openslr.org (www.openslr.org)... 46.101.158.64
Connecting to www.openslr.org (www.openslr.org)|46.101.158.64|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 337926286 (322M) [application/x-gzip]
Saving to: ‘./data/train/dev-clean.tar.gz’


2022-07-07 07:38:29 (10.9 MB/s) - ‘./data/train/dev-clean.tar.gz’ saved [337926286/337926286]



In [11]:
ls ./data/train/

dev-clean.tar.gz  [0m[01;34mLibriSpeech[0m/


In [12]:
data_dir = "./data/train/LibriSpeech/dev-clean/2428/83705/"
all_files = os.listdir(data_dir)

flac_files = [f for f in all_files if f.endswith(".flac")]
txt_files = [f for f in all_files if f.endswith(".txt")]

print("Transcription files:", txt_files, "\nSound files:", flac_files)

Transcription files: ['2428-83705.trans.txt'] 
Sound files: ['2428-83705-0013.flac', '2428-83705-0004.flac', '2428-83705-0016.flac', '2428-83705-0012.flac', '2428-83705-0003.flac', '2428-83705-0021.flac', '2428-83705-0018.flac', '2428-83705-0028.flac', '2428-83705-0001.flac', '2428-83705-0024.flac', '2428-83705-0030.flac', '2428-83705-0019.flac', '2428-83705-0005.flac', '2428-83705-0032.flac', '2428-83705-0043.flac', '2428-83705-0009.flac', '2428-83705-0029.flac', '2428-83705-0041.flac', '2428-83705-0042.flac', '2428-83705-0008.flac', '2428-83705-0026.flac', '2428-83705-0031.flac', '2428-83705-0007.flac', '2428-83705-0035.flac', '2428-83705-0034.flac', '2428-83705-0036.flac', '2428-83705-0014.flac', '2428-83705-0022.flac', '2428-83705-0023.flac', '2428-83705-0025.flac', '2428-83705-0038.flac', '2428-83705-0002.flac', '2428-83705-0006.flac', '2428-83705-0040.flac', '2428-83705-0027.flac', '2428-83705-0015.flac', '2428-83705-0033.flac', '2428-83705-0039.flac', '2428-83705-0020.flac', '24

In [13]:
def read_txt_file(f):
  with open(f, "r") as f:
    samples = f.read().split("\n")
    samples = {s.split()[0]: " ".join(s.split()[1:]) for s in samples if len(s.split()) > 2}
  return samples

we will define a function for loading a speech sample from a .flac file.

REQUIRED_SAMPLE_RATE is set to 16000 as wav2vec2 was pre-trained with 16K frequency and it's recommended to fine-tune it without any major change in data distribution due to frequency.

In [14]:
import soundfile as sf

REQUIRED_SAMPLE_RATE = 16000

def read_flac_file(file_path):
  with open(file_path, "rb") as f:
      audio, sample_rate = sf.read(f)
  if sample_rate != REQUIRED_SAMPLE_RATE:
      raise ValueError(
          f"sample rate (={sample_rate}) of your files must be {REQUIRED_SAMPLE_RATE}"
      )
  file_id = os.path.split(file_path)[-1][:-len(".flac")]
  return {file_id: audio}

In [16]:
from IPython.display import Audio
import random

file_id = random.choice([f[:-len(".flac")] for f in flac_files])
flac_file_path, txt_file_path = os.path.join(data_dir, f"{file_id}.flac"), os.path.join(data_dir, "2428-83705.trans.txt")

print("Text Transcription:", read_txt_file(txt_file_path)[file_id], "\nAudio:")
Audio(filename=flac_file_path)

Text Transcription: THERE WERE NO SIGNS OF FALTERING ABOUT HER FLOW OF LANGUAGE 
Audio:


In [17]:
def fetch_sound_text_mapping(data_dir):
  all_files = os.listdir(data_dir)

  flac_files = [os.path.join(data_dir, f) for f in all_files if f.endswith(".flac")]
  txt_files = [os.path.join(data_dir, f) for f in all_files if f.endswith(".txt")]

  txt_samples = {}
  for f in txt_files:
    txt_samples.update(read_txt_file(f))

  speech_samples = {}
  for f in flac_files:
    speech_samples.update(read_flac_file(f))

  assert len(txt_samples) == len(speech_samples)

  samples = [(speech_samples[file_id], txt_samples[file_id]) for file_id in speech_samples.keys() if len(speech_samples[file_id]) < AUDIO_MAXLEN]
  return samples

trying out on few samples

In [18]:
samples = fetch_sound_text_mapping(data_dir)
samples[:5]

[(array([-0.0007019 , -0.00057983, -0.00033569, ..., -0.00021362,
         -0.00015259, -0.00012207]),
  'IT IS FROM HER ACTION IN THAT MATTER THAT MY SUSPICION SPRINGS'),
 (array([-0.00112915, -0.00131226, -0.00158691, ...,  0.00067139,
          0.00091553,  0.00100708]),
  "IT MIGHT JUST AS WELL BE SOME ONE ELSE'S WEDDING SO UNIMPORTANT IS THE PART WHICH I AM SET TO PLAY IN IT"),
 (array([-0.00036621, -0.00015259, -0.00012207, ..., -0.0005188 ,
         -0.00048828, -0.00048828]),
  'THERE WERE NO SIGNS OF FALTERING ABOUT HER FLOW OF LANGUAGE'),
 (array([-0.00073242, -0.00054932, -0.00045776, ...,  0.        ,
          0.00024414,  0.00042725]),
  "I WAS PERSUADED THAT SOMEBODY BESIDES THAT COUSIN GOT A PROFIT OUT OF MARY ANN'S ENGAGEMENT RING BUT I HANDED OVER THE AMOUNT"),
 (array([-1.22070312e-04,  3.05175781e-05,  6.10351562e-05, ...,
         -4.27246094e-04, -6.10351562e-04, -9.15527344e-04]),
  'IT IS MOST DELIGHTFUL')]

Let's pre-process the data now !!!

We will first define the tokenizer & processor using gsoc-wav2vec2 package. Then, we will do very simple pre-processing. processor will normalize raw speech w.r.to frames axis and tokenizer will convert our model outputs into the string (using the defined vocabulary) & will take care of the removal of special tokens (depending on your tokenizer configuration)

In [19]:
from wav2vec2 import Wav2Vec2Processor
tokenizer = Wav2Vec2Processor(is_tokenizer=True)
processor = Wav2Vec2Processor(is_tokenizer=False)

def preprocess_text(text):
  label = tokenizer(text)
  return tf.constant(label, dtype=tf.int32)

def preprocess_speech(audio):
  audio = tf.constant(audio, dtype=tf.float32)
  return processor(tf.transpose(audio))

Downloading `vocab.json` from https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/vocab.json ... DONE


In [20]:
def inputs_generator():
  for speech, text in samples:
    yield preprocess_speech(speech), preprocess_text(text)

Setting up tf.data.Dataset
Following cell will setup tf.data.Dataset object using its .from_generator(...) method. We will be using the generator object, we defined in the above cell.

Note: For distributed training (especially on TPUs), .from_generator(...) doesn't work currently and it is recommended to train on data stored in .tfrecord format (Note: The TFRecords should ideally be stored inside a GCS Bucket in order for the TPUs to work to the fullest extent).

You can refer to the script for more details on how to convert LibriSpeech data into tfrecords.

In [21]:
output_signature = (
    tf.TensorSpec(shape=(None),  dtype=tf.float32),
    tf.TensorSpec(shape=(None), dtype=tf.int32),
)

dataset = tf.data.Dataset.from_generator(inputs_generator, output_signature=output_signature)

In [22]:
BUFFER_SIZE = len(flac_files)
SEED = 42

dataset = dataset.shuffle(BUFFER_SIZE, seed=SEED)

We will pass the dataset into multiple batches, so let's prepare batches in the following cell. Now, all the sequences in a batch should be padded to a constant length. We will use the.padded_batch(...) method for that purpose.

In [23]:
dataset = dataset.padded_batch(BATCH_SIZE, padded_shapes=(AUDIO_MAXLEN, LABEL_MAXLEN), padding_values=(0.0, 0))

In [24]:
dataset = dataset.prefetch(tf.data.AUTOTUNE)

In [25]:
num_train_batches = 10
num_val_batches = 4

train_dataset = dataset.take(num_train_batches)
val_dataset = dataset.skip(num_train_batches).take(num_val_batches)

Model Training

In [26]:
model.compile(optimizer, loss=loss_fn)

In [27]:
history = model.fit(train_dataset, validation_data=val_dataset, epochs=3)
history.history

Epoch 1/3
Instructions for updating:
Prefer tf.tensor_scatter_nd_add, which offers the same functionality with well-defined read-write semantics.


Instructions for updating:
Prefer tf.tensor_scatter_nd_add, which offers the same functionality with well-defined read-write semantics.


Instructions for updating:
Prefer tf.tensor_scatter_nd_update, which offers the same functionality with well-defined read-write semantics.


Instructions for updating:
Prefer tf.tensor_scatter_nd_update, which offers the same functionality with well-defined read-write semantics.










Epoch 2/3
Epoch 3/3


{'loss': [1358.8096923828125, 827.9127197265625, 761.6920166015625],
 'val_loss': [872.3134765625, 519.2056884765625, 817.7880859375]}

Let's save our model with .save(...) method to be able to perform inference later. 

In [28]:
save_dir = "finetuned-wav2vec2"
model.save(save_dir, include_optimizer=False)



INFO:tensorflow:Assets written to: finetuned-wav2vec2/assets


INFO:tensorflow:Assets written to: finetuned-wav2vec2/assets


Word error rate (WER) is a common metric for measuring the performance of an automatic speech recognition system. The WER is derived from the Levenshtein distance, working at the word level. Word error rate can then be computed as: WER = (S + D + I) / N = (S + D + I) / (S + D + C) where S is the number of substitutions, D is the number of deletions, I is the number of insertions, C is the number of correct words, N is the number of words in the reference (N=S+D+C). This value indicates the percentage of words that were incorrectly predicted.

In [30]:
!pip install -q datasets

from datasets import load_metric
metric = load_metric("wer")

In [31]:
@tf.function(jit_compile=True)
def eval_fwd(batch):
  logits = model(batch, training=False)
  return tf.argmax(logits, axis=-1)

evaluation on validation data

In [32]:
from tqdm.auto import tqdm

for speech, labels in tqdm(val_dataset, total=num_val_batches):
    predictions  = eval_fwd(speech)
    predictions = [tokenizer.decode(pred) for pred in predictions.numpy().tolist()]
    references = [tokenizer.decode(label, group_tokens=False) for label in labels.numpy().tolist()]
    metric.add_batch(references=references, predictions=predictions)

  0%|          | 0/4 [00:00<?, ?it/s]

In [33]:
metric.compute()

1.0

Inference

In [34]:
finetuned_model = tf.keras.models.load_model(save_dir)
!wget https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/SA2.wav
import numpy as np

speech, _ = sf.read("SA2.wav")
speech = np.pad(speech, (0, AUDIO_MAXLEN - len(speech)))
speech = tf.expand_dims(processor(tf.constant(speech)), 0)

outputs = finetuned_model(speech)
outputs





--2022-07-07 08:24:23--  https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/SA2.wav
Resolving github.com (github.com)... 20.205.243.166
Connecting to github.com (github.com)|20.205.243.166|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/vasudevgupta7/gsoc-wav2vec2/main/data/SA2.wav [following]
--2022-07-07 08:24:23--  https://raw.githubusercontent.com/vasudevgupta7/gsoc-wav2vec2/main/data/SA2.wav
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 94252 (92K) [audio/wav]
Saving to: ‘SA2.wav’


2022-07-07 08:24:25 (46.5 MB/s) - ‘SA2.wav’ saved [94252/94252]



<tf.Tensor: shape=(1, 768, 32), dtype=float32, numpy=
array([[[ 0.9268575 , -0.82134336, -0.08236028, ..., -1.0640054 ,
         -0.7868407 , -0.5740926 ],
        [ 0.9274203 , -0.8210297 , -0.08172926, ..., -1.0638304 ,
         -0.7880098 , -0.5723727 ],
        [ 0.9278157 , -0.82224953, -0.07799153, ..., -1.0646136 ,
         -0.7883819 , -0.57123953],
        ...,
        [ 0.9231877 , -0.8248101 , -0.08464065, ..., -1.0613258 ,
         -0.7787218 , -0.5690993 ],
        [ 0.9235312 , -0.8247275 , -0.08395161, ..., -1.0613401 ,
         -0.77876025, -0.56889355],
        [ 0.9236906 , -0.8242859 , -0.08403508, ..., -1.061506  ,
         -0.7788548 , -0.56927687]]], dtype=float32)>

In [35]:
predictions = tf.argmax(outputs, axis=-1)
predictions = [tokenizer.decode(pred) for pred in predictions.numpy().tolist()]
predictions

['R']

In [36]:
from google.colab import files
files.download("/content/finetuned-wav2vec2")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [38]:
from google.colab import files
files.download("/content/data")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>