Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
yinwei committed May 23, 2024
1 parent 533c12d commit f63d157
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ def _get_interleave_power_of_2(n):
def build_alibi_tensor(
bool_attention_mask: Tensor, num_heads: int, dtype: paddle.dtype, tensor_parallel_degree=1
) -> Tensor:
attention_mask = bool_attention_mask.astype("float32")
batch_size, seq_length = attention_mask.shape[0], attention_mask.shape[-1]
batch_size, seq_length = bool_attention_mask.shape[0], bool_attention_mask.shape[-1]

Check warning on line 121 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L121

Added line #L121 was not covered by tests
slopes = paddle.to_tensor(_get_interleave(num_heads), dtype="float32")
alibi = slopes.unsqueeze(axis=[1, 2]) * paddle.arange(seq_length, dtype="float32").unsqueeze(axis=[0, 1]).expand(
[num_heads, -1, -1]
Expand Down Expand Up @@ -1533,15 +1532,21 @@ def forward(
if position_ids is None:
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))

attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]
is_casual_mask = (
True if hasattr(self.config, "use_casual_mask") and self.config.use_casual_mask is True else False

Check warning on line 1536 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1536

Added line #L1536 was not covered by tests
)
if is_casual_mask:

Check warning on line 1538 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1538

Added line #L1538 was not covered by tests
attention_mask = None
else:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]

is_casual = False

if self.config.use_flash_attention and get_env_device() != "gcu":
if hasattr(self.config, "use_casual_mask"):
is_casual = self.config.use_casual_mask
if is_casual_mask:
is_casual = True

Check warning on line 1549 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1549

Added line #L1549 was not covered by tests
else:
is_casual = is_casual_mask(attention_mask)
if get_env_device() != "npu":
Expand Down

0 comments on commit f63d157

Please sign in to comment.