Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ming1753 committed May 16, 2024
1 parent 642ab05 commit 16840b1
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
20 changes: 12 additions & 8 deletions paddle/phi/kernels/fusion/gpu/block_multi_head_attention_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,12 @@ void DispatchWithDtype(
max_dec_len_this_time_data =
GetMaxLen(dev_ctx, seq_lens_decoder, &max_dec_len_tensor, bsz);
} else {
PADDLE_ENFORCE_EQ(max_dec_len_this_time.get().place().GetType(),
phi::AllocationType::CPU,
"max_dec_len_this_time must be on CPU, but Got %s.",
max_dec_len_this_time.get().place());
PADDLE_ENFORCE_EQ(
max_dec_len_this_time.get().place().GetType(),
phi::AllocationType::CPU,
errors::InvalidArgument(
"The place of input max_dec_len_this_time must be CPU, but got %s.",
max_dec_len_this_time.get().place()));
max_dec_len_this_time_data = *max_dec_len_this_time.get().data<int>();
}

Expand All @@ -327,10 +329,12 @@ void DispatchWithDtype(
max_enc_len_this_time_data =
GetMaxLen(dev_ctx, seq_lens_encoder, &max_enc_len_tensor, bsz);
} else {
PADDLE_ENFORCE_EQ(max_enc_len_this_time.get().place().GetType(),
phi::AllocationType::CPU,
"max_enc_len_this_time must be on CPU, but Got %s.",
max_enc_len_this_time.get().place());
PADDLE_ENFORCE_EQ(
max_enc_len_this_time.get().place().GetType(),
phi::AllocationType::CPU,
errors::InvalidArgument(
"The place of input max_enc_len_this_time must be CPU, but got %s.",
max_enc_len_this_time.get().place()));
max_enc_len_this_time_data = *max_enc_len_this_time.get().data<int>();
}

Expand Down
4 changes: 2 additions & 2 deletions python/paddle/incubate/nn/functional/blha_get_max_len.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from paddle import _C_ops
from paddle.framework import LayerHelper, in_dynamic_mode
from paddle.framework import LayerHelper, in_dynamic_or_pir_mode


def blha_get_max_len(seq_lens_encoder, seq_lens_decoder, batch_size):
Expand Down Expand Up @@ -41,7 +41,7 @@ def blha_get_max_len(seq_lens_encoder, seq_lens_decoder, batch_size):
>>> batch_size = paddle.ones(shape=[bsz])
>>> max_enc_len_this_time, max_dec_len_this_time = paddle.incubate.nn.functional.blha_get_max_len(seq_lens_encoder, seq_lens_decoder, batch_size)
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.blha_get_max_len(
seq_lens_encoder, seq_lens_decoder, batch_size
)
Expand Down

0 comments on commit 16840b1

Please sign in to comment.