From 60cafd95973ca247038ec37ef38ddb4e3a6e90a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 10 Apr 2024 12:14:06 -0400 Subject: [PATCH 1/5] Greedy and temperature sampling decoding for Canary/multi-task MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../transformer/transformer_generators.py | 15 +- .../parts/submodules/multitask_decoding.py | 28 ++- .../submodules/multitask_greedy_decoding.py | 196 ++++++++++++++++++ 3 files changed, 225 insertions(+), 14 deletions(-) create mode 100644 nemo/collections/asr/parts/submodules/multitask_greedy_decoding.py diff --git a/nemo/collections/asr/modules/transformer/transformer_generators.py b/nemo/collections/asr/modules/transformer/transformer_generators.py index 4061f54a907a..6066db6681ab 100644 --- a/nemo/collections/asr/modules/transformer/transformer_generators.py +++ b/nemo/collections/asr/modules/transformer/transformer_generators.py @@ -15,6 +15,7 @@ from contextlib import contextmanager import torch +from torch.distributions import Categorical from nemo.collections.common.parts import NEG_INF, mask_padded_tokens @@ -58,6 +59,7 @@ def __init__( max_sequence_length=512, max_delta_length=20, batch_size=1, + temperature=None, ): super().__init__() self.embedding = embedding @@ -67,6 +69,7 @@ def __init__( self.max_seq_length = max_sequence_length self.max_delta_len = max_delta_length self.batch_size = batch_size + self.temperature = temperature def _one_step_forward( self, @@ -107,8 +110,8 @@ def _one_step_forward( decoder_mems_list = self.decoder.forward( decoder_hidden_states, decoder_input_mask, decoder_mems_list, return_mems=True ) - log_probs = self.log_softmax.forward(hidden_states=decoder_mems_list[-1][:, -1:]) - return log_probs, decoder_mems_list + logits = decoder_mems_list[-1][:, -1] + return logits, decoder_mems_list def _prepare_for_search(self, decoder_input_ids=None, encoder_hidden_states=None): """ @@ -155,11 +158,15 @@ def _forward( decoder_mems_list = None for i in range(max_generation_length): - log_probs, decoder_mems_list = self._one_step_forward( + logits, decoder_mems_list = self._one_step_forward( tgt[:, -1:], encoder_hidden_states, encoder_input_mask, decoder_mems_list, i ) + # log_probs = self.log_softmax.forward(hidden_states=logits) + if self.temperature is None: + next_tokens = torch.argmax(logits, dim=-1, keepdim=True) + else: + next_tokens = Categorical(logits=logits / self.temperature).sample() - next_tokens = torch.argmax(log_probs[:, -1], dim=-1, keepdim=True) next_tokens = self.pad * pad_profile + next_tokens * (1 - pad_profile) pad_profile = torch.max(pad_profile, (next_tokens == self.eos).long()) tgt = torch.cat((tgt, next_tokens), dim=-1) diff --git a/nemo/collections/asr/parts/submodules/multitask_decoding.py b/nemo/collections/asr/parts/submodules/multitask_decoding.py index c336ae7d4170..913c865c9e4b 100644 --- a/nemo/collections/asr/parts/submodules/multitask_decoding.py +++ b/nemo/collections/asr/parts/submodules/multitask_decoding.py @@ -25,6 +25,10 @@ AEDBeamInferConfig, TransformerAEDBeamInfer, ) +from nemo.collections.asr.parts.submodules.multitask_greedy_decoding import ( + AEDGreedyInferConfig, + TransformerAEDGreedyInfer, +) from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis, NBestHypotheses from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec @@ -60,11 +64,11 @@ class AbstractMultiTaskDecoding(ABC): The config may further contain the following sub-dictionaries: "greedy": - max_symbols: int, describing the maximum number of target tokens to decode per - timestep during greedy decoding. Setting to larger values allows longer sentences - to be decoded, at the cost of increased execution time. - preserve_frame_confidence: Same as above, overrides above value. - confidence_method_cfg: Same as above, overrides confidence_cfg.method_cfg. + temperature: None (disabled) or float, specifying this enables temperature sampling instead of greedy decoding. + + max_generation_delta: int = -1 # -1 means up to the max length of the decoder + + preserve_alignments: bool = False (unsupported) "beam": beam_size: int, defining the beam size for beam search. Must be >= 1. @@ -118,8 +122,14 @@ def __init__( if self.cfg.strategy == 'greedy' or self.cfg.strategy == 'greedy_batch': - # self.decoding = None - raise NotImplementedError("Greedy decoding is not implemented yet.") + self.decoding = TransformerAEDGreedyInfer( + transformer_decoder=transformer_decoder, + log_softmax_module=log_softmax_module, + tokenizer=tokenizer, + max_generation_delta=self.cfg.greedy.get('max_generation_delta', 50), + preserve_alignments=self.preserve_alignments, + temperature=self.cfg.greedy.temperature, + ) elif self.cfg.strategy == 'beam': @@ -476,9 +486,7 @@ class MultiTaskDecodingConfig: compute_langs: bool = False # greedy decoding config - # greedy: rnnt_greedy_decoding.GreedyBatchedRNNTInferConfig = field( - # default_factory=rnnt_greedy_decoding.GreedyBatchedRNNTInferConfig - # ) + greedy: AEDGreedyInferConfig = field(default_factory=AEDGreedyInferConfig) # beam decoding config beam: AEDBeamInferConfig = field(default_factory=lambda: AEDBeamInferConfig(beam_size=1)) diff --git a/nemo/collections/asr/parts/submodules/multitask_greedy_decoding.py b/nemo/collections/asr/parts/submodules/multitask_greedy_decoding.py new file mode 100644 index 000000000000..3798faeb95b5 --- /dev/null +++ b/nemo/collections/asr/parts/submodules/multitask_greedy_decoding.py @@ -0,0 +1,196 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List, Optional + +import torch + +from nemo.collections.asr.modules.transformer import GreedySequenceGenerator +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis, NBestHypotheses +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.core import Typing, typecheck +from nemo.core.neural_types import ChannelType, HypothesisType, LabelsType, MaskType, NeuralType +from nemo.utils import logging + + +def pack_hypotheses( + hypotheses: List[Hypothesis], beam_hypotheses: torch.Tensor, scores: List[Optional[float]] +) -> List[Hypothesis]: + + for idx, hyp in enumerate(hypotheses): # type: Hypothesis + if scores[idx] is not None: + hyp.score = scores[idx] + + hypi = beam_hypotheses[idx] + if torch.is_tensor(hypi): + hyp.y_sequence = hypi.long() + else: + hyp.y_sequence = torch.tensor(hypi, dtype=torch.long) + + if hyp.dec_state is not None: + hyp.dec_state = _states_to_device(hyp.dec_state) + + return hypotheses + + +def _states_to_device(dec_state, device='cpu'): + if torch.is_tensor(dec_state): + dec_state = dec_state.to(device) + + elif isinstance(dec_state, (list, tuple)): + dec_state = tuple(_states_to_device(dec_i, device) for dec_i in dec_state) + + return dec_state + + +class AEDGreedyInfer(ABC): + def __init__( + self, + transformer_decoder: torch.nn.Module, + log_softmax_module: torch.nn.Module, + tokenizer: TokenizerSpec, + search_type: str = 'default', + preserve_alignments: bool = False, + ): + super().__init__() + + self.transformer_decoder = transformer_decoder + self.log_softmax_module = log_softmax_module + self.tokenizer = tokenizer + self.search_type = search_type + + self.preserve_alignments = preserve_alignments + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + @abstractmethod + def forward( + self, + encoder_hidden_states: torch.Tensor, + encoder_input_mask: torch.Tensor, + decoder_input_ids: Optional[torch.Tensor] = None, + partial_hypotheses: Optional[List[Hypothesis]] = None, + ): + raise NotImplementedError() + + def set_decoding_type(self, decoding_type: str): + self.decoding_type = decoding_type + + +class TransformerAEDGreedyInfer(AEDGreedyInfer, Typing): + """ + A greedy decoder engine for AED Transformer models with support for temperature sampling. + """ + + @property + def input_types(self): + """Returns definitions of module input ports. + """ + # Input can be of dimention - + # ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels] + + return { + "encoder_hidden_states": NeuralType(tuple(('B', 'T', 'D')), ChannelType()), + "encoder_input_mask": NeuralType(tuple(('B', 'T')), MaskType()), + "decoder_input_ids": NeuralType(('B', 'T'), LabelsType()), + "partial_hypotheses": NeuralType(optional=True), + } + + @property + def output_types(self): + """Returns definitions of module output ports. + """ + return {"predictions": [NeuralType(elements_type=HypothesisType())]} + + def __init__( + self, + transformer_decoder: torch.nn.Module, + log_softmax_module: torch.nn.Module, + tokenizer: TokenizerSpec, + temperature: float | None = None, + max_generation_delta: int = 50, + preserve_alignments: bool = False, + ): + super().__init__( + transformer_decoder=transformer_decoder, + log_softmax_module=log_softmax_module, + tokenizer=tokenizer, + preserve_alignments=preserve_alignments, + ) + self.temperature = temperature + self.greedy_search = GreedySequenceGenerator( + embedding=transformer_decoder.embedding, + decoder=transformer_decoder.decoder, + log_softmax=log_softmax_module, + max_sequence_length=transformer_decoder.max_sequence_length, + bos=tokenizer.bos_id, + pad=tokenizer.pad_id, + eos=tokenizer.eos_id, + max_delta_length=max_generation_delta, + temperature=self.temperature, + ) + + self.preserve_alignments = preserve_alignments + if self.preserve_alignments: + logging.info( + "Preservation of alignments was requested but {} does not implement it.".format( + self.__class__.__name__ + ) + ) + + @typecheck() + def forward( + self, + encoder_hidden_states: torch.Tensor, + encoder_input_mask: torch.Tensor, + decoder_input_ids: Optional[torch.Tensor] = None, + partial_hypotheses: Optional[List[Hypothesis]] = None, + ): + """Returns a list of hypotheses given an input batch of the encoder hidden embedding. + Output token is generated auto-repressively. + + Args: + decoder_output: A tensor of size (batch, timesteps, features) or (batch, timesteps) (each timestep is a label). + decoder_lengths: list of int representing the length of each sequence + output sequence. + + Returns: + packed list containing batch number of sentences (Hypotheses). + """ + with torch.inference_mode(): + best_hypo = self.greedy_search( + encoder_hidden_states=encoder_hidden_states, + encoder_input_mask=encoder_input_mask, + decoder_input_ids=decoder_input_ids, + return_beam_scores=False, + ) + beam_scores = [None for _ in range(len(best_hypo))] + best_hypo = best_hypo.cpu() + hypotheses = [ + Hypothesis(score=0.0, y_sequence=[], timestep=[]) for _ in range(encoder_hidden_states.shape[0]) + ] + # Pack results into Hypotheses + packed_result = pack_hypotheses(hypotheses, best_hypo, beam_scores) + + return (packed_result,) + + +@dataclass +class AEDGreedyInferConfig: + temperature: float | None = None + max_generation_delta: int = -1 # -1 means up to the max length of the decoder + preserve_alignments: bool = False From b3b23dd7288272efea01a5d9e54168297253c5c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 10 Apr 2024 12:20:35 -0400 Subject: [PATCH 2/5] Enable changing multitask decoding strategy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../parts/submodules/multitask_decoding.py | 33 ++++++++++++------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/multitask_decoding.py b/nemo/collections/asr/parts/submodules/multitask_decoding.py index 913c865c9e4b..058fec4ac1bf 100644 --- a/nemo/collections/asr/parts/submodules/multitask_decoding.py +++ b/nemo/collections/asr/parts/submodules/multitask_decoding.py @@ -107,36 +107,43 @@ def __init__( self.preserve_alignments = self.cfg.get('preserve_alignments', None) self.compute_langs = self.cfg.get('compute_langs', False) self.compute_hypothesis_token_set = self.cfg.get('compute_hypothesis_token_set', False) + self.transformer_decoder = transformer_decoder + self.log_softmax_module = log_softmax_module + self.tokenizer = tokenizer + self.change_strategy(self.cfg.strategy) + def change_strategy(self, strategy: str) -> "AbstractMultiTaskDecoding": possible_strategies = ['greedy', 'greedy_batch', 'beam'] - if self.cfg.strategy not in possible_strategies: + if strategy not in possible_strategies: raise ValueError(f"Decoding strategy must be one of {possible_strategies}") + self.cfg.strategy = strategy + # Update preserve alignments if self.preserve_alignments is None: - if self.cfg.strategy in ['greedy', 'greedy_batch']: + if strategy in ['greedy', 'greedy_batch']: self.preserve_alignments = self.cfg.greedy.get('preserve_alignments', False) - elif self.cfg.strategy in ['beam']: + elif strategy in ['beam']: self.preserve_alignments = self.cfg.beam.get('preserve_alignments', False) - if self.cfg.strategy == 'greedy' or self.cfg.strategy == 'greedy_batch': + if strategy == 'greedy' or strategy == 'greedy_batch': self.decoding = TransformerAEDGreedyInfer( - transformer_decoder=transformer_decoder, - log_softmax_module=log_softmax_module, - tokenizer=tokenizer, + transformer_decoder=self.transformer_decoder, + log_softmax_module=self.log_softmax_module, + tokenizer=self.tokenizer, max_generation_delta=self.cfg.greedy.get('max_generation_delta', 50), preserve_alignments=self.preserve_alignments, temperature=self.cfg.greedy.temperature, ) - elif self.cfg.strategy == 'beam': + elif strategy == 'beam': self.decoding = TransformerAEDBeamInfer( - transformer_decoder=transformer_decoder, - log_softmax_module=log_softmax_module, - tokenizer=tokenizer, + transformer_decoder=self.transformer_decoder, + log_softmax_module=self.log_softmax_module, + tokenizer=self.tokenizer, search_type=self.cfg.beam.get('search_type', 'default'), beam_size=self.cfg.beam.beam_size, length_penalty=self.cfg.beam.get('length_penalty', 0.0), @@ -149,9 +156,11 @@ def __init__( raise ValueError( f"Incorrect decoding strategy provided. Must be one of {possible_strategies}\n" - f"but was provided {self.cfg.strategy}" + f"but was provided {strategy}" ) + return self + def decode_predictions_tensor( self, encoder_hidden_states: torch.Tensor, From e5a82a2b961fe99fee64784f7a94a05a4f6879f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 10 Apr 2024 16:50:52 -0400 Subject: [PATCH 3/5] fix various bugs and support temperature sampling n samples MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- examples/asr/transcribe_speech.py | 3 ++ .../transformer/transformer_generators.py | 36 ++++++++++++++----- .../parts/submodules/multitask_decoding.py | 1 + .../submodules/multitask_greedy_decoding.py | 33 ++++++++++++----- 4 files changed, 55 insertions(+), 18 deletions(-) diff --git a/examples/asr/transcribe_speech.py b/examples/asr/transcribe_speech.py index c8372c422e7b..944e42ecc409 100644 --- a/examples/asr/transcribe_speech.py +++ b/examples/asr/transcribe_speech.py @@ -278,6 +278,9 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis # Setup decoding strategy if hasattr(asr_model, 'change_decoding_strategy') and hasattr(asr_model, 'decoding'): if isinstance(asr_model.decoding, MultiTaskDecoding): + cfg.multitask_decoding.strategy = "greedy" + cfg.multitask_decoding.greedy.temperature = 0.7 + cfg.multitask_decoding.greedy.n_samples = 2 cfg.multitask_decoding.compute_langs = cfg.compute_langs cfg.multitask_decoding.preserve_alignments = cfg.preserve_alignment if cfg.extract_nbest: diff --git a/nemo/collections/asr/modules/transformer/transformer_generators.py b/nemo/collections/asr/modules/transformer/transformer_generators.py index 6066db6681ab..3028dc86a17a 100644 --- a/nemo/collections/asr/modules/transformer/transformer_generators.py +++ b/nemo/collections/asr/modules/transformer/transformer_generators.py @@ -59,16 +59,19 @@ def __init__( max_sequence_length=512, max_delta_length=20, batch_size=1, + n_samples=1, temperature=None, ): super().__init__() self.embedding = embedding self.decoder = decoder self.log_softmax = log_softmax + self.log_softmax.mlp.log_softmax = False self.pad, self.bos, self.eos = pad, bos, eos self.max_seq_length = max_sequence_length self.max_delta_len = max_delta_length self.batch_size = batch_size + self.n_samples = n_samples self.temperature = temperature def _one_step_forward( @@ -110,7 +113,7 @@ def _one_step_forward( decoder_mems_list = self.decoder.forward( decoder_hidden_states, decoder_input_mask, decoder_mems_list, return_mems=True ) - logits = decoder_mems_list[-1][:, -1] + logits = self.log_softmax.forward(hidden_states=decoder_mems_list[-1][:, -1:]) return logits, decoder_mems_list def _prepare_for_search(self, decoder_input_ids=None, encoder_hidden_states=None): @@ -148,34 +151,49 @@ def _forward( self, decoder_input_ids=None, encoder_hidden_states=None, encoder_input_mask=None, return_beam_scores=False ): assert not return_beam_scores + is_sampling = self.temperature is not None and self.n_samples > 1 + tgt, batch_size, max_generation_length = self._prepare_for_search(decoder_input_ids, encoder_hidden_states) + if is_sampling: + tgt = torch.repeat_interleave(tgt, self.n_samples, dim=0) + encoder_hidden_states = torch.repeat_interleave(encoder_hidden_states, self.n_samples, dim=0) + encoder_input_mask = torch.repeat_interleave(encoder_input_mask, self.n_samples, dim=0) + orig_batch_size = batch_size + batch_size = batch_size * self.n_samples # pad profile tracks sequences ending with token to replace # everything after with token decoder_parameter = next(self.decoder.parameters()) - pad_profile = torch.zeros(batch_size, 1).long().to(decoder_parameter.device) + pad_profile = torch.zeros(batch_size).long().to(decoder_parameter.device) decoder_mems_list = None for i in range(max_generation_length): - + if i == 0: + input_ids = tgt + else: + input_ids = tgt[:, -1:] logits, decoder_mems_list = self._one_step_forward( - tgt[:, -1:], encoder_hidden_states, encoder_input_mask, decoder_mems_list, i + input_ids, encoder_hidden_states, encoder_input_mask, decoder_mems_list, i ) - # log_probs = self.log_softmax.forward(hidden_states=logits) if self.temperature is None: - next_tokens = torch.argmax(logits, dim=-1, keepdim=True) + next_tokens = torch.argmax(logits[:, -1], dim=-1, keepdim=True) else: - next_tokens = Categorical(logits=logits / self.temperature).sample() + next_tokens = Categorical(logits=logits[:, -1] / self.temperature).sample() next_tokens = self.pad * pad_profile + next_tokens * (1 - pad_profile) pad_profile = torch.max(pad_profile, (next_tokens == self.eos).long()) - tgt = torch.cat((tgt, next_tokens), dim=-1) + tgt = torch.cat((tgt, next_tokens.unsqueeze(1)), dim=-1) # abort generation if all sequences end with if pad_profile.sum() == batch_size: break - return tgt + samples = None + if is_sampling: + samples = list(tgt.view(orig_batch_size, self.n_samples, -1)) + tgt = tgt[:: self.n_samples] + + return tgt, samples def __call__( self, decoder_input_ids=None, encoder_hidden_states=None, encoder_input_mask=None, return_beam_scores=False diff --git a/nemo/collections/asr/parts/submodules/multitask_decoding.py b/nemo/collections/asr/parts/submodules/multitask_decoding.py index 058fec4ac1bf..f03469e0f5d3 100644 --- a/nemo/collections/asr/parts/submodules/multitask_decoding.py +++ b/nemo/collections/asr/parts/submodules/multitask_decoding.py @@ -136,6 +136,7 @@ def change_strategy(self, strategy: str) -> "AbstractMultiTaskDecoding": max_generation_delta=self.cfg.greedy.get('max_generation_delta', 50), preserve_alignments=self.preserve_alignments, temperature=self.cfg.greedy.temperature, + n_samples=self.cfg.greedy.n_samples, ) elif strategy == 'beam': diff --git a/nemo/collections/asr/parts/submodules/multitask_greedy_decoding.py b/nemo/collections/asr/parts/submodules/multitask_greedy_decoding.py index 3798faeb95b5..4701dab5548b 100644 --- a/nemo/collections/asr/parts/submodules/multitask_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/multitask_greedy_decoding.py @@ -124,6 +124,7 @@ def __init__( temperature: float | None = None, max_generation_delta: int = 50, preserve_alignments: bool = False, + n_samples: int = 1, ): super().__init__( transformer_decoder=transformer_decoder, @@ -132,6 +133,7 @@ def __init__( preserve_alignments=preserve_alignments, ) self.temperature = temperature + self.n_samples = n_samples self.greedy_search = GreedySequenceGenerator( embedding=transformer_decoder.embedding, decoder=transformer_decoder.decoder, @@ -142,6 +144,7 @@ def __init__( eos=tokenizer.eos_id, max_delta_length=max_generation_delta, temperature=self.temperature, + n_samples=n_samples, ) self.preserve_alignments = preserve_alignments @@ -172,19 +175,30 @@ def forward( packed list containing batch number of sentences (Hypotheses). """ with torch.inference_mode(): - best_hypo = self.greedy_search( + best_hypo, topk_hypotheses = self.greedy_search( encoder_hidden_states=encoder_hidden_states, encoder_input_mask=encoder_input_mask, decoder_input_ids=decoder_input_ids, - return_beam_scores=False, ) - beam_scores = [None for _ in range(len(best_hypo))] - best_hypo = best_hypo.cpu() - hypotheses = [ - Hypothesis(score=0.0, y_sequence=[], timestep=[]) for _ in range(encoder_hidden_states.shape[0]) - ] - # Pack results into Hypotheses - packed_result = pack_hypotheses(hypotheses, best_hypo, beam_scores) + + if topk_hypotheses is not None: + topk_hypotheses = [x.detach().cpu() for x in topk_hypotheses] # each item is [beam, seq_len] + beam_scores = [[None] * self.n_samples for _ in topk_hypotheses] # each item is [beam,] + packed_result = [] + for i in range(len(topk_hypotheses)): + hypotheses = [Hypothesis(score=0.0, y_sequence=[], timestep=[]) for _ in range(self.n_samples)] + # Pack results into Hypotheses + packed_result.append( + NBestHypotheses(pack_hypotheses(hypotheses, topk_hypotheses[i], beam_scores[i])) + ) + else: + beam_scores = [None for _ in range(len(best_hypo))] + best_hypo = best_hypo.cpu() + hypotheses = [ + Hypothesis(score=0.0, y_sequence=[], timestep=[]) for _ in range(encoder_hidden_states.shape[0]) + ] + # Pack results into Hypotheses + packed_result = pack_hypotheses(hypotheses, best_hypo, beam_scores) return (packed_result,) @@ -194,3 +208,4 @@ class AEDGreedyInferConfig: temperature: float | None = None max_generation_delta: int = -1 # -1 means up to the max length of the decoder preserve_alignments: bool = False + n_samples: int = 1 From 5ad0e695cd4584066b0adad192073d46f548bab6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 10 Apr 2024 17:11:34 -0400 Subject: [PATCH 4/5] fix greedy non-temperature decoding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- examples/asr/transcribe_speech.py | 3 --- .../asr/modules/transformer/transformer_generators.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/examples/asr/transcribe_speech.py b/examples/asr/transcribe_speech.py index 944e42ecc409..c8372c422e7b 100644 --- a/examples/asr/transcribe_speech.py +++ b/examples/asr/transcribe_speech.py @@ -278,9 +278,6 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis # Setup decoding strategy if hasattr(asr_model, 'change_decoding_strategy') and hasattr(asr_model, 'decoding'): if isinstance(asr_model.decoding, MultiTaskDecoding): - cfg.multitask_decoding.strategy = "greedy" - cfg.multitask_decoding.greedy.temperature = 0.7 - cfg.multitask_decoding.greedy.n_samples = 2 cfg.multitask_decoding.compute_langs = cfg.compute_langs cfg.multitask_decoding.preserve_alignments = cfg.preserve_alignment if cfg.extract_nbest: diff --git a/nemo/collections/asr/modules/transformer/transformer_generators.py b/nemo/collections/asr/modules/transformer/transformer_generators.py index 3028dc86a17a..bba13a05f25b 100644 --- a/nemo/collections/asr/modules/transformer/transformer_generators.py +++ b/nemo/collections/asr/modules/transformer/transformer_generators.py @@ -176,7 +176,7 @@ def _forward( input_ids, encoder_hidden_states, encoder_input_mask, decoder_mems_list, i ) if self.temperature is None: - next_tokens = torch.argmax(logits[:, -1], dim=-1, keepdim=True) + next_tokens = torch.argmax(logits[:, -1], dim=-1) else: next_tokens = Categorical(logits=logits[:, -1] / self.temperature).sample() From 8e6b2204336e76c48fb1704c506d2eb2807661e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 12 Apr 2024 15:41:19 -0400 Subject: [PATCH 5/5] Refactor + unit tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../transformer/transformer_generators.py | 21 ++-- .../asr/parts/submodules/token_classifier.py | 41 +++++-- .../common/parts/multi_layer_perceptron.py | 6 +- .../asr/decoding/test_aed_decoding.py | 107 ++++++++++++++++++ 4 files changed, 151 insertions(+), 24 deletions(-) create mode 100644 tests/collections/asr/decoding/test_aed_decoding.py diff --git a/nemo/collections/asr/modules/transformer/transformer_generators.py b/nemo/collections/asr/modules/transformer/transformer_generators.py index bba13a05f25b..b78c473a956e 100644 --- a/nemo/collections/asr/modules/transformer/transformer_generators.py +++ b/nemo/collections/asr/modules/transformer/transformer_generators.py @@ -35,8 +35,8 @@ class GreedySequenceGenerator: Args: embedding: nn.Module, transforms input_ids into vector embeddings decoder: nn.Module, takes embeddings and produces hidden_states - log_softmax: nn.Module, takes hidden_states and produces log_probs - which correspond to probability distribution of tokens (ids) + classifier: nn.Module, takes hidden_states and produces + logits or log-probability distribution of tokens (ids) pad: index of padding token in the vocabulary bos: index of beginning of sequence token in the vocabulary eos: index of end of sequence token in the vocabulary @@ -52,7 +52,7 @@ def __init__( self, embedding, decoder, - log_softmax, + classifier, pad=0, bos=1, eos=2, @@ -65,8 +65,7 @@ def __init__( super().__init__() self.embedding = embedding self.decoder = decoder - self.log_softmax = log_softmax - self.log_softmax.mlp.log_softmax = False + self.classifier = classifier.set_log_softmax_enabled(False) self.pad, self.bos, self.eos = pad, bos, eos self.max_seq_length = max_sequence_length self.max_delta_len = max_delta_length @@ -113,7 +112,7 @@ def _one_step_forward( decoder_mems_list = self.decoder.forward( decoder_hidden_states, decoder_input_mask, decoder_mems_list, return_mems=True ) - logits = self.log_softmax.forward(hidden_states=decoder_mems_list[-1][:, -1:]) + logits = self.classifier.forward(hidden_states=decoder_mems_list[-1][:, -1:], temperature=self.temperature) return logits, decoder_mems_list def _prepare_for_search(self, decoder_input_ids=None, encoder_hidden_states=None): @@ -221,9 +220,9 @@ def freeze(self) -> None: for param in self.decoder.parameters(): param.requires_grad = False self.decoder.eval() - for param in self.log_softmax.parameters(): + for param in self.classifier.parameters(): param.requires_grad = False - self.log_softmax.eval() + self.classifier.eval() def unfreeze(self) -> None: """Unfreeze weights of embedding, decoder, and classification layers. @@ -234,14 +233,14 @@ def unfreeze(self) -> None: for param in self.decoder.parameters(): param.requires_grad = True self.decoder.train() - for param in self.log_softmax.parameters(): + for param in self.classifier.parameters(): param.requires_grad = True - self.log_softmax.train() + self.classifier.train() @contextmanager def as_frozen(self): """ - Context manager which temporarily freezes embedding, decoder, and log_softmax modules, + Context manager which temporarily freezes embedding, decoder, and classifier modules, yields control and finally unfreezes the modules. """ self.freeze() diff --git a/nemo/collections/asr/parts/submodules/token_classifier.py b/nemo/collections/asr/parts/submodules/token_classifier.py index 4061d19d9015..3a75da5ef0cb 100644 --- a/nemo/collections/asr/parts/submodules/token_classifier.py +++ b/nemo/collections/asr/parts/submodules/token_classifier.py @@ -15,12 +15,13 @@ from dataclasses import dataclass from typing import Dict, Optional +import torch from torch import nn as nn from nemo.collections.asr.parts.submodules.classifier import Classifier from nemo.collections.common.parts import MultiLayerPerceptron from nemo.core.classes import typecheck -from nemo.core.neural_types import LogitsType, LogprobsType, NeuralType +from nemo.core.neural_types import ChannelType, FloatType, LogitsType, LogprobsType, NeuralType __all__ = ['BertPretrainingTokenClassifier', 'TokenClassifier'] @@ -42,10 +43,14 @@ class TokenClassifier(Classifier): """ @property - def output_types(self) -> Optional[Dict[str, NeuralType]]: - """ - Returns definitions of module output ports. - """ + def input_types(self) -> Dict[str, NeuralType]: + return { + "hidden_states": NeuralType(('B', 'T', 'D'), ChannelType()), + "temperature": NeuralType(None, FloatType(), optional=True), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: if not self.log_softmax: return {"logits": NeuralType(('B', 'T', 'C'), LogitsType())} else: @@ -81,8 +86,12 @@ def __init__( ) self.post_init(use_transformer_init=use_transformer_init) + def set_log_softmax_enabled(self, value: bool) -> "TokenClassifier": + self.log_softmax = value + return self + @typecheck() - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor, temperature: float | None = None) -> torch.Tensor: """ Performs the forward step of the module. Args: @@ -91,7 +100,7 @@ def forward(self, hidden_states): Returns: logits value for each class [BATCH_SIZE x SEQ_LENGTH x NUM_CLASSES] """ hidden_states = self.dropout(hidden_states) - logits = self.mlp(hidden_states) + logits = self.mlp(hidden_states, temperature=temperature) return logits @@ -100,11 +109,15 @@ class BertPretrainingTokenClassifier(Classifier): A module to perform token level classification tasks for Bert pretraining. """ + @property + def input_types(self) -> Dict[str, NeuralType]: + return { + "hidden_states": NeuralType(('B', 'T', 'D'), ChannelType()), + "temperature": NeuralType(None, FloatType(), optional=True), + } + @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - """ - Returns definitions of module output ports. - """ if not self.log_softmax: return {"logits": NeuralType(('B', 'T', 'C'), LogitsType())} else: @@ -147,8 +160,12 @@ def __init__( ) self.post_init(use_transformer_init=use_transformer_init) + def set_log_softmax_enabled(self, value: bool) -> "BertPretrainingTokenClassifier": + self.log_softmax = value + return self + @typecheck() - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor, temperature: float | None = None) -> torch.Tensor: """ Performs the forward step of the module. Args: @@ -160,5 +177,5 @@ def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.act(hidden_states) transform = self.norm(hidden_states) - logits = self.mlp(transform) + logits = self.mlp(transform, temperature=temperature) return logits diff --git a/nemo/collections/common/parts/multi_layer_perceptron.py b/nemo/collections/common/parts/multi_layer_perceptron.py index 76c06bf23ea6..5110406fedfd 100644 --- a/nemo/collections/common/parts/multi_layer_perceptron.py +++ b/nemo/collections/common/parts/multi_layer_perceptron.py @@ -51,11 +51,15 @@ def __init__( def last_linear_layer(self): return getattr(self, f'layer{self.layers - 1}') - def forward(self, hidden_states): + def forward(self, hidden_states, temperature: float | None = None): output_states = hidden_states[:] for i in range(self.layers): output_states = getattr(self, f'layer{i}')(output_states) + if temperature is not None: + output_states = output_states / temperature + if self.log_softmax: output_states = torch.log_softmax(output_states, dim=-1) + return output_states diff --git a/tests/collections/asr/decoding/test_aed_decoding.py b/tests/collections/asr/decoding/test_aed_decoding.py new file mode 100644 index 000000000000..53571e77e506 --- /dev/null +++ b/tests/collections/asr/decoding/test_aed_decoding.py @@ -0,0 +1,107 @@ +import pytest +import torch + +from nemo.collections.asr.modules.transformer import TransformerDecoder, TransformerEmbedding +from nemo.collections.asr.modules.transformer.transformer_generators import ( + BeamSearchSequenceGenerator, + GreedySequenceGenerator, +) +from nemo.collections.asr.parts.submodules.token_classifier import TokenClassifier + + +@pytest.fixture() +def deterministic_rng(): + state = torch.get_rng_state() + torch.manual_seed(0) + yield + torch.set_rng_state(state) + + +@pytest.fixture() +def nnet(deterministic_rng): + ans = ( + TransformerEmbedding(vocab_size=8, hidden_size=2, max_sequence_length=32), + TransformerDecoder(num_layers=1, hidden_size=2, inner_size=4), + TokenClassifier(hidden_size=2, num_classes=8), + ) + ans = tuple(m.eval() for m in ans) + return ans + + +@pytest.fixture() +def inputs(): + B, T, C = 1, 5, 2 + return ( + torch.tensor([[1]], dtype=torch.long), # decoder_input_ids + torch.ones(B, T, C, dtype=torch.float), # encoder_hidden_states + torch.ones(B, T, dtype=torch.float), # encoder_input_mask + ) + + +def test_greedy_decoding(inputs, nnet): + gen = GreedySequenceGenerator(*nnet) + output = gen(*inputs) + + assert len(output) == 2 + best_path, hypotheses = output + + assert best_path is not None + assert torch.is_tensor(best_path) + assert best_path.shape == (1, 25) + + assert hypotheses is None + + +def test_temperature_sampling_decoding(inputs, nnet): + gen = GreedySequenceGenerator(*nnet, temperature=10.0, n_samples=2) + output = gen(*inputs) + + assert len(output) == 2 + best_path, hypotheses = output + + assert best_path is not None + assert torch.is_tensor(best_path) + assert best_path.shape == (1, 25) + + assert isinstance(hypotheses, list) + assert len(hypotheses) == 1 + (seq0,) = hypotheses + assert seq0.shape == (2, 25) + + +def test_beam_decoding_beam_scores_false(inputs, nnet): + gen = BeamSearchSequenceGenerator(*nnet, beam_size=2) + output = gen(*inputs, return_beam_scores=False) + + assert len(output) == 1 + (best_path,) = output + + assert best_path is not None + assert torch.is_tensor(best_path) + assert best_path.shape == (26,) + + +def test_beam_decoding_beam_scores_true(inputs, nnet): + gen = BeamSearchSequenceGenerator(*nnet, beam_size=2) + output = gen(*inputs, return_beam_scores=True) + + assert len(output) == 3 + beam_paths, scores, best_path = output + + assert beam_paths is not None + assert isinstance(beam_paths, list) + assert len(beam_paths) == 1 + (beam_paths_seq0,) = beam_paths + assert torch.is_tensor(beam_paths_seq0) + assert beam_paths_seq0.shape == (2, 26) + + assert scores is not None + assert isinstance(scores, list) + assert len(scores) == 1 + (scores_seq0,) = scores + assert torch.is_tensor(scores_seq0) + assert scores_seq0.shape == (2,) + + assert best_path is not None + assert torch.is_tensor(best_path) + assert best_path.shape == (1, 26)