From 2b13c686e8c39c6e5f357ff82bcaa123f4c9bd29 Mon Sep 17 00:00:00 2001 From: GuoxiaWang Date: Fri, 10 May 2024 11:26:29 +0800 Subject: [PATCH] support sparse mask for sft --- llm/argument.py | 1 + llm/data.py | 24 ++++++ llm/finetune_generation.py | 19 ++++- llm/llama/sft_argument.json | 8 +- paddlenlp/datasets/intokens_dataset.py | 79 +++++++++++++++++++ paddlenlp/trainer/trainer.py | 16 ++++ paddlenlp/transformers/llama/modeling.py | 39 +++++++-- .../transformers/tokenizer_utils_base.py | 11 +++ 8 files changed, 183 insertions(+), 14 deletions(-) diff --git a/llm/argument.py b/llm/argument.py index fcec69a93de..efe2729fe89 100644 --- a/llm/argument.py +++ b/llm/argument.py @@ -48,6 +48,7 @@ class DataArgument: dataset_name_or_path: str = field(default=None, metadata={"help": "Name or path for dataset"}) task_name: str = field(default=None, metadata={"help": "Additional name to select a more specific task."}) zero_padding: bool = field(default=False, metadata={"help": "Whether to use Zero Padding data stream"}) + use_sparse_mask: bool = field(default=False, metadata={"help": "Whether to use Column Sparse Mask"}) src_length: int = field(default=1024, metadata={"help": "The maximum length of source(context) tokens."}) max_length: int = field( default=2048, diff --git a/llm/data.py b/llm/data.py index 6a38a709604..cff948d0bea 100644 --- a/llm/data.py +++ b/llm/data.py @@ -188,6 +188,30 @@ def convert_example_common(example, tokenizer, data_args, is_test=True, intokens return features +def convert_example_using_sparse_mask(example, tokenizer, data_args, is_test=True, intokens=False): + if tokenizer.chat_template is not None: + raise NotImplementedError + #return convert_rounds_example_common(example, tokenizer, data_args, is_test, intokens) + + tokenized_source, tokenized_target_input_ids = tokenize_example(tokenizer, example, data_args) + if is_test: + return { + **tokenized_source, + "labels": tokenized_target_input_ids, + } + else: + input_ids = tokenized_source["input_ids"] + tokenized_target_input_ids + source_length = len(tokenized_source["input_ids"]) + labels = [-100] * source_length + input_ids[source_length:] + # shift input_ids and labels + input_ids, labels = input_ids[:-1], labels[1:] + seq_length = len(input_ids) + features = {"input_ids": input_ids, "labels": labels} + if "position_ids" in tokenized_source: + features["position_ids"] = list(range(seq_length)) + + return features + def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, intokens=False): """convert multi-rounds conversation example diff --git a/llm/finetune_generation.py b/llm/finetune_generation.py index d9a54a0e622..76213654065 100644 --- a/llm/finetune_generation.py +++ b/llm/finetune_generation.py @@ -35,7 +35,7 @@ ) from paddlenlp.data import DataCollatorForSeq2Seq -from paddlenlp.datasets import InTokensIterableDataset, InTokensMapDataset, load_dataset +from paddlenlp.datasets import InTokensIterableDataset, InTokensMapDataset, InTokensSparseMaskMapDataset, load_dataset from paddlenlp.metrics import BLEU, Rouge1, Rouge2, RougeL from paddlenlp.peft import LoRAConfig, LoRAModel, PrefixConfig, PrefixModelForCausalLM from paddlenlp.trainer import PdArgumentParser, get_last_checkpoint @@ -182,6 +182,7 @@ def neft_post_hook(module, input, output): # Load tokenizer & dataset tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, from_aistudio=model_args.from_aistudio) + #tokenizer.padding_side = "right" # init chat_template for tokenizer init_chat_template(tokenizer, model_args.model_name_or_path, data_args.chat_template) @@ -315,7 +316,11 @@ def neft_post_hook(module, input, output): ) train_ds = train_ds.skip(consumed_samples) - if training_args.pipeline_parallel_degree > 1: + if data_args.use_sparse_mask: + from data import convert_example_using_sparse_mask + trans_func = partial(convert_example_using_sparse_mask, tokenizer=tokenizer, data_args=data_args) + print('convert_example_using_sparse_mask') + elif training_args.pipeline_parallel_degree > 1: from data import convert_example_common trans_func = partial(convert_example_common, tokenizer=tokenizer, data_args=data_args) @@ -350,11 +355,14 @@ def neft_post_hook(module, input, output): else None ) if data_args.zero_padding: - if data_args.lazy: + if data_args.use_sparse_mask: + intoken_dataset = InTokensSparseMaskMapDataset + elif data_args.lazy: intoken_dataset = InTokensIterableDataset else: intoken_dataset = InTokensMapDataset logger.info("Creating Zero Padding Data Stream. This may take a few minutes.") + print('InTokensSparseMaskMapDataset', intoken_dataset) train_ds = ( intoken_dataset( train_ds, @@ -471,6 +479,10 @@ def compute_metrics_do_generation(eval_preds): else: metrics = compute_metrics + padding = "max_length" + max_length = data_args.max_length + training_args.max_length = max_length + return_attention_mask = False if data_args.use_sparse_mask else True trainer = CausalLMTrainer( model=model, args=training_args, @@ -484,6 +496,7 @@ def compute_metrics_do_generation(eval_preds): padding=padding, max_label_length=max_length, return_tensors="np", + return_attention_mask=return_attention_mask, ), do_generation=data_args.eval_with_do_generation, callbacks=[InTokensIterDatasetCallback()] if isinstance(train_ds, InTokensIterableDataset) else None, diff --git a/llm/llama/sft_argument.json b/llm/llama/sft_argument.json index 487074cb12b..ae16758d097 100644 --- a/llm/llama/sft_argument.json +++ b/llm/llama/sft_argument.json @@ -26,7 +26,7 @@ "save_total_limit": 1, "tensor_parallel_degree": 4, "pipeline_parallel_degree": 1, - "intokens": true, - "zero_padding": false, - "use_flash_attention": false - } \ No newline at end of file + "zero_padding": true, + "use_sparse_mask": true, + "use_flash_attention": true + } diff --git a/paddlenlp/datasets/intokens_dataset.py b/paddlenlp/datasets/intokens_dataset.py index 795d82d93e8..bca34d7bcfa 100644 --- a/paddlenlp/datasets/intokens_dataset.py +++ b/paddlenlp/datasets/intokens_dataset.py @@ -57,6 +57,44 @@ def _pad_batch_records(cls, batch_records): batched_features["position_ids"] = np.concatenate(batched_features["position_ids"], axis=-1).tolist() return batched_features +class InTokensSparseMask: + required_input_keys = ["input_ids", "labels"] + required_output_keys = ["input_ids", "labels", "attn_mask_start_row_indices"] + # Only supported the following keys for InTokens. Keys outside of the set will be ignored. + supported_input_keys = ["input_ids", "labels", "position_ids"] + + @classmethod + def _pad_batch_records(cls, batch_records): + # Only consider supported input keys + input_keys = [key for key in batch_records[0].keys() if key in cls.supported_input_keys] + + # Check required_keys + for key in cls.required_input_keys: + if key not in input_keys: + raise ValueError(f"feature `{key}` is required for InTokensDataset") + # Output features must include all required output keys + for key in cls.required_output_keys: + if key not in input_keys: + input_keys.append(key) + + batched_features = {key: [] for key in input_keys} + cur_len_so_far = 0 + for record in batch_records: + batched_features["input_ids"].extend(record["input_ids"]) + batched_features["labels"].extend(record["labels"]) + seq_length = len(record["input_ids"]) + cur_len_so_far += seq_length + batched_features["attn_mask_start_row_indices"].extend(seq_length * [cur_len_so_far]) + # NOTE: position_ids is optional and not required by every model + # We append instead of extend here to accomodate 2D position ids + if "position_ids" in record: + batched_features["position_ids"].append(record["position_ids"]) + # convert to 2-D [num_head(1), seq_length] + batched_features["attn_mask_start_row_indices"] = np.array(batched_features["attn_mask_start_row_indices"], dtype=np.int32).reshape([1, -1]) + if "position_ids" in batched_features: + # Accomodate both 1D and 2D position ids + batched_features["position_ids"] = np.concatenate(batched_features["position_ids"], axis=-1).tolist() + return batched_features class InTokensMapDataset(InTokens, Dataset): def __init__(self, data, tokenizer, max_length): @@ -131,3 +169,44 @@ def __iter__(self): if batch_records: padded_list = self._pad_batch_records(batch_records) yield padded_list + +class InTokensSparseMaskMapDataset(InTokensSparseMask, Dataset): + def __init__(self, data, tokenizer, max_length): + self.tokenizer = tokenizer + self.max_length = max_length + self.new_data = self._create_intokens_data(data) + + def _create_intokens_data(self, data): + batch_records, max_len = [], 0 + cur_len_so_far = 0 + + total_data = [] + for i in range(len(data)): + record = data[i] + max_len = max(max_len, len(record["input_ids"])) + to_append = (cur_len_so_far + len(record["input_ids"])) <= self.max_length + if to_append: + batch_records.append(record) + cur_len_so_far += len(record["input_ids"]) + else: + # exceed max length + padded_list = self._pad_batch_records(batch_records) + total_data.append(padded_list) + # reset + batch_records, max_len = [], 0 + cur_len_so_far = 0 + # append current data + batch_records.append(record) + cur_len_so_far += len(record["input_ids"]) + + # remaining data + if batch_records: + padded_list = self._pad_batch_records(batch_records) + total_data.append(padded_list) + return total_data + + def __getitem__(self, idx): + return self.new_data[idx] + + def __len__(self): + return len(self.new_data) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 298f6e9540a..f8a7c7433e0 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -264,6 +264,8 @@ def __init__( self.args = args self.is_in_train = False # self.do_grad_scaling = args.fp16 + self._tokens_per_sec_per_card_buffer = [] + self.skip_first_tokens_per_sec_per_card_buffer = True # memory metrics - must set up as early as possible self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) @@ -1226,6 +1228,7 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, logs["learning_rate"] = float("{0:.3e}".format(self._get_learning_rate())) logs["global_step"] = int(self.state.global_step) + total_train_batch_size = ( self.args.train_batch_size * self.args.gradient_accumulation_steps * self.args.dataset_world_size ) @@ -1239,6 +1242,13 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, ) ) + logs["tokens_per_sec_per_card"] = round( + self.args.max_length * logs["interval_samples_per_second"] / self.args.world_size, 2 + ) + + self._tokens_per_sec_per_card_buffer.append(logs["tokens_per_sec_per_card"]) + logs["tokens_per_sec_per_card_average"] = round(np.mean(self._tokens_per_sec_per_card_buffer), 2) + self._total_loss_scalar += tr_loss_scalar self._globalstep_last_logged = self.state.global_step self._globalstep_last_start_time = time.time() @@ -1261,6 +1271,10 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, self.log(logs, **kwargs) + if self.skip_first_tokens_per_sec_per_card_buffer and len(self._tokens_per_sec_per_card_buffer) == 5: + self._tokens_per_sec_per_card_buffer = [] + self.skip_first_tokens_per_sec_per_card_buffer = False + metrics = None if self.control.should_evaluate: if isinstance(self.optimizer, GroupShardedOptimizerStage2) and self.optimizer._broadcast_overlap: @@ -1283,6 +1297,8 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, self._save_checkpoint(model, metrics=metrics) logger.info(f"{self.runtime_timer.log()}") self.control = self.callback_handler.on_save(self.args, self.state, self.control) + self._tokens_per_sec_per_card_buffer = [] + self.skip_first_tokens_per_sec_per_card_buffer = True def _get_learning_rate(self): return self.optimizer.get_lr() diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 94d7f5b1ef1..edd1957425b 100644 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -202,6 +202,7 @@ def scaled_dot_product_attention( attention_mask, output_attentions, alibi=None, + attn_mask_start_row_indices=None, sequence_parallel=False, reshard_layer=None, ): @@ -213,7 +214,17 @@ def scaled_dot_product_attention( # Torch Flash Attention input [ bz, nhead, seqlen, head_dim] version = paddle.version.full_version - if version != "0.0.0" and version <= "2.5.2": + if attn_mask_start_row_indices is not None: + assert alibi is None, "flash_attention_with_sparse_mask not support alibi" + attn_output = F.flash_attention_with_sparse_mask( + query_states, + key_states, + value_states, + attn_mask_start_row_indices=attn_mask_start_row_indices, + is_causal=True, + ) + attn_weights = None + elif version != "0.0.0" and version <= "2.5.2": if alibi is not None: raise ValueError("Flash Attention doesn't support alibi") attn_output, attn_weights = flash_attention( @@ -808,6 +819,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, alibi: Optional[paddle.Tensor] = None, + attn_mask_start_row_indices: Optional[paddle.Tensor] = None, ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: """Input shape: Batch x Time x Channel""" # [bs, seq_len, num_head * head_dim] -> [seq_len / n, bs, num_head * head_dim] (n is model parallelism) @@ -1010,6 +1022,7 @@ def forward( attention_mask, output_attentions, alibi, + attn_mask_start_row_indices, self.sequence_parallel, reshard_layer=self.reshard_layer, use_reentrant=self.config.recompute_use_reentrant, @@ -1023,6 +1036,7 @@ def forward( attention_mask, output_attentions, alibi, + attn_mask_start_row_indices, self.sequence_parallel, reshard_layer=self.reshard_layer, ) @@ -1077,6 +1091,7 @@ def forward( past_key_value: Optional[Tuple[paddle.Tensor]] = None, use_cache: Optional[bool] = False, alibi: Optional[paddle.Tensor] = None, + attn_mask_start_row_indices: Optional[paddle.Tensor] = None, ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: """ Args: @@ -1113,6 +1128,7 @@ def forward( output_attentions, use_cache, alibi, + attn_mask_start_row_indices, use_reentrant=self.config.recompute_use_reentrant, ) else: @@ -1124,6 +1140,7 @@ def forward( output_attentions, use_cache, alibi, + attn_mask_start_row_indices, ) if type(outputs) is tuple: @@ -1396,6 +1413,7 @@ def recompute_training_full( past_key_value: Tensor, use_cache: bool, alibi=None, + attn_mask_start_row_indices=None, ): def create_custom_forward(module): def custom_forward(*inputs): @@ -1412,6 +1430,7 @@ def custom_forward(*inputs): past_key_value, use_cache, alibi, + attn_mask_start_row_indices, use_reentrant=self.config.recompute_use_reentrant, ) @@ -1428,6 +1447,7 @@ def forward( output_attentions=False, output_hidden_states=None, return_dict=False, + attn_mask_start_row_indices=None, **kwargs, ): if self.sequence_parallel and use_cache: @@ -1472,10 +1492,10 @@ def forward( inputs_embeds = ScatterOp.apply(inputs_embeds) # embed positions - if attention_mask is None: + if attn_mask_start_row_indices is None and attention_mask is None: # [bs, seq_len] attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) - if self.config.alibi: + if attn_mask_start_row_indices is None and self.config.alibi: alibi = build_alibi_tensor(attention_mask, self.config.num_attention_heads, dtype=inputs_embeds.dtype) if self.config.tensor_parallel_degree > 1: block_size = self.config.num_attention_heads // self.config.tensor_parallel_degree @@ -1494,10 +1514,11 @@ def forward( if position_ids is None: position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype - ) # [bs, 1, seq_len, seq_len] - if self.config.use_flash_attention: + if attn_mask_start_row_indices is None: + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype + ) # [bs, 1, seq_len, seq_len] + if attn_mask_start_row_indices is None and self.config.use_flash_attention: is_casual = is_casual_mask(attention_mask) if is_casual and alibi is None: attention_mask = None @@ -1529,6 +1550,7 @@ def forward( past_key_value, use_cache, alibi=alibi, + attn_mask_start_row_indices=attn_mask_start_row_indices, ) else: layer_outputs = decoder_layer( @@ -1539,6 +1561,7 @@ def forward( past_key_value, use_cache, alibi=alibi, + attn_mask_start_row_indices=attn_mask_start_row_indices, ) # NOTE: clear outdate cache after it has been used for memory saving @@ -1763,6 +1786,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + attn_mask_start_row_indices=None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1779,6 +1803,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + attn_mask_start_row_indices=attn_mask_start_row_indices, ) hidden_states = outputs[0] # [bs, seq_len, dim] diff --git a/paddlenlp/transformers/tokenizer_utils_base.py b/paddlenlp/transformers/tokenizer_utils_base.py index 2c3ac240114..af84ad12cdf 100644 --- a/paddlenlp/transformers/tokenizer_utils_base.py +++ b/paddlenlp/transformers/tokenizer_utils_base.py @@ -3169,6 +3169,12 @@ def _pad( if return_attention_mask: encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference + if "attn_mask_start_row_indices" in encoded_inputs: + full_value = encoded_inputs["attn_mask_start_row_indices"][-1, -1] + encoded_inputs["attn_mask_start_row_indices"] = ( + np.concatenate([encoded_inputs["attn_mask_start_row_indices"], + np.full([1, difference], full_value, dtype=np.int32)], axis=-1) + ) if "token_type_ids" in encoded_inputs: encoded_inputs["token_type_ids"] = ( encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference @@ -3188,6 +3194,11 @@ def _pad( elif self.padding_side == "left": if return_attention_mask: encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] + if "attn_mask_start_row_indices" in encoded_inputs: + encoded_inputs["attn_mask_start_row_indices"] = ( + np.concatenate([np.zeros([1, difference], dtype=np.int32), + encoded_inputs["attn_mask_start_row_indices"]+difference], axis=-1) + ) if "token_type_ids" in encoded_inputs: encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ "token_type_ids"