diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 802f12ab495..49fd0236f45 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -1532,10 +1532,10 @@ def forward( if position_ids is None: position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) - is_casual_mask = ( + use_casual_mask = ( True if hasattr(self.config, "use_casual_mask") and self.config.use_casual_mask is True else False ) - if is_casual_mask: + if use_casual_mask: attention_mask = None else: attention_mask = self._prepare_decoder_attention_mask( @@ -1545,7 +1545,7 @@ def forward( is_casual = False if self.config.use_flash_attention and get_env_device() != "gcu": - if is_casual_mask: + if use_casual_mask: is_casual = True else: is_casual = is_casual_mask(attention_mask)