Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept None as an argument to decoder_lengths in GreedyBatchedCTCInfer::forward #9278

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 36 additions & 14 deletions nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
from nemo.utils import logging


def pack_hypotheses(hypotheses: List[rnnt_utils.Hypothesis], logitlen: torch.Tensor,) -> List[rnnt_utils.Hypothesis]:
def pack_hypotheses(
hypotheses: List[rnnt_utils.Hypothesis],
logitlen: torch.Tensor,
) -> List[rnnt_utils.Hypothesis]:

if logitlen is not None:
if hasattr(logitlen, 'cpu'):
Expand Down Expand Up @@ -55,6 +58,9 @@ def _states_to_device(dec_state, device='cpu'):
return dec_state


_DECODER_LENGTHS_NONE_WARNING = "Passing in decoder_lengths=None for CTC decoding is likely to be an error, since it is unlikely that each element of your batch has exactly the same length. decoder_lengths will default to decoder_output.shape[0]."


class GreedyCTCInfer(Typing, ConfidenceMethodMixin):
"""A greedy CTC decoder.

Expand Down Expand Up @@ -108,8 +114,7 @@ class GreedyCTCInfer(Typing, ConfidenceMethodMixin):

@property
def input_types(self):
"""Returns definitions of module input ports.
"""
"""Returns definitions of module input ports."""
# Input can be of dimension -
# ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels]

Expand All @@ -120,8 +125,7 @@ def input_types(self):

@property
def output_types(self):
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return {"predictions": [NeuralType(elements_type=HypothesisType())]}

def __init__(
Expand All @@ -145,7 +149,9 @@ def __init__(

@typecheck()
def forward(
self, decoder_output: torch.Tensor, decoder_lengths: torch.Tensor,
self,
decoder_output: torch.Tensor,
decoder_lengths: Optional[torch.Tensor],
):
"""Returns a list of hypotheses given an input batch of the encoder hidden embedding.
Output token is generated auto-repressively.
Expand All @@ -158,6 +164,15 @@ def forward(
Returns:
packed list containing batch number of sentences (Hypotheses).
"""

logging.warning(
"CTC decoding strategy 'greedy' is slower than 'greedy_batch', which implements the same exact interface. Consider changing your strategy to 'greedy_batch' for a free performance improvement.",
mode=logging_mode.ONCE,
)

if decoder_lengths is None:
logging.warning(_DECODER_LENGTHS_NONE_WARNING, mode=logging_mode.ONCE)

with torch.inference_mode():
hypotheses = []
# Process each sequence independently
Expand Down Expand Up @@ -204,7 +219,7 @@ def forward(
return (packed_result,)

@torch.no_grad()
def _greedy_decode_logprobs(self, x: torch.Tensor, out_len: torch.Tensor):
def _greedy_decode_logprobs(self, x: torch.Tensor, out_len: Optional[torch.Tensor]):
# x: [T, D]
# out_len: [seq_len]

Expand Down Expand Up @@ -234,7 +249,7 @@ def _greedy_decode_logprobs(self, x: torch.Tensor, out_len: torch.Tensor):
return hypothesis

@torch.no_grad()
def _greedy_decode_labels(self, x: torch.Tensor, out_len: torch.Tensor):
def _greedy_decode_labels(self, x: torch.Tensor, out_len: Optional[torch.Tensor]):
# x: [T]
# out_len: [seq_len]

Expand Down Expand Up @@ -324,8 +339,7 @@ class GreedyBatchedCTCInfer(Typing, ConfidenceMethodMixin):

@property
def input_types(self):
"""Returns definitions of module input ports.
"""
"""Returns definitions of module input ports."""
# Input can be of dimension -
# ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels]

Expand All @@ -336,8 +350,7 @@ def input_types(self):

@property
def output_types(self):
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return {"predictions": [NeuralType(elements_type=HypothesisType())]}

def __init__(
Expand All @@ -361,7 +374,9 @@ def __init__(

@typecheck()
def forward(
self, decoder_output: torch.Tensor, decoder_lengths: torch.Tensor,
self,
decoder_output: torch.Tensor,
decoder_lengths: Optional[torch.Tensor],
):
"""Returns a list of hypotheses given an input batch of the encoder hidden embedding.
Output token is generated auto-repressively.
Expand All @@ -374,11 +389,18 @@ def forward(
Returns:
packed list containing batch number of sentences (Hypotheses).
"""

input_decoder_lengths = decoder_lengths

if decoder_lengths is None:
logging.warning(_DECODER_LENGTHS_NONE_WARNING, mode=logging_mode.ONCE)
decoder_lengths = torch.tensor([decoder_output.shape[1]], dtype=torch.long).expand(decoder_output.shape[0])

if decoder_output.ndim == 2:
hypotheses = self._greedy_decode_labels_batched(decoder_output, decoder_lengths)
else:
hypotheses = self._greedy_decode_logprobs_batched(decoder_output, decoder_lengths)
packed_result = pack_hypotheses(hypotheses, decoder_lengths)
packed_result = pack_hypotheses(hypotheses, input_decoder_lengths)
return (packed_result,)

@torch.no_grad()
Expand Down
22 changes: 17 additions & 5 deletions tests/collections/asr/decoding/test_ctc_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def test_constructor_subword(self, tmp_tokenizer):
assert decoding is not None

@pytest.mark.unit
def test_char_decoding_greedy_forward(self,):
def test_char_decoding_greedy_forward(
self,
):
cfg = CTCDecodingConfig(strategy='greedy')
vocab = char_vocabulary()
decoding = CTCDecoding(decoding_cfg=cfg, vocabulary=vocab)
Expand Down Expand Up @@ -197,7 +199,10 @@ def test_subword_decoding_greedy_forward_hypotheses(self, tmp_tokenizer, alignme
@pytest.mark.parametrize('alignments', [False, True])
@pytest.mark.parametrize('timestamps', [False, True])
@pytest.mark.parametrize('preserve_frame_confidence', [False, True])
def test_batched_decoding_logprobs(self, tmp_tokenizer, alignments, timestamps, preserve_frame_confidence):
@pytest.mark.parametrize('length_is_none', [False, True])
def test_batched_decoding_logprobs(
self, tmp_tokenizer, alignments, timestamps, preserve_frame_confidence, length_is_none
):
cfg = CTCBPEDecodingConfig(
strategy='greedy',
preserve_alignments=alignments,
Expand All @@ -217,7 +222,10 @@ def test_batched_decoding_logprobs(self, tmp_tokenizer, alignments, timestamps,
# that we always handle at least a few blanks.
input_signal[:, 0, unbatched_decoding.tokenizer.tokenizer.vocab_size] = 1000
input_signal[:, 1, unbatched_decoding.tokenizer.tokenizer.vocab_size] = 1000
length = torch.randint(low=1, high=T, size=[B])
if length_is_none:
length = None
else:
length = torch.randint(low=1, high=T, size=[B])

with torch.inference_mode():
hyps, _ = unbatched_decoding.ctc_decoder_predictions_tensor(
Expand All @@ -240,7 +248,8 @@ def test_batched_decoding_logprobs(self, tmp_tokenizer, alignments, timestamps,

@pytest.mark.unit
@pytest.mark.parametrize('timestamps', [False, True])
def test_batched_decoding_labels(self, tmp_tokenizer, timestamps):
@pytest.mark.parametrize('length_is_none', [False, True])
def test_batched_decoding_labels(self, tmp_tokenizer, timestamps, length_is_none):
cfg = CTCBPEDecodingConfig(strategy='greedy', compute_timestamps=timestamps)
unbatched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer)
cfg.strategy = 'greedy_batched'
Expand All @@ -254,7 +263,10 @@ def test_batched_decoding_labels(self, tmp_tokenizer, timestamps):
# at least a few blanks.
input_labels[:, 0] = unbatched_decoding.tokenizer.tokenizer.vocab_size
input_labels[:, 1] = unbatched_decoding.tokenizer.tokenizer.vocab_size
length = torch.randint(low=1, high=T, size=[B])
if length_is_none:
length = None
else:
length = torch.randint(low=1, high=T, size=[B])

with torch.inference_mode():
hyps, _ = unbatched_decoding.ctc_decoder_predictions_tensor(
Expand Down
Loading