diff --git a/src/rules.jl b/src/rules.jl index f6a976e..47e7576 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -208,7 +208,12 @@ end @define_diffrule NaNMath.log1p(x) = :( inv($x + 1) ) @define_diffrule NaNMath.lgamma(x) = :( SpecialFunctions.digamma($x) ) + # binary # #--------# @define_diffrule NaNMath.pow(x, y) = :( $y * NaNMath.pow($x, ($y - 1)) ), :( NaNMath.pow($x, $y) * NaNMath.log($x) ) +@define_diffrule NaNMath.max(x, y) = :(ifelse(($y > $x) | (signbit($y) < signbit($x)), ifelse(isnan($y), one($x), zero($x)), ifelse(isnan($x), zero($x), one($x)))), + :(ifelse(($y > $x) | (signbit($y) < signbit($x)), ifelse(isnan($y), zero($y), one($y)), ifelse(isnan($x), one($y), zero($y)))) +@define_diffrule NaNMath.min(x, y) = :(ifelse(($y < $x) | (signbit($y) > signbit($x)), ifelse(isnan($y), one($x), zero($x)), ifelse(isnan($x), zero($x), one($x)))), + :(ifelse(($y < $x) | (signbit($y) > signbit($x)), ifelse(isnan($y), zero($y), one($y)), ifelse(isnan($x), one($x), zero($x))))