# Overview
This tutorial demonstrates how to run inference with [SpellMapper](https://arxiv.org/abs/2306.02317) - a model for Spellchecking ASR (Automatic Speech Recognition) Customization.

Estimated time: 10-15 min.

SpellMapper is a non-autoregressive (NAR) model based on transformer architecture ([BERT](https://arxiv.org/pdf/1810.04805.pdf) with multiple separators).
It gets as input a single ASR hypothesis (text) and a **custom vocabulary** and predicts which fragments in the ASR hypothesis should be replaced by which custom words/phrases if any.

This model is an alternative to word boosting/shallow fusion approaches:
  - does not require retraining ASR model;
  - does not require beam-search/language model(LM);
  - can be applied on top of any English ASR model output;

## What is custom vocabulary?
**Custom vocabulary** is a list of words/phrases that are important for a particular user. For example, user's contact names, playlist, selected terminology and so on. The size of the custom vocabulary can vary from several hundreds to **several thousand entries** - but this is not an equivalent to ngram language model.

![Scope of customization with user vocabulary](images/spellmapper_customization_vocabulary.png)

Note that unlike traditional spellchecking approaches, which aim to correct known words using language models, the goal of contextual spelling correction is to correct highly specific user terms, most of which can be 1) out-of-vocabulary (OOV) words, 2) spelling variations (e.g., "John Koehn", "Jon Cohen") and language models cannot help much with that.

## Tutorial Plan

1.   Create a sample custom vocabulary using some medical terminology.
2.   Study what customization does - a detailed analysis of a small example.
3.   Run a bigger example:
   *  Create sample ASR results by running TTS (text-to-speech synthesis) + ASR on some medical paper abstracts.
   *  Run SpellMapper inference and show how it can improve ASR results using custom vocabulary.

TL;DR We reduce WER from `14.3%` to `11.4%` by correcting medical terms, e.g.
* `puramesin` => `puromycin`
* `parromsin` => `puromycin`
* `and hydrod` => `anhydride`
* `lesh night and` => `lesch-nyhan`


# Preparation

## Installing NeMo

In [None]:
# Install NeMo library. If you are running locally (rather than on Google Colab), comment out the below lines
# and instead follow the instructions at https://github.com/NVIDIA/NeMo#Installation
GITHUB_ACCOUNT = "NVIDIA"
BRANCH = 'main'
!python -m pip install git+https://github.com/{GITHUB_ACCOUNT}/NeMo.git@{BRANCH}#egg=nemo_toolkit[all]

# Download local version of NeMo scripts. If you are running locally and want to use your own local NeMo code,
# comment out the below lines and set NEMO_DIR to your local path.
NEMO_DIR = 'nemo'
!git clone -b {BRANCH} https://github.com/{GITHUB_ACCOUNT}/NeMo.git $NEMO_DIR

## Additional installs
We will use `sentence_splitter` to split abstracts to sentences.

In [None]:
!pip install sentence_splitter

Clone the SpellMapper model from HuggingFace.
Note that we will need not only the checkpoint itself, but also the ngram mapping vocabulary `replacement_vocab_filt.txt` from the same folder.

In [None]:
!git clone https://huggingface.co/bene-ges/spellmapper_asr_customization_en

## Imports


In [None]:
import IPython.display as ipd
import json
import random
import re
import soundfile as sf
import torch

from collections import Counter, defaultdict
from difflib import SequenceMatcher
from matplotlib.pyplot import imshow
from matplotlib import pyplot as plt
from sentence_splitter import SentenceSplitter
from typing import List, Set, Tuple

from nemo.collections.tts.models import FastPitchModel
from nemo.collections.tts.models import HifiGanModel

from nemo.collections.asr.parts.utils.manifest_utils import read_manifest

from nemo.collections.nlp.data.spellchecking_asr_customization.utils import (
    get_all_candidates_coverage,
    get_index,
    load_ngram_mappings,
    search_in_index,
    get_candidates,
    read_spellmapper_predictions,
    apply_replacements_to_text,
    load_ngram_mappings_for_dp,
    get_alignment_by_dp,
)

Use seed to get a reproducible behaviour.

In [None]:
random.seed(0)
torch.manual_seed(0)

## Download data

File `pubmed23n0009.xml` taken from public ftp server of https://www.ncbi.nlm.nih.gov/pmc/ contains information about 5593 medical papers, from which we extract only their abstracts. We will feed sentences from there to TTS + ASR to get initial ASR results.

File `wordlist.txt` contains 100k **single-word** medical terms.

File `valid_adam.txt` contains 24k medical abbreviations with their full forms. We will use those full forms as examples of **multi-word** medical terms.

File `count_1w.txt` contains 330k single words with their frequencies from Google Ngrams corpus. We will use this file to filter out frequent words from our custom vocabulary.


In [None]:
!wget https://ftp.ncbi.nlm.nih.gov/pubmed/baseline/pubmed23n0009.xml.gz
!gunzip pubmed23n0009.xml.gz
!grep "AbstractText" pubmed23n0009.xml > abstract.txt

!wget https://raw.githubusercontent.com/McGill-NLP/medal/master/toy_data/valid_adam.txt
!wget https://raw.githubusercontent.com/glutanimate/wordlist-medicalterms-en/master/wordlist.txt
!wget https://norvig.com/ngrams/count_1w.txt

## Auxiliary functions




In [None]:
CHARS_TO_IGNORE_REGEX = re.compile(r"[\.\,\?\:!;()«»…\]\[/\*–‽+&_\\½√>€™$•¼}{~—=“\"”″‟„]")


def get_medical_vocabulary() -> Tuple[Set[str], Set[str]]:
    """This function builds a vocabulary of medical terms using downloaded sources:
        wordlist.txt - 100k single-word medical terms.
        valid_adam.txt - 24k medical abbreviations with their full forms. We use those full forms as examples of multi-word medical terms.
        count_1w.txt - 330k single words with their frequencies from Google Ngrams corpus. We will use this file to filter out frequent words from our custom vocabulary.
    """
    common_words  = set()
    with open("count_1w.txt", "r", encoding="utf-8") as f:
        for line in f:
            word, freq = line.strip().casefold().split("\t")
            if int(freq) < 500000:
                break
            common_words.add(word)
    print("Size of common words vocabulary:", len(common_words))

    abbreviations = defaultdict(set)
    medical_vocabulary = set()
    with open("valid_adam.txt", "r", encoding="utf-8") as f:
        lines = f.readlines()
        # first line is header
        for line in lines[1:]:
            abbrev, _, phrase = line.strip().split("\t")
            # skip phrases longer than 3 words because some of them are long explanations
            if phrase.count(" ") > 2:
                continue
            if phrase in common_words:
                continue
            medical_vocabulary.add(phrase)
            abbrev = abbrev.lower()
            abbreviations[abbrev].add(phrase)

    with open("wordlist.txt", "r", encoding="utf-8") as f:
        for line in f:
            word = line.strip().casefold()
            # skip words contaning digits
            if re.match(r".*\d.*", word):
                continue
            if re.match(r".*[\[\]\(\)\+\,\.].*", word):
                continue
            if word in common_words:
                continue
            medical_vocabulary.add(word)

    print("Size of medical vocabulary:", len(medical_vocabulary))
    print("Size of abbreviation vocabulary:", len(abbreviations))
    return medical_vocabulary, abbreviations


def read_abstracts(medical_vocabulary: Set[str]) -> Tuple[List[str], Set[str], Set[str]]:
    """This function reads the downloaded medical abstracts, and extracts sentences containing any word/phrase from the medical vocabulary.
    Args:
        medical_vocabulary: set of known medical words or phrases
    Returns:
        sentences: list of extracted sentences
        all_found_singleword: set of single words from medical vocabulary that occurred at least in one sentence
        all_found_multiword: set of multi-word phrases from medical vocabulary that occurred at least in one sentence
    """
    splitter = SentenceSplitter(language='en')

    all_sentences = []
    all_found_singleword = set()
    all_found_multiword = set()
    with open("abstract.txt", "r", encoding="utf-8") as f:
        for line in f:
            text = line.strip().replace("<AbstractText>", "").replace("</AbstractText>", "")
            sents = splitter.split(text)
            found_singleword = set()
            found_multiword = set()
            for sent in sents:
                # remove anything in brackets from text
                sent = re.sub(r"\(.+\)", r"", sent)
                # remove quotes from text
                sent = sent.replace("\"", "")
                # skip sentences contaning digits because normalization is out of scope of this tutorial
                if re.match(r".*\d.*", sent):
                    continue
                # skip sentences contaning abbreviations with period inside the sentence (for the same reason)
                if ". " in sent:
                    continue
                # skip long sentences as they may cause OOM issues
                if len(sent) > 150:
                    continue
                # replace all punctuation to space and convert to lowercase
                sent_clean = CHARS_TO_IGNORE_REGEX.sub(" ", sent).lower()
                sent_clean = " ".join(sent_clean.split(" "))
                words = sent_clean.split(" ")

                found_phrases = set()
                for begin in range(len(words)):
                    for end in range(begin + 1, min(begin + 4, len(words))):
                        phrase = " ".join(words[begin:end])
                        if phrase in medical_vocabulary:
                            found_phrases.add(phrase)
                            if end - begin == 1:
                                found_singleword.add(phrase)
                            else:
                                found_multiword.add(phrase)
                if len(found_phrases) > 0:
                    all_sentences.append((sent, ";".join(found_phrases)))
            all_found_singleword = all_found_singleword.union(found_singleword)
            all_found_multiword = all_found_multiword.union(found_multiword)

    print("Sentences:", len(all_sentences))
    print("Unique single-word terms found:", len(all_found_singleword))
    print("Unique multi-word terms found:", len(all_found_multiword))
    print("Examples of multi-word terms", str(list(all_found_multiword)[0:10]))
    
    return all_sentences, all_found_singleword, all_found_multiword

In [None]:
def get_fragments(i_words: List[str], j_words: List[str]) -> List[Tuple[str, str, str, int, int, int, int]]:
    """This function is used to compare two word sequences to find minimal fragments that differ.
    Args:
        i_words: list of words in first sequence
        j_words: list of words in second sequence
    Returns:
        list of tuples (difference_type, fragment1, fragment2, begin_of_fragment1, end_of_fragment1, begin_of_fragment2, end_of_fragment2)
    """
    s = SequenceMatcher(None, i_words, j_words)
    result = []
    for tag, i1, i2, j1, j2 in s.get_opcodes():
        result.append((tag, " ".join(i_words[i1:i2]), " ".join(j_words[j1:j2]), i1, i2, j1, j2))
    result = sorted(result, key=lambda x: x[3])
    return result

## Read medical data

In [None]:
medical_vocabulary, _ = get_medical_vocabulary()
sentences, found_singleword, found_multiword = read_abstracts(medical_vocabulary)
# in case if we need random candidates from a big sample - we will use full medical vocabulary for that purpose.
big_sample = list(medical_vocabulary)

In [None]:
for sent, phrases in sentences[0:10]:
    print(sent, "\t", phrases)

# SpellMapper ASR Customization

SpellMapper model relies on two offline preparation steps:
1. Collecting n-gram mappings from a large corpus (this mappings vocabulary had been collected once on a large corpus and is supplied with the model).
2. Indexing of user vocabulary by n-grams.

![Offline data preparation](images/spellmapper_data_preparation.png)

At inference time we take as input an ASR hypothesis and an n-gram-indexed user vocabulary and perform following steps:
1. Retrieve the top 10 candidate phrases from the user vocabulary that are likely to be contained in the given ASR-hypothesis, possibly in a misspelled form.
2. Run the neural model that tags the input characters with correct candidate labels or 0 if no match is found.
3. Do post-processing to combine results.

![Inference pipeline](images/spellmapper_inference_pipeline.png)


## N-gram mappings
Note that n-gram mappings vocabulary had been collected from a large corpus and is supplied with the model. It is supposed to be "universal" for English language.


Let's see what n-gram mappings are like, for example, for an n-gram `l u c`.
Note that n-grams in `replacement_vocab_filt.txt` preserve one-to-one correspondence between original letters and misspelled fragments (this additional markup is handled during loading). 
* `+` means that adjacent letters are concatenated and correspond to a single source letter. 
* `<DELETE>` means that the original letter is deleted. 
This auxiliary markup will be removed automatically during loading.

`_` is used instead of real space symbol.

Last three columns are:
* joint frequency
* frequency of original n-gram
* frequency of misspelled n-gram

$$\frac{JointFrequency}{SourceFrequency}=TranslationProbability$$



In [None]:
!awk 'BEGIN {FS="\t"} ($1=="l u c"){print $0}' < spellmapper_asr_customization_en/replacement_vocab_filt.txt | sort -t$'\t' -k3nr

Now we read n-gram mappings from the file. Parameter `max_misspelled_freq` controls maximum frequency of misspelled n-grams. N-grams more frequent than that are put in the list of banned n-grams and won't be used in indexing.

In [None]:
print("load n-gram mappings...")
ngram_mapping_vocab, ban_ngram = load_ngram_mappings("spellmapper_asr_customization_en/replacement_vocab_filt.txt", max_misspelled_freq=125000)
# CAUTION: entries in ban_ngram end with a space and can contain "+" "="
print("Size of ngram mapping vocabulary:", len(ngram_mapping_vocab))
print("Size of banned ngrams:", len(ban_ngram))


## Indexing of custom vocabulary

As we mentioned earlier, this model pipeline is intended to work with custom vocabularies up to several thousand entries. Since the whole medical vocabulary contains 110k entries, we restrict our custom vocabulary to 5000+ terms that occurred in given corpus of abstracts.

The goal of indexing our custom vocabulary is to build an index where key is a letter n-gram and value is the whole phrase. The keys are n-grams in the given user phrase and their misspelled variants taken from our collection of n-
gram mappings (see Index of custom vocabulary in Fig. 1)

*Though it is possible to index and search the whole 110k vocabulary, it will require additional optimizations and is beyond the scope of this tutorial.*

In [None]:
custom_phrases = []
for phrase in medical_vocabulary:
    if phrase not in found_singleword and phrase not in found_multiword:
        continue
    custom_phrases.append(" ".join(list(phrase.replace(" ", "_"))))
print("Size of customization vocabulary:", len(custom_phrases))

Now we build the index for our custom phrases.

Parameter `min_log_prob` controls minimum log probability, after which we stop growing this n-gram.

Parameter `max_phrases_per_ngram` controls maximum number of phrases that can be indexed by one ngram. N-grams exceeding this limit are also banned and not used in indexing.



In [None]:
phrases, ngram2phrases = get_index(custom_phrases, ngram_mapping_vocab, ban_ngram, min_log_prob=-4.0, max_phrases_per_ngram=600)
print("Size of phrases:", len(phrases))
print("Size of ngram2phrases:", len(ngram2phrases))

# Save index to file - later we will use it in other script
with open("index.txt", "w", encoding="utf-8") as out:
    for ngram in ngram2phrases:
        for phrase_id, begin, size, logprob in ngram2phrases[ngram]:
            phrase = phrases[phrase_id]
            out.write(ngram + "\t" + phrase + "\t" + str(begin) + "\t" + str(size) + "\t" + str(logprob) + "\n")


## Small detailed example

Let's consider, for example, one custom phrase `thoracic aorta` and an incorrect ASR-hypothesis `the tarasic oorda is a part of the aorta located in the thorax`, containing a misspelled phrase `tarasic_oorda`. 

We will see 
1. How this custom phrase is indexed.
2. How candidate retrieval works, given ASR-hypothesis.
3. How inference and post-processing work.


### N-grams in index

Let's look, for example, by what n-grams a custom phrase `thoracic aorta` is indexed. 
Columns: 
1. n-gram
2. beginning position in the phrase
3. length
4. log probability

Note that many n-grams are not from n-gram mappings file. Those are derived by  growing previous n-grams with new replacements. In this case log probabilities are summed up. Growing stops, when minimum log prob is exceeded.


In [None]:
for ngram in ngram2phrases:
    for phrase_id, b, length, lprob in ngram2phrases[ngram]:
        if phrases[phrase_id] == "t h o r a c i c _ a o r t a":
            print(ngram.ljust(16) + "\t" + str(b).rjust(4) + "\t" + str(length).rjust(4) + "\t" + str(lprob))

### Candidate retrieval
Candidate retrieval tasks are:
 - Given an input sentence and an index of custom vocabulary find all n-grams from the index matching the sentence. 
 - Find which sentence fragments and which custom phrases have most "hits" - potential candidates.
 - Find approximate starting position for each candidate phrase. 


Let's look at the hits, that phrase "thoracic aorta" gets by searching all ngrams in the input text. We can see some hits in different part of the sentence, but a moving window can find a fragment with most hits.

In [None]:
sent = "the_tarasic_oorda_is_a_part_of_the_aorta_located_in_the_thorax"
phrases2positions, position2ngrams = search_in_index(ngram2phrases, phrases, sent)
print(" ".join(list(sent)))
print(" ".join(list(map(str, phrases2positions[phrases.index("t h o r a c i c _ a o r t a")].astype(int)))))

`phrases2positions` is a matrix of size (len(phrases), len(ASR_hypothesis)).
It is filled with 1.0 (hits) on intersection of letter n-grams and phrases that are indexed by these n-grams, 0.0 - elsewhere.
It is used to find phrases with many hits within a contiguous window - potential matching candidates.

`position2ngrams` is a list of sets of ngrams. List index is the starting position in the ASR-hypothesis.
It is used later to check how well each found candidate is covered by n-grams (to avoid cases where some repeating n-gram gives many hits to a phrase, but the phrase itself is not well covered).

In [None]:
candidate2coverage, candidate2position = get_all_candidates_coverage(phrases, phrases2positions)
print("Coverage=", candidate2coverage[phrases.index("t h o r a c i c _ a o r t a")])
print("Starting position=", candidate2position[phrases.index("t h o r a c i c _ a o r t a")])

`candidate2coverage` is a list of size len(phrases) containing coverage (0.0 to 1.0) in best window.
Coverage is a smoothed percentage of hits in the window of size of the given phrase.

`candidate2position` is a list of size len(phrases) containing starting position of best window.

Starting position is approximate, it's ok. If it is not at the beginning of some word, SpellMapper will try to adjust it later. In this particular example we get 5 as starting position instead of 4, missing the first letter.

### Inference

Now let's generate input for SpellMapper inference. 
An input line should consist of 4 tab-separated columns:
  - text of ASR-hypothesis
  - texts of 10 candidates separated by semicolon
  - 1-based ids of non-dummy candidates
  - approximate start/end coordinates of non-dummy candidates (correspond to ids)
Note that candidate retrieval is done inside the function `get_candidates`.

In [None]:
out = open("spellmapper_input.txt", "w", encoding="utf-8")
letters = list(sent)
candidates = get_candidates(ngram2phrases, phrases, letters, big_sample)
# We add two columns with targets and span_info. 
# They have same format as during training, but start and end positions are APPROXIMATE, they will be adjusted when constructing BertExample.
targets = []
span_info = []
for idx, c in enumerate(candidates):
    if c[1] == -1:
        continue
    targets.append(str(idx + 1))  # targets are 1-based
    start = c[1]
    end = min(c[1] + c[2], len(letters))  # ensure that end is not outside sentence length (it can happen because c[2] is candidate length used as approximation)
    span_info.append("CUSTOM " + str(start) + " " + str(end))

out.write(" ".join(letters) + "\t" + ";".join([x[0] for x in candidates])  + "\t" + " ".join(targets) + "\t" + ";".join(span_info) + "\n")
out.close()


In [None]:
!cat spellmapper_input.txt

In [None]:
!python nemo/examples/nlp/spellchecking_asr_customization/spellchecking_asr_customization_infer.py \
      pretrained_model=spellmapper_asr_customization_en/training_10m_5ep.nemo \
      model.max_sequence_len=512 \
      inference.from_file=spellmapper_input.txt \
      inference.out_file=spellmapper_output.txt \
      inference.batch_size=16 \
      lang=en


Each line in SpellMapper output is tab-separated and consists of 4 columns:
1. ASR-hypothesis (same as in input)
2. 10 candidates separated with semicolon (same as in input)
3. fragment predictions, separated with semicolon, each prediction is a tuple (start, end, candidate_id, probability)
4. letter predictions - candidate_id predicted for each letter (this is only for debug purposes)

In [None]:
!cat spellmapper_output.txt

We can use some utility functions to apply found replacements and get actual corrected text.

In [None]:
spellmapper_results = read_spellmapper_predictions("spellmapper_output.txt")
text, replacements, _ = spellmapper_results[0]
corrected_text = apply_replacements_to_text(text, replacements, replace_hyphen_to_space=False)
print("Text before correction:\n", text)
print("Text after correction:\n", corrected_text)



# Bigger customization example

Let's test customization on more data. The plan is
   *  Get baseline ASR transcriptions by running TTS + ASR on some medical paper abstracts.
   *  Run SpellMapper inference and show how it can improve ASR results using custom vocabulary.


## Run TTS

In [None]:
# create a folder for wav files (TTS output)
!rm -r audio
!mkdir audio

In [None]:
if torch.cuda.is_available():
  device = "cuda"
else:
  device = "cpu"

# Load FastPitch from HuggingFace
spectrogram_generator = FastPitchModel.from_pretrained("nvidia/tts_en_fastpitch").eval().to(device)
# Load HifiGan vocoder from HuggingFace
vocoder = HifiGanModel.from_pretrained(model_name="nvidia/tts_hifigan").eval().to(device)

# Write sentences that we want to feed to TTS
with open("tts_input.txt", "w", encoding="utf-8") as out:
    for sent, _ in sentences[0:100]:
        out.write(sent + "\n")

out_manifest = open("manifest.json", "w", encoding="utf-8")
i = 0
with open("tts_input.txt", "r", encoding="utf-8") as inp:
    for line in inp:
        text = line.strip()
        text_clean = CHARS_TO_IGNORE_REGEX.sub(" ", text).lower()  #replace all punctuation to space and convert to lowercase
        text_clean = " ".join(text_clean.split())

        parsed = spectrogram_generator.parse(text, normalize=True)

        spectrogram = spectrogram_generator.generate_spectrogram(tokens=parsed)
        audio = vocoder.convert_spectrogram_to_audio(spec=spectrogram)

        # Note that vocoder return a batch of audio. In this example, we just take the first and only sample.
        filename = "audio/" + str(i) + ".wav"
        sf.write(filename, audio.to('cpu').detach().numpy()[0], 16000)
        out_manifest.write(
            "{\"audio_filepath\": \"" + filename + "\", \"text\": \"" + text_clean + "\", \"orig_text\": \"" + text + "\"}\n"
        )
        i += 1

        # display some examples
        if i < 10:
            print(f'"{text}"\n')
            ipd.display(ipd.Audio(audio.to('cpu').detach(), rate=22050))

out_manifest.close()


Now we have a folder with generated audios `audio/*.wav` and a nemo manifest with json records like `{"audio_filepath": "audio/0.wav", "text": "no renal auditory or vestibular toxicity was observed", "orig_text": "No renal, auditory, or vestibular toxicity was observed."}`.

In [None]:
lines = []
with open("manifest.json", "r", encoding="utf-8") as f:
    lines = f.readlines()

for line in lines:
    try:
        data = json.loads(line.strip())
    except:
        print(line)

Free GPU memory to avoid OOM.

In [None]:
del spectrogram_generator
del vocoder
torch.cuda.empty_cache()

## Run baseline ASR

Next we transcribe our .wav files with a general domain [ASR model](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_conformer_ctc_large). It will generate an output file `ctc_baseline_transcript.json` where the predicted transcriptions are stored in the field `pred_text` of each record.

Note that this ASR model was not trained or fine-tuned on medical domain, so we expect it to make mistakes on medical terms.

In [None]:
!python nemo/examples/asr/transcribe_speech.py \
      pretrained_name="stt_en_conformer_ctc_large" \
      dataset_manifest=manifest.json \
      output_filename=ctc_baseline_transcript_tmp.json \
      batch_size=2

ATTENTION: SpellMapper relies on words to be separated by _single_ space

There is a bug with multiple space, observed in ASR results produced by Conformer-CTC, probably connected to this issue: https://github.com/NVIDIA/NeMo/issues/4034.

So we need to correct the manifests to ensure that all spaces are single.

In [None]:
test_data = read_manifest("ctc_baseline_transcript_tmp.json")

for i in range(len(test_data)):
    # if there are multiple spaces in the string they will be merged to one
    test_data[i]["pred_text"] = " ".join(test_data[i]["pred_text"].split())

with open("ctc_baseline_transcript.json", "w", encoding="utf-8") as out:
    for d in test_data:
        line = json.dumps(d)
        out.write(line + "\n")


In [None]:
!head -n 4 ctc_baseline_transcript.json

### Calculating WER of baseline transcript
We use the standard script from NeMo to calculate WER and CER of our baseline transcript. Internally it compares the text in `pred_text` (predicted transcript) to `text` (reference transcript). 

In [None]:
!python nemo/examples/asr/speech_to_text_eval.py \
  dataset_manifest=ctc_baseline_transcript.json \
  only_score_manifest=True


### See fragments that differ
We use SequenceMatcher to see fragments that differ. (Another option is to use a more powerful analytics tool [Speech Data Explorer](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/tools/speech_data_explorer.html))

In [None]:
test_data = read_manifest("ctc_baseline_transcript.json")
pred_text = [data['pred_text'] for data in test_data]
ref_text = [data['text'] for data in test_data]
audio_filepath = [data['audio_filepath'] for data in test_data]

diff_vocab = Counter()

for i in range(len(test_data)):
    ref_sent = " " + ref_text[i] + " "
    pred_sent = " " + pred_text[i] + " "

    pred_words = pred_sent.strip().split()
    ref_words = ref_sent.strip().split()

    for tag, hyp_fragment, ref_fragment, i1, i2, j1, j2 in get_fragments(pred_words, ref_words):
        if tag != "equal":
            diff_vocab[(tag, hyp_fragment, ref_fragment)] += 1

sum_ = 0
print("PRED vs REF")
for k, v in diff_vocab.most_common(1000000):
    sum_ += v
    print(k, v, "sum=", sum_)


## Run SpellMapper

Now we run retrieval on our input manifest and prepare input for SpellMapper inference. Note that we use index of custom vocabulary (file `index.txt` that we saved earlier).

In [None]:
!python nemo/examples/nlp/spellchecking_asr_customization/prepare_input_from_manifest.py \
  --manifest ctc_baseline_transcript.json \
  --custom_vocab_index index.txt \
  --big_sample spellmapper_asr_customization_en/big_sample.txt \
  --short2full_name short2full.txt \
  --output_name spellmapper_input.txt

Run the inference.

In [None]:
!python nemo/examples/nlp/spellchecking_asr_customization/spellchecking_asr_customization_infer.py \
      pretrained_model=spellmapper_asr_customization_en/training_10m_5ep.nemo \
      model.max_sequence_len=512 \
      inference.from_file=spellmapper_input.txt \
      inference.out_file=spellmapper_output.txt \
      inference.batch_size=16 \
      lang=en


Now we postprocess SpellMapper output and create output corrected manifest.

In [None]:
!python nemo/examples/nlp/spellchecking_asr_customization/postprocess_and_update_manifest.py \
  --input_manifest ctc_baseline_transcript.json \
  --short2full_name short2full.txt \
  --output_manifest ctc_corrected_transcript.json \
  --spellmapper_result spellmapper_output.txt \
  --replace_hyphen_to_space \
  --field_name pred_text \
  --ngram_mappings ""


### Calculating WER of corrected transcript.

In [None]:
!python nemo/examples/asr/speech_to_text_eval.py \
  dataset_manifest=ctc_corrected_transcript.json \
  only_score_manifest=True


In [None]:
test_data = read_manifest("ctc_corrected_transcript.json")
pred_text = [data['pred_text'] for data in test_data]
ref_text = [data['pred_text_before_correction'] for data in test_data]

diff_vocab = Counter()

for i in range(len(test_data)):
    ref_sent = " " + ref_text[i] + " "
    pred_sent = " " + pred_text[i] + " "

    pred_words = pred_sent.strip().split()
    ref_words = ref_sent.strip().split()

    for tag, hyp_fragment, ref_fragment, i1, i2, j1, j2 in get_fragments(pred_words, ref_words):
        if tag != "equal":
            diff_vocab[(tag, hyp_fragment, ref_fragment)] += 1

sum_ = 0
print("Corrected vs baseline")
for k, v in diff_vocab.most_common(1000000):
    sum_ += v
    print(k, v, "sum=", sum_)


### Filtering by Dynamic Programming(DP) score

What else can be done?
Given a fragment and its potential replacement, we can apply **dynamic programming** to find the most probable "translation" path between them. We will use the same n-gram mapping vocabulary, because its frequencies give us "translation probability" of each n-gram pair. The final path score can be calculated as maximum sum of log probabilities of matching n-grams along this path.
Let's look at an example. 

In [None]:
joint_vocab, orig_vocab, misspelled_vocab, max_len = load_ngram_mappings_for_dp("spellmapper_asr_customization_en/replacement_vocab_filt.txt")

fragment = "and hydrod"
replacement = "anhydride"
fragment_spaced = " ".join(list(fragment.replace(" ", "_")))
replacement_spaced = " ".join(list(replacement.replace(" ", "_")))
path = get_alignment_by_dp(
    replacement_spaced,
    fragment_spaced,
    dp_data=(joint_vocab, orig_vocab, misspelled_vocab, max_len)
)
print("Dynamic Programming path:")
for fragment_ngram, replacement_ngram, score, sum_score, joint_freq, orig_freq, misspelled_freq in path:
    print(
        "\t",
        "frag=",
        fragment_ngram,
        "; repl=",
        replacement_ngram,
        "; score=",
        score,
        "; sum_score=",
        sum_score,
        "; joint_freq=",
        joint_freq,
        "; orig_freq=",
        orig_freq,
        "; misspelled_freq=",
        misspelled_freq,
    )

print("Final path score is in path[-1][3]: ", path[-1][3])
print("Dynamic programming(DP) score per symbol is final score divided by len(fragment): ", path[-1][3] / (len(fragment)))


The idea is that we can skip replacements whose average DP score per symbol is below some predefined minimum, say -1.5.
Note that dynamic programming works slow because of quadratic complexity, but it allows to get rid of some false positives. Let's apply it on the same test set.

In [None]:
!python nemo/examples/nlp/spellchecking_asr_customization/postprocess_and_update_manifest.py \
  --input_manifest ctc_baseline_transcript.json \
  --short2full_name short2full.txt \
  --output_manifest ctc_corrected_transcript_dp.json \
  --spellmapper_result spellmapper_output.txt \
  --replace_hyphen_to_space \
  --field_name pred_text \
  --use_dp \
  --ngram_mappings spellmapper_asr_customization_en/replacement_vocab_filt.txt \
  --min_dp_score_per_symbol -1.5

In [None]:
!python nemo/examples/asr/speech_to_text_eval.py \
  dataset_manifest=ctc_corrected_transcript_dp.json \
  only_score_manifest=True

# Final notes
1. Bash-script with example of inference pipeline [run_infer.sh](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/spellchecking_asr_customization/run_infer.sh)

2. Check our paper: [SpellMapper: A non-autoregressive neural spellchecker for ASR customization with candidate retrieval based on n-gram mappings](https://arxiv.org/abs/2306.02317)

3. To reproduce evaluation experiments from this paper see these scripts:
 - [test_on_kensho.sh](https://github.com/bene-ges/nemo_compatible/blob/main/scripts/nlp/en_spellmapper/evaluation/test_on_kensho.sh)
 - [test_on_userlibri.sh](https://github.com/bene-ges/nemo_compatible/blob/main/scripts/nlp/en_spellmapper/evaluation/test_on_kensho.sh)
 - [test_on_spoken_wikipedia.sh](https://github.com/bene-ges/nemo_compatible/blob/main/scripts/nlp/en_spellmapper/evaluation/test_on_kensho.sh)

4. To reproduce creation of training data see [README.md](https://github.com/bene-ges/nemo_compatible/blob/main/scripts/nlp/en_spellmapper/README.md)

5. To run training see [run_training.sh](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/spellchecking_asr_customization/run_training.sh)

6. Promising future research directions would be:
  - add a simple trainable classifier on top of SpellMapper predictions instead of using multiple thresholds
  - retrain with adding more various false positives to the training data