Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add fuse_bn_act op #27230

Merged
merged 6 commits into from
Sep 23, 2020
Merged

Conversation

zhangting2020
Copy link
Contributor

@zhangting2020 zhangting2020 commented Sep 10, 2020

PR types

Function optimization

PR changes

OPs

Describe

This Op performs batch norm on input x, and adds the result to input y. Then it performs activation on the sum. We use cuDNN API to implements this function, the following points need to be noted:

This Op will be used in automatic mixed precision training of the resnet model. The following image is part of the model. The red parts represent the inputs of this Op. The green parts represent the computation performed by the Op.
Untitled Diagram (2)

Performance of ResNet50 AMP Training

Test on V100, CUDA 10.1, cuDNN 7.6, single card, BS=128

  • before:1015.18 imgs/s
  • after:1085.98 imgs/s,+6.9%

loss and accuracy

set fuse_bn_add_act=true and train 63 epochs
image
image
image
image
image

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

luotao1
luotao1 previously approved these changes Sep 21, 2020
Copy link
Contributor

@luotao1 luotao1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

wangchaochaohu
wangchaochaohu previously approved these changes Sep 21, 2020
Copy link
Contributor

@wangchaochaohu wangchaochaohu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhiqiu zhiqiu self-requested a review September 21, 2020 05:45
zhiqiu
zhiqiu previously approved these changes Sep 21, 2020
Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for paddle.fluid.contrib.fused_bn_add_act wihout core.ops since it is used in static graph only.

'matmul',
'mul',
}
white_list = {'conv2d', 'matmul', 'mul', 'fused_bn_add_activation'}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think fused_bn_add_activation should be added in the gray_list .

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

Comment on lines +1743 to +1745
check_variable_and_dtype(x, 'input', ['float16', 'float32', 'float64'],
'fused_bn_add_act')
check_variable_and_dtype(y, 'input', ['float16', 'float32', 'float64'],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, you have only registered the float16 kernel. So, 'float32' and 'float64' is not needed here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dtype check is performed during the compilation time, and the limit to float16 will cause the check to fail.

Comment on lines 76 to 77
if in_name != 'X' or in_name != 'Z':
continue
Copy link
Contributor

@wzzju wzzju Sep 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the condition test here maybe wrong. Maybe the logic below is more understandable.

if in_name not in {'X', 'Z'}:
    continue

Maybe the condition test about batch_norm can be simplified as below:

if src_dtype == core.VarDesc.VarType.FP32 and op.type in {'batch_norm', 'fused_bn_add_activation'}:
    if in_name not in {'X', 'Z'}:
        continue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

Comment on lines 108 to 111
if op.type == 'batch_norm' and out_name != 'Y':
continue
if op.type == 'fused_bn_add_activation' and out_name != 'Y':
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if op.type in {'batch_norm', 'fused_bn_add_activation'} and out_name != 'Y':
    continue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

Comment on lines 257 to 260
saved_mean->template data<BatchNormParamType<float>>();
const auto *saved_var_data =
saved_var->template data<BatchNormParamType<float>>();

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use T instead of float.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

@wzzju wzzju left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Copy link
Contributor

@TCChenlong TCChenlong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhiqiu zhiqiu self-requested a review September 22, 2020 06:59
@zhangting2020 zhangting2020 merged commit 906e7f9 into PaddlePaddle:develop Sep 23, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants