From a3791156e85f2c6e8d300b5444539a8235f955dc Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Tue, 3 Jan 2023 12:18:24 -0800 Subject: [PATCH] Add custom stop token ids for generation (#20727) * Add StopIdStoppingCriteria * add a working test for stop id criteria * add to global scope * add stop_ids to generate * add pipeline test * use tokenizer encode in test * add test to generation utils * reformat * fixup * make-fix-copies * rename to stop_token_id * use stop_tokens instead * add to text to text generation * make fixup * make repo-consistency * Add support for list of ints for eos_token_id inside generation/utils.py * Instead of having if elses, cast the eos_token_id into a List[int] * Add List[int] support for logits_process.py * add List[int] for beam_search.py * add List[int] for forced_eos_token_id * revert stop token id stopping criteria changes * make fixup * fix tests * add eos_token_id to generation/utils.py and added tests test_utils.py * add eos_token_id type hints and fix for pad tokens * add comments * remove some prints and remove forced false test * fix * put back test_stop_sequence_stopping_criteria * remove unused import and make fixup * add a none check * update docstring * add more docstring for list ints * make fixup --- src/transformers/generation/beam_search.py | 44 +++++--- .../generation/configuration_utils.py | 13 ++- src/transformers/generation/logits_process.py | 61 ++++++---- src/transformers/generation/utils.py | 51 ++++++--- tests/generation/test_utils.py | 105 +++++++++++++++++- 5 files changed, 210 insertions(+), 64 deletions(-) diff --git a/src/transformers/generation/beam_search.py b/src/transformers/generation/beam_search.py index d22fbaf280deef..6e4f9cb936e8f6 100644 --- a/src/transformers/generation/beam_search.py +++ b/src/transformers/generation/beam_search.py @@ -16,7 +16,7 @@ import warnings from abc import ABC, abstractmethod from collections import UserDict -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -42,8 +42,8 @@ Beam indices indicating to which beam hypothesis the `next_tokens` correspond. pad_token_id (`int`, *optional*): The id of the *padding* token. - eos_token_id (`int`, *optional*): - The id of the *end-of-sequence* token. + eos_token_id (`Union[int, List[int]]`, *optional*): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. Return: `UserDict`: A dictionary composed of the fields as defined above: @@ -74,8 +74,8 @@ The beam indices indicating to which beam the `final_beam_tokens` shall be added. pad_token_id (`int`, *optional*): The id of the *padding* token. - eos_token_id (`int`, *optional*): - The id of the *end-of-sequence* token. + eos_token_id (`Union[int, List[int]]`, *optional*): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. Return: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences. @@ -212,7 +212,7 @@ def process( next_tokens: torch.LongTensor, next_indices: torch.LongTensor, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, beam_indices: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor]: cur_len = input_ids.shape[-1] @@ -234,6 +234,9 @@ def process( next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device) next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: if self.num_beams < len(beam_hyp): @@ -253,7 +256,7 @@ def process( ): batch_beam_idx = batch_idx * self.group_size + next_index # add to generated hypotheses if end of sentence - if (eos_token_id is not None) and (next_token.item() == eos_token_id): + if (eos_token_id is not None) and (next_token.item() in eos_token_id): # if beam_token does not belong to top num_beams tokens, it should not be added is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size if is_beam_token_worse_than_top_num_beams: @@ -307,11 +310,14 @@ def finalize( final_beam_indices: torch.LongTensor, max_length: int, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, beam_indices: Optional[torch.LongTensor] = None, ) -> Tuple[torch.LongTensor]: batch_size = len(self._beam_hyps) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + # finalize all open beam hypotheses and add to generated hypotheses for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: @@ -376,7 +382,8 @@ def finalize( indices[i, : len(best_idx)] = torch.tensor(best_idx) if sent_lengths[i] < sent_max_len: - decoded[i, sent_lengths[i]] = eos_token_id + # inserting only the first eos_token_id + decoded[i, sent_lengths[i]] = eos_token_id[0] return UserDict( { @@ -491,7 +498,7 @@ def process( next_indices: torch.LongTensor, scores_for_all_vocab: torch.FloatTensor, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, ) -> Tuple[torch.Tensor]: r""" Args: @@ -512,8 +519,8 @@ def process( The scores of all tokens in the vocabulary for each of the beam hypotheses. pad_token_id (`int`, *optional*): The id of the *padding* token. - eos_token_id (`int`, *optional*): - The id of the *end-of-sequence* token. + eos_token_id (`Union[int, List[int]]`, *optional*): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. Return: `UserDict`: A dictionary composed of the fields as defined above: @@ -549,6 +556,9 @@ def process( next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device) next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: if self.num_beams < len(beam_hyp): @@ -568,7 +578,7 @@ def process( ): batch_beam_idx = batch_idx * self.group_size + next_index # add to generated hypotheses if end of sentence - if (eos_token_id is not None) and (next_token.item() == eos_token_id): + if (eos_token_id is not None) and (next_token.item() in eos_token_id): # if beam_token does not belong to top num_beams tokens, it should not be added is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size @@ -773,10 +783,13 @@ def finalize( final_beam_indices: torch.LongTensor, max_length: int, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, ) -> Tuple[torch.LongTensor]: batch_size = len(self._beam_hyps) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + # finalize all open beam hypotheses and add to generated hypotheses for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: @@ -840,7 +853,8 @@ def finalize( for i, hypo in enumerate(best): decoded[i, : sent_lengths[i]] = hypo if sent_lengths[i] < sent_max_len: - decoded[i, sent_lengths[i]] = eos_token_id + # inserting only the first eos_token_id + decoded[i, sent_lengths[i]] = eos_token_id[0] return UserDict( { diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index a477ebe4203c43..a01222c8b41ecf 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -142,8 +142,9 @@ class GenerationConfig(PushToHubMixin): The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be the target language token. - forced_eos_token_id (`int`, *optional*, defaults to `model.config.forced_eos_token_id`): - The id of the token to force as the last generated token when `max_length` is reached. + forced_eos_token_id (`Union[int, List[int]]`, *optional*, defaults to `model.config.forced_eos_token_id`): + The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a + list to set multiple *end-of-sequence* tokens. remove_invalid_values (`bool`, *optional*, defaults to `model.config.remove_invalid_values`): Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to crash. Note that using `remove_invalid_values` can slow down generation. @@ -152,10 +153,10 @@ class GenerationConfig(PushToHubMixin): generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty starts and `decay_factor` represents the factor of exponential decay suppress_tokens (`List[int]`, *optional*): - A list of tokens that will be supressed at generation. The `SupressTokens` logit processor will set their + A list of tokens that will be suppressed at generation. The `SupressTokens` logit processor will set their log probs to `-inf` so that they are not sampled. begin_suppress_tokens (`List[int]`, *optional*): - A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens` logit + A list of tokens that will be suppressed at the beginning of the generation. The `SupressBeginTokens` logit processor will set their log probs to `-inf` so that they are not sampled. forced_decoder_ids (`List[List[int]]`, *optional*): A list of pairs of integers which indicates a mapping from generation indices to token indices that will be @@ -183,8 +184,8 @@ class GenerationConfig(PushToHubMixin): The id of the *padding* token. bos_token_id (`int`, *optional*): The id of the *beginning-of-sequence* token. - eos_token_id (`int`, *optional*): - The id of the *end-of-sequence* token. + eos_token_id (`Union[int, List[int]]`, *optional*): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. > Generation parameters exclusive to encoder-decoder models diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 09d88d89fe5135..1639a98d93be2c 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -15,7 +15,7 @@ import inspect import math -from typing import Callable, Iterable, List, Optional, Tuple +from typing import Callable, Iterable, List, Optional, Tuple, Union import numpy as np import torch @@ -100,16 +100,18 @@ class MinLengthLogitsProcessor(LogitsProcessor): Args: min_length (`int`): The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`. - eos_token_id (`int`): - The id of the *end-of-sequence* token. + eos_token_id (`Union[int, List[int]]`): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. """ - def __init__(self, min_length: int, eos_token_id: int): + def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]): if not isinstance(min_length, int) or min_length < 0: raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}") - if not isinstance(eos_token_id, int) or eos_token_id < 0: - raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}") + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + if not all([isinstance(i, int) for i in eos_token_id]) or any([i < 0 for i in eos_token_id]): + raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") self.min_length = min_length self.eos_token_id = eos_token_id @@ -117,7 +119,8 @@ def __init__(self, min_length: int, eos_token_id: int): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: cur_len = input_ids.shape[-1] if cur_len < self.min_length: - scores[:, self.eos_token_id] = -float("inf") + for i in self.eos_token_id: + scores[:, i] = -float("inf") return scores @@ -431,11 +434,11 @@ class NoBadWordsLogitsProcessor(LogitsProcessor): List of list of token ids that are not allowed to be generated. In order to get the token ids of the words that should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True, add_special_tokens=False).input_ids`. - eos_token_id (`int`): - The id of the *end-of-sequence* token. + eos_token_id (`Union[int, List[int]]`): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. """ - def __init__(self, bad_words_ids: List[List[int]], eos_token_id: int): + def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]): if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0: raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.") @@ -449,7 +452,14 @@ def __init__(self, bad_words_ids: List[List[int]], eos_token_id: int): f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}." ) - bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids)) + if eos_token_id is None: + eos_token_id = [] + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + + bad_words_ids = list( + filter(lambda bad_token_seq: all([bad_token_seq != [i] for i in eos_token_id]), bad_words_ids) + ) self.bad_words_id_length_1 = [] self.bad_words_id_length_greater_than_1 = [] for word in bad_words_ids: @@ -664,20 +674,24 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor): Args: max_length (`int`): The maximum length of the sequence to be generated. - eos_token_id (`int`): - The id of the token to force as the last generated token when `max_length` is reached. + eos_token_id (`Union[int, List[int]]`): + The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a + list to set multiple *end-of-sequence* tokens. """ - def __init__(self, max_length: int, eos_token_id: int): + def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]): self.max_length = max_length + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] self.eos_token_id = eos_token_id def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: cur_len = input_ids.shape[-1] if cur_len == self.max_length - 1: num_tokens = scores.shape[1] - scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]] = -float("inf") - scores[:, self.eos_token_id] = 0 + scores[:, [i for i in range(num_tokens) if i not in self.eos_token_id]] = -float("inf") + for i in self.eos_token_id: + scores[:, i] = 0 return scores @@ -707,23 +721,26 @@ class ExponentialDecayLengthPenalty(LogitsProcessor): exponential_decay_length_penalty (`tuple(int, float)`, *optional*): This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty starts and `decay_factor` represents the factor of exponential decay - eos_token_id (`int`): - The id of the *end-of-sequence* token. + eos_token_id (`Union[int, List[int]]`): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. input_ids_seq_length (`int`): The length of the input sequence. """ - def __init__(self, exponential_decay_length_penalty: Tuple, eos_token_id: int, input_ids_seq_length: int): + def __init__( + self, exponential_decay_length_penalty: Tuple, eos_token_id: Union[int, List[int]], input_ids_seq_length: int + ): self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length self.regulation_factor = exponential_decay_length_penalty[1] + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] self.eos_token_id = eos_token_id def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor: cur_len = input_ids.shape[-1] if cur_len > self.regulation_start: - scores[:, self.eos_token_id] = scores[:, self.eos_token_id] * pow( - self.regulation_factor, cur_len - self.regulation_start - ) + for i in self.eos_token_id: + scores[:, i] = scores[:, i] * pow(self.regulation_factor, cur_len - self.regulation_start) return scores diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 1f7ac3b2e98e4f..65c968a3d55a40 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -575,11 +575,13 @@ def _prepare_attention_mask_for_generation( self, inputs: torch.Tensor, pad_token_id: Optional[int], - eos_token_id: Optional[int], + eos_token_id: Optional[Union[int, List[int]]], ) -> torch.LongTensor: is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) - is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id != eos_token_id) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id) # Check if input is input_ids and padded -> only then is attention_mask defined if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: @@ -902,7 +904,7 @@ def compute_transition_beam_scores( sequences: torch.Tensor, scores: Tuple[torch.Tensor], beam_indices: torch.Tensor, - eos_token_id: int = None, + eos_token_id: Union[int, List[int]] = None, ): """compute the transition probabilities of sequences given generation scores and beam indices""" @@ -1165,10 +1167,11 @@ def generate( "The attention mask and the pad token id were not set. As a consequence, you may observe " "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." ) - logger.warning( - f"Setting `pad_token_id` to `eos_token_id`:{generation_config.eos_token_id} for open-end generation." - ) - generation_config.pad_token_id = generation_config.eos_token_id + eos_token_id = generation_config.eos_token_id + if isinstance(eos_token_id, list): + eos_token_id = eos_token_id[0] + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") + generation_config.pad_token_id = eos_token_id # 3. Define model inputs # inputs_tensor has to be defined @@ -1624,7 +1627,7 @@ def contrastive_search( logits_warper: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -1708,6 +1711,8 @@ def contrastive_search( stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions @@ -1930,7 +1935,7 @@ def contrastive_search( # if eos_token was found in one sentence, set sentence to finished if eos_token_id is not None: - unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) # stop when each sentence is finished, or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): @@ -1967,7 +1972,7 @@ def greedy_search( stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -2067,6 +2072,8 @@ def greedy_search( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions @@ -2162,7 +2169,7 @@ def greedy_search( # if eos_token was found in one sentence, set sentence to finished if eos_token_id is not None: - unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) # stop when each sentence is finished, or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): @@ -2200,7 +2207,7 @@ def sample( logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -2320,6 +2327,8 @@ def sample( logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions @@ -2418,7 +2427,7 @@ def sample( # if eos_token was found in one sentence, set sentence to finished if eos_token_id is not None: - unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) # stop when each sentence is finished, or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): @@ -2456,7 +2465,7 @@ def beam_search( stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -2576,6 +2585,8 @@ def beam_search( warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions @@ -2770,7 +2781,7 @@ def beam_sample( logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -2900,6 +2911,8 @@ def beam_sample( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions @@ -3089,7 +3102,7 @@ def group_beam_search( stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -3213,6 +3226,8 @@ def group_beam_search( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions @@ -3455,7 +3470,7 @@ def constrained_beam_search( stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -3586,6 +3601,8 @@ def constrained_beam_search( warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index a03f0d12b9d147..dfe8be0efd7e8b 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -17,7 +17,7 @@ import inspect import unittest -from transformers import is_torch_available +from transformers import is_torch_available, pipeline from transformers.testing_utils import require_torch, slow, torch_device from ..test_modeling_common import floats_tensor, ids_tensor @@ -39,7 +39,6 @@ SpeechEncoderDecoderModel, T5ForConditionalGeneration, VisionEncoderDecoderModel, - pipeline, top_k_top_p_filtering, ) from transformers.generation import ( @@ -91,8 +90,9 @@ def _get_input_ids_and_config(self): max_length = input_ids.shape[-1] + 3 if config.eos_token_id is not None and config.pad_token_id is None: # hack to allow generate for models such as GPT2 as is done in `generate()` - config.pad_token_id = config.eos_token_id - + if isinstance(config.eos_token_id, int): + config.eos_token_id = [config.eos_token_id] + config.pad_token_id = config.eos_token_id[0] # TransfoXL has no attention mask if "transfoxl" in config.__class__.__name__.lower(): attention_mask = None @@ -3025,3 +3025,100 @@ def test_validate_generation_inputs(self): # However, valid model_kwargs are accepted valid_model_kwargs = {"attention_mask": torch.zeros_like(input_ids)} model.generate(input_ids, **valid_model_kwargs) + + def test_eos_token_id_int_and_list_greedy_search(self): + generation_kwargs = { + "do_sample": False, + "num_beams": 1, + } + expectation = 13 + + tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + text = """Hello, my dog is cute and""" + tokens = tokenizer(text, return_tensors="pt") + + model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + + torch.manual_seed(0) + eos_token_id = 873 + generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) + self.assertTrue(expectation == len(generated_tokens[0])) + + torch.manual_seed(0) + eos_token_id = [873] + generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) + self.assertTrue(expectation == len(generated_tokens[0])) + + def test_eos_token_id_int_and_list_contrastive_search(self): + generation_kwargs = { + "do_sample": False, + "num_beams": 1, + "penalty_alpha": 0.6, + "top_k": 4, + } + expectation = 17 + + tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + text = """Hello, my dog is cute and""" + tokens = tokenizer(text, return_tensors="pt") + + model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + + torch.manual_seed(0) + eos_token_id = 225 + generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) + self.assertTrue(expectation == len(generated_tokens[0])) + + torch.manual_seed(0) + eos_token_id = [225] + generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) + self.assertTrue(expectation == len(generated_tokens[0])) + + def test_eos_token_id_int_and_list_top_k_top_sampling(self): + generation_kwargs = { + "do_sample": True, + "num_beams": 1, + "top_p": 0.7, + "top_k": 10, + "temperature": 0.7, + } + expectation = 15 + + tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + text = """Hello, my dog is cute and""" + tokens = tokenizer(text, return_tensors="pt") + + model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + + torch.manual_seed(0) + eos_token_id = 846 + generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) + self.assertTrue(expectation == len(generated_tokens[0])) + + torch.manual_seed(0) + eos_token_id = [846] + generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) + self.assertTrue(expectation == len(generated_tokens[0])) + + def test_eos_token_id_int_and_list_beam_search(self): + generation_kwargs = { + "do_sample": False, + "num_beams": 3, + } + expectation = 13 + + tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + text = """Hello, my dog is cute and""" + tokens = tokenizer(text, return_tensors="pt") + + model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + + torch.manual_seed(0) + eos_token_id = 873 + generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) + self.assertTrue(expectation == len(generated_tokens[0])) + + torch.manual_seed(0) + eos_token_id = [873] + generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) + self.assertTrue(expectation == len(generated_tokens[0]))