# Automatic Speech Recognitionwith CTC Decoder

**Name**:

The first lab introduced the building blocks of an ASR system, including feature extraction and classification with an acoustic model (wav2vec2), which produced an *emission* matrix (probability for each character at each time step). From this emission matrix, we could compute the most likely character at each time step using a naïve *greedy* decoder. The drawback of such an approach is the lack of context, which can produce sequences of characters that do not correspond to actual words, and/or sequences of words that are incorrect / do not correspond to any language rules.

In this lab, we introduce the usage of a more advanced decoding technique that is based on [connectionist temporal classification](https://towardsdatascience.com/intuitively-understanding-connectionist-temporal-classification-3797e43a86c) (CTC). The general idea of such a decoder is to consider some context (sequences of characters, possible words, and possible sequences of words), in oder to yield more likely / realistic outputs than those given by the greedy decoder.

<center><a href="https://gab41.lab41.org/speech-recognition-you-down-with-ctc-8d3b558943f0">
    <img src="https://miro.medium.com/v2/resize:fit:640/format:webp/1*XbIp4pn38rxL_DwzSeYVxQ.png" width="400"></a></center>

To do so, the CTC decoder relies on three main components:
    
- A **beam search**, which is an algorithm to efficently find the *best path* from the emission matrix, that is, the sequence of characters with highest probability (rather than the sequence of individually most likely characters).
- A **lexicon**, which is a mapping between token sequences (list of characters) and words. It is used to restrict the search space of the decoder to words that only belong to this dictionary (e.g., the word "azfpojazflj" does not exist in the English vocabulary).
- A **language model**, which specifies sequences of words that are more likely to occur than others. A common choice of language model is an $n$-gram, which is a statistical model for the probability of occurrence of a given word based on the previous $n$ ones (for instance, the sequence "the sky is" is more likely to be followed with "blue" rather than "trumpet").

The CTC decoder combines these ingredients to compute the score of several word sequences (or *hypothesis*), in order to find the best possible transcript. In this lab, we study the influence of the lexicon, language model, and the beam search size onto ASR performance from a practical perspective, without going into the technical details of the [beam search algorithm](https://www.width.ai/post/what-is-beam-search) or the [CTC loss](https://distill.pub/2017/ctc/) (which can also be used for training the network). 

**Note**: This lab is based on this [tutorial](https://pytorch.org/audio/main/tutorials/asr_inference_with_ctc_decoder_tutorial.html), which you can check for more details on CTC decoder parameters in torchaudio.

In [18]:
import torch
import torchaudio
from torchaudio.models.decoder import ctc_decoder, download_pretrained_files
import IPython
import os
import fnmatch
import matplotlib.pyplot as plt
import time
torch.random.manual_seed(0)

MAX_FILES = 100 # lower this number for processing a subset of the dataset

In [19]:
# Main dataset path - Only change it HERE if needed and not later in the notebook
data_dir = "../dataset/"

## Preparation

As in the previous lab, we first load an example speech signal, and we display it. We also provide the function to get the true transcript and compute the WER. Finally, we load the wav2vec2 acoustic model.

In [20]:
# Dataset path (audio and transcripts)
data_speech_dir = os.path.join(data_dir, 'speech')
data_transc_dir = os.path.join(data_dir, 'transcription')

In [21]:
# Example file
audio_file = '61-70968-0001.wav'
audio_file_path = os.path.join(data_speech_dir, audio_file)
print(f"Audio file path: {audio_file_path}")

waveform, sr = torchaudio.load(audio_file_path, channels_first=True)
IPython.display.Audio(data=waveform, rate=sr)

Audio file path: ../dataset/speech/61-70968-0001.wav


In [22]:
# We provide the function for loading the true transcript and computing the WER
def get_true_transcript(transc_file_path):
    with open(transc_file_path, "r") as f:
        true_transcript = f.read()
    true_transcript = true_transcript.lower().split()
    return true_transcript

def get_wer(true_transcript, est_transcript):
    wer = torchaudio.functional.edit_distance(true_transcript, est_transcript) / len(true_transcript)
    return wer
    

In [23]:
# Load and display the true transcription
transc_file_path = os.path.join(data_transc_dir, audio_file.replace('wav', 'txt'))
true_transcript = get_true_transcript(transc_file_path)
print(true_transcript)

['give', 'not', 'so', 'earnest', 'a', 'mind', 'to', 'these', 'mummeries', 'child']


In [24]:
# Load the acoustic model
model_name = 'WAV2VEC2_ASR_BASE_100H'
bundle = getattr(torchaudio.pipelines, model_name)
acoustic_model = bundle.get_model()
labels = bundle.get_labels()

# Apply the model to the waveform to get the emission tensor
with torch.inference_mode():
    emission, _ = acoustic_model(waveform)

## CTC Decoder

The CTC decoder can be constructed directly by using the `ctc_decoder` function in torchaudio. In addition to the parameters related to the beam search (we will study them later on), it takes as inputs:
- the list of tokens, in order to map emissions to characters in the classifier.
- the path to the lexicon, expected as a .txt file containing, on each line, a word followed by its space-split tokens (and special end-of-sequence token `|`):

```
# lexicon.txt
a     a |
able  a b l e |
about a b o u t |
...
```
- the path to the language model, expected as a .bin file.

All these are assembled in pretrained files that can be downloaded using the `download_pretrained_files` function (this might take some time as the language model can be large), and then used to contruct the decoder.

In [25]:
# Download the files corresponding to the model: here we use a 4-gram language model (which comes with lexicon and tokens)
files = download_pretrained_files("librispeech-4-gram")
print(files.tokens)
print(files.lexicon)
print(files.lm)

/Users/tunji/.cache/torch/hub/torchaudio/decoder-assets/librispeech-4-gram/tokens.txt
/Users/tunji/.cache/torch/hub/torchaudio/decoder-assets/librispeech-4-gram/lexicon.txt
/Users/tunji/.cache/torch/hub/torchaudio/decoder-assets/librispeech-4-gram/lm.bin


In [26]:
# Vizualize the first 10 tokens (includes the blank and end-of-word token)
with open(files.tokens, 'r') as f:
    tok = f.read().splitlines()
print("\n".join(tok[:10]))

-
|
e
t
a
o
n
i
h
s


In [27]:
# Vizualize the lexicon content (first 10 entries)
with open(files.lexicon, 'r') as f:
    lex = f.read().splitlines()
print("\n".join(lex[:10]))

a	a |
a''s	a ' ' s |
a'body	a ' b o d y |
a'court	a ' c o u r t |
a'd	a ' d |
a'gha	a ' g h a |
a'goin	a ' g o i n |
a'll	a ' l l |
a'm	a ' m |
a'mighty	a ' m i g h t y |


### Construct the decoder

In [28]:
print(files.lexicon)

/Users/tunji/.cache/torch/hub/torchaudio/decoder-assets/librispeech-4-gram/lexicon.txt


In [29]:
# Instanciate the CTC decoder
decoder = ctc_decoder(
    lexicon=files.lexicon,
    tokens=files.tokens,
    lm=files.lm,
)

We can now apply the constructed decoder to the `emission` tensor.

In [30]:
# Apply the decoder, and get the first element (batch_size=1) and best hypothesis
ctc_decoder_result = decoder(emission)[0][0]

### Getting the transcript

The decoder output `ctc_decoder_result` contains many fields, including the predicted token IDs.

In [31]:
# Get the token IDs using the .tokens field
ctc_decoder_indices = ctc_decoder_result.tokens
print(f"Token indices: {ctc_decoder_indices}")

# You can manually convert token IDs to tokens using the decoder.idxs_to_tokens method
# (+ a bit of postprocessing)
ctc_decoder_tokens = decoder.idxs_to_tokens(ctc_decoder_indices)
ctc_decoder_transcript = "".join(ctc_decoder_tokens).replace("|", " ").strip().split() 

print(f"Transcript: {ctc_decoder_transcript}")

Token indices: tensor([ 1, 18,  7, 22,  2,  1,  6,  5,  3,  1,  9,  5,  1,  2,  4, 10,  6,  2,
         9,  3,  1,  4,  1, 14,  7,  6, 11,  1,  3,  5,  1,  3,  8,  2,  9,  2,
         1, 14,  2, 14,  5, 10,  7,  2,  9,  1, 16,  8,  7, 12, 11,  1,  1])
Transcript: ['give', 'not', 'so', 'earnest', 'a', 'mind', 'to', 'these', 'memories', 'child']


Alternatively, you can obtain the transcript directly via the `.words` field, if a lexicon is provided; otherwise the `.words` field is an empty list, so the transcript needs to be built manually as done above.

In [32]:
print(f"Words: {ctc_decoder_result.words}")

Words: ['give', 'not', 'so', 'earnest', 'a', 'mind', 'to', 'these', 'memories', 'child']


## Influence of the lexicon

The lexicon is expected to have a strong influence on ASR performance, since it constrains the decoder to produce only words that belong to this lexicon, therefore avoiding to procude words that potentially do not exist in a language or given corpus.

<span style="color:red"> **Exercise 1**</span>. Create your own custom lexicon file `mylexicon.txt` that contains only words that belong to the example audio sentence ("give", "not", "so", etc.). Construct a decoder using this lexicon (as well as the language model) and perform ASR. Display the result and compare it with the true transcript. Does it seem relevent to use such a specific lexicon?

In [33]:
mylexicon_file = "mylexicon.txt"

if not os.path.exists(os.path.join(data_dir, mylexicon_file)):
    with open(os.path.join(data_dir, mylexicon_file), "w") as f:
        f.write(
            "give g i v e |\n"
            "not n o t |\n"
            "so s o |\n"
            "earnest e a r n e s t |\n"
            "a a |\n"
            "mind m i n d |\n"
            "to t o |\n"
            "these t h e s e |\n"
            "memories m e m o r i e s |\n"
            "child c h i l d |\n"
        )

mylexicon_file_path = os.path.join(data_dir, mylexicon_file)
print(f"Lexicon file path: {mylexicon_file_path}")

Lexicon file path: ../dataset/mylexicon.txt


In [34]:
# Instanciate the CTC decoder with the custom lexicon
custom_decoder = ctc_decoder(
    lexicon=mylexicon_file_path,
    tokens=files.tokens,
    lm=files.lm,
)

# perform ASR with the custom lexicon
ctc_custom_decoder_result = custom_decoder(emission)[0][0]
# Display the result 
ctc_custom_decoder_indices = ctc_custom_decoder_result.tokens
ctc_custom_decoder_tokens = custom_decoder.idxs_to_tokens(ctc_custom_decoder_indices)
ctc_custom_decoder_transcript = "".join(ctc_custom_decoder_tokens).replace("|", " ").strip().split()
print(f"Transcript: {ctc_custom_decoder_transcript}")

# Compare custom transcription with true transcription
wer = get_wer(true_transcript, ctc_custom_decoder_transcript)
print(f"WER: {wer}")

# Comment: With WER: 0.1 the lexicon seems very specific and does not seem to improve the accuracy of the ASR system 


Transcript: ['give', 'not', 'so', 'earnest', 'a', 'mind', 'to', 'these', 'memories', 'child']
WER: 0.1


## Influence of the language model

The language model is also expected to have a strong impact onto performance, since it guides the decoder towards more likely word sequences.

<span style="color:red"> **Exercise 2**</span>. Construct a decoder with no language model nor lexicon (pass `None` as input arguments). Perform ASR on the example audio and display the transcript. Compare it with the transcript obtained using the greedy decoder in the first lab.

In [35]:
no_lm_decoder = ctc_decoder(
    lexicon=None,
    tokens=files.tokens,
    lm=None,
)

# perform ASR on the example waveform without the language model
ctc_no_lm_decoder_result = no_lm_decoder(emission)[0][0]
# Display the result
ctc_no_lm_decoder_indices = ctc_no_lm_decoder_result.tokens
ctc_no_lm_decoder_tokens = no_lm_decoder.idxs_to_tokens(ctc_no_lm_decoder_indices)
ctc_no_lm_decoder_transcript = "".join(ctc_no_lm_decoder_tokens).replace("|", " ").strip().split()
print(f"Transcript: {ctc_no_lm_decoder_transcript}")

Transcript: ['give', 'not', 'so', 'earnest', 'a', 'mind', 'to', 'these', 'mumories', 'child']


In [36]:
class GreedyDecoder(torch.nn.Module):
    def __init__(self, labels, blank_token_indx=0):
        super().__init__()
        self.labels = labels
        self.blank_token_indx = blank_token_indx

    def forward(self, emission):
        """Given a sequence emission over labels, decode the transcript
        Args:
          emission (Tensor): Logit tensors. Shape `[1, num_seq, num_label]`.
        Returns:
          transcript (List of strings): The resulting transcript
        """
        emission = emission[0] # take the first element in the batch (only one)
        indices = torch.argmax(emission, dim=-1)  # take the most likely index at each time step
        indices = torch.unique_consecutive(indices, dim=-1) # remove duplicates
        indices = [i for i in indices if i != self.blank_token_indx] # remove blank token
        transcript = "".join([self.labels[i] for i in indices]) # convert indices into tokens
        transcript = transcript.replace("|", " ").lower().split() # a bit of post-processing
        return transcript

In [37]:
# Instanciate the decoder, apply it, and display the result
greedy_decoder = GreedyDecoder(labels=labels)
est_transcript = greedy_decoder(emission)
print(est_transcript)

['give', 'not', 'so', 'earnest', 'a', 'mind', 'to', 'these', 'mumories', 'child']


In [38]:
# Compare the estimated transcription with the no language model transcription
print("ctc_no_lm_decoder_transcript", ctc_no_lm_decoder_transcript)
print("est_transcript", est_transcript)
wer = get_wer(est_transcript, ctc_no_lm_decoder_transcript)
print(f"WER: {wer:.4f}")

ctc_no_lm_decoder_transcript ['give', 'not', 'so', 'earnest', 'a', 'mind', 'to', 'these', 'mumories', 'child']
est_transcript ['give', 'not', 'so', 'earnest', 'a', 'mind', 'to', 'these', 'mumories', 'child']
WER: 0.0


<span style="color:red"> **Exercise 3**</span>. Re-construct a CTC decoder using `librispeech-4-gram` files (lexicon, token, and language model downloaded above in this script). Perform ASR on the whole dataset folder (feel free to reuse/adapt a function from the previous lab) and compute the mean WER. Compare it with the WER when performing ASR with the greedy decoder.

In [39]:
# Re-construct the CTC decoder using librispeech-4-gram files
ctc_decoder_with_lm = ctc_decoder(
  lexicon=files.lexicon,
  tokens=files.tokens,
  lm=files.lm,
)

# Function to perform ASR on the whole dataset folder and compute mean WER
def compute_mean_wer(data_speech_dir, data_transc_dir, decoder, acoustic_model, max_files=MAX_FILES, is_greedy=False):
  total_wer = 0
  num_files = 0

  for root, _, filenames in os.walk(data_speech_dir):
    for filename in fnmatch.filter(filenames, '*.wav'):
      if num_files >= max_files:
        break
      audio_file_path = os.path.join(root, filename)
      transc_file_path = os.path.join(data_transc_dir, filename.replace('wav', 'txt'))

      # Load waveform
      waveform, sr = torchaudio.load(audio_file_path, channels_first=True)

      # Get true transcript
      true_transcript = get_true_transcript(transc_file_path)

      # Apply the acoustic model to get the emission tensor
      with torch.inference_mode():
        emission, _ = acoustic_model(waveform)

      # Apply the decoder

      if is_greedy:
        est_transcript = decoder(emission)
      else:
        ctc_decoder_result = decoder(emission)[0][0]
        ctc_decoder_indices = ctc_decoder_result.tokens
        ctc_decoder_tokens = decoder.idxs_to_tokens(ctc_decoder_indices)
        est_transcript = "".join(ctc_decoder_tokens).replace("|", " ").strip().split()
        
      # Compute WER
      wer = get_wer(true_transcript, est_transcript)
      total_wer += wer
      num_files += 1
      
  mean_wer = total_wer / num_files
  return mean_wer

# Compute mean WER for CTC decoder with language model
mean_wer_ctc_lm = compute_mean_wer(data_speech_dir, data_transc_dir, ctc_decoder_with_lm, acoustic_model)
print(f"Mean WER with CTC decoder and language model: {mean_wer_ctc_lm}")

# Compute mean WER for greedy decoder
mean_wer_greedy = compute_mean_wer(data_speech_dir, data_transc_dir, greedy_decoder, acoustic_model, is_greedy=True)
print(f"Mean WER with greedy decoder: {mean_wer_greedy}")

Mean WER with CTC decoder and language model: 0.07857539682539683
Mean WER with greedy decoder: 0.09489163614163613


## Beam search parameters

The beam search algorithm used in the CTC decoder depends on other parameters, such as `nbest` which determines the number of hypotheses to return, or `lm_weight` which adjust the relative importance of the language model vs. the acoustic model predictions. Here we only focus on `beam_size`, which determines the maximum number of best hypotheses to hold after each decoding step. Using larger beam sizes allows for exploring a larger range of possible hypotheses which can produce hypotheses with higher scores, which really is [the core](https://distill.pub/2017/ctc/) of the beam search algorithm.

<span style="color:red"> **Exercise 4**</span>. Perform ASR on the whole dataset folder for several values of the beam search size parameter: `beam_size` $\in [1, 10, 100]$. Compute the WER and the computation time (e.g., via the [time](https://docs.python.org/3/library/time.html#time.time) package). Comment.

In [40]:
beam_sizes = [1, 10, 100]

results = {}

for beam_size in beam_sizes:
  start_time = time.time()

  # Construct the CTC decoder with the specified beam size
  decoder_with_beam_size = ctc_decoder(
    lexicon=files.lexicon,
    tokens=files.tokens,
    lm=files.lm,
    beam_size=beam_size
  )
  
  # Compute mean WER for the current beam size
  mean_wer = compute_mean_wer(data_speech_dir, data_transc_dir, decoder_with_beam_size, acoustic_model)
  
  end_time = time.time()
  elapsed_time = end_time - start_time
  
  results[beam_size] = {
    "mean_wer": mean_wer,
    "time": elapsed_time
  }
  
  print(f"Beam size: {beam_size}, Mean WER: {mean_wer}, Time: {elapsed_time} seconds")

# Display the results
print(results)

Beam size: 1, Mean WER: 0.15342994359441728, Time: 28.33608102798462 seconds
Beam size: 10, Mean WER: 0.07857539682539681, Time: 28.807101011276245 seconds
Beam size: 100, Mean WER: 0.08024206349206349, Time: 33.238343954086304 seconds
{1: {'mean_wer': 0.15342994359441728, 'time': 28.33608102798462}, 10: {'mean_wer': 0.07857539682539681, 'time': 28.807101011276245}, 100: {'mean_wer': 0.08024206349206349, 'time': 33.238343954086304}}


In [41]:
# Comment: The beam search algorithm with a beam size of 10 seems to provide the best trade-off between WER and computation time