Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Audio Norm #2285

Merged
merged 24 commits into from
Jun 8, 2021
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,32 @@ pipeline {
sh 'rm -rf /home/TestData/nlp/text_denorm/output/*'
}
}
stage('L2: TN with Audio (audio and raw text)') {
steps {
sh 'cd nemo_text_processing/text_normalization && \
python normalize_with_audio.py --text "The total amounts to \\$4.76." \
--audio_data /home/TestData/nlp/text_norm/audio_based/audio.wav | tail -n2 | head -n1 > /home/TestData/nlp/text_norm/audio_based/output/out_raw.txt 2>&1 && \
cmp --silent /home/TestData/nlp/text_norm/audio_based/output/out_raw.txt /home/TestData/nlp/text_norm/audio_based/result.txt || exit 1'
sh 'rm -rf /home/TestData/nlp/text_norm/audio_based/output/out_raw.txt'
}
}
stage('L2: TN with Audio (audio and text file)') {
steps {
sh 'cd nemo_text_processing/text_normalization && \
python normalize_with_audio.py --text /home/TestData/nlp/text_norm/audio_based/text.txt \
--audio_data /home/TestData/nlp/text_norm/audio_based/audio.wav | tail -n2 | head -n1 > /home/TestData/nlp/text_norm/audio_based/output/out_file.txt 2>&1 && \
cmp --silent /home/TestData/nlp/text_norm/audio_based/output/out_file.txt /home/TestData/nlp/text_norm/audio_based/result.txt || exit 1'
sh 'rm -rf /home/TestData/nlp/text_norm/audio_based/output/out_file.txt'
}
}
stage('L2: TN with Audio (manifest)') {
steps {
sh 'cd nemo_text_processing/text_normalization && \
python normalize_with_audio.py --audio_data /home/TestData/nlp/text_norm/audio_based/manifest.json --n_tagged=120 && \
cmp --silent /home/TestData/nlp/text_norm/audio_based/manifest_normalized.json /home/TestData/nlp/text_norm/audio_based/manifest_result.json || exit 1'
sh 'rm -rf /home/TestData/nlp/text_norm/audio_based/manifest_normalized.json'
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
k kay
j jay
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
II. the Second
II Second
III. the Third
III Third
IV. the Fourth
IV Fourth
VIII. the Eighth
Hon. honorable
Hon. honourable
St. street
St street
St. saint
St saint
Dr. drive
Dr. doctor
Mr mister
Mrs misses
Hon. Honorable
Hon. Honourable
St. Street
St Street
St. Saint
St Saint
Dr. Drive
Dr. Doctor
Mr. Mister
ekmb marked this conversation as resolved.
Show resolved Hide resolved
Mrs. Misses
Ms. Miss
Mr Mister
Mrs Misses
Ms Miss
139 changes: 87 additions & 52 deletions nemo_text_processing/text_normalization/normalize_with_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import json
import os
import re
import time
from argparse import ArgumentParser
from typing import List, Tuple

Expand All @@ -28,6 +29,7 @@

try:
import pynini
from pynini.lib import rewrite

PYNINI_AVAILABLE = True
except (ModuleNotFoundError, ImportError):
Expand Down Expand Up @@ -76,14 +78,15 @@ def __init__(self, input_case: str):
self.tagger = ClassifyFst(input_case=input_case, deterministic=False)
self.verbalizer = VerbalizeFinalFst(deterministic=False)

def normalize_with_audio(self, text: str, verbose: bool = False) -> str:
def normalize(self, text: str, n_tagged: int, punct_post_process: bool = True, verbose: bool = False) -> str:
"""
Main function. Normalizes tokens from written to spoken form
e.g. 12 kg -> twelve kilograms

Args:
text: string that may include semiotic classes
transcript: transcription of the audio
n_tagged: number of tagged options to consider, -1 - to get all possible tagged options
punct_post_process: whether to normalize punctuation
ekmb marked this conversation as resolved.
Show resolved Hide resolved
verbose: whether to print intermediate meta information

Returns:
Expand All @@ -94,51 +97,65 @@ def normalize_with_audio(self, text: str, verbose: bool = False) -> str:
if verbose:
print(text)
return text
text = pynini.escape(text)

def get_tagged_texts(text):
tagged_lattice = self.find_tags(text)
tagged_texts = self.select_all_semiotic_tags(tagged_lattice)
return tagged_texts
text = pynini.escape(text)
if n_tagged == -1:
tagged_texts = rewrite.rewrites(text, self.tagger.fst)
else:
tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged)

tagged_texts = set(get_tagged_texts(text))
normalized_texts = []

for tagged_text in tagged_texts:
self.parser(tagged_text)
tokens = self.parser.parse()
tags_reordered = self.generate_permutations(tokens)
for tagged_text_reordered in tags_reordered:
tagged_text_reordered = pynini.escape(tagged_text_reordered)

verbalizer_lattice = self.find_verbalizer(tagged_text_reordered)
if verbalizer_lattice.num_states() == 0:
continue

verbalized = self.get_all_verbalizers(verbalizer_lattice)
for verbalized_option in verbalized:
normalized_texts.append(verbalized_option)
self._verbalize(tagged_text, normalized_texts)

if len(normalized_texts) == 0:
raise ValueError()

normalized_texts = [post_process(t) for t in normalized_texts]
if punct_post_process:
normalized_texts = [post_process(t) for t in normalized_texts]
normalized_texts = set(normalized_texts)
return normalized_texts

def select_best_match(self, normalized_texts: List[str], transcript: str, verbose: bool = False):
def _verbalize(self, tagged_text: str, normalized_texts: List[str]):
"""
Verbalizes tagged text

Args:
tagged_text: text with tags
normalized_texts: list of possible normalization options
"""

def get_verbalized_text(tagged_text):
tagged_text = pynini.escape(tagged_text)
return rewrite.rewrites(tagged_text, self.verbalizer.fst)

try:
normalized_texts.extend(get_verbalized_text(tagged_text))
except pynini.lib.rewrite.Error:
self.parser(tagged_text)
tokens = self.parser.parse()
tags_reordered = self.generate_permutations(tokens)
for tagged_text_reordered in tags_reordered:
try:
normalized_texts.extend(get_verbalized_text(tagged_text_reordered))
except pynini.lib.rewrite.Error:
continue

def select_best_match(
self, normalized_texts: List[str], transcript: str, verbose: bool = False, remove_punct: bool = False
):
"""
Selects the best normalization option based on the lowest CER

Args:
normalized_texts: normalized text options
transcript: ASR model transcript of the audio file corresponding to the normalized text
verbose: whether to print intermediate meta information
remove_punct: whether to remove punctuation before calculating CER

Returns:
normalized text with the lowest CER and CER value
"""
normalized_texts = calculate_cer(normalized_texts, transcript)
normalized_texts = calculate_cer(normalized_texts, transcript, remove_punct)
normalized_texts = sorted(normalized_texts, key=lambda x: x[1])
normalized_text, cer = normalized_texts[0]

Expand All @@ -149,16 +166,6 @@ def select_best_match(self, normalized_texts: List[str], transcript: str, verbos
print('-' * 30)
return normalized_text, cer

def select_all_semiotic_tags(self, lattice: 'pynini.FstLike', n=100) -> List[str]:
tagged_text_options = pynini.shortestpath(lattice, nshortest=n)
tagged_text_options = [t[1] for t in tagged_text_options.paths("utf8").items()]
return tagged_text_options

def get_all_verbalizers(self, lattice: 'pynini.FstLike', n=100) -> List[str]:
verbalized_options = pynini.shortestpath(lattice, nshortest=n)
verbalized_options = [t[1] for t in verbalized_options.paths("utf8").items()]
return verbalized_options


def calculate_cer(normalized_texts: List[str], transcript: str, remove_punct=False) -> List[Tuple[str, float]]:
"""
Expand All @@ -172,7 +179,7 @@ def calculate_cer(normalized_texts: List[str], transcript: str, remove_punct=Fal
"""
normalized_options = []
for text in normalized_texts:
text_clean = text.replace('-', ' ').lower().strip()
text_clean = text.replace('-', ' ').lower()
if remove_punct:
for punct in "!?:;,.-()*+-/<=>@^_":
text_clean = text_clean.replace(punct, " ")
Expand Down Expand Up @@ -215,8 +222,7 @@ def post_process(text: str, punctuation='!,.:;?') -> str:
Returns: text with normalized spaces and quotes
"""
text = (
text.replace('--', '-')
.replace('( ', '(')
text.replace('( ', '(')
.replace(' )', ')')
.replace(' ', ' ')
.replace('”', '"')
Expand Down Expand Up @@ -265,17 +271,27 @@ def parse_args():
parser.add_argument(
'--model', type=str, default='QuartzNet15x5Base-En', help='Pre-trained model name or path to model checkpoint'
)
parser.add_argument("--verbose", help="print info for debugging", action='store_true')
parser.add_argument(
"--n_tagged",
type=int,
default=300,
help="number of tagged options to consider, -1 - return all possible tagged options",
)
parser.add_argument("--verbose", help="print info for debugging", action="store_true")
parser.add_argument("--remove_punct", help="remove punctuation before calculating cer", action="store_true")
parser.add_argument(
"--no_punct_post_process", help="set to True to disable punctuation post processing", action="store_true"
)
return parser.parse_args()


def normalize_manifest(args):
"""
Args:
manifest: path to .json manifest file.
args.audio_data: path to .json manifest file.
"""
normalizer = NormalizerWithAudio(input_case=args.input_case)
manifest_out = args.audio_data.replace('.json', '_nemo_wfst.json')
manifest_out = args.audio_data.replace('.json', '_normalized.json')
asr_model = None
with open(args.audio_data, 'r') as f:
with open(manifest_out, 'w') as f_out:
Expand All @@ -288,30 +304,48 @@ def normalize_manifest(args):
if asr_model is None:
asr_model = get_asr_model(args.model)
transcript = asr_model.transcribe([audio])[0]
normalized_texts = normalizer.normalize_with_audio(line['text'], args.verbose)
normalized_text, cer = normalizer.select_best_match(normalized_texts, transcript, args.verbose)

line['nemo_wfst'] = normalized_text
line['CER_nemo_wfst'] = cer
normalized_texts = normalizer.normalize(
text=line['text'],
verbose=args.verbose,
n_tagged=args.n_tagged,
punct_post_process=not args.no_punct_post_process,
)
normalized_text, cer = normalizer.select_best_match(
normalized_texts, transcript, args.verbose, args.remove_punct
)
line['nemo_normalized'] = normalized_text
line['CER_nemo_normalized'] = cer
f_out.write(json.dumps(line, ensure_ascii=False) + '\n')
print(f'Normalized version saved at {manifest_out}')


if __name__ == "__main__":
args = parse_args()

start = time.time()
if args.text:
normalizer = NormalizerWithAudio(input_case=args.input_case)
if os.path.exists(args.text):
with open(args.text, 'r') as f:
args.text = f.read()
normalized_texts = normalizer.normalize_with_audio(args.text, args.verbose)
for norm_text in normalized_texts:
print(norm_text)
args.text = f.read().strip()
normalized_texts = normalizer.normalize(
text=args.text,
verbose=args.verbose,
n_tagged=args.n_tagged,
punct_post_process=not args.no_punct_post_process,
)
if args.audio_data:
asr_model = get_asr_model(args.model)
transcript = asr_model.transcribe([args.audio_data])[0]
normalized_text, cer = normalizer.select_best_match(normalized_texts, transcript, args.verbose)
normalized_text, cer = normalizer.select_best_match(
normalized_texts, transcript, args.verbose, args.remove_punct
)
print(f'Transcript: {transcript}')
print(f'Normalized: {normalized_text}')
else:
print('Normalization options:')
for norm_text in normalized_texts:
print(norm_text)
elif not os.path.exists(args.audio_data):
raise ValueError(f'{args.audio_data} not found.')
elif args.audio_data.endswith('.json'):
Expand All @@ -322,3 +356,4 @@ def normalize_manifest(args):
+ "'--audio_data' path to audio file and '--text' path to a text file OR"
"'--text' string text (for debugging without audio)"
)
print(f'Execution time: {round((time.time() - start)/60, 2)} min.')
Loading