Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix segmenting for pcla inference #5849

Merged
merged 7 commits into from
Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
--input_manifest <PATH/TO/INPUT/MANIFEST> \
--output_manifest <PATH/TO/OUTPUT/MANIFEST> \
--use_audio


<PATH/TO/INPUT/MANIFEST> is a path to NeMo ASR manifest. Usually it is an output of
NeMo/examples/asr/transcribe_speech.py but can be a manifest with 'text' key. Alternatively you can use
Expand All @@ -55,7 +55,7 @@ def get_args() -> argparse.Namespace:
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="The script is for restoring punctuation and capitalization in text or text and audio. To use text and audio use '--use_audio'. Long strings are split into "
"segments of length `--max_seq_length`. `--max_seq_length` is the length which includes [CLS] and [SEP] "
"tokens. Long audios are split into segments of length 4000*`--max_seq_length`. Parameter `--step` controls segments overlapping. `--step` is a distance between beginnings of "
"tokens. If `--use_audio` is set, samples with texts longer than `--max_seq_length` will be ignored. Parameter `--step` controls segments overlapping. `--step` is a distance between beginnings of "
"consequent segments. Model outputs for tokens near the borders of tensors are less accurate and can be "
"discarded before final predictions computation. Parameter `--margin` is number of discarded outputs near "
"segments borders. Probabilities of tokens in overlapping parts of segments multiplied before selecting the "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ def get_features_infer(
for q_i, (query_st, query_audio) in enumerate(zip(st, audios)):
q_inp_ids, q_segment_ids, q_subtokens_mask, q_inp_mask, q_quantities_of_preceding_words = [], [], [], [], []
q_audio_queries, q_audio_lengths = [], []
if query_audio and length < len(query_st):
logging.info(f'Ignoring query with id {q_i}')
continue
for i in range(0, max(len(query_st), length) - length + step, step):
subtokens = [tokenizer.cls_token] + query_st[i : i + length] + [tokenizer.sep_token]
q_inp_ids.append(tokenizer.tokens_to_ids(subtokens))
Expand All @@ -140,7 +143,7 @@ def get_features_infer(
q_inp_mask.append([True] * len(subtokens))
q_quantities_of_preceding_words.append(np.count_nonzero(stm[q_i][:i]))
if query_audio:
samples = query_audio.samples[i * 4000 : (i + length) * 4000]
samples = query_audio.samples
q_audio_queries.append(samples)
q_audio_lengths.append(len(samples))
all_input_ids.append(q_inp_ids)
Expand Down