Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rm nn.Graph.train #5424

Merged
merged 6 commits into from
Jul 8, 2021
Merged

rm nn.Graph.train #5424

merged 6 commits into from
Jul 8, 2021

Conversation

strint
Copy link
Contributor

@strint strint commented Jul 8, 2021

讨论认为把train的控制还都放在nn.Module层面处理,能处理GAN这些部分网络Train、部分网络Predict的情况。

这里总结下利用Module freeze部分网络的方法:

  • nn.Moudle.train(False),修改Module.training标识,控制batchnorm/dropout等的forward执行逻辑;
  • nn.Moudle.requires_grad_(False),修改Module的parameters,使其不需要grad,变成普通tensor,不为其生成grad;
  • optimizer不绑定特定parameter,这样就optimizer就不会更新对应parameter;

这部分对train的控制比较精细、也复杂,所以nn.Graph层面就不做总体控制。后面nn.Graph根据是否add_optimizer判断时训练任务,然后做些检查。

@strint strint requested review from chengtbf and leaves-zwx July 8, 2021 04:39
@strint strint requested a review from oneflow-ci-bot July 8, 2021 04:42
@oneflow-ci-bot oneflow-ci-bot removed their request for review July 8, 2021 05:54
@strint strint requested a review from oneflow-ci-bot July 8, 2021 10:10
@strint strint requested review from oneflow-ci-bot and removed request for oneflow-ci-bot July 8, 2021 15:24
@oneflow-ci-bot oneflow-ci-bot removed their request for review July 8, 2021 17:28
@oneflow-ci-bot oneflow-ci-bot merged commit a94748f into master Jul 8, 2021
@oneflow-ci-bot oneflow-ci-bot deleted the fix/nn_graph/rm_train_func branch July 8, 2021 17:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants