From 219422146e6f8b9ddefee0d62f7b60996b282eb2 Mon Sep 17 00:00:00 2001 From: Yam <40912707+Yam0214@users.noreply.github.com> Date: Mon, 16 Jan 2023 08:56:14 +0000 Subject: [PATCH] fix generating attention_mask of ernie-m --- paddlenlp/transformers/ernie_m/configuration.py | 2 +- paddlenlp/transformers/ernie_m/modeling.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddlenlp/transformers/ernie_m/configuration.py b/paddlenlp/transformers/ernie_m/configuration.py index 0db82249afd..87d8ee768e1 100644 --- a/paddlenlp/transformers/ernie_m/configuration.py +++ b/paddlenlp/transformers/ernie_m/configuration.py @@ -160,7 +160,7 @@ def __init__( max_position_embeddings: int = 514, type_vocab_size: int = 16, initializer_range: float = 0.02, - pad_token_id: int = 0, + pad_token_id: int = 1, **kwargs ): super().__init__(pad_token_id=pad_token_id, **kwargs) diff --git a/paddlenlp/transformers/ernie_m/modeling.py b/paddlenlp/transformers/ernie_m/modeling.py index 24c70ed9856..69cfaf67571 100644 --- a/paddlenlp/transformers/ernie_m/modeling.py +++ b/paddlenlp/transformers/ernie_m/modeling.py @@ -278,7 +278,7 @@ def forward( if attention_mask is None: attention_mask = paddle.unsqueeze( - (input_ids == 0).astype(self.pooler.dense.weight.dtype) * -1e4, axis=[1, 2] + (input_ids == self.pad_token_id).astype(self.pooler.dense.weight.dtype) * -1e4, axis=[1, 2] ) if past_key_values is not None: batch_size = past_key_values[0][0].shape[0]