Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tiny VAD refactoring for postprocessing #4625

Merged
merged 9 commits into from Jul 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/asr/conf/vad/vad_inference_postprocessing.yaml
Expand Up @@ -5,7 +5,6 @@ num_workers: 4
sample_rate: 16000

# functionality
gen_overlap_seq: True # whether to generate predictions with overlapping input segments and smoothing filter
gen_seg_table: True # whether to converting frame level prediction to speech/no-speech segment in start and end times format
write_to_manifest: True # whether to writing above segments to a single manifest json file.

Expand Down
6 changes: 4 additions & 2 deletions examples/asr/speech_classification/vad_infer.py
Expand Up @@ -122,9 +122,10 @@ def main(cfg):
logging.info(
f"Finish generating VAD frame level prediction with window_length_in_sec={cfg.vad.parameters.window_length_in_sec} and shift_length_in_sec={cfg.vad.parameters.shift_length_in_sec}"
)
frame_length_in_sec = cfg.vad.parameters.shift_length_in_sec

# overlap smoothing filter
if cfg.gen_overlap_seq:
if cfg.vad.parameters.smoothing:
# Generate predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments.
# smoothing_method would be either in majority vote (median) or average (mean)
logging.info("Generating predictions with overlapping input segments")
Expand All @@ -141,14 +142,15 @@ def main(cfg):
f"Finish generating predictions with overlapping input segments with smoothing_method={cfg.vad.parameters.smoothing} and overlap={cfg.vad.parameters.overlap}"
)
pred_dir = smoothing_pred_dir
frame_length_in_sec = 0.01

# postprocessing and generate speech segments
if cfg.gen_seg_table:
logging.info("Converting frame level prediction to speech/no-speech segment in start and end times format.")
table_out_dir = generate_vad_segment_table(
vad_pred_dir=pred_dir,
postprocessing_params=cfg.vad.parameters.postprocessing,
shift_length_in_sec=cfg.vad.parameters.shift_length_in_sec,
frame_length_in_sec=frame_length_in_sec,
num_workers=cfg.num_workers,
out_dir=cfg.table_out_dir,
)
Expand Down
4 changes: 3 additions & 1 deletion nemo/collections/asr/models/clustering_diarizer.py
Expand Up @@ -237,6 +237,7 @@ def _run_vad(self, manifest_file):
if not self._vad_params.smoothing:
# Shift the window by 10ms to generate the frame and use the prediction of the window to represent the label for the frame;
self.vad_pred_dir = self._vad_dir
frame_length_in_sec = self._vad_shift_length_in_sec
else:
# Generate predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments.
# smoothing_method would be either in majority vote (median) or average (mean)
Expand All @@ -250,13 +251,14 @@ def _run_vad(self, manifest_file):
num_workers=self._cfg.num_workers,
)
self.vad_pred_dir = smoothing_pred_dir
frame_length_in_sec = 0.01

logging.info("Converting frame level prediction to speech/no-speech segment in start and end times format.")

table_out_dir = generate_vad_segment_table(
vad_pred_dir=self.vad_pred_dir,
postprocessing_params=self._vad_params,
shift_length_in_sec=self._vad_shift_length_in_sec,
frame_length_in_sec=frame_length_in_sec,
num_workers=self._cfg.num_workers,
)

Expand Down
46 changes: 24 additions & 22 deletions nemo/collections/asr/parts/utils/vad_utils.py
Expand Up @@ -459,12 +459,12 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te
offset (float): offset threshold for detecting the end of a speech.
pad_onset (float): adding durations before each speech segment
pad_offset (float): adding durations after each speech segment;
shift_length_in_sec (float): amount of shift of window for generating the frame.
frame_length_in_sec (float): length of frame.

