Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mt5 #382

Merged
merged 13 commits into from
Sep 15, 2022
Merged

add mt5 #382

merged 13 commits into from
Sep 15, 2022

Conversation

xiezipeng-ML
Copy link
Contributor

MT5的0.3.0发布计划

discussions:#380

@xiezipeng-ML
Copy link
Contributor Author

xiezipeng-ML commented Sep 10, 2022

  • add training config
  • add MT5 model
  • add model config
  • add model test
  • T5 Loader同步,MT5模型会加载lm_head权重,T5模型共享embed层权重
  • genertator测试,load huggingface权重后测试生成效果
  • 准备把IDEA那个项目的T5模型部分换成这个branch下的模型,避免模型重复,这个分支作为libai中MT5的“main”,IDEA的交付项目作为一个项目可以使用这里的模型

@CPFLAME @strint

@xiezipeng-ML
Copy link
Contributor Author

xiezipeng-ML commented Sep 10, 2022

模型单测

image

Loader单测

image

@xiezipeng-ML
Copy link
Contributor Author

T5在generator上测试

翻译任务测试

image

@xiezipeng-ML xiezipeng-ML requested review from oneflow-ci-bot and removed request for oneflow-ci-bot September 15, 2022 06:41
if attention_mask is not None:
attention_scores = flow.mul(attention_scores, attention_mask)
attention_scores = attention_scores - 10000.0 * (1 - attention_mask)
# TODO(xingyu.liao): graph will occur `where_scalar` errors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的注释是不是可以改一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

@@ -75,6 +75,12 @@ def _convert_state_dict(self, flow_state_dict, cfg):
prefix1 + "decoder.final_layer_norm.weight"
)

# Convert MT5's lm_head
if cfg.model_type == "mt5":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有个疑问是 这个t5_loader.py是放在libai的文件里面的,

但是看这里好像专门支持的是projects下面mt5所写的loader

是不是放到projects/MT5的文件夹下面更为合理一点.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

)
)
self.init_method(self.weight)
# FIXME(lxy): Fill padding_idx is not supported in nd_sbp right now.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的注释也可以处理一下

@xiezipeng-ML xiezipeng-ML requested review from oneflow-ci-bot and removed request for oneflow-ci-bot September 15, 2022 11:49
@xiezipeng-ML xiezipeng-ML merged commit b9ee884 into main Sep 15, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants