diff --git a/nemo/collections/asr/modules/transformer/transformer_generators.py b/nemo/collections/asr/modules/transformer/transformer_generators.py index 4061f54a907a..b78c473a956e 100644 --- a/nemo/collections/asr/modules/transformer/transformer_generators.py +++ b/nemo/collections/asr/modules/transformer/transformer_generators.py @@ -15,6 +15,7 @@ from contextlib import contextmanager import torch +from torch.distributions import Categorical from nemo.collections.common.parts import NEG_INF, mask_padded_tokens @@ -34,8 +35,8 @@ class GreedySequenceGenerator: Args: embedding: nn.Module, transforms input_ids into vector embeddings decoder: nn.Module, takes embeddings and produces hidden_states - log_softmax: nn.Module, takes hidden_states and produces log_probs - which correspond to probability distribution of tokens (ids) + classifier: nn.Module, takes hidden_states and produces + logits or log-probability distribution of tokens (ids) pad: index of padding token in the vocabulary bos: index of beginning of sequence token in the vocabulary eos: index of end of sequence token in the vocabulary @@ -51,22 +52,26 @@ def __init__( self, embedding, decoder, - log_softmax, + classifier, pad=0, bos=1, eos=2, max_sequence_length=512, max_delta_length=20, batch_size=1, + n_samples=1, + temperature=None, ): super().__init__() self.embedding = embedding self.decoder = decoder - self.log_softmax = log_softmax + self.classifier = classifier.set_log_softmax_enabled(False) self.pad, self.bos, self.eos = pad, bos, eos self.max_seq_length = max_sequence_length self.max_delta_len = max_delta_length self.batch_size = batch_size + self.n_samples = n_samples + self.temperature = temperature def _one_step_forward( self, @@ -107,8 +112,8 @@ def _one_step_forward( decoder_mems_list = self.decoder.forward( decoder_hidden_states, decoder_input_mask, decoder_mems_list, return_mems=True ) - log_probs = self.log_softmax.forward(hidden_states=decoder_mems_list[-1][:, -1:]) - return log_probs, decoder_mems_list + logits = self.classifier.forward(hidden_states=decoder_mems_list[-1][:, -1:], temperature=self.temperature) + return logits, decoder_mems_list def _prepare_for_search(self, decoder_input_ids=None, encoder_hidden_states=None): """ @@ -145,30 +150,49 @@ def _forward( self, decoder_input_ids=None, encoder_hidden_states=None, encoder_input_mask=None, return_beam_scores=False ): assert not return_beam_scores + is_sampling = self.temperature is not None and self.n_samples > 1 + tgt, batch_size, max_generation_length = self._prepare_for_search(decoder_input_ids, encoder_hidden_states) + if is_sampling: + tgt = torch.repeat_interleave(tgt, self.n_samples, dim=0) + encoder_hidden_states = torch.repeat_interleave(encoder_hidden_states, self.n_samples, dim=0) + encoder_input_mask = torch.repeat_interleave(encoder_input_mask, self.n_samples, dim=0) + orig_batch_size = batch_size + batch_size = batch_size * self.n_samples # pad profile tracks sequences ending with token to replace # everything after with token decoder_parameter = next(self.decoder.parameters()) - pad_profile = torch.zeros(batch_size, 1).long().to(decoder_parameter.device) + pad_profile = torch.zeros(batch_size).long().to(decoder_parameter.device) decoder_mems_list = None for i in range(max_generation_length): - - log_probs, decoder_mems_list = self._one_step_forward( - tgt[:, -1:], encoder_hidden_states, encoder_input_mask, decoder_mems_list, i + if i == 0: + input_ids = tgt + else: + input_ids = tgt[:, -1:] + logits, decoder_mems_list = self._one_step_forward( + input_ids, encoder_hidden_states, encoder_input_mask, decoder_mems_list, i ) + if self.temperature is None: + next_tokens = torch.argmax(logits[:, -1], dim=-1) + else: + next_tokens = Categorical(logits=logits[:, -1] / self.temperature).sample() - next_tokens = torch.argmax(log_probs[:, -1], dim=-1, keepdim=True) next_tokens = self.pad * pad_profile + next_tokens * (1 - pad_profile) pad_profile = torch.max(pad_profile, (next_tokens == self.eos).long()) - tgt = torch.cat((tgt, next_tokens), dim=-1) + tgt = torch.cat((tgt, next_tokens.unsqueeze(1)), dim=-1) # abort generation if all sequences end with if pad_profile.sum() == batch_size: break - return tgt + samples = None + if is_sampling: + samples = list(tgt.view(orig_batch_size, self.n_samples, -1)) + tgt = tgt[:: self.n_samples] + + return tgt, samples def __call__( self, decoder_input_ids=None, encoder_hidden_states=None, encoder_input_mask=None, return_beam_scores=False @@ -196,9 +220,9 @@ def freeze(self) -> None: for param in self.decoder.parameters(): param.requires_grad = False self.decoder.eval() - for param in self.log_softmax.parameters(): + for param in self.classifier.parameters(): param.requires_grad = False - self.log_softmax.eval() + self.classifier.eval() def unfreeze(self) -> None: """Unfreeze weights of embedding, decoder, and classification layers. @@ -209,14 +233,14 @@ def unfreeze(self) -> None: for param in self.decoder.parameters(): param.requires_grad = True self.decoder.train() - for param in self.log_softmax.parameters(): + for param in self.classifier.parameters(): param.requires_grad = True - self.log_softmax.train() + self.classifier.train() @contextmanager def as_frozen(self): """ - Context manager which temporarily freezes embedding, decoder, and log_softmax modules, + Context manager which temporarily freezes embedding, decoder, and classifier modules, yields control and finally unfreezes the modules. """ self.freeze() diff --git a/nemo/collections/asr/parts/submodules/multitask_decoding.py b/nemo/collections/asr/parts/submodules/multitask_decoding.py index c336ae7d4170..f03469e0f5d3 100644 --- a/nemo/collections/asr/parts/submodules/multitask_decoding.py +++ b/nemo/collections/asr/parts/submodules/multitask_decoding.py @@ -25,6 +25,10 @@ AEDBeamInferConfig, TransformerAEDBeamInfer, ) +from nemo.collections.asr.parts.submodules.multitask_greedy_decoding import ( + AEDGreedyInferConfig, + TransformerAEDGreedyInfer, +) from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis, NBestHypotheses from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec @@ -60,11 +64,11 @@ class AbstractMultiTaskDecoding(ABC): The config may further contain the following sub-dictionaries: "greedy": - max_symbols: int, describing the maximum number of target tokens to decode per - timestep during greedy decoding. Setting to larger values allows longer sentences - to be decoded, at the cost of increased execution time. - preserve_frame_confidence: Same as above, overrides above value. - confidence_method_cfg: Same as above, overrides confidence_cfg.method_cfg. + temperature: None (disabled) or float, specifying this enables temperature sampling instead of greedy decoding. + + max_generation_delta: int = -1 # -1 means up to the max length of the decoder + + preserve_alignments: bool = False (unsupported) "beam": beam_size: int, defining the beam size for beam search. Must be >= 1. @@ -103,30 +107,44 @@ def __init__( self.preserve_alignments = self.cfg.get('preserve_alignments', None) self.compute_langs = self.cfg.get('compute_langs', False) self.compute_hypothesis_token_set = self.cfg.get('compute_hypothesis_token_set', False) + self.transformer_decoder = transformer_decoder + self.log_softmax_module = log_softmax_module + self.tokenizer = tokenizer + self.change_strategy(self.cfg.strategy) + def change_strategy(self, strategy: str) -> "AbstractMultiTaskDecoding": possible_strategies = ['greedy', 'greedy_batch', 'beam'] - if self.cfg.strategy not in possible_strategies: + if strategy not in possible_strategies: raise ValueError(f"Decoding strategy must be one of {possible_strategies}") + self.cfg.strategy = strategy + # Update preserve alignments if self.preserve_alignments is None: - if self.cfg.strategy in ['greedy', 'greedy_batch']: + if strategy in ['greedy', 'greedy_batch']: self.preserve_alignments = self.cfg.greedy.get('preserve_alignments', False) - elif self.cfg.strategy in ['beam']: + elif strategy in ['beam']: self.preserve_alignments = self.cfg.beam.get('preserve_alignments', False) - if self.cfg.strategy == 'greedy' or self.cfg.strategy == 'greedy_batch': + if strategy == 'greedy' or strategy == 'greedy_batch': - # self.decoding = None - raise NotImplementedError("Greedy decoding is not implemented yet.") + self.decoding = TransformerAEDGreedyInfer( + transformer_decoder=self.transformer_decoder, + log_softmax_module=self.log_softmax_module, + tokenizer=self.tokenizer, + max_generation_delta=self.cfg.greedy.get('max_generation_delta', 50), + preserve_alignments=self.preserve_alignments, + temperature=self.cfg.greedy.temperature, + n_samples=self.cfg.greedy.n_samples, + ) - elif self.cfg.strategy == 'beam': + elif strategy == 'beam': self.decoding = TransformerAEDBeamInfer( - transformer_decoder=transformer_decoder, - log_softmax_module=log_softmax_module, - tokenizer=tokenizer, + transformer_decoder=self.transformer_decoder, + log_softmax_module=self.log_softmax_module, + tokenizer=self.tokenizer, search_type=self.cfg.beam.get('search_type', 'default'), beam_size=self.cfg.beam.beam_size, length_penalty=self.cfg.beam.get('length_penalty', 0.0), @@ -139,9 +157,11 @@ def __init__( raise ValueError( f"Incorrect decoding strategy provided. Must be one of {possible_strategies}\n" - f"but was provided {self.cfg.strategy}" + f"but was provided {strategy}" ) + return self + def decode_predictions_tensor( self, encoder_hidden_states: torch.Tensor, @@ -476,9 +496,7 @@ class MultiTaskDecodingConfig: compute_langs: bool = False # greedy decoding config - # greedy: rnnt_greedy_decoding.GreedyBatchedRNNTInferConfig = field( - # default_factory=rnnt_greedy_decoding.GreedyBatchedRNNTInferConfig - # ) + greedy: AEDGreedyInferConfig = field(default_factory=AEDGreedyInferConfig) # beam decoding config beam: AEDBeamInferConfig = field(default_factory=lambda: AEDBeamInferConfig(beam_size=1)) diff --git a/nemo/collections/asr/parts/submodules/multitask_greedy_decoding.py b/nemo/collections/asr/parts/submodules/multitask_greedy_decoding.py new file mode 100644 index 000000000000..4701dab5548b --- /dev/null +++ b/nemo/collections/asr/parts/submodules/multitask_greedy_decoding.py @@ -0,0 +1,211 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List, Optional + +import torch + +from nemo.collections.asr.modules.transformer import GreedySequenceGenerator +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis, NBestHypotheses +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.core import Typing, typecheck +from nemo.core.neural_types import ChannelType, HypothesisType, LabelsType, MaskType, NeuralType +from nemo.utils import logging + + +def pack_hypotheses( + hypotheses: List[Hypothesis], beam_hypotheses: torch.Tensor, scores: List[Optional[float]] +) -> List[Hypothesis]: + + for idx, hyp in enumerate(hypotheses): # type: Hypothesis + if scores[idx] is not None: + hyp.score = scores[idx] + + hypi = beam_hypotheses[idx] + if torch.is_tensor(hypi): + hyp.y_sequence = hypi.long() + else: + hyp.y_sequence = torch.tensor(hypi, dtype=torch.long) + + if hyp.dec_state is not None: + hyp.dec_state = _states_to_device(hyp.dec_state) + + return hypotheses + + +def _states_to_device(dec_state, device='cpu'): + if torch.is_tensor(dec_state): + dec_state = dec_state.to(device) + + elif isinstance(dec_state, (list, tuple)): + dec_state = tuple(_states_to_device(dec_i, device) for dec_i in dec_state) + + return dec_state + + +class AEDGreedyInfer(ABC): + def __init__( + self, + transformer_decoder: torch.nn.Module, + log_softmax_module: torch.nn.Module, + tokenizer: TokenizerSpec, + search_type: str = 'default', + preserve_alignments: bool = False, + ): + super().__init__() + + self.transformer_decoder = transformer_decoder + self.log_softmax_module = log_softmax_module + self.tokenizer = tokenizer + self.search_type = search_type + + self.preserve_alignments = preserve_alignments + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + @abstractmethod + def forward( + self, + encoder_hidden_states: torch.Tensor, + encoder_input_mask: torch.Tensor, + decoder_input_ids: Optional[torch.Tensor] = None, + partial_hypotheses: Optional[List[Hypothesis]] = None, + ): + raise NotImplementedError() + + def set_decoding_type(self, decoding_type: str): + self.decoding_type = decoding_type + + +class TransformerAEDGreedyInfer(AEDGreedyInfer, Typing): + """ + A greedy decoder engine for AED Transformer models with support for temperature sampling. + """ + + @property + def input_types(self): + """Returns definitions of module input ports. + """ + # Input can be of dimention - + # ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels] + + return { + "encoder_hidden_states": NeuralType(tuple(('B', 'T', 'D')), ChannelType()), + "encoder_input_mask": NeuralType(tuple(('B', 'T')), MaskType()), + "decoder_input_ids": NeuralType(('B', 'T'), LabelsType()), + "partial_hypotheses": NeuralType(optional=True), + } + + @property + def output_types(self): + """Returns definitions of module output ports. + """ + return {"predictions": [NeuralType(elements_type=HypothesisType())]} + + def __init__( + self, + transformer_decoder: torch.nn.Module, + log_softmax_module: torch.nn.Module, + tokenizer: TokenizerSpec, + temperature: float | None = None, + max_generation_delta: int = 50, + preserve_alignments: bool = False, + n_samples: int = 1, + ): + super().__init__( + transformer_decoder=transformer_decoder, + log_softmax_module=log_softmax_module, + tokenizer=tokenizer, + preserve_alignments=preserve_alignments, + ) + self.temperature = temperature + self.n_samples = n_samples + self.greedy_search = GreedySequenceGenerator( + embedding=transformer_decoder.embedding, + decoder=transformer_decoder.decoder, + log_softmax=log_softmax_module, + max_sequence_length=transformer_decoder.max_sequence_length, + bos=tokenizer.bos_id, + pad=tokenizer.pad_id, + eos=tokenizer.eos_id, + max_delta_length=max_generation_delta, + temperature=self.temperature, + n_samples=n_samples, + ) + + self.preserve_alignments = preserve_alignments + if self.preserve_alignments: + logging.info( + "Preservation of alignments was requested but {} does not implement it.".format( + self.__class__.__name__ + ) + ) + + @typecheck() + def forward( + self, + encoder_hidden_states: torch.Tensor, + encoder_input_mask: torch.Tensor, + decoder_input_ids: Optional[torch.Tensor] = None, + partial_hypotheses: Optional[List[Hypothesis]] = None, + ): + """Returns a list of hypotheses given an input batch of the encoder hidden embedding. + Output token is generated auto-repressively. + + Args: + decoder_output: A tensor of size (batch, timesteps, features) or (batch, timesteps) (each timestep is a label). + decoder_lengths: list of int representing the length of each sequence + output sequence. + + Returns: + packed list containing batch number of sentences (Hypotheses). + """ + with torch.inference_mode(): + best_hypo, topk_hypotheses = self.greedy_search( + encoder_hidden_states=encoder_hidden_states, + encoder_input_mask=encoder_input_mask, + decoder_input_ids=decoder_input_ids, + ) + + if topk_hypotheses is not None: + topk_hypotheses = [x.detach().cpu() for x in topk_hypotheses] # each item is [beam, seq_len] + beam_scores = [[None] * self.n_samples for _ in topk_hypotheses] # each item is [beam,] + packed_result = [] + for i in range(len(topk_hypotheses)): + hypotheses = [Hypothesis(score=0.0, y_sequence=[], timestep=[]) for _ in range(self.n_samples)] + # Pack results into Hypotheses + packed_result.append( + NBestHypotheses(pack_hypotheses(hypotheses, topk_hypotheses[i], beam_scores[i])) + ) + else: + beam_scores = [None for _ in range(len(best_hypo))] + best_hypo = best_hypo.cpu() + hypotheses = [ + Hypothesis(score=0.0, y_sequence=[], timestep=[]) for _ in range(encoder_hidden_states.shape[0]) + ] + # Pack results into Hypotheses + packed_result = pack_hypotheses(hypotheses, best_hypo, beam_scores) + + return (packed_result,) + + +@dataclass +class AEDGreedyInferConfig: + temperature: float | None = None + max_generation_delta: int = -1 # -1 means up to the max length of the decoder + preserve_alignments: bool = False + n_samples: int = 1 diff --git a/nemo/collections/asr/parts/submodules/token_classifier.py b/nemo/collections/asr/parts/submodules/token_classifier.py index 4061d19d9015..3a75da5ef0cb 100644 --- a/nemo/collections/asr/parts/submodules/token_classifier.py +++ b/nemo/collections/asr/parts/submodules/token_classifier.py @@ -15,12 +15,13 @@ from dataclasses import dataclass from typing import Dict, Optional +import torch from torch import nn as nn from nemo.collections.asr.parts.submodules.classifier import Classifier from nemo.collections.common.parts import MultiLayerPerceptron from nemo.core.classes import typecheck -from nemo.core.neural_types import LogitsType, LogprobsType, NeuralType +from nemo.core.neural_types import ChannelType, FloatType, LogitsType, LogprobsType, NeuralType __all__ = ['BertPretrainingTokenClassifier', 'TokenClassifier'] @@ -42,10 +43,14 @@ class TokenClassifier(Classifier): """ @property - def output_types(self) -> Optional[Dict[str, NeuralType]]: - """ - Returns definitions of module output ports. - """ + def input_types(self) -> Dict[str, NeuralType]: + return { + "hidden_states": NeuralType(('B', 'T', 'D'), ChannelType()), + "temperature": NeuralType(None, FloatType(), optional=True), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: if not self.log_softmax: return {"logits": NeuralType(('B', 'T', 'C'), LogitsType())} else: @@ -81,8 +86,12 @@ def __init__( ) self.post_init(use_transformer_init=use_transformer_init) + def set_log_softmax_enabled(self, value: bool) -> "TokenClassifier": + self.log_softmax = value + return self + @typecheck() - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor, temperature: float | None = None) -> torch.Tensor: """ Performs the forward step of the module. Args: @@ -91,7 +100,7 @@ def forward(self, hidden_states): Returns: logits value for each class [BATCH_SIZE x SEQ_LENGTH x NUM_CLASSES] """ hidden_states = self.dropout(hidden_states) - logits = self.mlp(hidden_states) + logits = self.mlp(hidden_states, temperature=temperature) return logits @@ -100,11 +109,15 @@ class BertPretrainingTokenClassifier(Classifier): A module to perform token level classification tasks for Bert pretraining. """ + @property + def input_types(self) -> Dict[str, NeuralType]: + return { + "hidden_states": NeuralType(('B', 'T', 'D'), ChannelType()), + "temperature": NeuralType(None, FloatType(), optional=True), + } + @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - """ - Returns definitions of module output ports. - """ if not self.log_softmax: return {"logits": NeuralType(('B', 'T', 'C'), LogitsType())} else: @@ -147,8 +160,12 @@ def __init__( ) self.post_init(use_transformer_init=use_transformer_init) + def set_log_softmax_enabled(self, value: bool) -> "BertPretrainingTokenClassifier": + self.log_softmax = value + return self + @typecheck() - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor, temperature: float | None = None) -> torch.Tensor: """ Performs the forward step of the module. Args: @@ -160,5 +177,5 @@ def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.act(hidden_states) transform = self.norm(hidden_states) - logits = self.mlp(transform) + logits = self.mlp(transform, temperature=temperature) return logits diff --git a/nemo/collections/common/parts/multi_layer_perceptron.py b/nemo/collections/common/parts/multi_layer_perceptron.py index 76c06bf23ea6..5110406fedfd 100644 --- a/nemo/collections/common/parts/multi_layer_perceptron.py +++ b/nemo/collections/common/parts/multi_layer_perceptron.py @@ -51,11 +51,15 @@ def __init__( def last_linear_layer(self): return getattr(self, f'layer{self.layers - 1}') - def forward(self, hidden_states): + def forward(self, hidden_states, temperature: float | None = None): output_states = hidden_states[:] for i in range(self.layers): output_states = getattr(self, f'layer{i}')(output_states) + if temperature is not None: + output_states = output_states / temperature + if self.log_softmax: output_states = torch.log_softmax(output_states, dim=-1) + return output_states diff --git a/tests/collections/asr/decoding/test_aed_decoding.py b/tests/collections/asr/decoding/test_aed_decoding.py new file mode 100644 index 000000000000..53571e77e506 --- /dev/null +++ b/tests/collections/asr/decoding/test_aed_decoding.py @@ -0,0 +1,107 @@ +import pytest +import torch + +from nemo.collections.asr.modules.transformer import TransformerDecoder, TransformerEmbedding +from nemo.collections.asr.modules.transformer.transformer_generators import ( + BeamSearchSequenceGenerator, + GreedySequenceGenerator, +) +from nemo.collections.asr.parts.submodules.token_classifier import TokenClassifier + + +@pytest.fixture() +def deterministic_rng(): + state = torch.get_rng_state() + torch.manual_seed(0) + yield + torch.set_rng_state(state) + + +@pytest.fixture() +def nnet(deterministic_rng): + ans = ( + TransformerEmbedding(vocab_size=8, hidden_size=2, max_sequence_length=32), + TransformerDecoder(num_layers=1, hidden_size=2, inner_size=4), + TokenClassifier(hidden_size=2, num_classes=8), + ) + ans = tuple(m.eval() for m in ans) + return ans + + +@pytest.fixture() +def inputs(): + B, T, C = 1, 5, 2 + return ( + torch.tensor([[1]], dtype=torch.long), # decoder_input_ids + torch.ones(B, T, C, dtype=torch.float), # encoder_hidden_states + torch.ones(B, T, dtype=torch.float), # encoder_input_mask + ) + + +def test_greedy_decoding(inputs, nnet): + gen = GreedySequenceGenerator(*nnet) + output = gen(*inputs) + + assert len(output) == 2 + best_path, hypotheses = output + + assert best_path is not None + assert torch.is_tensor(best_path) + assert best_path.shape == (1, 25) + + assert hypotheses is None + + +def test_temperature_sampling_decoding(inputs, nnet): + gen = GreedySequenceGenerator(*nnet, temperature=10.0, n_samples=2) + output = gen(*inputs) + + assert len(output) == 2 + best_path, hypotheses = output + + assert best_path is not None + assert torch.is_tensor(best_path) + assert best_path.shape == (1, 25) + + assert isinstance(hypotheses, list) + assert len(hypotheses) == 1 + (seq0,) = hypotheses + assert seq0.shape == (2, 25) + + +def test_beam_decoding_beam_scores_false(inputs, nnet): + gen = BeamSearchSequenceGenerator(*nnet, beam_size=2) + output = gen(*inputs, return_beam_scores=False) + + assert len(output) == 1 + (best_path,) = output + + assert best_path is not None + assert torch.is_tensor(best_path) + assert best_path.shape == (26,) + + +def test_beam_decoding_beam_scores_true(inputs, nnet): + gen = BeamSearchSequenceGenerator(*nnet, beam_size=2) + output = gen(*inputs, return_beam_scores=True) + + assert len(output) == 3 + beam_paths, scores, best_path = output + + assert beam_paths is not None + assert isinstance(beam_paths, list) + assert len(beam_paths) == 1 + (beam_paths_seq0,) = beam_paths + assert torch.is_tensor(beam_paths_seq0) + assert beam_paths_seq0.shape == (2, 26) + + assert scores is not None + assert isinstance(scores, list) + assert len(scores) == 1 + (scores_seq0,) = scores + assert torch.is_tensor(scores_seq0) + assert scores_seq0.shape == (2,) + + assert best_path is not None + assert torch.is_tensor(best_path) + assert best_path.shape == (1, 26)