# **Fine-tuning Wav2Vec2 for English ASR with 🤗 Transformers**

fine-tune the "base"-sized [pretrained checkpoint](https://huggingface.co/facebook/wav2vec2-base) on LJSpeech Dataset

Wav2Vec2 is fine-tuned using Connectionist Temporal Classification (CTC), which is an algorithm that is used to train neural networks for sequence-to-sequence problems and mainly in Automatic Speech Recognition and handwriting recognition. 


SETUP

In [None]:
# !pip install datasets
# !pip install transformers
# !pip install librosa
# !pip install jiwer

In [1]:
import torch
import librosa
torch.cuda.is_available()

True

## Prepare Data, Tokenizer, Feature Extractor

In 🤗 Transformers, the Wav2Vec2 model is accompanied by both a tokenizer, called [Wav2Vec2CTCTokenizer](https://huggingface.co/transformers/master/model_doc/wav2vec2.html#wav2vec2ctctokenizer), and a feature extractor, called [Wav2Vec2FeatureExtractor](https://huggingface.co/transformers/master/model_doc/wav2vec2.html#wav2vec2featureextractor).

Let's start by creating the tokenizer responsible for decoding the model's predictions.

### Create Wav2Vec2CTCTokenizer

Let's start by loading the dataset and taking a look at its structure.

In [2]:
from datasets import load_dataset, load_metric, Audio

train_ds = load_dataset("lj_speech",split='train[:90%]')
val_ds = load_dataset("lj_speech",split='train[-10%:]')

Reusing dataset lj_speech (C:\Users\Harman\.cache\huggingface\datasets\lj_speech\main\1.1.0\f4518d8fe24e62a9045ac697d23037d073cc76202777ee14267664978c222c2f)
Reusing dataset lj_speech (C:\Users\Harman\.cache\huggingface\datasets\lj_speech\main\1.1.0\f4518d8fe24e62a9045ac697d23037d073cc76202777ee14267664978c222c2f)


In [3]:
train_ds

Dataset({
    features: ['id', 'audio', 'file', 'text', 'normalized_text'],
    num_rows: 11790
})

In [4]:
val_ds

Dataset({
    features: ['id', 'audio', 'file', 'text', 'normalized_text'],
    num_rows: 1310
})

In [5]:
train_ds = train_ds.remove_columns(["id","file","text"])
val_ds = val_ds.remove_columns(["id","file","text"])

train_ds

Dataset({
    features: ['audio', 'normalized_text'],
    num_rows: 11790
})

Let's write a short function to display some random samples of the dataset and run it a couple of times to get a feeling for the transcriptions.

In [6]:
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))

In [7]:
show_random_elements(train_ds.remove_columns(["audio"]), num_examples=10)

Unnamed: 0,normalized_text
0,"The courtyards and streets are paved with them, and as you walk about the city the name of Nebuchadnezzar everywhere meets your eye."
1,to carry out a carefully planned killing of another human being and was willing to consummate such a purpose if he thought there was sufficient reason to do so.
2,"and said, quote, Everybody will know who I am now, end quote."
3,"Concerning the shots which struck the President in the back of the neck,"
4,"Mr. Cobbett was also a lodger of Mr. Newman's; and so were any members of the aristocracy,"
5,"to May fourteen, nineteen sixty-three."
6,The relation of skeleton and muscle in arthropods is exactly the reverse.
7,"All vertebrates, and none other, have two cavities,"
8,"The Aliases ""Hidell"" and ""O. H. Lee"""
9,"This was the well and astutely devised plot of the brothers Bidwell,"


Normalize the text to only have lower case letters and append a word separator token at the end.

In [8]:
import re
chars_to_ignore_regex = '[^A-Za-z0-9\s\']'

def remove_special_characters(batch):
    batch["normalized_text"] = re.sub(chars_to_ignore_regex, '', batch["normalized_text"]).lower() + " "
    return batch

In [9]:
train_ds = train_ds.map(remove_special_characters)
val_ds = val_ds.map(remove_special_characters)

