Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade Roberta tokenizer #1821

Merged
merged 13 commits into from
Mar 28, 2022
3 changes: 2 additions & 1 deletion paddlenlp/transformers/auto/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@
("MBartTokenizer", "mbart"),
("MPNetTokenizer", "mpnet"),
("NeZhaTokenizer", "nezha"),
("RobertaTokenizer", "roberta"),
("RobertaChineseTokenizer", "roberta"),
("RobertaBPETokenizer", "roberta"),
("RoFormerTokenizer", "roformer"),
("ReformerTokenizer", "reformer"),
("SqueezeBertTokenizer", "squeezebert"),
Expand Down
4 changes: 3 additions & 1 deletion paddlenlp/transformers/bert/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,9 @@ def forward(self,
else:
if attention_mask.ndim == 2:
# attention_mask [batch_size, sequence_length] -> [batch_size, 1, 1, sequence_length]
attention_mask = attention_mask.unsqueeze(axis=[1, 2])
attention_mask = attention_mask.unsqueeze(
axis=[1, 2]).astype(paddle.get_default_dtype())
attention_mask = (1.0 - attention_mask) * -1e4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里和其他模型中ndim==2的行为是一致的不,是已经统一了mask的这个语义了吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还没有统一,这个pr统一了 bert 和 roberta


embedding_output = self.embeddings(
input_ids=input_ids,
Expand Down
Loading