Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bce loss half grad functor #7476

Merged
merged 4 commits into from
Feb 11, 2022
Merged

Conversation

guo-ran
Copy link
Contributor

@guo-ran guo-ran commented Feb 11, 2022

BinaryCrossEntropyGradFunctor的half的内部计算应该转成float进行计算,现有计算方式可能会产生nan

@@ -90,17 +90,16 @@ struct BinaryCrossEntropyGradFunctor {

template<>
struct BinaryCrossEntropyGradFunctor<half> {
half eps_;
half one_;
BinaryCrossEntropyGradFunctor() : eps_(__float2half(1e-12)), one_(__float2half(1.f)) {}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里1e-12超出了half的表示范围

__device__ __forceinline__ half operator()(half input_val, half target_val, half dy_val) const {
half divisor = (one_ - input_val) * input_val;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当divisor非常小时可能会出现0

__device__ __forceinline__ half operator()(half input_val, half target_val, half dy_val) const {
half divisor = (one_ - input_val) * input_val;
if (divisor < eps_) { divisor = eps_; }
return dy_val * (input_val - target_val) / divisor;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当divisor为0时会出现nan。

@@ -90,17 +90,16 @@ struct BinaryCrossEntropyGradFunctor {

template<>
struct BinaryCrossEntropyGradFunctor<half> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BinaryCrossEntropyGradFunctor的half的内部计算应该转成float进行计算

@oneflow-ci-bot oneflow-ci-bot requested review from oneflow-ci-bot and removed request for oneflow-ci-bot February 11, 2022 09:42
@github-actions
Copy link
Contributor

Speed stats:
GPU Name: GeForce GTX 1080 

✔️ OneFlow resnet50 time: 129.0ms (= 12896.1ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 142.3ms (= 14234.3ms / 100, input_shape=[16, 3, 224, 224])
✔️ Relative speed: 1.10 (= 142.3ms / 129.0ms)

✔️ OneFlow resnet50 time: 75.7ms (= 7573.1ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 84.5ms (= 8454.2ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.12 (= 84.5ms / 75.7ms)

OneFlow resnet50 time: 52.9ms (= 10572.5ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 57.8ms (= 11567.2ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.09 (= 57.8ms / 52.9ms)

OneFlow resnet50 time: 41.1ms (= 8214.4ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 49.9ms (= 9980.5ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.21 (= 49.9ms / 41.1ms)

OneFlow resnet50 time: 36.7ms (= 7345.5ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 41.7ms (= 8349.5ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.14 (= 41.7ms / 36.7ms)

✔️ OneFlow resnet50 time: 143.5ms (= 14347.8ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 159.1ms (= 15907.9ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.11 (= 159.1ms / 143.5ms)

OneFlow resnet50 time: 88.1ms (= 8813.0ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 102.3ms (= 10229.1ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.16 (= 102.3ms / 88.1ms)

OneFlow resnet50 time: 61.6ms (= 12321.4ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 75.3ms (= 15060.9ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.22 (= 75.3ms / 61.6ms)

OneFlow resnet50 time: 56.3ms (= 11261.3ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 66.9ms (= 13376.6ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.19 (= 66.9ms / 56.3ms)

OneFlow resnet50 time: 62.3ms (= 12454.2ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 61.1ms (= 12221.3ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 0.98 (= 61.1ms / 62.3ms)

@oneflow-ci-bot oneflow-ci-bot removed their request for review February 11, 2022 11:40
@oneflow-ci-bot oneflow-ci-bot merged commit 9fda259 into master Feb 11, 2022
@oneflow-ci-bot oneflow-ci-bot deleted the dev_fix_bce_half_grad_functor branch February 11, 2022 11:40
marigoold pushed a commit that referenced this pull request Mar 15, 2022
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants