diff --git a/examples/asr/speech_classification/vad_infer.py b/examples/asr/speech_classification/vad_infer.py index 992096b6a6e1..8ab040b34c79 100644 --- a/examples/asr/speech_classification/vad_infer.py +++ b/examples/asr/speech_classification/vad_infer.py @@ -123,7 +123,7 @@ def main(cfg): f"Finish generating VAD frame level prediction with window_length_in_sec={cfg.vad.parameters.window_length_in_sec} and shift_length_in_sec={cfg.vad.parameters.shift_length_in_sec}" ) frame_length_in_sec = cfg.vad.parameters.shift_length_in_sec - + # overlap smoothing filter if cfg.vad.parameters.smoothing: # Generate predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments. diff --git a/nemo/collections/asr/parts/utils/vad_utils.py b/nemo/collections/asr/parts/utils/vad_utils.py index 021e36744eda..074286a0805e 100644 --- a/nemo/collections/asr/parts/utils/vad_utils.py +++ b/nemo/collections/asr/parts/utils/vad_utils.py @@ -637,7 +637,7 @@ def generate_vad_segment_table_per_tensor(sequence: torch.Tensor, per_args: Dict speech_segments, _ = torch.sort(speech_segments, 0) - dur = speech_segments[:, 1:2] - speech_segments[:, 0:1] + 0.01 # 10ms frame unit + dur = speech_segments[:, 1:2] - speech_segments[:, 0:1] + 0.01 # 10ms frame unit speech_segments = torch.column_stack((speech_segments, dur)) return speech_segments diff --git a/scripts/voice_activity_detection/vad_tune_threshold.py b/scripts/voice_activity_detection/vad_tune_threshold.py index a3b871e887f8..2f41677fe3ae 100644 --- a/scripts/voice_activity_detection/vad_tune_threshold.py +++ b/scripts/voice_activity_detection/vad_tune_threshold.py @@ -82,10 +82,7 @@ default='DetER', ) parser.add_argument( - "--frame_length_in_sec", - help="frame_length_in_sec ", - type=float, - default=0.01, + "--frame_length_in_sec", help="frame_length_in_sec ", type=float, default=0.01, ) args = parser.parse_args() @@ -131,7 +128,13 @@ ) best_threhsold, optimal_scores = vad_tune_threshold_on_dev( - params, args.vad_pred, args.groundtruth_RTTM, args.result_file, args.vad_pred_method, args.focus_metric, args.frame_length_in_sec + params, + args.vad_pred, + args.groundtruth_RTTM, + args.result_file, + args.vad_pred_method, + args.focus_metric, + args.frame_length_in_sec, ) logging.info( f"Best combination of thresholds for binarization selected from input ranges is {best_threhsold}, and the optimal score is {optimal_scores}"