diff --git a/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py b/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py index 1ef26cd7adf3..c4e9a14f6e1d 100644 --- a/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py @@ -25,7 +25,10 @@ from nemo.utils import logging -def pack_hypotheses(hypotheses: List[rnnt_utils.Hypothesis], logitlen: torch.Tensor,) -> List[rnnt_utils.Hypothesis]: +def pack_hypotheses( + hypotheses: List[rnnt_utils.Hypothesis], + logitlen: torch.Tensor, +) -> List[rnnt_utils.Hypothesis]: if logitlen is not None: if hasattr(logitlen, 'cpu'): @@ -55,6 +58,9 @@ def _states_to_device(dec_state, device='cpu'): return dec_state +_DECODER_LENGTHS_NONE_WARNING = "Passing in decoder_lengths=None for CTC decoding is likely to be an error, since it is unlikely that each element of your batch has exactly the same length. decoder_lengths will default to decoder_output.shape[0]." + + class GreedyCTCInfer(Typing, ConfidenceMethodMixin): """A greedy CTC decoder. @@ -108,8 +114,7 @@ class GreedyCTCInfer(Typing, ConfidenceMethodMixin): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" # Input can be of dimension - # ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels] @@ -120,8 +125,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return {"predictions": [NeuralType(elements_type=HypothesisType())]} def __init__( @@ -145,7 +149,9 @@ def __init__( @typecheck() def forward( - self, decoder_output: torch.Tensor, decoder_lengths: torch.Tensor, + self, + decoder_output: torch.Tensor, + decoder_lengths: Optional[torch.Tensor], ): """Returns a list of hypotheses given an input batch of the encoder hidden embedding. Output token is generated auto-repressively. @@ -158,6 +164,15 @@ def forward( Returns: packed list containing batch number of sentences (Hypotheses). """ + + logging.warning( + "CTC decoding strategy 'greedy' is slower than 'greedy_batch', which implements the same exact interface. Consider changing your strategy to 'greedy_batch' for a free performance improvement.", + mode=logging_mode.ONCE, + ) + + if decoder_lengths is None: + logging.warning(_DECODER_LENGTHS_NONE_WARNING, mode=logging_mode.ONCE) + with torch.inference_mode(): hypotheses = [] # Process each sequence independently @@ -204,7 +219,7 @@ def forward( return (packed_result,) @torch.no_grad() - def _greedy_decode_logprobs(self, x: torch.Tensor, out_len: torch.Tensor): + def _greedy_decode_logprobs(self, x: torch.Tensor, out_len: Optional[torch.Tensor]): # x: [T, D] # out_len: [seq_len] @@ -234,7 +249,7 @@ def _greedy_decode_logprobs(self, x: torch.Tensor, out_len: torch.Tensor): return hypothesis @torch.no_grad() - def _greedy_decode_labels(self, x: torch.Tensor, out_len: torch.Tensor): + def _greedy_decode_labels(self, x: torch.Tensor, out_len: Optional[torch.Tensor]): # x: [T] # out_len: [seq_len] @@ -324,8 +339,7 @@ class GreedyBatchedCTCInfer(Typing, ConfidenceMethodMixin): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" # Input can be of dimension - # ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels] @@ -336,8 +350,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return {"predictions": [NeuralType(elements_type=HypothesisType())]} def __init__( @@ -361,7 +374,9 @@ def __init__( @typecheck() def forward( - self, decoder_output: torch.Tensor, decoder_lengths: torch.Tensor, + self, + decoder_output: torch.Tensor, + decoder_lengths: Optional[torch.Tensor], ): """Returns a list of hypotheses given an input batch of the encoder hidden embedding. Output token is generated auto-repressively. @@ -374,11 +389,18 @@ def forward( Returns: packed list containing batch number of sentences (Hypotheses). """ + + input_decoder_lengths = decoder_lengths + + if decoder_lengths is None: + logging.warning(_DECODER_LENGTHS_NONE_WARNING, mode=logging_mode.ONCE) + decoder_lengths = torch.tensor([decoder_output.shape[1]], dtype=torch.long).expand(decoder_output.shape[0]) + if decoder_output.ndim == 2: hypotheses = self._greedy_decode_labels_batched(decoder_output, decoder_lengths) else: hypotheses = self._greedy_decode_logprobs_batched(decoder_output, decoder_lengths) - packed_result = pack_hypotheses(hypotheses, decoder_lengths) + packed_result = pack_hypotheses(hypotheses, input_decoder_lengths) return (packed_result,) @torch.no_grad() diff --git a/tests/collections/asr/decoding/test_ctc_decoding.py b/tests/collections/asr/decoding/test_ctc_decoding.py index 8eceb822fd38..a42d61f051ad 100644 --- a/tests/collections/asr/decoding/test_ctc_decoding.py +++ b/tests/collections/asr/decoding/test_ctc_decoding.py @@ -90,7 +90,9 @@ def test_constructor_subword(self, tmp_tokenizer): assert decoding is not None @pytest.mark.unit - def test_char_decoding_greedy_forward(self,): + def test_char_decoding_greedy_forward( + self, + ): cfg = CTCDecodingConfig(strategy='greedy') vocab = char_vocabulary() decoding = CTCDecoding(decoding_cfg=cfg, vocabulary=vocab) @@ -197,7 +199,10 @@ def test_subword_decoding_greedy_forward_hypotheses(self, tmp_tokenizer, alignme @pytest.mark.parametrize('alignments', [False, True]) @pytest.mark.parametrize('timestamps', [False, True]) @pytest.mark.parametrize('preserve_frame_confidence', [False, True]) - def test_batched_decoding_logprobs(self, tmp_tokenizer, alignments, timestamps, preserve_frame_confidence): + @pytest.mark.parametrize('length_is_none', [False, True]) + def test_batched_decoding_logprobs( + self, tmp_tokenizer, alignments, timestamps, preserve_frame_confidence, length_is_none + ): cfg = CTCBPEDecodingConfig( strategy='greedy', preserve_alignments=alignments, @@ -217,7 +222,10 @@ def test_batched_decoding_logprobs(self, tmp_tokenizer, alignments, timestamps, # that we always handle at least a few blanks. input_signal[:, 0, unbatched_decoding.tokenizer.tokenizer.vocab_size] = 1000 input_signal[:, 1, unbatched_decoding.tokenizer.tokenizer.vocab_size] = 1000 - length = torch.randint(low=1, high=T, size=[B]) + if length_is_none: + length = None + else: + length = torch.randint(low=1, high=T, size=[B]) with torch.inference_mode(): hyps, _ = unbatched_decoding.ctc_decoder_predictions_tensor( @@ -240,7 +248,8 @@ def test_batched_decoding_logprobs(self, tmp_tokenizer, alignments, timestamps, @pytest.mark.unit @pytest.mark.parametrize('timestamps', [False, True]) - def test_batched_decoding_labels(self, tmp_tokenizer, timestamps): + @pytest.mark.parametrize('length_is_none', [False, True]) + def test_batched_decoding_labels(self, tmp_tokenizer, timestamps, length_is_none): cfg = CTCBPEDecodingConfig(strategy='greedy', compute_timestamps=timestamps) unbatched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer) cfg.strategy = 'greedy_batched' @@ -254,7 +263,10 @@ def test_batched_decoding_labels(self, tmp_tokenizer, timestamps): # at least a few blanks. input_labels[:, 0] = unbatched_decoding.tokenizer.tokenizer.vocab_size input_labels[:, 1] = unbatched_decoding.tokenizer.tokenizer.vocab_size - length = torch.randint(low=1, high=T, size=[B]) + if length_is_none: + length = None + else: + length = torch.randint(low=1, high=T, size=[B]) with torch.inference_mode(): hyps, _ = unbatched_decoding.ctc_decoder_predictions_tensor(