Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade Roberta tokenizer #1821

Merged
merged 13 commits into from
Mar 28, 2022
Merged

Upgrade Roberta tokenizer #1821

merged 13 commits into from
Mar 28, 2022

Conversation

yingyibiao
Copy link
Contributor

@yingyibiao yingyibiao commented Mar 23, 2022

PR types

Performance optimization

PR changes

Models

Description

  1. Upgrade Roberta tokenizer to support both "Bert" style and "BPE" style tokenizer.
  2. Add optional argument "output_hidden_states" to output hidden_states of each hidden layer.
  3. Move community directory one level up in bos.

@yingyibiao yingyibiao marked this pull request as ready for review March 23, 2022 08:18
Copy link
Member

@ZeyuChen ZeyuChen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注意自查下链接变动是否会导致哪些模型无法下载。
以及2.3版本发布后,是否能监控到社区模型与自有模型的下载量情况

paddlenlp/utils/downloader.py Show resolved Hide resolved
@yingyibiao
Copy link
Contributor Author

注意自查下链接变动是否会导致哪些模型无法下载。 以及2.3版本发布后,是否能监控到社区模型与自有模型的下载量情况

  1. 完成所有模型地址的迁移
  2. 模型下载量还未进行监控

attention_mask = attention_mask.unsqueeze(axis=[1, 2])
attention_mask = attention_mask.unsqueeze(
axis=[1, 2]).astype(paddle.get_default_dtype())
attention_mask = (1.0 - attention_mask) * -1e4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里和其他模型中ndim==2的行为是一致的不,是已经统一了mask的这个语义了吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还没有统一,这个pr统一了 bert 和 roberta

@@ -360,17 +375,26 @@ def forward(self,
attention_mask = paddle.unsqueeze(
attention_mask, axis=[1, 2]).astype(paddle.get_default_dtype())
attention_mask = (1.0 - attention_mask) * -1e4
attention_mask.stop_gradient = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么去掉呢,attention_mask这个确实是可以stop_gradient的吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attention_mask 默认的 stop_gradient 就是 True,这一行代码是冗余的。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是打印出来attention_mask看了是吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attention_mask 是和参数无关的 tensor,是没有梯度的。(也打印验证过,stop_gradient=True)

sequence_output = encoder_outputs
pooled_output = self.pooler(sequence_output)
return sequence_output, pooled_output
if output_hidden_states:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还要注意后面所有模型进行统一一致性的方案考虑,另外看看是否能插件化这个功能需求 #1752 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,roberta目前采用的是和bert一致的方案。

"roberta-base-ft-cluener2020-chn":
"https://bj.bcebos.com/paddlenlp/models/transformers/community/nosaydomore/uer_roberta_base_finetuned_cluener2020_chinese/vocab.txt",
"roberta-base-chn-extractive-qa":
"https://bj.bcebos.com/paddlenlp/models/transformers/community/nosaydomore/uer_roberta_base_chinese_extractive_qa/vocab.txt",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要明确是否会对已有模型用法造成影响

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不会造成影响。

pretrained_model_name_or_path, *model_args, **kwargs)
else:
return RobertaBPETokenizer.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果为了解决Roberta Tokenizer这种包含多个需要不同resouce file的tokenizer在from_pretrained加载community模型时的问题,是否能提取出公共的方案,对其他包含多个需要不同resouce file的tokenizer也能适用,比如ALBERT。如果按照目前这样,开发者贡献一个这样的tokenizer代价还是比其他tokenizer大不少的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如上周讨论,该方案在具体实现的 tokenizer 上进行了一层轻量级的封装,开发成本比原本方案(如 ALBERT)小很多,且更加不容易出错。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的重点在于说看看如何将办法通用化,而不要只考虑一个RobertaTokenizer,要提供一个对于开发者简单的方案。普通的模型贡献者是无法获知COMMUNITY_MODEL_PREFIX这些内容的,如何让他们也能解决这一类问题

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

开发成本比原本方案(如 ALBERT)小很多

这个其实是因为原来没有使用__getattribute__转发请求到内含的tokenizer来这样来实现,所以显得原来的代码多

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK,这里我后续再看看。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

要把当前已知的问题都建立issue记录下来,加入开发计划 @yingyibiao

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

@@ -88,7 +88,7 @@ def __init__(self,
sentencepiece_model_file,
do_lower_case=False,
remove_space=True,
keep_accents=False,
keep_accents=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个默认值为什么修改了呢,修改了是否example会有影响

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hf 相对应的默认值是 True,NLP同学反馈了这个问题。

@yingyibiao yingyibiao merged commit 3351ab0 into PaddlePaddle:develop Mar 28, 2022
@yingyibiao yingyibiao deleted the roberta branch March 28, 2022 10:55
ZeyuChen pushed a commit to ZeyuChen/PaddleNLP that referenced this pull request Apr 17, 2022
* update roberta

* update roberta tokenizer

* update roberta tokenizer

* update

* update
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants