diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 01b902478622..7f3825c7f939 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -1892,7 +1892,6 @@ def get_expected_keys(inputs, keys): optimizer._set_broadcast_overlap(True, model) self.optimizer = optimizer - # pure tesnor parallel mode, no pipeline_parallel, no sharding. if ( not in_pipeline_parallel_mode @@ -1908,6 +1907,21 @@ def get_expected_keys(inputs, keys): self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) self.optimizer = fleet.distributed_optimizer(self.optimizer) + # stage1 has v1 and v2 version + if in_sharding_parallel_mode and ShardingOption.SHARD_OP in self.args.sharding: + if "split_param" in self.args.sharding_parallel_config: + if ( + hasattr(self.optimizer, "_set_all_gather_overlap_forward") + and "enable_stage1_allgather_overlap" in self.args.sharding_parallel_config + ): + self.optimizer._set_all_gather_overlap_forward(True, model) + else: + if ( + hasattr(self.optimizer, "_set_broadcast_overlap") + and "enable_stage1_broadcast_overlap" in self.args.sharding_parallel_config + ): + self.optimizer._set_broadcast_overlap(True, model) + return model def _prepare_input(self, data: Union[paddle.Tensor, Any]) -> Union[paddle.Tensor, Any]: diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 8f2dd1c36415..1e4f2577ba40 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -115,11 +115,15 @@ def _get_interleave_power_of_2(n): ) +def get_use_casual_mask(): + """Get the value of the 'USE_CASUAL_MASK' environment variable.""" + return os.getenv("USE_CASUAL_MASK", "False") + + def build_alibi_tensor( bool_attention_mask: Tensor, num_heads: int, dtype: paddle.dtype, tensor_parallel_degree=1 ) -> Tensor: - attention_mask = bool_attention_mask.astype("float32") - batch_size, seq_length = attention_mask.shape[0], attention_mask.shape[-1] + batch_size, seq_length = bool_attention_mask.shape[0], bool_attention_mask.shape[-1] slopes = paddle.to_tensor(_get_interleave(num_heads), dtype="float32") alibi = slopes.unsqueeze(axis=[1, 2]) * paddle.arange(seq_length, dtype="float32").unsqueeze(axis=[0, 1]).expand( [num_heads, -1, -1] @@ -307,7 +311,7 @@ def is_casual_mask(attention_mask): def _make_causal_mask(input_ids_shape, past_key_values_length): """ - Make causal mask used for self-attention + Make casual mask used for self-attention """ batch_size, target_length = input_ids_shape # target_length: seq_len @@ -1533,12 +1537,22 @@ def forward( if position_ids is None: position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype - ) # [bs, 1, seq_len, seq_len] + use_casual_mask = get_use_casual_mask() + + if use_casual_mask: + attention_mask = None + else: + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype + ) # [bs, 1, seq_len, seq_len] + is_casual = False + if self.config.use_flash_attention and get_env_device() != "gcu": - is_casual = is_casual_mask(attention_mask) + if use_casual_mask: + is_casual = True + else: + is_casual = is_casual_mask(attention_mask) if get_env_device() != "npu": if is_casual and alibi is None: attention_mask = None