Skip to content

Commit

Permalink
[BugFix] Removing <unk> tokens from decoding timestamp (#5481)
Browse files Browse the repository at this point in the history
* Removing <unk> tokens from timestamp decoding

Signed-off-by: Taejin Park <tango4j@gmail.com>

* Fixed notebook bug

Signed-off-by: Taejin Park <tango4j@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Taejin Park <tango4j@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
tango4j and pre-commit-ci[bot] committed Nov 29, 2022
1 parent e711467 commit f0d9c51
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 19 deletions.
14 changes: 10 additions & 4 deletions nemo/collections/asr/parts/utils/decoder_timestamps_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,13 @@ def get_ts_from_decoded_prediction(self, decoded_prediction: List[str], hypothes
word_ts, word_seq = [], []
word_open_flag = False
for idx, ch in enumerate(decoded_char_list):

# If the symbol is space and not an end of the utterance, move on
if idx != end_idx and (space == ch and space in decoded_char_list[idx + 1]):
continue
# If the symbol is unkown symbol such as '<unk>' symbol, move on
elif ch in ['<unk>']:
continue

if (idx == stt_idx or space == decoded_char_list[idx - 1] or (space in ch and len(ch) > 1)) and (
ch != space
Expand All @@ -135,7 +140,7 @@ def get_ts_from_decoded_prediction(self, decoded_prediction: List[str], hypothes
word_ts.append([_stt, _end])
stitched_word = ''.join(decoded_char_list[stt_ch_idx : end_ch_idx + 1]).replace(space, '')
word_seq.append(stitched_word)
assert len(word_ts) == len(hypothesis.split()), "Hypothesis does not match word time stamp."
assert len(word_ts) == len(hypothesis.split()), "Text hypothesis does not match word timestamps."
return word_ts


Expand Down Expand Up @@ -338,7 +343,6 @@ def set_asr_model(self):
self.word_ts_anchor_offset = if_none_get_default(self.params['word_ts_anchor_offset'], 0.12)
self.asr_batch_size = if_none_get_default(self.params['asr_batch_size'], 16)
self.model_stride_in_secs = 0.04

# Conformer requires buffered inference and the parameters for buffered processing.
self.chunk_len_in_sec = 5
self.total_buffer_in_secs = 25
Expand Down Expand Up @@ -739,11 +743,13 @@ def run_pyctcdecode(
return hyp_words, word_ts

@staticmethod
def get_word_ts_from_wordframes(idx, word_frames: List[List[float]], frame_duration: float, onset_delay: float):
def get_word_ts_from_wordframes(
idx, word_frames: List[List[float]], frame_duration: float, onset_delay: float, word_block_delay: float = 2.25
):
"""
Extract word timestamps from word frames generated from pyctcdecode.
"""
offset = -1 * 2.25 * frame_duration - onset_delay
offset = -1 * word_block_delay * frame_duration - onset_delay
frame_begin = word_frames[idx][1][0]
if frame_begin == -1:
frame_begin = word_frames[idx - 1][1][1] if idx != 0 else 0
Expand Down
32 changes: 17 additions & 15 deletions tutorials/speaker_tasks/ASR_with_SpeakerDiarization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,8 @@
"metadata": {},
"outputs": [],
"source": [
"hyp1 = \"eleven twenty seven fifty seven october twenty fourth nineteen seventy\"\n",
"hyp2 = \"october twenty fourth nineteen seventy eleven twenty seven fifty seven\""
"hyp1 = \"eleven twenty seven fifty seven october twenty four nineteen seventy\"\n",
"hyp2 = \"october twenty four nineteen seventy eleven twenty seven fifty seven\""
]
},
{
Expand Down Expand Up @@ -558,7 +558,7 @@
"\"an4_diarize_test speaker_0 2.32 0.75 seven 0\",\n",
"\"an4_diarize_test speaker_1 3.12 0.42 october 0\",\n",
"\"an4_diarize_test speaker_1 3.6 0.28 twenty 0\",\n",
"\"an4_diarize_test speaker_1 3.95 0.35 fourth 0\",\n",
"\"an4_diarize_test speaker_1 3.95 0.35 four 0\",\n",
"\"an4_diarize_test speaker_1 4.3 0.31 nineteen 0\",\n",
"\"an4_diarize_test speaker_1 4.65 0.35 seventy 0\"]"
]
Expand Down Expand Up @@ -661,27 +661,27 @@
"outputs": [],
"source": [
"from nemo.collections.asr.parts.utils.diarization_utils import convert_word_dict_seq_to_text\n",
"\n",
"import copy\n",
"word_seq_lists_modified = copy.deepcopy(word_seq_lists)\n",
"# Let's artificially flip a speaker label and check whether cpWER reflects it\n",
"word_seq_lists[0][-1]['speaker_label'] = 'speaker_0'\n",
"spk_hypothesis, mix_hypothesis = convert_word_dict_seq_to_text(word_seq_list)\n",
"word_seq_lists_modified[0][-1]['speaker'] = 'speaker_0'\n",
"print(word_seq_lists_modified[0])\n",
"\n",
"\n",
"spk_hypothesis, mix_hypothesis = convert_word_dict_seq_to_text(word_seq_list)\n",
"spk_hypothesis_modified, mix_hypothesis_modified = convert_word_dict_seq_to_text(word_seq_lists_modified[0])\n",
"\n",
"# Check that \"seventy\" in spk_hypothesis has been moved to speaker_0\n",
"print(f\"spk_hypothesis: {spk_hypothesis}\")\n",
"print(f\"mix_hypothesis: {mix_hypothesis}\\n\")\n",
"print(f\"spk_hypothesis_modified: {spk_hypothesis_modified}\")\n",
"print(f\"mix_hypothesis_modified: {mix_hypothesis_modified}\\n\")\n",
"\n",
"print(f\"spk_reference: {spk_reference}\")\n",
"print(f\"mix_reference: {mix_reference}\")\n",
"\n",
"# Recalculate cpWER and WER\n",
"cpWER, concat_hyp, concat_ref = concat_perm_word_error_rate([spk_hypothesis], [spk_reference])\n",
"WER = word_error_rate([mix_hypothesis], [mix_reference])\n",
"cpWER_modified, concat_hyp, concat_ref = concat_perm_word_error_rate([spk_hypothesis_modified], [spk_reference])\n",
"WER_modified = word_error_rate([mix_hypothesis_modified], [mix_reference])\n",
"\n",
"print(f\"cpWER: {cpWER[0]}\")\n",
"print(f\"WER: {WER}\")"
"print(f\"cpWER: {cpWER_modified[0]}\")\n",
"print(f\"WER: {WER_modified}\")"
]
},
{
Expand Down Expand Up @@ -921,7 +921,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"asr_diar_offline.get_transcript_with_speaker_labels(diar_hyp, word_hyp, word_ts_hyp)"
Expand Down

0 comments on commit f0d9c51

Please sign in to comment.