Loading cached processed dataset at C:\Users\Harman\.cache\huggingface\datasets\lj_speech\main\1.1.0\f4518d8fe24e62a9045ac697d23037d073cc76202777ee14267664978c222c2f\cache-2485e78dd968bfcc.arrow
Loading cached processed dataset at C:\Users\Harman\.cache\huggingface\datasets\lj_speech\main\1.1.0\f4518d8fe24e62a9045ac697d23037d073cc76202777ee14267664978c222c2f\cache-5ebba5d7b5386287.arrow


In [10]:
show_random_elements(train_ds.remove_columns(["audio"]))

Unnamed: 0,normalized_text
0,and it was on the entire absence of the latter that the defense was principally based when palmer was brought to trial
1,the governor himself admitted that a prisoner of weak intellect who had been severely beaten and much injured by a wardsman did not dare complain
2,committed by the house of commons who had been lodged in the governor's own house
3,the sheepstealer smiles and extending his arms upwards looks with a glad expression to the roof of the chapel
4,the economic embargo against that country and the general policy of the united states with regard to cuba
5,both from the soil through the roots liquids and from the atmosphere through the leaves gases
6,he saw certain rooms fill up and yet took no steps to open others that were locked up and empty
7,the car lurched forward causing him to lose his footing he ran three or four steps regained his position and mounted the car
8,the proceeds of these forgeries amounted it was said to some thousands per annum
9,the 'dryad' was a brig owned principally by two persons named wallace one a seaman the other merchant


In [11]:
def extract_all_chars(batch):
    all_text = " ".join(batch["normalized_text"])
    vocab = list(set(all_text))
    return {"vocab": [vocab], "all_text": [all_text]}

In [12]:
vocab_train = train_ds.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=train_ds.column_names)
vocab_test = val_ds.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=val_ds.column_names)

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

Now, we create the union of all distinct letters in the training dataset and test dataset and convert the resulting list into an enumerated dictionary.

In [13]:
vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))

In [14]:
vocab_dict = {v: k for k, v in enumerate(vocab_list)}
vocab_dict

{'a': 0,
 'c': 1,
 'm': 2,
 'w': 3,
 'x': 4,
 'd': 5,
 'g': 6,
 'q': 7,
 'b': 8,
 'j': 9,
 'k': 10,
 'v': 11,
 's': 12,
 "'": 13,
 'o': 14,
 'y': 15,
 'f': 16,
 't': 17,
 'h': 18,
 'z': 19,
 'e': 20,
 'l': 21,
 'i': 22,
 'p': 23,
 'u': 24,
 'r': 25,
 'n': 26,
 ' ': 27}

In [15]:
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

In [16]:
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
len(vocab_dict)

30

Our vocabulary is complete and consists of 30 tokens, which means that the linear layer that we will add on top of the pretrained Wav2Vec2 checkpoint will have an output dimension of 30.

Let's now save the vocabulary as a json file.

In [18]:
import json
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

Use the json file to instantiate an object of the `Wav2Vec2CTCTokenizer` class.

In [19]:
from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

### Create Wav2Vec2 Feature Extractor

In [20]:
from transformers import Wav2Vec2FeatureExtractor

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)

In [21]:
from transformers import Wav2Vec2Processor

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

### Preprocess Data

So far, we have not looked at the actual values of the speech signal but just the transcription. In addition to `'normalized_text'`, our datasets include one more column `'audio'`. 

In [22]:
train_ds[0]["audio"]

{'path': 'C:\\Users\\Harman\\.cache\\huggingface\\datasets\\downloads\\extracted\\9260b0031f61f6673594580f2aaf1cb6d744ff1b7ea8ef7afc4131ecd7b02cd5\\LJSpeech-1.1\\wavs\\LJ001-0001.wav',
 'array': array([-7.3242188e-04, -7.6293945e-04, -6.4086914e-04, ...,
         7.3242188e-04,  2.1362305e-04,  6.1035156e-05], dtype=float32),
 'sampling_rate': 22050}

