Skip to content

Commit

Permalink
Audio Norm (#2285)
Browse files Browse the repository at this point in the history
* add jenkins test, refactoring

Signed-off-by: ekmb <ebakhturina@nvidia.com>

* update test

Signed-off-by: ekmb <ebakhturina@nvidia.com>

* fix new test

Signed-off-by: ekmb <ebakhturina@nvidia.com>

* add serial to the default normalizer, add tests

Signed-off-by: ekmb <ebakhturina@nvidia.com>

* manifest test added

Signed-off-by: ekmb <ebakhturina@nvidia.com>

* expose more params, new test cases

Signed-off-by: ekmb <ebakhturina@nvidia.com>

* fix jenkins, serial clean, exclude range from cardinal

Signed-off-by: ekmb <ebakhturina@nvidia.com>

* jenkins

Signed-off-by: ekmb <ebakhturina@nvidia.com>

* jenkins dollar sign format

Signed-off-by: ekmb <ebakhturina@nvidia.com>

* jenkins

Signed-off-by: ekmb <ebakhturina@nvidia.com>

* jenkins dollar sign format

Signed-off-by: ekmb <ebakhturina@nvidia.com>

* addressed review comments

Signed-off-by: ekmb <ebakhturina@nvidia.com>

* fix decimal in measure

Signed-off-by: ekmb <ebakhturina@nvidia.com>

* move serial in cardinal

Signed-off-by: ekmb <ebakhturina@nvidia.com>

* clean up

Signed-off-by: ekmb <ebakhturina@nvidia.com>

* update for SH zero -> oh

Signed-off-by: ekmb <ebakhturina@nvidia.com>

* change n_tagger default

Signed-off-by: ekmb <ebakhturina@nvidia.com>
Signed-off-by: mchrzanowski <mchrzanowski@nvidia.com>
  • Loading branch information
ekmb authored and mchrzanowski committed Jun 23, 2021
1 parent 020d925 commit 5ff8182
Show file tree
Hide file tree
Showing 31 changed files with 458 additions and 249 deletions.
26 changes: 26 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,32 @@ pipeline {
sh 'rm -rf /home/TestData/nlp/text_denorm/output/*'
}
}
stage('L2: TN with Audio (audio and raw text)') {
steps {
sh 'cd nemo_text_processing/text_normalization && \
python normalize_with_audio.py --text "The total amounts to \\$4.76." \
--audio_data /home/TestData/nlp/text_norm/audio_based/audio.wav | tail -n2 | head -n1 > /home/TestData/nlp/text_norm/audio_based/output/out_raw.txt 2>&1 && \
cmp --silent /home/TestData/nlp/text_norm/audio_based/output/out_raw.txt /home/TestData/nlp/text_norm/audio_based/result.txt || exit 1'
sh 'rm -rf /home/TestData/nlp/text_norm/audio_based/output/out_raw.txt'
}
}
stage('L2: TN with Audio (audio and text file)') {
steps {
sh 'cd nemo_text_processing/text_normalization && \
python normalize_with_audio.py --text /home/TestData/nlp/text_norm/audio_based/text.txt \
--audio_data /home/TestData/nlp/text_norm/audio_based/audio.wav | tail -n2 | head -n1 > /home/TestData/nlp/text_norm/audio_based/output/out_file.txt 2>&1 && \
cmp --silent /home/TestData/nlp/text_norm/audio_based/output/out_file.txt /home/TestData/nlp/text_norm/audio_based/result.txt || exit 1'
sh 'rm -rf /home/TestData/nlp/text_norm/audio_based/output/out_file.txt'
}
}
stage('L2: TN with Audio (manifest)') {
steps {
sh 'cd nemo_text_processing/text_normalization && \
python normalize_with_audio.py --audio_data /home/TestData/nlp/text_norm/audio_based/manifest.json --n_tagged=120 && \
cmp --silent /home/TestData/nlp/text_norm/audio_based/manifest_normalized.json /home/TestData/nlp/text_norm/audio_based/manifest_result.json || exit 1'
sh 'rm -rf /home/TestData/nlp/text_norm/audio_based/manifest_normalized.json'
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
k kay
j jay
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
II. the Second
II Second
III. the Third
III Third
IV. the Fourth
IV Fourth
VIII. the Eighth
Hon. honorable
Hon. honourable
St. street
St street
St. saint
St saint
Dr. drive
Dr. doctor
Mr mister
Mrs misses
Hon. Honorable
Hon. Honourable
St. Street
St Street
St. Saint
St Saint
Dr. Drive
Dr. Doctor
Mr. Mister
Mrs. Misses
Ms. Miss
Mr Mister
Mrs Misses
Ms Miss
139 changes: 87 additions & 52 deletions nemo_text_processing/text_normalization/normalize_with_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import json
import os
import re
import time
from argparse import ArgumentParser
from typing import List, Tuple

Expand All @@ -28,6 +29,7 @@

try:
import pynini
from pynini.lib import rewrite

PYNINI_AVAILABLE = True
except (ModuleNotFoundError, ImportError):
Expand Down Expand Up @@ -76,14 +78,15 @@ def __init__(self, input_case: str):
self.tagger = ClassifyFst(input_case=input_case, deterministic=False)
self.verbalizer = VerbalizeFinalFst(deterministic=False)

def normalize_with_audio(self, text: str, verbose: bool = False) -> str:
def normalize(self, text: str, n_tagged: int, punct_post_process: bool = True, verbose: bool = False) -> str:
"""
Main function. Normalizes tokens from written to spoken form
e.g. 12 kg -> twelve kilograms
Args:
text: string that may include semiotic classes
transcript: transcription of the audio
n_tagged: number of tagged options to consider, -1 - to get all possible tagged options
punct_post_process: whether to normalize punctuation
verbose: whether to print intermediate meta information
Returns:
Expand All @@ -94,51 +97,65 @@ def normalize_with_audio(self, text: str, verbose: bool = False) -> str:
if verbose:
print(text)
return text
text = pynini.escape(text)

def get_tagged_texts(text):
tagged_lattice = self.find_tags(text)
tagged_texts = self.select_all_semiotic_tags(tagged_lattice)
return tagged_texts
text = pynini.escape(text)
if n_tagged == -1:
tagged_texts = rewrite.rewrites(text, self.tagger.fst)
else:
tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged)

tagged_texts = set(get_tagged_texts(text))
normalized_texts = []

for tagged_text in tagged_texts:
self.parser(tagged_text)
tokens = self.parser.parse()
tags_reordered = self.generate_permutations(tokens)
for tagged_text_reordered in tags_reordered:
tagged_text_reordered = pynini.escape(tagged_text_reordered)

verbalizer_lattice = self.find_verbalizer(tagged_text_reordered)
if verbalizer_lattice.num_states() == 0:
continue

verbalized = self.get_all_verbalizers(verbalizer_lattice)
for verbalized_option in verbalized:
normalized_texts.append(verbalized_option)
self._verbalize(tagged_text, normalized_texts)

if len(normalized_texts) == 0:
raise ValueError()

normalized_texts = [post_process(t) for t in normalized_texts]
if punct_post_process:
normalized_texts = [post_process(t) for t in normalized_texts]
normalized_texts = set(normalized_texts)
return normalized_texts

def select_best_match(self, normalized_texts: List[str], transcript: str, verbose: bool = False):
def _verbalize(self, tagged_text: str, normalized_texts: List[str]):
"""
Verbalizes tagged text
Args:
tagged_text: text with tags
normalized_texts: list of possible normalization options
"""

def get_verbalized_text(tagged_text):
tagged_text = pynini.escape(tagged_text)
return rewrite.rewrites(tagged_text, self.verbalizer.fst)

try:
normalized_texts.extend(get_verbalized_text(tagged_text))
except pynini.lib.rewrite.Error:
self.parser(tagged_text)
tokens = self.parser.parse()
tags_reordered = self.generate_permutations(tokens)
for tagged_text_reordered in tags_reordered:
try:
normalized_texts.extend(get_verbalized_text(tagged_text_reordered))
except pynini.lib.rewrite.Error:
continue

def select_best_match(
self, normalized_texts: List[str], transcript: str, verbose: bool = False, remove_punct: bool = False
):
"""
Selects the best normalization option based on the lowest CER
Args:
normalized_texts: normalized text options
transcript: ASR model transcript of the audio file corresponding to the normalized text
verbose: whether to print intermediate meta information
remove_punct: whether to remove punctuation before calculating CER
Returns:
normalized text with the lowest CER and CER value
"""
normalized_texts = calculate_cer(normalized_texts, transcript)
normalized_texts = calculate_cer(normalized_texts, transcript, remove_punct)
normalized_texts = sorted(normalized_texts, key=lambda x: x[1])
normalized_text, cer = normalized_texts[0]

Expand All @@ -149,16 +166,6 @@ def select_best_match(self, normalized_texts: List[str], transcript: str, verbos
print('-' * 30)
return normalized_text, cer

def select_all_semiotic_tags(self, lattice: 'pynini.FstLike', n=100) -> List[str]:
tagged_text_options = pynini.shortestpath(lattice, nshortest=n)
tagged_text_options = [t[1] for t in tagged_text_options.paths("utf8").items()]
return tagged_text_options

def get_all_verbalizers(self, lattice: 'pynini.FstLike', n=100) -> List[str]:
verbalized_options = pynini.shortestpath(lattice, nshortest=n)
verbalized_options = [t[1] for t in verbalized_options.paths("utf8").items()]
return verbalized_options


def calculate_cer(normalized_texts: List[str], transcript: str, remove_punct=False) -> List[Tuple[str, float]]:
"""
Expand All @@ -172,7 +179,7 @@ def calculate_cer(normalized_texts: List[str], transcript: str, remove_punct=Fal
"""
normalized_options = []
for text in normalized_texts:
text_clean = text.replace('-', ' ').lower().strip()
text_clean = text.replace('-', ' ').lower()
if remove_punct:
for punct in "!?:;,.-()*+-/<=>@^_":
text_clean = text_clean.replace(punct, " ")
Expand Down Expand Up @@ -215,8 +222,7 @@ def post_process(text: str, punctuation='!,.:;?') -> str:
Returns: text with normalized spaces and quotes
"""
text = (
text.replace('--', '-')
.replace('( ', '(')
text.replace('( ', '(')
.replace(' )', ')')
.replace(' ', ' ')
.replace('”', '"')
Expand Down Expand Up @@ -265,17 +271,27 @@ def parse_args():
parser.add_argument(
'--model', type=str, default='QuartzNet15x5Base-En', help='Pre-trained model name or path to model checkpoint'
)
parser.add_argument("--verbose", help="print info for debugging", action='store_true')
parser.add_argument(
"--n_tagged",
type=int,
default=1000,
help="number of tagged options to consider, -1 - return all possible tagged options",
)
parser.add_argument("--verbose", help="print info for debugging", action="store_true")
parser.add_argument("--remove_punct", help="remove punctuation before calculating cer", action="store_true")
parser.add_argument(
"--no_punct_post_process", help="set to True to disable punctuation post processing", action="store_true"
)
return parser.parse_args()


def normalize_manifest(args):
"""
Args:
manifest: path to .json manifest file.
args.audio_data: path to .json manifest file.
"""
normalizer = NormalizerWithAudio(input_case=args.input_case)
manifest_out = args.audio_data.replace('.json', '_nemo_wfst.json')
manifest_out = args.audio_data.replace('.json', '_normalized.json')
asr_model = None
with open(args.audio_data, 'r') as f:
with open(manifest_out, 'w') as f_out:
Expand All @@ -288,30 +304,48 @@ def normalize_manifest(args):
if asr_model is None:
asr_model = get_asr_model(args.model)
transcript = asr_model.transcribe([audio])[0]
normalized_texts = normalizer.normalize_with_audio(line['text'], args.verbose)
normalized_text, cer = normalizer.select_best_match(normalized_texts, transcript, args.verbose)

line['nemo_wfst'] = normalized_text
line['CER_nemo_wfst'] = cer
normalized_texts = normalizer.normalize(
text=line['text'],
verbose=args.verbose,
n_tagged=args.n_tagged,
punct_post_process=not args.no_punct_post_process,
)
normalized_text, cer = normalizer.select_best_match(
normalized_texts, transcript, args.verbose, args.remove_punct
)
line['nemo_normalized'] = normalized_text
line['CER_nemo_normalized'] = cer
f_out.write(json.dumps(line, ensure_ascii=False) + '\n')
print(f'Normalized version saved at {manifest_out}')


if __name__ == "__main__":
args = parse_args()

start = time.time()
if args.text:
normalizer = NormalizerWithAudio(input_case=args.input_case)
if os.path.exists(args.text):
with open(args.text, 'r') as f:
args.text = f.read()
normalized_texts = normalizer.normalize_with_audio(args.text, args.verbose)
for norm_text in normalized_texts:
print(norm_text)
args.text = f.read().strip()
normalized_texts = normalizer.normalize(
text=args.text,
verbose=args.verbose,
n_tagged=args.n_tagged,
punct_post_process=not args.no_punct_post_process,
)
if args.audio_data:
asr_model = get_asr_model(args.model)
transcript = asr_model.transcribe([args.audio_data])[0]
normalized_text, cer = normalizer.select_best_match(normalized_texts, transcript, args.verbose)
normalized_text, cer = normalizer.select_best_match(
normalized_texts, transcript, args.verbose, args.remove_punct
)
print(f'Transcript: {transcript}')
print(f'Normalized: {normalized_text}')
else:
print('Normalization options:')
for norm_text in normalized_texts:
print(norm_text)
elif not os.path.exists(args.audio_data):
raise ValueError(f'{args.audio_data} not found.')
elif args.audio_data.endswith('.json'):
Expand All @@ -322,3 +356,4 @@ def normalize_manifest(args):
+ "'--audio_data' path to audio file and '--text' path to a text file OR"
"'--text' string text (for debugging without audio)"
)
print(f'Execution time: {round((time.time() - start)/60, 2)} min.')
Loading

0 comments on commit 5ff8182

Please sign in to comment.