From 92b106fc3412a7b53a56916bc8127ee8d0660bb3 Mon Sep 17 00:00:00 2001 From: iosmers Date: Thu, 23 May 2024 13:06:05 +0800 Subject: [PATCH 1/3] update --- llm/run_pretrain.py | 5 +++++ paddlenlp/trainer/trainer.py | 16 +++++++++++++++- paddlenlp/transformers/llama/modeling.py | 24 +++++++++++++++++------- 3 files changed, 37 insertions(+), 8 deletions(-) diff --git a/llm/run_pretrain.py b/llm/run_pretrain.py index e58888772a5d..b7a5f9b6d35d 100644 --- a/llm/run_pretrain.py +++ b/llm/run_pretrain.py @@ -223,6 +223,10 @@ class ModelArguments: default=None, metadata={"help": "num_hidden_layers."}, ) + use_casual_mask: Optional[bool] = field( + default=True, + metadata={"help": "whether to use casual mask"}, + ) def create_pretrained_dataset( @@ -476,6 +480,7 @@ def main(): config.pp_recompute_interval = model_args.pp_recompute_interval config.recompute_use_reentrant = model_args.recompute_use_reentrant config.use_recompute = training_args.recompute + config.use_casual_mask = model_args.use_casual_mask config.tensor_parallel_degree = training_args.tensor_parallel_degree config.tensor_parallel_rank = training_args.tensor_parallel_rank diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 01b902478622..7f3825c7f939 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -1892,7 +1892,6 @@ def get_expected_keys(inputs, keys): optimizer._set_broadcast_overlap(True, model) self.optimizer = optimizer - # pure tesnor parallel mode, no pipeline_parallel, no sharding. if ( not in_pipeline_parallel_mode @@ -1908,6 +1907,21 @@ def get_expected_keys(inputs, keys): self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) self.optimizer = fleet.distributed_optimizer(self.optimizer) + # stage1 has v1 and v2 version + if in_sharding_parallel_mode and ShardingOption.SHARD_OP in self.args.sharding: + if "split_param" in self.args.sharding_parallel_config: + if ( + hasattr(self.optimizer, "_set_all_gather_overlap_forward") + and "enable_stage1_allgather_overlap" in self.args.sharding_parallel_config + ): + self.optimizer._set_all_gather_overlap_forward(True, model) + else: + if ( + hasattr(self.optimizer, "_set_broadcast_overlap") + and "enable_stage1_broadcast_overlap" in self.args.sharding_parallel_config + ): + self.optimizer._set_broadcast_overlap(True, model) + return model def _prepare_input(self, data: Union[paddle.Tensor, Any]) -> Union[paddle.Tensor, Any]: diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 8f2dd1c36415..802f12ab4956 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -118,8 +118,7 @@ def _get_interleave_power_of_2(n): def build_alibi_tensor( bool_attention_mask: Tensor, num_heads: int, dtype: paddle.dtype, tensor_parallel_degree=1 ) -> Tensor: - attention_mask = bool_attention_mask.astype("float32") - batch_size, seq_length = attention_mask.shape[0], attention_mask.shape[-1] + batch_size, seq_length = bool_attention_mask.shape[0], bool_attention_mask.shape[-1] slopes = paddle.to_tensor(_get_interleave(num_heads), dtype="float32") alibi = slopes.unsqueeze(axis=[1, 2]) * paddle.arange(seq_length, dtype="float32").unsqueeze(axis=[0, 1]).expand( [num_heads, -1, -1] @@ -307,7 +306,7 @@ def is_casual_mask(attention_mask): def _make_causal_mask(input_ids_shape, past_key_values_length): """ - Make causal mask used for self-attention + Make casual mask used for self-attention """ batch_size, target_length = input_ids_shape # target_length: seq_len @@ -1533,12 +1532,23 @@ def forward( if position_ids is None: position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype - ) # [bs, 1, seq_len, seq_len] + is_casual_mask = ( + True if hasattr(self.config, "use_casual_mask") and self.config.use_casual_mask is True else False + ) + if is_casual_mask: + attention_mask = None + else: + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype + ) # [bs, 1, seq_len, seq_len] + is_casual = False + if self.config.use_flash_attention and get_env_device() != "gcu": - is_casual = is_casual_mask(attention_mask) + if is_casual_mask: + is_casual = True + else: + is_casual = is_casual_mask(attention_mask) if get_env_device() != "npu": if is_casual and alibi is None: attention_mask = None From 370d2c908c158b18dca69850bdaef4cb80cfedad Mon Sep 17 00:00:00 2001 From: iosmers Date: Thu, 23 May 2024 13:46:48 +0800 Subject: [PATCH 2/3] update is_casual_mask to use_casual_mask --- paddlenlp/transformers/llama/modeling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 802f12ab4956..49fd0236f45e 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -1532,10 +1532,10 @@ def forward( if position_ids is None: position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) - is_casual_mask = ( + use_casual_mask = ( True if hasattr(self.config, "use_casual_mask") and self.config.use_casual_mask is True else False ) - if is_casual_mask: + if use_casual_mask: attention_mask = None else: attention_mask = self._prepare_decoder_attention_mask( @@ -1545,7 +1545,7 @@ def forward( is_casual = False if self.config.use_flash_attention and get_env_device() != "gcu": - if is_casual_mask: + if use_casual_mask: is_casual = True else: is_casual = is_casual_mask(attention_mask) From 594a05051899dda932e1314a73cb66db2711be3b Mon Sep 17 00:00:00 2001 From: iosmers Date: Thu, 23 May 2024 20:01:25 +0800 Subject: [PATCH 3/3] update by environment --- llm/run_pretrain.py | 5 ----- paddlenlp/transformers/llama/modeling.py | 10 +++++++--- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/llm/run_pretrain.py b/llm/run_pretrain.py index b7a5f9b6d35d..e58888772a5d 100644 --- a/llm/run_pretrain.py +++ b/llm/run_pretrain.py @@ -223,10 +223,6 @@ class ModelArguments: default=None, metadata={"help": "num_hidden_layers."}, ) - use_casual_mask: Optional[bool] = field( - default=True, - metadata={"help": "whether to use casual mask"}, - ) def create_pretrained_dataset( @@ -480,7 +476,6 @@ def main(): config.pp_recompute_interval = model_args.pp_recompute_interval config.recompute_use_reentrant = model_args.recompute_use_reentrant config.use_recompute = training_args.recompute - config.use_casual_mask = model_args.use_casual_mask config.tensor_parallel_degree = training_args.tensor_parallel_degree config.tensor_parallel_rank = training_args.tensor_parallel_rank diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 49fd0236f45e..1e4f2577ba40 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -115,6 +115,11 @@ def _get_interleave_power_of_2(n): ) +def get_use_casual_mask(): + """Get the value of the 'USE_CASUAL_MASK' environment variable.""" + return os.getenv("USE_CASUAL_MASK", "False") + + def build_alibi_tensor( bool_attention_mask: Tensor, num_heads: int, dtype: paddle.dtype, tensor_parallel_degree=1 ) -> Tensor: @@ -1532,9 +1537,8 @@ def forward( if position_ids is None: position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) - use_casual_mask = ( - True if hasattr(self.config, "use_casual_mask") and self.config.use_casual_mask is True else False - ) + use_casual_mask = get_use_casual_mask() + if use_casual_mask: attention_mask = None else: