Skip to content

Commit

Permalink
update is_casual_mask to use_casual_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
iosmers committed May 23, 2024
1 parent 92b106f commit 370d2c9
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,10 +1532,10 @@ def forward(
if position_ids is None:
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))

is_casual_mask = (
use_casual_mask = (
True if hasattr(self.config, "use_casual_mask") and self.config.use_casual_mask is True else False

Check warning on line 1536 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1536

Added line #L1536 was not covered by tests
)
if is_casual_mask:
if use_casual_mask:

Check warning on line 1538 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1538

Added line #L1538 was not covered by tests
attention_mask = None
else:
attention_mask = self._prepare_decoder_attention_mask(
Expand All @@ -1545,7 +1545,7 @@ def forward(
is_casual = False

if self.config.use_flash_attention and get_env_device() != "gcu":
if is_casual_mask:
if use_casual_mask:
is_casual = True

Check warning on line 1549 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1549

Added line #L1549 was not covered by tests
else:
is_casual = is_casual_mask(attention_mask)
Expand Down

0 comments on commit 370d2c9

Please sign in to comment.