Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transfer the value of stop_gradient for feeding data. #4831

Merged
merged 1 commit into from
Sep 3, 2020

Conversation

Xreki
Copy link
Contributor

@Xreki Xreki commented Sep 2, 2020

图像分类模型中,当设置data_format=“NHWC”时,图像数据依旧按NCHW读进来,之后通过一个tranpose转换成NHWC格式。feed数据都是设置stop_gradient=True,直接传给conv2d,则该conv2d_grad不需要计算input_grad。插入transpose后,feed数据的stop_gradient属性没有传递给tranpose的输出变量,导致了conv2d_grad中产生了多余的计算(计算input_grad)。

这个PR将feed image的stop_gradient属性值传递给transpose的输出变量。

Copy link
Contributor

@wzzju wzzju left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@Xreki Xreki merged commit bc07a01 into PaddlePaddle:develop Sep 3, 2020
@Xreki Xreki deleted the resnet/fix_stop_gradient branch September 3, 2020 05:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants