# Context-Biasing for ASR models with CTC-based Word Spotter

This tutorial aims to show how to improve the recognition accuracy of specific words in NeMo framework
for CTC and Trasducer ASR models by using fast context-biasing method with CTC-based Word Spotter.

## Tutorial content:
* Intro in context-biasing problem
* Description of CTC-based Words Spotter (CTC-WS) approach
* Practical part 1 (base):
    * Download data set and ASR models
    * Build context-biasing list
    * Evaluate recognition results with and without context-biasing
* Practical part 2 (advance):
    * Visuilization of context-biasing graph
    * Running word spotter
    * Merge greedy decoding results with word spotter hypotheses
    * Results analysis
* Summary

## Context-biasing: intro

ASR models often struggle to recognize words that were absent or have few examples in the training data.
This problem is especially acute due to the constant emergence of new names and titles in a rapidly developing world.
It is essential for the users to be able to recognize these new words.
Context-biasing methods attempt to solve this problem by assuming that we have a list of words and phrases (context-biasing list) in advance
for which we want to improve recognition accuracy.

One of the directions of context-biasing methods is a `deep fusion` approach.
These methods require integration into the ASR model and its training process.
The main disadvantage of these methods is that they require a lot of computational resources and time to train the model.

Another direction is methods based on `shallow fusion` approach. In this case, the only decoding process is modified.
During the beam-search decoding, the hypothesis is rescored depending on the presence of the current word in the context-biasing list or external language model.
The beam-search decoding may be computationally expensive, especially for the models with a large vocabulary and a large context-biasing list.
This problem is considerably worsened in the case of the Transducer (RNNT) model since beam-search decoding involves multiple Decoder (Prediction) and Joint networks calculations.
Moreover, the context-biasing recognition is limited by the initial predictions of the model. In the case of rare or new words,
the model may not have a hypothesis for the desired word from the context-biasing list whose probability we want to improve.

## CTC-based Word Spotter

<img width="500px" height="auto"
     src="cws.png"
     alt="CTC-WS"
     style="float: right; margin-left: 20px;">

In this tutorial we consider a fast context-biasing method using a CTC-based Word Spotter (CTC-WS) -- paper link.
The method involves decoding CTC log probabilities with a context graph built for words and phrases from the context-biasing list.
The spotted context-biasing candidates (with their scores and time intervals) are compared by scores with words from the greedy
CTC decoding results to improve recognition accuracy and pretend false accepts of context-biasing (Figure 1).

A Hybrid Transducer-CTC model [link] (a shared encoder trained together with CTC and Transducer output heads) 
allows the use of the CTC-WS method for the Transducer model.
Context-biasing candidates obtained by CTC-WS are also filtered by the scores with greedy CTC predictions and then merged with greedy Transducer results.

CTC-WS method allows to use pretrained NeMo models (`CTC` or `Hybrid Transducer-CTC`) for context-biasing recognition without any model retraining.
The method shows inspired results for context-biasing with only little additional work time and computational resources.

## Practical part 1 (base)
In this part, we will consider the base usage of the CTC-WS method for the pretrained NeMo model.

### Data preparation.
We will use a subset (10 files) of the GTC dataset. The dataset contains audio files with NVIDIA GTC talks. 
The main dataset feature is the computer science and engineering domain,
which has a large number of unique terms and product names (NVIDIA, GPU, GeForce, Ray Tracing, Omniverse, teraflops, etc.), which is ideal for the context-biasing task. All the text data is normolized and in lowercase.

In [None]:
from nemo.collections.asr.parts.utils.manifest_utils import read_manifest

# data is already stored in nemo data manifest format
test_nemo_manifest = "/nemo_recipes/librispeech/data/gtc_test_set_3.0/gtc_jhh_manifest.json.filt_ml-35.json.filter_high_wer.tutorial"
test_data = read_manifest(test_nemo_manifest)

for idx, item in enumerate(test_data):
    print(f"[{idx}]: {item['text']}")