Returns:
speech_segments(torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format.
"""
shift_length_in_sec = per_args.get('shift_length_in_sec', 0.01)
frame_length_in_sec = per_args.get('frame_length_in_sec', 0.01)

onset = per_args.get('onset', 0.5)
offset = per_args.get('offset', 0.5)
Expand All @@ -477,30 +477,30 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te

speech_segments = torch.empty(0)

for i in range(1, len(sequence)):
for i in range(0, len(sequence)):
# Current frame is speech
if speech:
# Switch from speech to non-speech
if sequence[i] < offset:
if i * shift_length_in_sec + pad_offset > max(0, start - pad_onset):
if i * frame_length_in_sec + pad_offset > max(0, start - pad_onset):
new_seg = torch.tensor(
[max(0, start - pad_onset), i * shift_length_in_sec + pad_offset]
[max(0, start - pad_onset), i * frame_length_in_sec + pad_offset]
).unsqueeze(0)
speech_segments = torch.cat((speech_segments, new_seg), 0)

start = i * shift_length_in_sec
start = i * frame_length_in_sec
speech = False

# Current frame is non-speech
else:
# Switch from non-speech to speech
if sequence[i] > onset:
start = i * shift_length_in_sec
start = i * frame_length_in_sec
speech = True

# if it's speech at the end, add final segment
if speech:
new_seg = torch.tensor([max(0, start - pad_onset), i * shift_length_in_sec + pad_offset]).unsqueeze(0)
new_seg = torch.tensor([max(0, start - pad_onset), i * frame_length_in_sec + pad_offset]).unsqueeze(0)
speech_segments = torch.cat((speech_segments, new_seg), 0)

# Merge the overlapped speech segments due to padding
Expand Down Expand Up @@ -627,8 +627,8 @@ def generate_vad_segment_table_per_tensor(sequence: torch.Tensor, per_args: Dict
See description in generate_overlap_vad_seq.
Use this for single instance pipeline.
"""
UNIT_FRAME_LEN = 0.01

shift_length_in_sec = per_args['shift_length_in_sec']
speech_segments = binarization(sequence, per_args)
speech_segments = filtering(speech_segments, per_args)

Expand All @@ -637,7 +637,7 @@ def generate_vad_segment_table_per_tensor(sequence: torch.Tensor, per_args: Dict

speech_segments, _ = torch.sort(speech_segments, 0)

dur = speech_segments[:, 1:2] - speech_segments[:, 0:1] + shift_length_in_sec
dur = speech_segments[:, 1:2] - speech_segments[:, 0:1] + UNIT_FRAME_LEN
speech_segments = torch.column_stack((speech_segments, dur))

return speech_segments
Expand Down Expand Up @@ -667,7 +667,7 @@ def generate_vad_segment_table_per_file(pred_filepath: str, per_args: dict) -> s


def generate_vad_segment_table(
vad_pred_dir: str, postprocessing_params: dict, shift_length_in_sec: float, num_workers: int, out_dir: str = None,
vad_pred_dir: str, postprocessing_params: dict, frame_length_in_sec: float, num_workers: int, out_dir: str = None,
) -> str:
"""
Convert frame level prediction to speech segment in start and end times format.
Expand All @@ -677,7 +677,7 @@ def generate_vad_segment_table(
Args:
vad_pred_dir (str): directory of prediction files to be processed.
postprocessing_params (dict): dictionary of thresholds for prediction score. See details in binarization and filtering.
shift_length_in_sec (float): amount of shift of window for generating the frame.
frame_length_in_sec (float): frame length.
out_dir (str): output dir of generated table/csv file.
num_workers(float): number of process for multiprocessing
Returns:
Expand All @@ -700,7 +700,7 @@ def generate_vad_segment_table(
os.mkdir(table_out_dir)

per_args = {
"shift_length_in_sec": shift_length_in_sec,
"frame_length_in_sec": frame_length_in_sec,
"out_dir": table_out_dir,
}
per_args = {**per_args, **postprocessing_params}
Expand Down Expand Up @@ -778,7 +778,7 @@ def vad_tune_threshold_on_dev(
result_file: str = "res",
vad_pred_method: str = "frame",
focus_metric: str = "DetER",
shift_length_in_sec: float = 0.01,
frame_length_in_sec: float = 0.01,
num_workers: int = 20,
) -> Tuple[dict, dict]:
"""
Expand All @@ -788,6 +788,8 @@ def vad_tune_threshold_on_dev(
vad_pred_method (str): suffix of prediction file. Use to locate file. Should be either in "frame", "mean" or "median".
groundtruth_RTTM_dir (str): directory of ground-truth rttm files or a file contains the paths of them.
focus_metric (str): metrics we care most when tuning threshold. Should be either in "DetER", "FA", "MISS"
frame_length_in_sec (float): frame length.
num_workers (int): number of workers.
Returns:
best_threshold (float): threshold that gives lowest DetER.
"""
Expand All @@ -810,7 +812,7 @@ def vad_tune_threshold_on_dev(
# Generate speech segments by performing binarization on the VAD prediction according to param.
# Filter speech segments according to param and write the result to rttm-like table.
vad_table_dir = generate_vad_segment_table(
vad_pred, param, shift_length_in_sec=shift_length_in_sec, num_workers=num_workers
vad_pred, param, frame_length_in_sec=frame_length_in_sec, num_workers=num_workers
)
# add reference and hypothesis to metrics
for filename in paired_filenames:
Expand Down Expand Up @@ -938,14 +940,14 @@ def plot(
per_args(dict): a dict that stores the thresholds for postprocessing.
"""
plt.figure(figsize=[20, 2])
FRAME_LEN = 0.01
UNIT_FRAME_LEN = 0.01

audio, sample_rate = librosa.load(path=path2audio_file, sr=16000, mono=True, offset=offset, duration=duration)
dur = librosa.get_duration(y=audio, sr=sample_rate)

time = np.arange(offset, offset + dur, FRAME_LEN)
time = np.arange(offset, offset + dur, UNIT_FRAME_LEN)
frame, _ = load_tensor_from_file(path2_vad_pred)
frame_snippet = frame[int(offset / FRAME_LEN) : int((offset + dur) / FRAME_LEN)]
frame_snippet = frame[int(offset / UNIT_FRAME_LEN) : int((offset + dur) / UNIT_FRAME_LEN)]

len_pred = len(frame_snippet)
ax1 = plt.subplot()
Expand All @@ -969,14 +971,14 @@ def plot(
) # take whole frame here for calculating onset and offset
speech_segments = generate_vad_segment_table_per_tensor(frame, per_args_float)
pred = gen_pred_from_speech_segments(speech_segments, frame)
pred_snippet = pred[int(offset / FRAME_LEN) : int((offset + dur) / FRAME_LEN)]
pred_snippet = pred[int(offset / UNIT_FRAME_LEN) : int((offset + dur) / UNIT_FRAME_LEN)]

if path2ground_truth_label:
label = extract_labels(path2ground_truth_label, time)
ax2.plot(np.arange(len_pred) * FRAME_LEN, label, 'r', label='label')
ax2.plot(np.arange(len_pred) * UNIT_FRAME_LEN, label, 'r', label='label')

ax2.plot(np.arange(len_pred) * FRAME_LEN, pred_snippet, 'b', label='pred')
ax2.plot(np.arange(len_pred) * FRAME_LEN, frame_snippet, 'g--', label='speech prob')
ax2.plot(np.arange(len_pred) * UNIT_FRAME_LEN, pred_snippet, 'b', label='pred')
ax2.plot(np.arange(len_pred) * UNIT_FRAME_LEN, frame_snippet, 'g--', label='speech prob')
ax2.tick_params(axis='y', labelcolor='r')
ax2.legend(loc='lower right', shadow=True)
ax2.set_ylabel('Preds and Probas')
Expand Down
6 changes: 3 additions & 3 deletions scripts/voice_activity_detection/vad_overlap_posterior.py
Expand Up @@ -83,19 +83,19 @@
start = time.time()
logging.info("Converting frame level prediction to speech/no-speech segment in start and end times format.")

frame_length_in_sec = args.shift_length_in_sec
if args.gen_overlap_seq:
logging.info("Use overlap prediction. Change if you want to use basic frame level prediction")
vad_pred_dir = overlap_out_dir
shift_length_in_sec = 0.01
frame_length_in_sec = 0.01
else:
logging.info("Use basic frame level prediction")
vad_pred_dir = args.frame_folder
shift_length_in_sec = args.shift_length_in_sec

table_out_dir = generate_vad_segment_table(
vad_pred_dir=vad_pred_dir,
postprocessing_params=postprocessing_params,
shift_length_in_sec=args.shift_length_in_sec,
frame_length_in_sec=frame_length_in_sec,
num_workers=args.num_workers,
out_dir=args.table_out_dir,
)
Expand Down
11 changes: 10 additions & 1 deletion scripts/voice_activity_detection/vad_tune_threshold.py
Expand Up @@ -81,6 +81,9 @@
type=str,
default='DetER',
)
parser.add_argument(
"--frame_length_in_sec", help="frame_length_in_sec ", type=float, default=0.01,
)
args = parser.parse_args()

params = {}
Expand Down Expand Up @@ -125,7 +128,13 @@
)

best_threhsold, optimal_scores = vad_tune_threshold_on_dev(
params, args.vad_pred, args.groundtruth_RTTM, args.result_file, args.vad_pred_method, args.focus_metric
params,
args.vad_pred,
args.groundtruth_RTTM,
args.result_file,
args.vad_pred_method,
args.focus_metric,
args.frame_length_in_sec,
)
logging.info(
f"Best combination of thresholds for binarization selected from input ranges is {best_threhsold}, and the optimal score is {optimal_scores}"
Expand Down