In [23]:
train_ds = train_ds.cast_column("audio", Audio(sampling_rate=16_000))
val_ds = val_ds.cast_column("audio", Audio(sampling_rate=16_000))

In [24]:
train_ds[0]['audio']

{'path': 'C:\\Users\\Harman\\.cache\\huggingface\\datasets\\downloads\\extracted\\9260b0031f61f6673594580f2aaf1cb6d744ff1b7ea8ef7afc4131ecd7b02cd5\\LJSpeech-1.1\\wavs\\LJ001-0001.wav',
 'array': array([-0.00064146, -0.00074657, -0.00068768, ...,  0.00068341,
         0.00014045,  0.        ], dtype=float32),
 'sampling_rate': 16000}

The sampling rate is set to 16kHz which is what `Wav2Vec2` expects as an input.

Verify that the audio was correctly loaded. 

In [25]:
import IPython.display as ipd
import numpy as np
import random

rand_int = random.randint(0, len(train_ds))

print(train_ds[rand_int]["normalized_text"])
ipd.Audio(data=np.asarray(train_ds[rand_int]["audio"]["array"]), autoplay=True, rate=16000)

to which a total of upwards of one hundred thousand prisoners had been committed in the year only twentythree prisons were divided according to law 


In [26]:
rand_int = random.randint(0, len(train_ds))

print("Target text:", train_ds[rand_int]["normalized_text"])
print("Input array shape:", np.asarray(train_ds[rand_int]["audio"]["array"]).shape)
print("Sampling rate:", train_ds[rand_int]["audio"]["sampling_rate"])

Target text: no fixed rates or rules governed the hiring out of rooms or parts of a room and all sorts of imposition was practiced 
Input array shape: (132747,)
Sampling rate: 16000


First, we load and resample the audio data, simply by calling `batch["audio"]`.
Second, we extract the `input_values` from the loaded audio file. In our case, the `Wav2Vec2Processor` only normalizes the data. 

In [27]:
def prepare_dataset(batch):
    audio = batch["audio"]

    # batched output is "un-batched" to ensure mapping is correct
    batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    batch["input_length"] = len(batch["input_values"])
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["normalized_text"]).input_ids
    return batch

Let's apply the data preparation function to all examples.

In [28]:
train_ds = train_ds.map(prepare_dataset, remove_columns=train_ds.column_names)
val_ds = val_ds.map(prepare_dataset, remove_columns=val_ds.column_names)

  0%|          | 0/11790 [00:00<?, ?ex/s]

  0%|          | 0/1310 [00:00<?, ?ex/s]

