From f63d157c66e2b24ce8b7464dcfadd291e060d2b5 Mon Sep 17 00:00:00 2001 From: yinwei <> Date: Thu, 23 May 2024 11:32:45 +0800 Subject: [PATCH] update --- paddlenlp/transformers/llama/modeling.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 7e6a3607843c..802f12ab4956 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -118,8 +118,7 @@ def _get_interleave_power_of_2(n): def build_alibi_tensor( bool_attention_mask: Tensor, num_heads: int, dtype: paddle.dtype, tensor_parallel_degree=1 ) -> Tensor: - attention_mask = bool_attention_mask.astype("float32") - batch_size, seq_length = attention_mask.shape[0], attention_mask.shape[-1] + batch_size, seq_length = bool_attention_mask.shape[0], bool_attention_mask.shape[-1] slopes = paddle.to_tensor(_get_interleave(num_heads), dtype="float32") alibi = slopes.unsqueeze(axis=[1, 2]) * paddle.arange(seq_length, dtype="float32").unsqueeze(axis=[0, 1]).expand( [num_heads, -1, -1] @@ -1533,15 +1532,21 @@ def forward( if position_ids is None: position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype - ) # [bs, 1, seq_len, seq_len] + is_casual_mask = ( + True if hasattr(self.config, "use_casual_mask") and self.config.use_casual_mask is True else False + ) + if is_casual_mask: + attention_mask = None + else: + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype + ) # [bs, 1, seq_len, seq_len] is_casual = False if self.config.use_flash_attention and get_env_device() != "gcu": - if hasattr(self.config, "use_casual_mask"): - is_casual = self.config.use_casual_mask + if is_casual_mask: + is_casual = True else: is_casual = is_casual_mask(attention_mask) if get_env_device() != "npu":