From 23f76647afa63fe32ca5f7cbcd6282e05703737b Mon Sep 17 00:00:00 2001 From: Saibo-creator <53392976+Saibo-creator@users.noreply.github.com> Date: Fri, 19 Jan 2024 12:36:54 +0100 Subject: [PATCH] feat: Sequential beam search (#26304) --- .../generation/configuration_utils.py | 3 +- src/transformers/generation/utils.py | 229 ++++++++++++++---- tests/generation/test_utils.py | 46 ++++ 3 files changed, 235 insertions(+), 43 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 4353a113223870..abc118aa8c1d60 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -200,7 +200,8 @@ class GenerationConfig(PushToHubMixin): Higher guidance scale encourages the model to generate samples that are more closely linked to the input prompt, usually at the expense of poorer quality. low_memory (`bool`, *optional*): - Switch to sequential topk for contrastive search to reduce peak memory. Used with contrastive search. + Switch to sequential beam search and sequential topk for contrastive search to reduce peak memory. + Used with beam search and contrastive search. > Parameters that define the output variables of `generate` diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index f8d1e3c3ef423d..ebbefa7cb89396 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1558,6 +1558,7 @@ def generate( output_scores=generation_config.output_scores, return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, + sequential=generation_config.low_memory, **model_kwargs, ) @@ -1951,8 +1952,7 @@ def contrastive_search( model_kwargs["past_key_values"] = tuple(new_key_values) if sequential: - all_outputs = {key: [] for key in outputs} # defined in first loop iteration - all_last_hstates, all_hstates, all_logits = [], [], [] + all_outputs = [] for i in range(top_k): # compute the candidate tokens by the language model and collect their hidden_states next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs) @@ -1963,32 +1963,8 @@ def contrastive_search( output_hidden_states=True, output_attentions=output_attentions, ) - for key in all_outputs: - all_outputs[key].append(outputs[key]) - - if self.config.is_encoder_decoder: - next_hidden = outputs.decoder_hidden_states[-1] - full_hidden_states = outputs.decoder_hidden_states - - else: - next_hidden = outputs.hidden_states[-1] - full_hidden_states = outputs.hidden_states - - all_last_hstates.append(torch.squeeze(next_hidden, 0)) - all_hstates.append(full_hidden_states) - all_logits.append(outputs.logits[:, -1, :]) - - # stack hidden states - next_hidden = torch.stack([all_last_hstates[i] for i in range(top_k)], dim=0) - final_full_hstates = [0 for i in range(len(full_hidden_states))] - for layer in range(len(full_hidden_states)): - final_full_hstates[layer] = torch.stack( - [torch.squeeze(all_hstates[i][layer], 0) for i in range(top_k)], dim=0 - ) - full_hidden_states = tuple(final_full_hstates) - - # stack logits - logits = torch.cat(all_logits, dim=0) + all_outputs.append(outputs) + outputs = stack_model_outputs(all_outputs) else: # compute the candidate tokens by the language model and collect their hidden_states @@ -2001,15 +1977,15 @@ def contrastive_search( output_hidden_states=True, output_attentions=output_attentions, ) - # name is different for encoder-decoder and decoder-only models - if self.config.is_encoder_decoder: - next_hidden = outputs.decoder_hidden_states[-1] - full_hidden_states = outputs.decoder_hidden_states - else: - next_hidden = outputs.hidden_states[-1] - full_hidden_states = outputs.hidden_states + # name is different for encoder-decoder and decoder-only models + if self.config.is_encoder_decoder: + next_hidden = outputs.decoder_hidden_states[-1] + full_hidden_states = outputs.decoder_hidden_states + else: + next_hidden = outputs.hidden_states[-1] + full_hidden_states = outputs.hidden_states - logits = outputs.logits[:, -1, :] + logits = outputs.logits[:, -1, :] context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) @@ -2747,6 +2723,7 @@ def beam_search( output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: bool = False, + sequential: Optional[bool] = None, **model_kwargs, ) -> Union[GenerateBeamOutput, torch.LongTensor]: r""" @@ -2792,6 +2769,10 @@ def beam_search( Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. synced_gpus (`bool`, *optional*, defaults to `False`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + sequential (`bool`, defaults to `False`): + By default, beam search has `batch_size * num_beams` as effective batch size (see `beam_search()` for + more details). This flag will avoid parallelizing the beam search and will instead run beam search + sequentially. model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. @@ -2858,6 +2839,7 @@ def beam_search( # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + sequential = sequential if sequential is not None else self.generation_config.low_memory if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" @@ -2932,12 +2914,39 @@ def beam_search( model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) + # if sequential is True, split the input to batches of batch_size and run sequentially + if sequential: + if any( + model_name in self.__class__.__name__.lower() + for model_name in ["fsmt", "reformer", "bloom", "ctrl", "gpt_bigcode", "transo_xl", "xlnet", "cpm"] + ): + raise RuntimeError( + f"Currently generation for {self.__class__.__name__} is not supported " + f"for `low_memory beam_search`. Please open an issue on GitHub if you need this feature." + ) + + inputs_per_sub_batches = _split_model_inputs( + model_inputs, split_size=batch_size, full_batch_size=batch_beam_size + ) + outputs_per_sub_batch = [ + self( + **inputs_per_sub_batch, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + for inputs_per_sub_batch in inputs_per_sub_batches + ] + + outputs = stack_model_outputs(outputs_per_sub_batch) + + else: # Unchanged original behavior + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 @@ -4656,3 +4665,139 @@ def _ranking_fast( contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K] _, selected_idx = contrastive_score.max(dim=-1) # [B] return selected_idx + + +def _split(data, full_batch_size: int, split_size: int = None): + """ + Takes care of three cases: + 1. data is a tensor: e.g. last_hidden_state, pooler_output etc. split them on the batch_size dim + 2. data is a tuple: e.g. hidden_states, attentions etc. Keep the tuple as it is and split each tensor in it and + return a list of tuples + 3. data is a tuple of tuples, e.g. past_key_values. Keep the tuple as it is and split each tuple in it and + return a list of tuples of tuples + (see documentation of ModelOutput) + """ + if data is None: + return [None] * (full_batch_size // split_size) + if isinstance(data, torch.Tensor): + return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)] + elif isinstance(data, tuple): + # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) + if isinstance(data[0], tuple): + return [ + tuple(tuple(tensor[i : i + split_size] for tensor in inner_tuple) for inner_tuple in data) + for i in range(0, full_batch_size, split_size) + ] + + else: + return [ + tuple(sub_tensor[i : i + split_size] for sub_tensor in data) + for i in range(0, full_batch_size, split_size) + ] + else: + raise ValueError(f"Unexpected attribute type: {type(data)}") + + +def _split_model_inputs( + model_input: Union[ModelOutput, Dict], split_size: int, full_batch_size: int +) -> List[Union[ModelOutput, Dict]]: + """ + Split a ModelOutput object (or its subclasses) or Dict into a list of same-class objects based on a specified split + size. The input object is dict when it was prepared for forward pass and ModelOutput when it was returned from + previous forward pass. + """ + # Edge case: if model_input is None, return a list of Nones + # this happens with Whisper where encoder_outputs is None + if model_input is None: + return [model_input] * (full_batch_size // split_size) + # Infer the class from the object + model_output_cls = type(model_input) + if (full_batch_size % split_size) != 0: + raise ValueError("`full_batch_size` must be divisible by `split_size`") + + if split_size > full_batch_size: + raise ValueError("`split_size` must be smaller or equal to `full_batch_size`") + + # Helper function to split tensors or tuples of tensors + + # Find all the dataclass fields (e.g., last_hidden_state, pooler_output etc.) and split them + keys = ( + model_input.__dataclass_fields__.keys() if hasattr(model_input, "__dataclass_fields__") else model_input.keys() + ) + # We only keep keys that are in the model_input + keys = [k for k in keys if k in model_input] + # Here we can have four types of values: tensors, tuples of tensors and booleans, and encoder_outputs which is a + # ModelOutput object. + # bool should not be split but replicated for each split + bool_keys = [k for k in keys if isinstance(model_input[k], bool)] + non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and not k == "encoder_outputs"] + + # we split the tensors and tuples of tensors + data_split_list = [ + {k: _split(model_input[k], full_batch_size, split_size)[i] for k in non_bool_keys} + for i in range(full_batch_size // split_size) + ] + # bool values are the same and replicated for each split + bool_data = {k: model_input[k] for k in bool_keys} + # encoder_outputs is a ModelOutput object and should be split by its own + if "encoder_outputs" in model_input: + encoder_outputs_split = _split_model_inputs(model_input["encoder_outputs"], split_size, full_batch_size) + data_split_list = [ + {**data_split, "encoder_outputs": encoder_outputs_split[i]} for i, data_split in enumerate(data_split_list) + ] + + # Convert each dictionary in the list to an object of the inferred class + split_model_inputs: List[Union[ModelOutput, Dict]] = [ + model_output_cls(**data_split, **bool_data) for data_split in data_split_list + ] + + return split_model_inputs + + +def stack_model_outputs(model_outputs: List[ModelOutput]) -> ModelOutput: + """ + Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the + specific ModelOutput subclass from the list provided. + """ + if not model_outputs: + raise ValueError("Input list is empty.") + + # Infer the class from the first object in the list + model_output_cls = type(model_outputs[0]) + + # Ensure all objects are of the same type + if not all(isinstance(obj, model_output_cls) for obj in model_outputs): + raise ValueError("All elements in the list should be of the same type.") + + # Helper function to concat tensors or tuples of tensors + def _concat(data): + """ + Reverse of `_split` function above. + """ + if any(data is None for data in data): + return None + if isinstance(data[0], torch.Tensor): + return torch.cat(data, dim=0) + elif isinstance(data[0], tuple): + # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) + if isinstance(data[0][0], tuple): + return tuple( + tuple(torch.cat([attr[i][j] for attr in data], dim=0) for j in range(len(data[0][0]))) + for i in range(len(data[0])) + ) + else: + return tuple(torch.cat([attr[i] for attr in data], dim=0) for i in range(len(data[0]))) + elif isinstance(data[0], (int, float)): + # If the elements are integers or floats, return a tensor + return torch.tensor(data) + else: + raise ValueError(f"Unexpected attribute type: {type(data[0])}") + + # Use a dictionary comprehension to gather attributes from all objects and concatenate them + concatenated_data = { + k: _concat([getattr(model_output, k) for model_output in model_outputs]) + for k in model_output_cls.__dataclass_fields__.keys() + } + + # Return a new object of the inferred class with the concatenated attributes + return model_output_cls(**concatenated_data) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index c41bc3b21a4ee3..05f0981dba3714 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1539,6 +1539,39 @@ def test_contrastive_generate_low_memory(self): ) self.assertListEqual(low_output.tolist(), high_output.tolist()) + def test_beam_search_low_memory(self): + # Check that choosing 'low_memory' does not change the model output + for model_class in self.all_generative_model_classes: + if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): + self.skipTest("Won't fix: old model with different cache format") + if any( + model_name in model_class.__name__.lower() + for model_name in [ + "bloom", + "ctrl", + "gptbigcode", + "transo_xl", + "xlnet", + "cpm", + ] + ): + self.skipTest("May fix in the future: need model-specific fixes") + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=2) + # batch_size=1 is ok, but batch_size>1 will cause non-identical output + + config.use_cache = True + config.is_decoder = True + + # test output equality of low versus high memory + model = model_class(config).to(torch_device).eval() + + low_output = model.generate(input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=True) + + high_output = model.generate( + input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=False + ) + self.assertListEqual(low_output.tolist(), high_output.tolist()) + @is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail. def test_assisted_decoding_matches_greedy_search(self): # This test ensures that the assisted generation does not introduce output changes over greedy search. @@ -2766,6 +2799,19 @@ def test_transition_scores_group_beam_search_encoder_decoder(self): self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3)) + def test_beam_search_low_memory(self): + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + model = AutoModelForCausalLM.from_pretrained("gpt2") + tokenizer.pad_token_id = tokenizer.eos_token_id + model_inputs = tokenizer("I", return_tensors="pt")["input_ids"] + + low_output = model.generate(model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=True) + + high_output = model.generate( + model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=False + ) + self.assertListEqual(low_output.tolist(), high_output.tolist()) + @slow def test_beam_search_example_integration(self): # PT-only test: TF doesn't have a BeamSearchScorer