`datasets` make use of [`torchaudio`](https://pytorch.org/audio/stable/index.html) and [`librosa`](https://librosa.org/doc/latest/index.html) for audio loading and resampling

Filter all sequences that are longer than 5 seconds out of the training dataset because of memory constraints

In [28]:
max_input_length_in_sec = 5
train_ds = train_ds.filter(lambda x: x < max_input_length_in_sec * processor.feature_extractor.sampling_rate, input_columns=["input_length"])

  0%|          | 0/12 [00:00<?, ?ba/s]

In [29]:
len(train_ds)

3000

In [32]:
train_ds

Dataset({
    features: ['input_values', 'input_length', 'labels'],
    num_rows: 3000
})

## Training & Evaluation

Training pipeline:

- Define a data collator. In contrast to most NLP models, Wav2Vec2 has a much larger input length than output length. *E.g.*, a sample of input length 50000 has an output length of no more than 100. Given the large input sizes, it is much more efficient to pad the training batches dynamically meaning that all training samples should only be padded to the longest sample in their batch and not the overall longest sample. Therefore, fine-tuning Wav2Vec2 requires a special padding data collator, which we will define below

- Evaluation metric. During training, the model should be evaluated on the word error rate. We should define a `compute_metrics` function accordingly

- Load a pretrained checkpoint. We need to load a pretrained checkpoint and configure it correctly for training.

- Define the training configuration.

After having fine-tuned the model, we will correctly evaluate it on the test data and verify that it has indeed learned to correctly transcribe speech.

### Set-up Trainer

Let's start by defining the data collator

In [29]:
import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
        max_length (:obj:`int`, `optional`):
            Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
        max_length_labels (:obj:`int`, `optional`):
            Maximum length of the ``labels`` returned list and optionally padding length (see above).
        pad_to_multiple_of (:obj:`int`, `optional`):
            If set will pad the sequence to a multiple of the provided value.
            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
            7.5 (Volta).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                max_length=self.max_length_labels,
                pad_to_multiple_of=self.pad_to_multiple_of_labels,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

In [30]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

Next, the evaluation metric is defined. As mentioned earlier, the 
predominant metric in ASR is the word error rate (WER), hence we will use it in this notebook as well.

In [31]:
wer_metric = load_metric("wer")

In [32]:
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

Load the pretrained `Wav2Vec2` checkpoint

In [33]:
from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-base", 
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
)

Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2ForCTC: ['project_q.weight', 'quantizer.weight_proj.weight', 'quantizer.codevectors', 'project_hid.bias', 'project_hid.weight', 'quantizer.weight_proj.bias', 'project_q.bias']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['lm_head.weight', 'lm_head.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predicti

In [34]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  group_by_length=True,
  output_dir = './temp/',
  per_device_train_batch_size=2,
  evaluation_strategy="steps",
  num_train_epochs=30,
  fp16=True,
  gradient_checkpointing=True,
  save_steps=500,
  eval_steps=500,
  logging_steps=500,
  learning_rate=1e-4,
  weight_decay=0.005,
  warmup_steps=1000,
  save_total_limit=1,
  push_to_hub=False,
)

Now, all instances can be passed to Trainer and we are ready to start training!

In [36]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=processor.feature_extractor,
)

Using amp fp16 backend


### Training

