# Automatic Speech Recognitionwith CTC Decoder

**Name**:

The first lab introduced the building blocks of an ASR system, including feature extraction and classification with an acoustic model (wav2vec2), which produced an *emission* matrix (probability for each character at each time step). From this emission matrix, we could compute the most likely character at each time step using a naïve *greedy* decoder. The drawback of such an approach is the lack of context, which can produce sequences of characters that do not correspond to actual words, and/or sequences of words that are incorrect / do not correspond to any language rules.

In this lab, we introduce the usage of a more advanced decoding technique that is based on [connectionist temporal classification](https://towardsdatascience.com/intuitively-understanding-connectionist-temporal-classification-3797e43a86c) (CTC). The general idea of such a decoder is to consider some context (sequences of characters, possible words, and possible sequences of words), in oder to yield more likely / realistic outputs than those given by the greedy decoder.

<center><a href="https://gab41.lab41.org/speech-recognition-you-down-with-ctc-8d3b558943f0">
    <img src="https://miro.medium.com/v2/resize:fit:640/format:webp/1*XbIp4pn38rxL_DwzSeYVxQ.png" width="400"></a></center>

To do so, the CTC decoder relies on three main components:
    
- A **beam search**, which is an algorithm to efficently find the *best path* from the emission matrix, that is, the sequence of characters with highest probability (rather than the sequence of individually most likely characters).
- A **lexicon**, which is a mapping between token sequences (list of characters) and words. It is used to restrict the search space of the decoder to words that only belong to this dictionary (e.g., the word "azfpojazflj" does not exist in the English vocabulary).
- A **language model**, which specifies sequences of words that are more likely to occur than others. A common choice of language model is an $n$-gram, which is a statistical model for the probability of occurrence of a given word based on the previous $n$ ones (for instance, the sequence "the sky is" is more likely to be followed with "blue" rather than "trumpet").

The CTC decoder combines these ingredients to compute the score of several word sequences (or *hypothesis*), in order to find the best possible transcript. In this lab, we study the influence of the lexicon, language model, and the beam search size onto ASR performance from a practical perspective, without going into the technical details of the [beam search algorithm](https://www.width.ai/post/what-is-beam-search) or the [CTC loss](https://distill.pub/2017/ctc/) (which can also be used for training the network). 

**Note**: This lab is based on this [tutorial](https://pytorch.org/audio/main/tutorials/asr_inference_with_ctc_decoder_tutorial.html), which you can check for more details on CTC decoder parameters in torchaudio.

In [None]:
import torch
import torchaudio
from torchaudio.models.decoder import ctc_decoder, download_pretrained_files
import IPython
import os
import fnmatch
import matplotlib.pyplot as plt
import time
torch.random.manual_seed(0)

MAX_FILES = 100 # lower this number for processing a subset of the dataset

In [None]:
# Main dataset path - Only change it HERE if needed and not later in the notebook
data_dir = "asr-dataset/"

## Preparation

As in the previous lab, we first load an example speech signal, and we display it. We also provide the function to get the true transcript and compute the WER. Finally, we load the wav2vec2 acoustic model.

In [None]:
# Dataset path (audio and transcripts)
data_speech_dir = os.path.join(data_dir, 'speech')
data_transc_dir = os.path.join(data_dir, 'transcription')

In [None]:
# Example file
audio_file = '61-70968-0001.wav'
audio_file_path = os.path.join(data_speech_dir, audio_file)
print(f"Audio file path: {audio_file_path}")

waveform, sr = torchaudio.load(audio_file_path, channels_first=True)
IPython.display.Audio(data=waveform, rate=sr)

In [None]:
# We provide the function for loading the true transcript and computing the WER
def get_true_transcript(transc_file_path):
    with open(transc_file_path, "r") as f:
        true_transcript = f.read()
    true_transcript = true_transcript.lower().split()
    return true_transcript

def get_wer(true_transcript, est_transcript):
    wer = torchaudio.functional.edit_distance(true_transcript, est_transcript) / len(true_transcript)
    return wer
    

In [None]:
# Load and display the true transcription
transc_file_path = os.path.join(data_transc_dir, audio_file.replace('wav', 'txt'))
true_transcript = get_true_transcript(transc_file_path)
print(true_transcript)

In [None]:
# Load the acoustic model
model_name = 'WAV2VEC2_ASR_BASE_100H'
bundle = getattr(torchaudio.pipelines, model_name)
acoustic_model = bundle.get_model()
labels = bundle.get_labels()

# Apply the model to the waveform to get the emission tensor
with torch.inference_mode():
    emission, _ = acoustic_model(waveform)

## CTC Decoder

The CTC decoder can be constructed directly by using the `ctc_decoder` function in torchaudio. In addition to the parameters related to the beam search (we will study them later on), it takes as inputs:
- the list of tokens, in order to map emissions to characters in the classifier.
- the path to the lexicon, expected as a .txt file containing, on each line, a word followed by its space-split tokens (and special end-of-sequence token `|`):

```
# lexicon.txt
a     a |
able  a b l e |
about a b o u t |
...
```
- the path to the language model, expected as a .bin file.

All these are assembled in pretrained files that can be downloaded using the `download_pretrained_files` function (this might take some time as the language model can be large), and then used to contruct the decoder.

In [None]:
# Download the files corresponding to the model: here we use a 4-gram language model (which comes with lexicon and tokens)
files = download_pretrained_files("librispeech-4-gram")
print(files.tokens)
print(files.lexicon)
print(files.lm)

In [None]:
# Vizualize the first 10 tokens (includes the blank and end-of-word token)
with open(files.tokens, 'r') as f:
    tok = f.read().splitlines()
print("\n".join(tok[:10]))

In [None]:
# Vizualize the lexicon content (first 10 entries)
with open(files.lexicon, 'r') as f:
    lex = f.read().splitlines()
print("\n".join(lex[:10]))

### Construct the decoder

In [None]:
# Instanciate the CTC decoder
decoder = ctc_decoder(
    lexicon=files.lexicon,
    tokens=files.tokens,
    lm=files.lm,
)

We can now apply the constructed decoder to the `emission` tensor.

In [None]:
# Apply the decoder, and get the first element (batch_size=1) and best hypothesis
ctc_decoder_result = decoder(emission)[0][0]

### Getting the transcript

The decoder output `ctc_decoder_result` contains many fields, including the predicted token IDs.

In [None]:
# Get the token IDs using the .tokens field
ctc_decoder_indices = ctc_decoder_result.tokens
print(f"Token indices: {ctc_decoder_indices}")

# You can manually convert token IDs to tokens using the decoder.idxs_to_tokens method
# (+ a bit of postprocessing)
ctc_decoder_tokens = decoder.idxs_to_tokens(ctc_decoder_indices)
ctc_decoder_transcript = "".join(ctc_decoder_tokens).replace("|", " ").strip().split() 

print(f"Transcript: {ctc_decoder_transcript}")

Alternatively, you can obtain the transcript directly via the `.words` field, if a lexicon is provided; otherwise the `.words` field is an empty list, so the transcript needs to be built manually as done above.

In [None]:
print(f"Words: {ctc_decoder_result.words}")

## Influence of the lexicon

The lexicon is expected to have a strong influence on ASR performance, since it constrains the decoder to produce only words that belong to this lexicon, therefore avoiding to procude words that potentially do not exist in a language or given corpus.

<span style="color:red"> **Exercise 1**</span>. Create your own custom lexicon file `mylexicon.txt` that contains only words that belong to the example audio sentence ("give", "not", "so", etc.). Construct a decoder using this lexicon (as well as the language model) and perform ASR. Display the result and compare it with the true transcript. Does it seem relevent to use such a specific lexicon?

## Influence of the language model

The language model is also expected to have a strong impact onto performance, since it guides the decoder towards more likely word sequences.

<span style="color:red"> **Exercise 2**</span>. Construct a decoder with no language model nor lexicon (pass `None` as input arguments). Perform ASR on the example audio and display the transcript. Compare it with the transcript obtained using the greedy decoder in the first lab.

<span style="color:red"> **Exercise 3**</span>. Re-construct a CTC decoder using `librispeech-4-gram` files (lexicon, token, and language model downloaded above in this script). Perform ASR on the whole dataset folder (feel free to reuse/adapt a function from the previous lab) and compute the mean WER. Compare it with the WER when performing ASR with the greedy decoder.

## Beam search parameters

The beam search algorithm used in the CTC decoder depends on other parameters, such as `nbest` which determines the number of hypotheses to return, or `lm_weight` which adjust the relative importance of the language model vs. the acoustic model predictions. Here we only focus on `beam_size`, which determines the maximum number of best hypotheses to hold after each decoding step. Using larger beam sizes allows for exploring a larger range of possible hypotheses which can produce hypotheses with higher scores, which really is [the core](https://distill.pub/2017/ctc/) of the beam search algorithm.

<span style="color:red"> **Exercise 4**</span>. Perform ASR on the whole dataset folder for several values of the beam search size parameter: `beam_size` $\in [1, 10, 100]$. Compute the WER and the computation time (e.g., via the [time](https://docs.python.org/3/library/time.html#time.time) package). Comment.