Skip to content

Commit

Permalink
fix bce loss half grad functor (#7476)
Browse files Browse the repository at this point in the history
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
2 people authored and marigoold committed Mar 15, 2022
1 parent dd8b790 commit d229d4d
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions oneflow/user/kernels/binary_cross_entropy_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,16 @@ struct BinaryCrossEntropyGradFunctor {

template<>
struct BinaryCrossEntropyGradFunctor<half> {
half eps_;
half one_;
BinaryCrossEntropyGradFunctor() : eps_(__float2half(1e-12)), one_(__float2half(1.f)) {}
BinaryCrossEntropyGradFunctor<float> float_functor;
BinaryCrossEntropyGradFunctor() {}
__device__ __forceinline__ half operator()(half input_val, half target_val, half dy_val) const {
half divisor = (one_ - input_val) * input_val;
if (divisor < eps_) { divisor = eps_; }
return dy_val * (input_val - target_val) / divisor;
return __float2half(
float_functor(__half2float(input_val), __half2float(target_val), __half2float(dy_val)));
}
__device__ __forceinline__ half operator()(half input_val, half target_val, half dy_val,
half weight_val) const {
return (*this)(input_val, target_val, dy_val) * weight_val;
return __float2half(float_functor(__half2float(input_val), __half2float(target_val),
__half2float(dy_val), __half2float(weight_val)));
}
};

Expand Down

0 comments on commit d229d4d

Please sign in to comment.