In [37]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.
***** Running training *****
  Num examples = 11790
  Num Epochs = 1
  Instantaneous batch size per device = 2
  Total train batch size (w. parallel, distributed & accumulation) = 2
  Gradient Accumulation steps = 1
  Total optimization steps = 5895
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
[34m[1mwandb[0m: Currently logged in as: [33mharman92[0m (use `wandb login --relogin` to force relogin)


  tensor = as_tensor(value)
  return (input_length - kernel_size) // stride + 1


Step,Training Loss,Validation Loss,Wer
500,6.6908,,0.986916
1000,0.0,,0.986916


The following columns in the evaluation set  don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.
***** Running Evaluation *****
  Num examples = 1310
  Batch size = 8
Saving model checkpoint to ./temp/checkpoint-500
Configuration saved in ./temp/checkpoint-500\config.json
Model weights saved in ./temp/checkpoint-500\pytorch_model.bin
Configuration saved in ./temp/checkpoint-500\preprocessor_config.json
Deleting older checkpoint [temp\checkpoint-120000] due to args.save_total_limit
  return (input_length - kernel_size) // stride + 1
The following columns in the evaluation set  don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.
***** Running Evaluation *****
  Num examples = 1310
  Batch size = 8
Saving model checkpoint to ./temp/checkpoint-1000
Configuration saved in ./temp/checkpoint-1000\config.json
Model weights saved in ./temp/checkpoint-1000\pytorch_model.bin
Configuration saved in ./

KeyboardInterrupt: 

In [None]:
# Save the model

model.save_pretrained('../Models/model_2')

### Evaluate

In the final part, we run our model on some of the validation data to get a feeling for how well it works.

Let's load the `model`.

In [43]:
model = Wav2Vec2ForCTC.from_pretrained("../Models/model_2")

loading configuration file ../Models/model_2\config.json
Model config Wav2Vec2Config {
  "_name_or_path": "facebook/wav2vec2-base",
  "activation_dropout": 0.0,
  "apply_spec_augment": true,
  "architectures": [
    "Wav2Vec2ForCTC"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 1,
  "classifier_proj_size": 256,
  "codevector_dim": 256,
  "contrastive_logits_temperature": 0.1,
  "conv_bias": false,
  "conv_dim": [
    512,
    512,
    512,
    512,
    512,
    512,
    512
  ],
  "conv_kernel": [
    10,
    3,
    3,
    3,
    3,
    2,
    2
  ],
  "conv_stride": [
    5,
    2,
    2,
    2,
    2,
    2,
    2
  ],
  "ctc_loss_reduction": "mean",
  "ctc_zero_infinity": false,
  "diversity_loss_weight": 0.1,
  "do_stable_layer_norm": false,
  "eos_token_id": 2,
  "feat_extract_activation": "gelu",
  "feat_extract_norm": "group",
  "feat_proj_dropout": 0.1,
  "feat_quantizer_dropout": 0.0,
  "final_dropout": 0.0,
  "freeze_feat_extract_train": true,
  "hidden_act": "gelu",
  "

In [44]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [45]:
# Move the model to gpu if available

model.to(device)

Wav2Vec2ForCTC(
  (wav2vec2): Wav2Vec2Model(
    (feature_extractor): Wav2Vec2FeatureExtractor(
      (conv_layers): ModuleList(
        (0): Wav2Vec2GroupNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
          (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
        )
        (1): Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        )
        (2): Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        )
        (3): Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        )
        (4): Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        )
        (5): Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
        )
        (6): Wav2Vec

In [46]:
def map_to_result(batch):
    """
        Compute evaluation on the test set
    """
    with torch.no_grad():
        input_values = torch.tensor(batch["input_values"], device="cuda").unsqueeze(0)
        logits = model(input_values).logits
            
    pred_ids = torch.argmax(logits, dim=-1)
    batch["pred_str"] = processor.batch_decode(pred_ids)[0]
    batch["normalized_text"] = processor.decode(batch["labels"], group_tokens=False)

    return batch

In [47]:
results = val_ds.map(map_to_result, remove_columns=val_ds.column_names)

print(results['normalized_text'][0])
print(results['pred_str'][0])

  0%|          | 0/1310 [00:00<?, ?ex/s]

many factors were undoubtedly involved in oswald's motivation for the assassination and the commission does not believe



Let's compute the overall WER now.

In [48]:
print("Test WER: {:.3f}".format(wer_metric.compute(predictions=results["pred_str"], references=results["normalized_text"])))

Test WER: 1.000


Let's download some speech samples for performing inference. You can replace the following sample with your speech sample also.

In [38]:
# Download a sample wav

!wget https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/SA2.wav

--2021-11-15 15:39:29--  https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/SA2.wav
Resolving github.com (github.com)... 13.236.229.21
Connecting to github.com (github.com)|13.236.229.21|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/vasudevgupta7/gsoc-wav2vec2/main/data/SA2.wav [following]
--2021-11-15 15:39:30--  https://raw.githubusercontent.com/vasudevgupta7/gsoc-wav2vec2/main/data/SA2.wav
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 94252 (92K) [audio/wav]
Saving to: 'SA2.wav.1'

     0K .......... .......... .......... .......... .......... 54% 1.27M 0s
    50K .......... .......... .......... .......... ..        100% 13.3M=0.04s

2021-11-15 15:39:30 (2.17 MB/s) - 'SA2.wav.

In [49]:
def read_audio(file_path):
    """
        Reads audio from file path and returns np array with sample rate of 16KHz
    """
    y,_=librosa.load(file_path,sr=16000)
    return y

def predict_text(file_path):
    """
        Given an audio file path, returns ASR model's predicted text
    """
    with torch.no_grad():
        logits = model(torch.tensor(np.expand_dims(read_audio(file_path), axis=0), device="cuda")).logits
    pred_ids = torch.argmax(logits, dim=-1)
    print(pred_ids)
    text = processor.batch_decode(pred_ids)[0]
    return text 

In [50]:
# Play Audio 

ipd.Audio(data=read_audio("SA2.wav"), autoplay=True, rate=16000)

In [51]:
# Print Predicted Speech

predict_text("SA2.wav")

tensor([[29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
         29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
         29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
         29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
         29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
         29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
         29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
         29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
         29, 29]], device='cuda:0')


''