From ebcf9431134b5a5cda7de5dc8edf7b1f4c43559e Mon Sep 17 00:00:00 2001 From: fayejf <36722593+fayejf@users.noreply.github.com> Date: Fri, 29 Jul 2022 02:05:06 -0700 Subject: [PATCH] Tiny VAD refactoring for postprocessing (#4625) * binarization start index Signed-off-by: fayejf * fix frame len Signed-off-by: fayejf * style fix Signed-off-by: fayejf * rame UNIT_FRAME_LEN Signed-off-by: fayejf * update overlap script and fix lgtm Signed-off-by: fayejf * style fi Signed-off-by: fayejf Signed-off-by: David Mosallanezhad --- .../vad/vad_inference_postprocessing.yaml | 1 - .../asr/speech_classification/vad_infer.py | 6 ++- .../asr/models/clustering_diarizer.py | 4 +- nemo/collections/asr/parts/utils/vad_utils.py | 46 ++++++++++--------- .../vad_overlap_posterior.py | 6 +-- .../vad_tune_threshold.py | 11 ++++- 6 files changed, 44 insertions(+), 30 deletions(-) diff --git a/examples/asr/conf/vad/vad_inference_postprocessing.yaml b/examples/asr/conf/vad/vad_inference_postprocessing.yaml index d2f606d3db40..f51d5480bd2b 100644 --- a/examples/asr/conf/vad/vad_inference_postprocessing.yaml +++ b/examples/asr/conf/vad/vad_inference_postprocessing.yaml @@ -5,7 +5,6 @@ num_workers: 4 sample_rate: 16000 # functionality -gen_overlap_seq: True # whether to generate predictions with overlapping input segments and smoothing filter gen_seg_table: True # whether to converting frame level prediction to speech/no-speech segment in start and end times format write_to_manifest: True # whether to writing above segments to a single manifest json file. diff --git a/examples/asr/speech_classification/vad_infer.py b/examples/asr/speech_classification/vad_infer.py index 675a1e11cf6f..8ab040b34c79 100644 --- a/examples/asr/speech_classification/vad_infer.py +++ b/examples/asr/speech_classification/vad_infer.py @@ -122,9 +122,10 @@ def main(cfg): logging.info( f"Finish generating VAD frame level prediction with window_length_in_sec={cfg.vad.parameters.window_length_in_sec} and shift_length_in_sec={cfg.vad.parameters.shift_length_in_sec}" ) + frame_length_in_sec = cfg.vad.parameters.shift_length_in_sec # overlap smoothing filter - if cfg.gen_overlap_seq: + if cfg.vad.parameters.smoothing: # Generate predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments. # smoothing_method would be either in majority vote (median) or average (mean) logging.info("Generating predictions with overlapping input segments") @@ -141,6 +142,7 @@ def main(cfg): f"Finish generating predictions with overlapping input segments with smoothing_method={cfg.vad.parameters.smoothing} and overlap={cfg.vad.parameters.overlap}" ) pred_dir = smoothing_pred_dir + frame_length_in_sec = 0.01 # postprocessing and generate speech segments if cfg.gen_seg_table: @@ -148,7 +150,7 @@ def main(cfg): table_out_dir = generate_vad_segment_table( vad_pred_dir=pred_dir, postprocessing_params=cfg.vad.parameters.postprocessing, - shift_length_in_sec=cfg.vad.parameters.shift_length_in_sec, + frame_length_in_sec=frame_length_in_sec, num_workers=cfg.num_workers, out_dir=cfg.table_out_dir, ) diff --git a/nemo/collections/asr/models/clustering_diarizer.py b/nemo/collections/asr/models/clustering_diarizer.py index 8cfb7d7636b2..688c482a370a 100644 --- a/nemo/collections/asr/models/clustering_diarizer.py +++ b/nemo/collections/asr/models/clustering_diarizer.py @@ -237,6 +237,7 @@ def _run_vad(self, manifest_file): if not self._vad_params.smoothing: # Shift the window by 10ms to generate the frame and use the prediction of the window to represent the label for the frame; self.vad_pred_dir = self._vad_dir + frame_length_in_sec = self._vad_shift_length_in_sec else: # Generate predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments. # smoothing_method would be either in majority vote (median) or average (mean) @@ -250,13 +251,14 @@ def _run_vad(self, manifest_file): num_workers=self._cfg.num_workers, ) self.vad_pred_dir = smoothing_pred_dir + frame_length_in_sec = 0.01 logging.info("Converting frame level prediction to speech/no-speech segment in start and end times format.") table_out_dir = generate_vad_segment_table( vad_pred_dir=self.vad_pred_dir, postprocessing_params=self._vad_params, - shift_length_in_sec=self._vad_shift_length_in_sec, + frame_length_in_sec=frame_length_in_sec, num_workers=self._cfg.num_workers, ) diff --git a/nemo/collections/asr/parts/utils/vad_utils.py b/nemo/collections/asr/parts/utils/vad_utils.py index 28008bccef80..33db39e4e5cd 100644 --- a/nemo/collections/asr/parts/utils/vad_utils.py +++ b/nemo/collections/asr/parts/utils/vad_utils.py @@ -459,12 +459,12 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te offset (float): offset threshold for detecting the end of a speech. pad_onset (float): adding durations before each speech segment pad_offset (float): adding durations after each speech segment; - shift_length_in_sec (float): amount of shift of window for generating the frame. + frame_length_in_sec (float): length of frame. Returns: speech_segments(torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format. """ - shift_length_in_sec = per_args.get('shift_length_in_sec', 0.01) + frame_length_in_sec = per_args.get('frame_length_in_sec', 0.01) onset = per_args.get('onset', 0.5) offset = per_args.get('offset', 0.5) @@ -477,30 +477,30 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te speech_segments = torch.empty(0) - for i in range(1, len(sequence)): + for i in range(0, len(sequence)): # Current frame is speech if speech: # Switch from speech to non-speech if sequence[i] < offset: - if i * shift_length_in_sec + pad_offset > max(0, start - pad_onset): + if i * frame_length_in_sec + pad_offset > max(0, start - pad_onset): new_seg = torch.tensor( - [max(0, start - pad_onset), i * shift_length_in_sec + pad_offset] + [max(0, start - pad_onset), i * frame_length_in_sec + pad_offset] ).unsqueeze(0) speech_segments = torch.cat((speech_segments, new_seg), 0) - start = i * shift_length_in_sec + start = i * frame_length_in_sec speech = False # Current frame is non-speech else: # Switch from non-speech to speech if sequence[i] > onset: - start = i * shift_length_in_sec + start = i * frame_length_in_sec speech = True # if it's speech at the end, add final segment if speech: - new_seg = torch.tensor([max(0, start - pad_onset), i * shift_length_in_sec + pad_offset]).unsqueeze(0) + new_seg = torch.tensor([max(0, start - pad_onset), i * frame_length_in_sec + pad_offset]).unsqueeze(0) speech_segments = torch.cat((speech_segments, new_seg), 0) # Merge the overlapped speech segments due to padding @@ -627,8 +627,8 @@ def generate_vad_segment_table_per_tensor(sequence: torch.Tensor, per_args: Dict See description in generate_overlap_vad_seq. Use this for single instance pipeline. """ + UNIT_FRAME_LEN = 0.01 - shift_length_in_sec = per_args['shift_length_in_sec'] speech_segments = binarization(sequence, per_args) speech_segments = filtering(speech_segments, per_args) @@ -637,7 +637,7 @@ def generate_vad_segment_table_per_tensor(sequence: torch.Tensor, per_args: Dict speech_segments, _ = torch.sort(speech_segments, 0) - dur = speech_segments[:, 1:2] - speech_segments[:, 0:1] + shift_length_in_sec + dur = speech_segments[:, 1:2] - speech_segments[:, 0:1] + UNIT_FRAME_LEN speech_segments = torch.column_stack((speech_segments, dur)) return speech_segments @@ -667,7 +667,7 @@ def generate_vad_segment_table_per_file(pred_filepath: str, per_args: dict) -> s def generate_vad_segment_table( - vad_pred_dir: str, postprocessing_params: dict, shift_length_in_sec: float, num_workers: int, out_dir: str = None, + vad_pred_dir: str, postprocessing_params: dict, frame_length_in_sec: float, num_workers: int, out_dir: str = None, ) -> str: """ Convert frame level prediction to speech segment in start and end times format. @@ -677,7 +677,7 @@ def generate_vad_segment_table( Args: vad_pred_dir (str): directory of prediction files to be processed. postprocessing_params (dict): dictionary of thresholds for prediction score. See details in binarization and filtering. - shift_length_in_sec (float): amount of shift of window for generating the frame. + frame_length_in_sec (float): frame length. out_dir (str): output dir of generated table/csv file. num_workers(float): number of process for multiprocessing Returns: @@ -700,7 +700,7 @@ def generate_vad_segment_table( os.mkdir(table_out_dir) per_args = { - "shift_length_in_sec": shift_length_in_sec, + "frame_length_in_sec": frame_length_in_sec, "out_dir": table_out_dir, } per_args = {**per_args, **postprocessing_params} @@ -778,7 +778,7 @@ def vad_tune_threshold_on_dev( result_file: str = "res", vad_pred_method: str = "frame", focus_metric: str = "DetER", - shift_length_in_sec: float = 0.01, + frame_length_in_sec: float = 0.01, num_workers: int = 20, ) -> Tuple[dict, dict]: """ @@ -788,6 +788,8 @@ def vad_tune_threshold_on_dev( vad_pred_method (str): suffix of prediction file. Use to locate file. Should be either in "frame", "mean" or "median". groundtruth_RTTM_dir (str): directory of ground-truth rttm files or a file contains the paths of them. focus_metric (str): metrics we care most when tuning threshold. Should be either in "DetER", "FA", "MISS" + frame_length_in_sec (float): frame length. + num_workers (int): number of workers. Returns: best_threshold (float): threshold that gives lowest DetER. """ @@ -810,7 +812,7 @@ def vad_tune_threshold_on_dev( # Generate speech segments by performing binarization on the VAD prediction according to param. # Filter speech segments according to param and write the result to rttm-like table. vad_table_dir = generate_vad_segment_table( - vad_pred, param, shift_length_in_sec=shift_length_in_sec, num_workers=num_workers + vad_pred, param, frame_length_in_sec=frame_length_in_sec, num_workers=num_workers ) # add reference and hypothesis to metrics for filename in paired_filenames: @@ -938,14 +940,14 @@ def plot( per_args(dict): a dict that stores the thresholds for postprocessing. """ plt.figure(figsize=[20, 2]) - FRAME_LEN = 0.01 + UNIT_FRAME_LEN = 0.01 audio, sample_rate = librosa.load(path=path2audio_file, sr=16000, mono=True, offset=offset, duration=duration) dur = librosa.get_duration(y=audio, sr=sample_rate) - time = np.arange(offset, offset + dur, FRAME_LEN) + time = np.arange(offset, offset + dur, UNIT_FRAME_LEN) frame, _ = load_tensor_from_file(path2_vad_pred) - frame_snippet = frame[int(offset / FRAME_LEN) : int((offset + dur) / FRAME_LEN)] + frame_snippet = frame[int(offset / UNIT_FRAME_LEN) : int((offset + dur) / UNIT_FRAME_LEN)] len_pred = len(frame_snippet) ax1 = plt.subplot() @@ -969,14 +971,14 @@ def plot( ) # take whole frame here for calculating onset and offset speech_segments = generate_vad_segment_table_per_tensor(frame, per_args_float) pred = gen_pred_from_speech_segments(speech_segments, frame) - pred_snippet = pred[int(offset / FRAME_LEN) : int((offset + dur) / FRAME_LEN)] + pred_snippet = pred[int(offset / UNIT_FRAME_LEN) : int((offset + dur) / UNIT_FRAME_LEN)] if path2ground_truth_label: label = extract_labels(path2ground_truth_label, time) - ax2.plot(np.arange(len_pred) * FRAME_LEN, label, 'r', label='label') + ax2.plot(np.arange(len_pred) * UNIT_FRAME_LEN, label, 'r', label='label') - ax2.plot(np.arange(len_pred) * FRAME_LEN, pred_snippet, 'b', label='pred') - ax2.plot(np.arange(len_pred) * FRAME_LEN, frame_snippet, 'g--', label='speech prob') + ax2.plot(np.arange(len_pred) * UNIT_FRAME_LEN, pred_snippet, 'b', label='pred') + ax2.plot(np.arange(len_pred) * UNIT_FRAME_LEN, frame_snippet, 'g--', label='speech prob') ax2.tick_params(axis='y', labelcolor='r') ax2.legend(loc='lower right', shadow=True) ax2.set_ylabel('Preds and Probas') diff --git a/scripts/voice_activity_detection/vad_overlap_posterior.py b/scripts/voice_activity_detection/vad_overlap_posterior.py index 340ffaafae45..d1c2d0e0d264 100644 --- a/scripts/voice_activity_detection/vad_overlap_posterior.py +++ b/scripts/voice_activity_detection/vad_overlap_posterior.py @@ -83,19 +83,19 @@ start = time.time() logging.info("Converting frame level prediction to speech/no-speech segment in start and end times format.") + frame_length_in_sec = args.shift_length_in_sec if args.gen_overlap_seq: logging.info("Use overlap prediction. Change if you want to use basic frame level prediction") vad_pred_dir = overlap_out_dir - shift_length_in_sec = 0.01 + frame_length_in_sec = 0.01 else: logging.info("Use basic frame level prediction") vad_pred_dir = args.frame_folder - shift_length_in_sec = args.shift_length_in_sec table_out_dir = generate_vad_segment_table( vad_pred_dir=vad_pred_dir, postprocessing_params=postprocessing_params, - shift_length_in_sec=args.shift_length_in_sec, + frame_length_in_sec=frame_length_in_sec, num_workers=args.num_workers, out_dir=args.table_out_dir, ) diff --git a/scripts/voice_activity_detection/vad_tune_threshold.py b/scripts/voice_activity_detection/vad_tune_threshold.py index 397d7c439cfe..2f41677fe3ae 100644 --- a/scripts/voice_activity_detection/vad_tune_threshold.py +++ b/scripts/voice_activity_detection/vad_tune_threshold.py @@ -81,6 +81,9 @@ type=str, default='DetER', ) + parser.add_argument( + "--frame_length_in_sec", help="frame_length_in_sec ", type=float, default=0.01, + ) args = parser.parse_args() params = {} @@ -125,7 +128,13 @@ ) best_threhsold, optimal_scores = vad_tune_threshold_on_dev( - params, args.vad_pred, args.groundtruth_RTTM, args.result_file, args.vad_pred_method, args.focus_metric + params, + args.vad_pred, + args.groundtruth_RTTM, + args.result_file, + args.vad_pred_method, + args.focus_metric, + args.frame_length_in_sec, ) logging.info( f"Best combination of thresholds for binarization selected from input ranges is {best_threhsold}, and the optimal score is {optimal_scores}"