diff --git a/src/rules.jl b/src/rules.jl index 03df753..07e5524 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -85,7 +85,7 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule Base.atan(x, y) = :( $y / ($x^2 + $y^2) ), :( -$x / ($x^2 + $y^2) ) @define_diffrule Base.hypot(x, y) = :( $x / hypot($x, $y) ), :( $y / hypot($x, $y) ) @define_diffrule Base.log(b, x) = :( log($x) * inv(-log($b)^2 * $b) ), :( inv($x) / log($b) ) -@define_diffrule Base.ldexp(x, y) = :( exp2($y) ), :NaN +@define_diffrule Base.ldexp(x, y) = :( oftype($x, exp2($y)) ), :NaN @define_diffrule Base.mod(x, y) = :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), one(float(z))) ), :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), -floor(float(z))) ) @define_diffrule Base.rem(x, y) = :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), one(float(z))) ), :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), -trunc(float(z))) ) @@ -296,7 +296,7 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule LogExpFunctions.logmxp1(x) = :((1 - $x) / $x) # binary -@define_diffrule LogExpFunctions.xlogy(x, y) = +@define_diffrule LogExpFunctions.xlogy(x, y) = :(log($y)), :(z = $x / $y; iszero($x) && !isnan($y) ? zero(z) : z) @define_diffrule LogExpFunctions.logaddexp(x, y) = @@ -304,6 +304,6 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule LogExpFunctions.logsubexp(x, y) = :(z = LogExpFunctions.logsubexp($x, $y); $x > $y ? exp($x - z) : -exp($x - z)), :(z = LogExpFunctions.logsubexp($x, $y); $x > $y ? -exp($y - z) : exp($y - z)) -@define_diffrule LogExpFunctions.xlog1py(x, y) = +@define_diffrule LogExpFunctions.xlog1py(x, y) = :(log1p($y)), :(z = $x / (1 + $y); iszero($x) && !isnan($y) ? zero(z) : z)