Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch transcribe_util for steaming mode and add wer calculation back to inference scripts #6601

Merged
merged 14 commits into from
May 9, 2023
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
total_buffer_in_secs=4.0 \
chunk_len_in_secs=1.6 \
model_stride=4 \
batch_size=32
batch_size=32 \
clean_groundtruth_text=True \
langid='en'

# NOTE:
You can use `DEBUG=1 python speech_to_text_buffered_infer_ctc.py ...` to print out the
Expand All @@ -45,6 +47,8 @@
import torch
from omegaconf import OmegaConf

from nemo.collections.asr.metrics.wer import CTCDecodingConfig
from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer
from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchASR
from nemo.collections.asr.parts.utils.transcribe_utils import (
compute_output_filename,
Expand Down Expand Up @@ -79,6 +83,9 @@ class TranscriptionConfig:
total_buffer_in_secs: float = 4.0 # Length of buffer (chunk + left and right padding) in seconds
model_stride: int = 8 # Model downsampling factor, 8 for Citrinet models and 4 for Conformer models",

# Decoding strategy for CTC models
decoding: CTCDecodingConfig = CTCDecodingConfig()

# Set `cuda` to int to define CUDA device. If 'None', will look for CUDA
# device anyway, and do inference on CPU only if CUDA device is not found.
# If `cuda` is a negative number, inference will be on CPU only.
Expand All @@ -89,6 +96,12 @@ class TranscriptionConfig:
# Recompute model transcription, even if the output folder exists with scores.
overwrite_transcripts: bool = True

# Config for word / character error rate calculation
calculate_wer: bool = True
clean_groundtruth_text: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you please add some comments on the top for users how and when to use these two arguments: clean_groundtruth_text and langid?

langid: str = "en" # specify this for convert_num_to_words step in groundtruth cleaning
use_cer: bool = False


@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
Expand Down Expand Up @@ -188,11 +201,23 @@ def autocast():
manifest,
filepaths,
)
output_filename = write_transcription(
output_filename, pred_text_attr_name = write_transcription(
hyps, cfg, model_name, filepaths=filepaths, compute_langs=False, compute_timestamps=False
)
logging.info(f"Finished writing predictions to {output_filename}!")

if cfg.calculate_wer:
output_manifest_w_wer, total_res, _ = cal_write_wer(
pred_manifest=output_filename,
pred_text_attr_name=pred_text_attr_name,
clean_groundtruth_text=cfg.clean_groundtruth_text,
langid=cfg.langid,
use_cer=cfg.use_cer,
output_filename=None,
)
logging.info(f"Writing prediction and error rate of each sample to {output_manifest_w_wer}!")
logging.info(f"{total_res}")

return cfg


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
total_buffer_in_secs=4.0 \
chunk_len_in_secs=1.6 \
model_stride=4 \
batch_size=32
batch_size=32 \
clean_groundtruth_text=True \
langid='en'

# Longer Common Subsequence (LCS) Merge algorithm

Expand Down Expand Up @@ -66,6 +68,7 @@
import torch
from omegaconf import OmegaConf, open_dict

from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer
from nemo.collections.asr.parts.utils.streaming_utils import (
BatchedFrameASRRNNT,
LongestCommonSubsequenceBatchedFrameASRRNNT,
Expand Down Expand Up @@ -101,7 +104,7 @@ class TranscriptionConfig:
# Chunked configs
chunk_len_in_secs: float = 1.6 # Chunk length in seconds
total_buffer_in_secs: float = 4.0 # Length of buffer (chunk + left and right padding) in seconds
model_stride: int = 8 # Model downsampling factor, 8 for Citrinet models and 4 for Conformer models",
model_stride: int = 8 # Model downsampling factor, 8 for Citrinet models and 4 for Conformer models

# Set `cuda` to int to define CUDA device. If 'None', will look for CUDA
# device anyway, and do inference on CPU only if CUDA device is not found.
Expand All @@ -120,6 +123,12 @@ class TranscriptionConfig:
merge_algo: Optional[str] = 'middle' # choices=['middle', 'lcs'], choice of algorithm to apply during inference.
lcs_alignment_dir: Optional[str] = None # Path to a directory to store LCS algo alignments

# Config for word / character error rate calculation
calculate_wer: bool = True
clean_groundtruth_text: bool = False
langid: str = "en" # specify this for convert_num_to_words step in groundtruth cleaning
use_cer: bool = False


@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
Expand Down Expand Up @@ -194,9 +203,13 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
decoding_cfg.strategy = "greedy_batch"
decoding_cfg.preserve_alignments = True # required to compute the middle token for transducers.
decoding_cfg.fused_batch_size = -1 # temporarily stop fused batch during inference.
decoding_cfg.beam.return_best_hypothesis = True

asr_model.change_decoding_strategy(decoding_cfg)

with open_dict(cfg):
cfg.decoding = decoding_cfg

feature_stride = model_cfg.preprocessor['window_stride']
model_stride_in_secs = feature_stride * cfg.model_stride
total_buffer = cfg.total_buffer_in_secs
Expand Down Expand Up @@ -242,11 +255,23 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
filepaths=filepaths,
)

output_filename = write_transcription(
output_filename, pred_text_attr_name = write_transcription(
hyps, cfg, model_name, filepaths=filepaths, compute_langs=False, compute_timestamps=False
)
logging.info(f"Finished writing predictions to {output_filename}!")

if cfg.calculate_wer:
output_manifest_w_wer, total_res, _ = cal_write_wer(
pred_manifest=output_filename,
pred_text_attr_name=pred_text_attr_name,
clean_groundtruth_text=cfg.clean_groundtruth_text,
langid=cfg.langid,
use_cer=cfg.use_cer,
output_filename=None,
)
logging.info(f"Writing prediction and error rate of each sample to {output_manifest_w_wer}!")
logging.info(f"{total_res}")

return cfg


Expand Down
30 changes: 29 additions & 1 deletion examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from nemo.collections.asr.metrics.wer import CTCDecodingConfig
from nemo.collections.asr.models import EncDecCTCModel, EncDecHybridRNNTCTCModel
from nemo.collections.asr.modules.conformer_encoder import ConformerChangeConfig
from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer
from nemo.collections.asr.parts.utils.transcribe_utils import (
compute_output_filename,
prepare_audio_data,
Expand Down Expand Up @@ -69,6 +70,11 @@
ctc_decoding: Decoding sub-config for CTC. Refer to documentation for specific values.
rnnt_decoding: Decoding sub-config for RNNT. Refer to documentation for specific values.

calculate_wer: Bool to decide whether to calculate wer/cer at end of this script
clean_groundtruth_text: Bool to clean groundtruth text
langid: Str used for convert_num_to_words during groundtruth cleaning
use_cer: Bool to use Character Error Rate (CER) or Word Error Rate (WER)

# Usage
ASR model can be specified by either "model_path" or "pretrained_name".
Data for transcription can be defined with either "audio_dir" or "dataset_manifest".
Expand All @@ -82,6 +88,8 @@
audio_dir="<remove or path to folder of audio files>" \
dataset_manifest="<remove or path to manifest>" \
output_filename="<remove or specify output filename>" \
clean_groundtruth_text=True \
langid='en' \
batch_size=32 \
compute_timestamps=False \
compute_langs=False \
Expand Down Expand Up @@ -149,6 +157,12 @@ class TranscriptionConfig:
# Use this for model-specific changes before transcription
model_change: ModelChangeConfig = ModelChangeConfig()

# Config for word / character error rate calculation
calculate_wer: bool = True
clean_groundtruth_text: bool = False
langid: str = "en" # specify this for convert_num_to_words step in groundtruth cleaning
use_cer: bool = False


@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
Expand Down Expand Up @@ -254,6 +268,8 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
else:
cfg.decoding = cfg.rnnt_decoding

cfg.decoding.beam.return_best_hypothesis = True

# prepare audio filepaths and decide wether it's partical audio
filepaths, partial_audio = prepare_audio_data(cfg)

Expand Down Expand Up @@ -322,7 +338,7 @@ def autocast():
transcriptions = transcriptions[0]

# write audio transcriptions
output_filename = write_transcription(
output_filename, pred_text_attr_name = write_transcription(
transcriptions,
cfg,
model_name,
Expand All @@ -332,6 +348,18 @@ def autocast():
)
logging.info(f"Finished writing predictions to {output_filename}!")

if cfg.calculate_wer:
output_manifest_w_wer, total_res, _ = cal_write_wer(
pred_manifest=output_filename,
pred_text_attr_name=pred_text_attr_name,
clean_groundtruth_text=cfg.clean_groundtruth_text,
langid=cfg.langid,
use_cer=cfg.use_cer,
output_filename=None,
)
logging.info(f"Writing prediction and error rate of each sample to {output_manifest_w_wer}!")
logging.info(f"{total_res}")

return cfg


Expand Down
152 changes: 152 additions & 0 deletions nemo/collections/asr/parts/utils/eval_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from typing import Tuple

from nemo.collections.asr.metrics.wer import word_error_rate_detail
from nemo.utils import logging


def clean_label(_str: str, num_to_words: bool = True, langid="en") -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This all will be replaced by the models normalizer. Such global regex for language normalization Is just error prone, so it will have to be updated to support the new decoding strategy method

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha. I can update this part once the normalizer is ready.

"""
Remove unauthorized characters in a string, lower it and remove unneeded spaces
"""
replace_with_space = [char for char in '/?*\",.:=?_{|}~¨«·»¡¿„…‧‹›≪≫!:;ː→']
replace_with_blank = [char for char in '`¨´‘’“”`ʻ‘’“"‘”']
replace_with_apos = [char for char in '‘’ʻ‘’‘']
_str = _str.strip()
_str = _str.lower()
for i in replace_with_blank:
_str = _str.replace(i, "")
for i in replace_with_space:
_str = _str.replace(i, " ")
for i in replace_with_apos:
_str = _str.replace(i, "'")
if num_to_words:
if langid == "en":
_str = convert_num_to_words(_str, langid="en")
else:
logging.info(
"Currently support basic num_to_words in English only. Please use Text Normalization to convert other languages! Skipping!"
)

ret = " ".join(_str.split())
return ret


def convert_num_to_words(_str: str, langid: str = "en") -> str:
"""
Convert digits to corresponding words. Note this is a naive approach and could be replaced with text normalization.
"""
if langid == "en":
num_to_words = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
_str = _str.strip()
words = _str.split()
out_str = ""
num_word = []
for word in words:
if word.isdigit():
num = int(word)
while num:
digit = num % 10
digit_word = num_to_words[digit]
num_word.append(digit_word)
num = int(num / 10)
if not (num):
num_str = ""
num_word = num_word[::-1]
for ele in num_word:
num_str += ele + " "
out_str += num_str + " "
num_word.clear()
else:
out_str += word + " "
out_str = out_str.strip()
else:
raise ValueError(
"Currently support basic num_to_words in English only. Please use Text Normalization to convert other languages!"
)
return out_str


def cal_write_wer(
pred_manifest: str = None,
pred_text_attr_name: str = "pred_text",
clean_groundtruth_text: bool = False,
langid: str = 'en',
use_cer: bool = False,
output_filename: str = None,
) -> Tuple[str, dict]:
"""
Calculate wer, inserion, deletion and substitution rate based on groundtruth text and pred_text_attr_name (pred_text)
We use WER in function name as a convention, but Error Rate (ER) currently support Word Error Rate (WER) and Character Error Rate (CER)
"""
samples = []
hyps = []
refs = []
eval_metric = "cer" if use_cer else "wer"

with open(pred_manifest, 'r') as fp:
for line in fp:
sample = json.loads(line)

if 'text' not in sample:
raise ValueError(
"ground-truth text is not present in manifest! Cannot calculate Word Error Rate. Exiting!"
)

hyp = sample[pred_text_attr_name]
ref = sample['text']

if clean_groundtruth_text:
ref = clean_label(ref, langid=langid)

wer, tokens, ins_rate, del_rate, sub_rate = word_error_rate_detail(
hypotheses=[hyp], references=[ref], use_cer=use_cer
)
sample[eval_metric] = wer # evaluatin metric, could be word error rate of character error rate
sample['tokens'] = tokens # number of word/characters/tokens
sample['ins_rate'] = ins_rate # insertion error rate
sample['del_rate'] = del_rate # deletion error rate
sample['sub_rate'] = sub_rate # substitution error rate

samples.append(sample)
hyps.append(hyp)
refs.append(ref)

total_wer, total_tokens, total_ins_rate, total_del_rate, total_sub_rate = word_error_rate_detail(
hypotheses=hyps, references=refs, use_cer=use_cer
)

if not output_filename:
output_manifest_w_wer = pred_manifest
else:
output_manifest_w_wer = output_filename

with open(output_manifest_w_wer, 'w') as fout:
for sample in samples:
json.dump(sample, fout)
fout.write('\n')
fout.flush()

total_res = {
"samples": len(samples),
"tokens": total_tokens,
eval_metric: total_wer,
"ins_rate": total_ins_rate,
"del_rate": total_del_rate,
"sub_rate": total_sub_rate,
}
return output_manifest_w_wer, total_res, eval_metric