Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Canary greedy and temperature decoding #8885

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 42 additions & 18 deletions nemo/collections/asr/modules/transformer/transformer_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from contextlib import contextmanager

import torch
from torch.distributions import Categorical

from nemo.collections.common.parts import NEG_INF, mask_padded_tokens

Expand All @@ -34,8 +35,8 @@ class GreedySequenceGenerator:
Args:
embedding: nn.Module, transforms input_ids into vector embeddings
decoder: nn.Module, takes embeddings and produces hidden_states
log_softmax: nn.Module, takes hidden_states and produces log_probs
which correspond to probability distribution of tokens (ids)
classifier: nn.Module, takes hidden_states and produces
logits or log-probability distribution of tokens (ids)
pad: index of padding token in the vocabulary
bos: index of beginning of sequence token in the vocabulary
eos: index of end of sequence token in the vocabulary
Expand All @@ -51,22 +52,26 @@ def __init__(
self,
embedding,
decoder,
log_softmax,
classifier,
pad=0,
bos=1,
eos=2,
max_sequence_length=512,
max_delta_length=20,
batch_size=1,
n_samples=1,
temperature=None,
):
super().__init__()
self.embedding = embedding
self.decoder = decoder
self.log_softmax = log_softmax
self.classifier = classifier.set_log_softmax_enabled(False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add a check to see if the classifier has the set_log_softmax_enabled() function, and if not, we default to not using temperature sampling and print out a warning?

self.pad, self.bos, self.eos = pad, bos, eos
self.max_seq_length = max_sequence_length
self.max_delta_len = max_delta_length
self.batch_size = batch_size
self.n_samples = n_samples
self.temperature = temperature

def _one_step_forward(
self,
Expand Down Expand Up @@ -107,8 +112,8 @@ def _one_step_forward(
decoder_mems_list = self.decoder.forward(
decoder_hidden_states, decoder_input_mask, decoder_mems_list, return_mems=True
)
log_probs = self.log_softmax.forward(hidden_states=decoder_mems_list[-1][:, -1:])
return log_probs, decoder_mems_list
logits = self.classifier.forward(hidden_states=decoder_mems_list[-1][:, -1:], temperature=self.temperature)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we add a check to see if the forward function has temperature arg?

return logits, decoder_mems_list

def _prepare_for_search(self, decoder_input_ids=None, encoder_hidden_states=None):
"""
Expand Down Expand Up @@ -145,30 +150,49 @@ def _forward(
self, decoder_input_ids=None, encoder_hidden_states=None, encoder_input_mask=None, return_beam_scores=False
):
assert not return_beam_scores
is_sampling = self.temperature is not None and self.n_samples > 1

tgt, batch_size, max_generation_length = self._prepare_for_search(decoder_input_ids, encoder_hidden_states)
if is_sampling:
tgt = torch.repeat_interleave(tgt, self.n_samples, dim=0)
encoder_hidden_states = torch.repeat_interleave(encoder_hidden_states, self.n_samples, dim=0)
encoder_input_mask = torch.repeat_interleave(encoder_input_mask, self.n_samples, dim=0)
orig_batch_size = batch_size
batch_size = batch_size * self.n_samples

# pad profile tracks sequences ending with <eos> token to replace
# everything after <eos> with <pad> token
decoder_parameter = next(self.decoder.parameters())
pad_profile = torch.zeros(batch_size, 1).long().to(decoder_parameter.device)
pad_profile = torch.zeros(batch_size).long().to(decoder_parameter.device)

decoder_mems_list = None
for i in range(max_generation_length):

log_probs, decoder_mems_list = self._one_step_forward(
tgt[:, -1:], encoder_hidden_states, encoder_input_mask, decoder_mems_list, i
if i == 0:
input_ids = tgt
else:
input_ids = tgt[:, -1:]
logits, decoder_mems_list = self._one_step_forward(
input_ids, encoder_hidden_states, encoder_input_mask, decoder_mems_list, i
)
if self.temperature is None:
next_tokens = torch.argmax(logits[:, -1], dim=-1)
else:
next_tokens = Categorical(logits=logits[:, -1] / self.temperature).sample()

next_tokens = torch.argmax(log_probs[:, -1], dim=-1, keepdim=True)
next_tokens = self.pad * pad_profile + next_tokens * (1 - pad_profile)
pad_profile = torch.max(pad_profile, (next_tokens == self.eos).long())
tgt = torch.cat((tgt, next_tokens), dim=-1)
tgt = torch.cat((tgt, next_tokens.unsqueeze(1)), dim=-1)

# abort generation if all sequences end with <eos>
if pad_profile.sum() == batch_size:
break

return tgt
samples = None
if is_sampling:
samples = list(tgt.view(orig_batch_size, self.n_samples, -1))
tgt = tgt[:: self.n_samples]

return tgt, samples
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we're adding an additional output samples, do we also need to update the __call__ function where the _forward function is called?


def __call__(
self, decoder_input_ids=None, encoder_hidden_states=None, encoder_input_mask=None, return_beam_scores=False
Expand Down Expand Up @@ -196,9 +220,9 @@ def freeze(self) -> None:
for param in self.decoder.parameters():
param.requires_grad = False
self.decoder.eval()
for param in self.log_softmax.parameters():
for param in self.classifier.parameters():
param.requires_grad = False
self.log_softmax.eval()
self.classifier.eval()

def unfreeze(self) -> None:
"""Unfreeze weights of embedding, decoder, and classification layers.
Expand All @@ -209,14 +233,14 @@ def unfreeze(self) -> None:
for param in self.decoder.parameters():
param.requires_grad = True
self.decoder.train()
for param in self.log_softmax.parameters():
for param in self.classifier.parameters():
param.requires_grad = True
self.log_softmax.train()
self.classifier.train()

@contextmanager
def as_frozen(self):
"""
Context manager which temporarily freezes embedding, decoder, and log_softmax modules,
Context manager which temporarily freezes embedding, decoder, and classifier modules,
yields control and finally unfreezes the modules.
"""
self.freeze()
Expand Down
56 changes: 37 additions & 19 deletions nemo/collections/asr/parts/submodules/multitask_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
AEDBeamInferConfig,
TransformerAEDBeamInfer,
)
from nemo.collections.asr.parts.submodules.multitask_greedy_decoding import (
AEDGreedyInferConfig,
TransformerAEDGreedyInfer,
)
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis, NBestHypotheses
from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
Expand Down Expand Up @@ -60,11 +64,11 @@

The config may further contain the following sub-dictionaries:
"greedy":
max_symbols: int, describing the maximum number of target tokens to decode per
timestep during greedy decoding. Setting to larger values allows longer sentences
to be decoded, at the cost of increased execution time.
preserve_frame_confidence: Same as above, overrides above value.
confidence_method_cfg: Same as above, overrides confidence_cfg.method_cfg.
temperature: None (disabled) or float, specifying this enables temperature sampling instead of greedy decoding.

max_generation_delta: int = -1 # -1 means up to the max length of the decoder

preserve_alignments: bool = False (unsupported)

"beam":
beam_size: int, defining the beam size for beam search. Must be >= 1.
Expand Down Expand Up @@ -103,30 +107,44 @@
self.preserve_alignments = self.cfg.get('preserve_alignments', None)
self.compute_langs = self.cfg.get('compute_langs', False)
self.compute_hypothesis_token_set = self.cfg.get('compute_hypothesis_token_set', False)
self.transformer_decoder = transformer_decoder
self.log_softmax_module = log_softmax_module
self.tokenizer = tokenizer

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute tokenizer, which was previously defined in subclass
MultiTaskDecoding
.
self.change_strategy(self.cfg.strategy)

def change_strategy(self, strategy: str) -> "AbstractMultiTaskDecoding":
possible_strategies = ['greedy', 'greedy_batch', 'beam']
if self.cfg.strategy not in possible_strategies:
if strategy not in possible_strategies:
raise ValueError(f"Decoding strategy must be one of {possible_strategies}")

self.cfg.strategy = strategy

# Update preserve alignments
if self.preserve_alignments is None:
if self.cfg.strategy in ['greedy', 'greedy_batch']:
if strategy in ['greedy', 'greedy_batch']:
self.preserve_alignments = self.cfg.greedy.get('preserve_alignments', False)

elif self.cfg.strategy in ['beam']:
elif strategy in ['beam']:
self.preserve_alignments = self.cfg.beam.get('preserve_alignments', False)

if self.cfg.strategy == 'greedy' or self.cfg.strategy == 'greedy_batch':
if strategy == 'greedy' or strategy == 'greedy_batch':

# self.decoding = None
raise NotImplementedError("Greedy decoding is not implemented yet.")
self.decoding = TransformerAEDGreedyInfer(
transformer_decoder=self.transformer_decoder,
log_softmax_module=self.log_softmax_module,
tokenizer=self.tokenizer,
max_generation_delta=self.cfg.greedy.get('max_generation_delta', 50),
preserve_alignments=self.preserve_alignments,
temperature=self.cfg.greedy.temperature,
n_samples=self.cfg.greedy.n_samples,
)

elif self.cfg.strategy == 'beam':
elif strategy == 'beam':

self.decoding = TransformerAEDBeamInfer(
transformer_decoder=transformer_decoder,
log_softmax_module=log_softmax_module,
tokenizer=tokenizer,
transformer_decoder=self.transformer_decoder,
log_softmax_module=self.log_softmax_module,
tokenizer=self.tokenizer,
search_type=self.cfg.beam.get('search_type', 'default'),
beam_size=self.cfg.beam.beam_size,
length_penalty=self.cfg.beam.get('length_penalty', 0.0),
Expand All @@ -139,9 +157,11 @@

raise ValueError(
f"Incorrect decoding strategy provided. Must be one of {possible_strategies}\n"
f"but was provided {self.cfg.strategy}"
f"but was provided {strategy}"
)

return self

def decode_predictions_tensor(
self,
encoder_hidden_states: torch.Tensor,
Expand Down Expand Up @@ -476,9 +496,7 @@
compute_langs: bool = False

# greedy decoding config
# greedy: rnnt_greedy_decoding.GreedyBatchedRNNTInferConfig = field(
# default_factory=rnnt_greedy_decoding.GreedyBatchedRNNTInferConfig
# )
greedy: AEDGreedyInferConfig = field(default_factory=AEDGreedyInferConfig)

# beam decoding config
beam: AEDBeamInferConfig = field(default_factory=lambda: AEDBeamInferConfig(beam_size=1))
Expand Down
Loading
Loading