In [None]:
import librosa
import IPython.display as ipd

# load and listen to the audio file example
example_file = test_data[0]['audio_filepath']
audio, sample_rate = librosa.load(example_file)

file_id = 0
print(f"[TEXT {file_id}]: {test_data[file_id]['text']}\n")
ipd.Audio(example_file, rate=sample_rate)

### Load ASR models

For testing CTC-WS method we will use the following NeMo models:
 - (CTC): [stt_en_fastconformer_ctc_large](https://huggingface.co/nvidia/stt_en_fastconformer_ctc_large) - a large fast-conformer model trained on english ASR data
 - (Hybrid Transducer-CTC): [stt_en_fastconformer_hybrid_large_streaming_multi](https://huggingface.co/nvidia/stt_en_fastconformer_hybrid_large_streaming_multi) - a large fast-conformer model trained jointly with CTC and Transducer heads on english ASR data. The last model is a streaming model, which means it can process audio in real-time. It can cause a sligth WER degradation in comparison with the first model.

In [None]:
from nemo.collections.asr.models import EncDecCTCModelBPE, EncDecHybridRNNTCTCBPEModel

# ctc model
ctc_model_name = "stt_en_fastconformer_ctc_large"
ctc_model = EncDecCTCModelBPE.from_pretrained(model_name=ctc_model_name)

# hybrid transducer-ctc model
hybrid_ctc_rnnt_model_name = "stt_en_fastconformer_hybrid_large_streaming_multi"
hybrid_ctc_rnnt_model = EncDecHybridRNNTCTCBPEModel.from_pretrained(model_name=hybrid_ctc_rnnt_model_name)

### Transcribe 
Let's transcribe test data and analyze regontion accuracy of specific words 

In [None]:
test_audio_files = [item['audio_filepath'] for item in test_data]
recog_results = ctc_model.transcribe(test_audio_files)

### Compute per-word recognition statisctic

In [None]:
import texterrors

word_dict = {} # {word: [num_of_occurance, num_of_correct_recognition]}
eps = "<eps>"
ref_text = [item['text'] for item in test_data]

for idx, ref in enumerate(ref_text):
    ref = ref.split()
    hyp = recog_results[idx].split()
    texterrors_ali = texterrors.align_texts(ref, hyp, False)
    ali = []
    for i in range(len(texterrors_ali[0])):
        ali.append((texterrors_ali[0][i], texterrors_ali[1][i]))

    for pair in ali:
        word_ref, word_hyp = pair
        if word_ref == eps:
            continue
        if word_ref in word_dict:
            word_dict[word_ref][0] += 1
        else:
            word_dict[word_ref] = [1, 0]
        if word_ref == word_hyp:
            word_dict[word_ref][1] += 1

word_candidats = {}

for word in word_dict:
    gt = word_dict[word][0]
    tp = word_dict[word][1]
    if tp/gt < 1.0:
        word_candidats[word] = [gt, round(tp/gt, 2)]
        
# print obtained per-word statistic
word_candidats_sorted = sorted(word_candidats.items(), key=lambda x:x[1][0], reverse=True)
max_word_len = max([len(x[0]) for x in word_candidats_sorted])
for item in word_candidats_sorted:
    print(f"{item[0]:<{max_word_len}} {item[1][0]}/{item[1][1]}")

## Create context-biasing list

Now we need to select the words, recognition of wich we want to improve by CTC-WS context-biasing.
Usually, we select only not trivial words with the lowest recognition accuracy.
Such words shuld have character length >= 3, because short words in context-biasing list may lead to high level of false positives recognition.
In this toy example we will select all the words that looks like names and has recognition accuracy less than 1.0.

The structure of the context-biasing file is:

WORD1_TRANSCRIPTION1  
WORD2_TRANSCRIPTION1   
...

TRANSCRIPTION here is a word spelling. We need this structure to be able to add more alternative transcriptions (spellings) for one specific word. We will cover such a case further.

In [None]:
cb_words = ["gpu", "nvidia", "nvidia's", "nvlink", "omniverse", "cunumeric", "numpy", "dgx", "dgxs", "dlss",
            "cpu", "tsmc", "culitho", "xlabs", "tensorrt", "tensorflow", "pytorch", "aws", "chatgpt", "pcie"]

# write context-biasing file 
cb_list_file = "context_biasing_list.txt"
with open(cb_list_file, "w", encoding="utf-8") as fn:
    for word in cb_words:
        fn.write(f"{word}_{word}\n")

In [None]:
!cat {cb_list_file}

## Run context biasing evaluation

The main script for CTC-WS context-biasing in NeMo is:\
`{NeMo_root}/scripts/asr_context_biasing/eval_greedy_decoding_with_context_biasing.py`

Context-biasing is managed by `apply_context_biasing` parameter [true or false].\
Other important context-biasing parameters are:
- `beam_threshold` - threshold for CTC-WS beam pruning
- `context_score` - per token weight for context biasing
- `ctc_ali_token_weight` - per token weight for CTC alignment (prevents false acceptances of context-biasing words)\

All the context-biasing parameters secelted according to the default values in the biasing script.\
You can tune them according to your data and ASR model (list all the values in the [] separated by commas)\
for example: `beam_threshold=[7.0,8.0,9.0]`, `context_score=[3.0,4.0,5.0]`, `ctc_ali_token_weight=[0.5,0.6,0.7]`.\
The script will run the recognition with all the combinations of the parameters and will select the best one based on WER value.

In [None]:
# create directory with experimental results
import os

exp_dir = "exp"
if not os.path.isdir(exp_dir):
    os.makedirs(exp_dir)
else:
    print(f"Directory '{exp_dir}' already exists")

In [None]:
# ctc model (no context-biasing)

!python ../../scripts/asr_context_biasing/eval_greedy_decoding_with_context_biasing.py \
            nemo_model_file={ctc_model_name} \
            input_manifest={test_nemo_manifest} \
            preds_output_folder={exp_dir} \
            decoder_type="ctc" \
            acoustic_batch_size=64 \
            apply_context_biasing=false \
            context_file={cb_list_file} \
            beam_threshold=[7.0] \
            context_score=[3.0] \
            ctc_ali_token_weight=[0.5]

The results must be:

`Precision`: 1.0000 (1/1) fp:0 (fp - false positive recognition)  
`Recall`:    0.0333 (1/30)  
`Fscore`:    0.0645  
`Greedy WER/CER` = 35.68%/8.16%

The model was able to recognize on 1 out of 29 words from context-biasing list.
Let's enable context-biasing during decoding:



In [None]:
# ctc model (with context biasing)
!python ../../scripts/asr_context_biasing/eval_greedy_decoding_with_context_biasing.py \
            nemo_model_file={ctc_model_name} \
            input_manifest={test_nemo_manifest} \
            preds_output_folder={exp_dir} \
            decoder_type="ctc" \
            acoustic_batch_size=64 \
            apply_context_biasing=true \
            context_file={cb_list_file} \
            beam_threshold=[7.0] \
            context_score=[3.0] \
            ctc_ali_token_weight=[0.5]

Now recognition results are better:

`Precision`: 1.0000 (21/21) fp:0  
`Recall`:    0.7000 (20/29)  
`Fscore`:    0.8235  
`Greedy WER/CER` = 17.09%/4.43%

But we are still able to recognize only 21 out of 30 specific words.\
You can see that not recognized words are mostly abbreviations (`dgxs`, `dlss`, `gpu`, `aws`, etc.) or complicated words (`culitho`).\
The ASR models tends to recognize such words as a sequence of characters (`"aws" -> "a w s"`) or subwords (`"culitho" -> "cu litho"`).\
We can try to improve the recognition of such words by adding alternative transcriptions to the context-biasing list.

### Alternative transcriptions

wordninja is used to split the complicated words into simpl words

In [None]:
!pip install wordninja

In [None]:
import wordninja

cb_list_file_modified = cb_list_file + ".abbr_and_ninja"

with open(cb_list_file, "r", encoding="utf-8") as fn1, \
    open(cb_list_file_modified, "w", encoding="utf-8") as fn2:

    for line in fn1:
        word = line.strip().split("_")[0]
        new_line = f"{word}_{word}"
        # split all the short words into characters
        if len(word) <= 4 and len(word.split()) == 1:
            abbr = ' '.join(list(word))
            new_line += f"_{abbr}"
        # split the long words into the simple words (not for phrases)
        new_segmentation = wordninja.split(word)
        if word != new_segmentation[0]:
            new_segmentation = ' '.join(new_segmentation)
            new_line += f"_{new_segmentation}"
        fn2.write(f"{new_line}\n")

In [None]:
!cat {cb_list_file_modified}

Run context-biasing with modified context-biasing list:

In [None]:
# ctc models (with context biasing and modified cb list)
!python ../../scripts/asr_context_biasing/eval_greedy_decoding_with_context_biasing.py \
            nemo_model_file={ctc_model_name} \
            input_manifest={test_nemo_manifest} \
            preds_output_folder={exp_dir} \
            decoder_type="ctc" \
            acoustic_batch_size=64 \
            apply_context_biasing=true \
            context_file={cb_list_file_modified} \
            beam_threshold=[7.0] \
            context_score=[3.0] \
            ctc_ali_token_weight=[0.5]

Now recognition results are:

`Precision`: 1.0000 (28/28) fp:1  
`Recall`:    0.9333 (28/30)  
`Fscore`:    0.9655  
`Greedy WER/CER` = 7.04%/2.93%

As you can see we get an additional recognition accuracy improvenment of the context-biasing words by adding alternative transcriptions to the cb_list file. However we still miss 2 words.

In some cases you can add alternative transcription manually based on recognition errors of your ASR model for the specific words (for example, `"nvidia" -> "n video"`) 

### Hybrid Transducer-CTC model
CTC-WS context-biasing for Transducer models is supported only for Hybrid Transducer-CTC model.\
To use Transducer head of the Hybrid Transducer-CTC model, we need to set `decoder_type="rnnt"`.\
Other parameters are the same as for the CTC model, because the context-biasing is applied only to the CTC part of the model.\
Spotted context-biasing words will be merged with greedy decoding results of the Transducer head.

In [None]:
# Transducer model (no context-biasing)
!python ../../scripts/asr_context_biasing/eval_greedy_decoding_with_context_biasing.py \
            nemo_model_file={hybrid_ctc_rnnt_model_name} \
            input_manifest={test_nemo_manifest} \
            preds_output_folder={exp_dir} \
            decoder_type="ctc" \
            acoustic_batch_size=64 \
            apply_context_biasing=false \
            context_file={cb_list_file_modified} \
            beam_threshold=[7.0] \
            context_score=[3.0] \
            ctc_ali_token_weight=[0.5]

In [None]:
# Transducer model (with context-biasing)
!python ../../scripts/asr_context_biasing/eval_greedy_decoding_with_context_biasing.py \
            nemo_model_file={hybrid_ctc_rnnt_model_name} \
            input_manifest={test_nemo_manifest} \
            preds_output_folder={exp_dir} \
            decoder_type="ctc" \
            acoustic_batch_size=64 \
            apply_context_biasing=true \
            context_file={cb_list_file_modified} \
            beam_threshold=[7.0] \
            context_score=[3.0] \
            ctc_ali_token_weight=[0.5]

CTC-WS context-biasing works for Transducer model as well as for CTC. The difference in the results may be caused by difference in offline and online models nature. 

## Practical part 2 (advanced)
In this section we will consider the context-biasing process more deeply:
- Visualization of the context-biasing graph
- Running CTC-WS with the context-biasing graph
- Merge the obtained spotted words with greedy decoding results
- Analysis of the results

### Build context graph (for visualization only)
The context graph consists of a composition of a prefix tree (Trie) with the CTC transition topology for words and phrases from the context-biasing list.\
We use a BPE tokenizer from the target ASR model for word segmentation.

In [None]:
from nemo.collections.asr.parts import context_biasing

# get bpe tokenization
cb_words_small = ["nvidia", "gpu", "nvlink", "numpy"]
context_transcripts = []
for word in cb_words_small:
    # use text_to_tokens method for viasualization only
    word_tokenization = [hybrid_ctc_rnnt_model.tokenizer.text_to_tokens(x) for x in word]
    context_transcripts.append([word, word_tokenization])

# build context graph
context_graph = context_biasing.ContextGraphCTC(blank_id=hybrid_ctc_rnnt_model.decoder.blank_idx)
context_graph.add_to_graph(context_transcripts)
#context_graph.draw()

### Build a real context graph (for decoding)

In [None]:
# get bpe tokenization
context_transcripts = []
for word in cb_words:
    word_tokenization = [hybrid_ctc_rnnt_model.tokenizer.text_to_ids(x) for x in word]
    context_transcripts.append([word, word_tokenization])

# build context graph
context_graph = context_biasing.ContextGraphCTC(blank_id=hybrid_ctc_rnnt_model.decoder.blank_idx)
context_graph.add_to_graph(context_transcripts)

### Run CTC-based Word Spotter

In [None]:
import torch

# get ctc logprobs
audio_file_paths = [item['audio_filepath'] for item in test_data]

with torch.no_grad():
    ctc_model.eval()
    ctc_model.encoder.freeze()
    device = next(ctc_model.parameters()).device
    ctc_logprobs = ctc_model.transcribe(audio_file_paths, batch_size=10, logprobs=True)
    blank_idx = ctc_model.decoding.blank_id

In [None]:
from tqdm.notebook import tqdm

# run ctc-based word spotter
ws_results = {}
for idx, logits in tqdm(
    enumerate(ctc_logprobs), desc=f"Eval CTC-based Word Spotter...", total=len(ctc_logprobs)
):
    ws_results[audio_file_paths[idx]] = context_biasing.run_word_spotter(
        logits,
        context_graph,
        ctc_model,
        blank_idx=blank_idx,
        beam_threshold=7.0,
        cb_weight=3.0,
        ctc_ali_token_weight=0.5,
    )

### Merge CTC-WS words with greedy CTC decoding results

We use `print_stats=True` to get more information about spotted words and greedy ctc word alignment

In [None]:
import numpy as np

target_transcripts = [item['text'] for item in test_data]

# merge spotted words with greedy results
for idx, logprobs in enumerate(ctc_logprobs):
    greedy_predicts = np.argmax(logprobs, axis=1)
    if ws_results[audio_file_paths[idx]]:
        # make new text by mearging alignment with ctc-ws predictions:
        print("\n" + "********" * 10)
        print(f"File name: {audio_file_paths[idx]}")
        pred_text, raw_text = context_biasing.merge_alignment_with_ws_hyps(
            greedy_predicts,
            ctc_model,
            ws_results[audio_file_paths[idx]],
            decoder_type="ctc",
            blank_idx=blank_idx,
            print_stats=True,
        )
        print(f"[raw text]: {raw_text}")
        print(f"[hyp text]: {pred_text}")
        print(f"[ref text]: {target_transcripts[idx]}")
    else:
        pred_text = asr_model.wer.decoding.ctc_decoder_predictions_tensor(greedy_predicts)[0][0]

Some words about context-biasing statistic analysis

## Summary

This tutorial demonstrates how to use the CTC-WS context-biasing technique to improve the recognition of specific words in the audio data in case of CTC and Transducer ASR models.\
The tutorial includes the methodology for the creation of the context-biasing list, improving recognition accuracy of abbriviations and complex words, visualization of the context-biasing process and results analisys.
