From c810f8bde24a88fdc3cd431b7ec8a8d6cb9e0b64 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Thu, 3 Nov 2022 11:44:25 -0400 Subject: [PATCH 1/4] Make Float32 stable for both arguments --- src/rules.jl | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index 03df753..fad8603 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -85,11 +85,11 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule Base.atan(x, y) = :( $y / ($x^2 + $y^2) ), :( -$x / ($x^2 + $y^2) ) @define_diffrule Base.hypot(x, y) = :( $x / hypot($x, $y) ), :( $y / hypot($x, $y) ) @define_diffrule Base.log(b, x) = :( log($x) * inv(-log($b)^2 * $b) ), :( inv($x) / log($b) ) -@define_diffrule Base.ldexp(x, y) = :( exp2($y) ), :NaN +@define_diffrule Base.ldexp(x, y) = :( oftype(float($x), exp2($y)) ), :(oftype(float($x), NaN)) @define_diffrule Base.mod(x, y) = :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), one(float(z))) ), :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), -floor(float(z))) ) @define_diffrule Base.rem(x, y) = :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), one(float(z))) ), :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), -trunc(float(z))) ) -@define_diffrule Base.rem2pi(x, r) = :( 1 ), :NaN +@define_diffrule Base.rem2pi(x, r) = :( 1 ), :(oftype(float($x), NaN)) @define_diffrule Base.max(x, y) = :( $x > $y ? one($x) : zero($x) ), :( $x > $y ? zero($y) : one($y) ) @define_diffrule Base.min(x, y) = :( $x > $y ? zero($x) : one($x) ), :( $x > $y ? one($y) : zero($y) ) @@ -200,36 +200,36 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) # for forward-mode and reverse-mode derivatives for complex inputs @define_diffrule SpecialFunctions.besselj(ν, x) = - :NaN, :( (SpecialFunctions.besselj($ν - 1, $x) - SpecialFunctions.besselj($ν + 1, $x)) / 2 ) + :(oftype($x, NaN)), :( (SpecialFunctions.besselj($ν - 1, $x) - SpecialFunctions.besselj($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.besseljx(ν, x) = - :NaN, :( (SpecialFunctions.besseljx($ν - 1, $x) - SpecialFunctions.besseljx($ν + 1, $x)) / 2 ) + :(oftype($x, NaN)), :( (SpecialFunctions.besseljx($ν - 1, $x) - SpecialFunctions.besseljx($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.besseli(ν, x) = - :NaN, :( (SpecialFunctions.besseli($ν - 1, $x) + SpecialFunctions.besseli($ν + 1, $x)) / 2 ) + :(oftype($x, NaN)), :( (SpecialFunctions.besseli($ν - 1, $x) + SpecialFunctions.besseli($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.besselix(ν, x) = - :NaN, :( (SpecialFunctions.besselix($ν - 1, $x) + SpecialFunctions.besselix($ν + 1, $x)) / 2 - sign($x) * SpecialFunctions.besselix($ν, $x) ) + :(oftype($x, NaN)), :( (SpecialFunctions.besselix($ν - 1, $x) + SpecialFunctions.besselix($ν + 1, $x)) / 2 - sign($x) * SpecialFunctions.besselix($ν, $x) ) @define_diffrule SpecialFunctions.bessely(ν, x) = - :NaN, :( (SpecialFunctions.bessely($ν - 1, $x) - SpecialFunctions.bessely($ν + 1, $x)) / 2 ) + :(oftype($x, NaN)), :( (SpecialFunctions.bessely($ν - 1, $x) - SpecialFunctions.bessely($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.besselyx(ν, x) = - :NaN, :( (SpecialFunctions.besselyx($ν - 1, $x) - SpecialFunctions.besselyx($ν + 1, $x)) / 2 ) + :(oftype($x, NaN)), :( (SpecialFunctions.besselyx($ν - 1, $x) - SpecialFunctions.besselyx($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.besselk(ν, x) = - :NaN, :( -(SpecialFunctions.besselk($ν - 1, $x) + SpecialFunctions.besselk($ν + 1, $x)) / 2 ) + :(oftype($x, NaN)), :( -(SpecialFunctions.besselk($ν - 1, $x) + SpecialFunctions.besselk($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.besselkx(ν, x) = - :NaN, :( -(SpecialFunctions.besselkx($ν - 1, $x) + SpecialFunctions.besselkx($ν + 1, $x)) / 2 + SpecialFunctions.besselkx($ν, $x) ) + :(oftype($x, NaN)), :( -(SpecialFunctions.besselkx($ν - 1, $x) + SpecialFunctions.besselkx($ν + 1, $x)) / 2 + SpecialFunctions.besselkx($ν, $x) ) @define_diffrule SpecialFunctions.besselh(ν, x) = - :NaN, :( (SpecialFunctions.besselh($ν - 1, $x) - SpecialFunctions.besselh($ν + 1, $x)) / 2 ) + :(oftype($x, NaN)), :( (SpecialFunctions.besselh($ν - 1, $x) - SpecialFunctions.besselh($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.besselhx(ν, x) = - :NaN, :( (SpecialFunctions.besselhx($ν - 1, $x) - SpecialFunctions.besselhx($ν + 1, $x)) / 2 - im * SpecialFunctions.besselhx($ν, $x) ) + :(oftype($x, NaN)), :( (SpecialFunctions.besselhx($ν - 1, $x) - SpecialFunctions.besselhx($ν + 1, $x)) / 2 - im * SpecialFunctions.besselhx($ν, $x) ) @define_diffrule SpecialFunctions.hankelh1(ν, x) = - :NaN, :( (SpecialFunctions.hankelh1($ν - 1, $x) - SpecialFunctions.hankelh1($ν + 1, $x)) / 2 ) + :(oftype($x, NaN)), :( (SpecialFunctions.hankelh1($ν - 1, $x) - SpecialFunctions.hankelh1($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.hankelh1x(ν, x) = - :NaN, :( (SpecialFunctions.hankelh1x($ν - 1, $x) - SpecialFunctions.hankelh1x($ν + 1, $x)) / 2 - im * SpecialFunctions.hankelh1x($ν, $x) ) + :(oftype($x, NaN)), :( (SpecialFunctions.hankelh1x($ν - 1, $x) - SpecialFunctions.hankelh1x($ν + 1, $x)) / 2 - im * SpecialFunctions.hankelh1x($ν, $x) ) @define_diffrule SpecialFunctions.hankelh2(ν, x) = - :NaN, :( (SpecialFunctions.hankelh2($ν - 1, $x) - SpecialFunctions.hankelh2($ν + 1, $x)) / 2 ) + :(oftype($x, NaN)), :( (SpecialFunctions.hankelh2($ν - 1, $x) - SpecialFunctions.hankelh2($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.hankelh2x(ν, x) = - :NaN, :( (SpecialFunctions.hankelh2x($ν - 1, $x) - SpecialFunctions.hankelh2x($ν + 1, $x)) / 2 + im * SpecialFunctions.hankelh2x($ν, $x) ) + :(oftype($x, NaN)), :( (SpecialFunctions.hankelh2x($ν - 1, $x) - SpecialFunctions.hankelh2x($ν + 1, $x)) / 2 + im * SpecialFunctions.hankelh2x($ν, $x) ) @define_diffrule SpecialFunctions.polygamma(m, x) = - :NaN, :( SpecialFunctions.polygamma($m + 1, $x) ) + :(oftype($x, NaN)), :( SpecialFunctions.polygamma($m + 1, $x) ) @define_diffrule SpecialFunctions.beta(a, b) = :( SpecialFunctions.beta($a, $b)*(SpecialFunctions.digamma($a) - SpecialFunctions.digamma($a + $b)) ), :( SpecialFunctions.beta($a, $b)*(SpecialFunctions.digamma($b) - SpecialFunctions.digamma($a + $b)) ) @@ -238,7 +238,7 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) # derivative wrt to `s` is not implemented @define_diffrule SpecialFunctions.zeta(s, z) = - :NaN, :( - $s * SpecialFunctions.zeta($s + 1, $z) ) + :(oftype($z, NaN)), :( - $s * SpecialFunctions.zeta($s + 1, $z) ) # ternary # #---------# @@ -296,7 +296,7 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule LogExpFunctions.logmxp1(x) = :((1 - $x) / $x) # binary -@define_diffrule LogExpFunctions.xlogy(x, y) = +@define_diffrule LogExpFunctions.xlogy(x, y) = :(log($y)), :(z = $x / $y; iszero($x) && !isnan($y) ? zero(z) : z) @define_diffrule LogExpFunctions.logaddexp(x, y) = @@ -304,6 +304,6 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule LogExpFunctions.logsubexp(x, y) = :(z = LogExpFunctions.logsubexp($x, $y); $x > $y ? exp($x - z) : -exp($x - z)), :(z = LogExpFunctions.logsubexp($x, $y); $x > $y ? -exp($y - z) : exp($y - z)) -@define_diffrule LogExpFunctions.xlog1py(x, y) = +@define_diffrule LogExpFunctions.xlog1py(x, y) = :(log1p($y)), :(z = $x / (1 + $y); iszero($x) && !isnan($y) ? zero(z) : z) From 535597e00cbea5d54d47c6a95bdec838ccb4066d Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Thu, 3 Nov 2022 14:20:09 -0400 Subject: [PATCH 2/4] revert :NaN change --- src/rules.jl | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index fad8603..3dc1866 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -89,7 +89,7 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule Base.mod(x, y) = :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), one(float(z))) ), :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), -floor(float(z))) ) @define_diffrule Base.rem(x, y) = :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), one(float(z))) ), :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), -trunc(float(z))) ) -@define_diffrule Base.rem2pi(x, r) = :( 1 ), :(oftype(float($x), NaN)) +@define_diffrule Base.rem2pi(x, r) = :( 1 ), :NaN @define_diffrule Base.max(x, y) = :( $x > $y ? one($x) : zero($x) ), :( $x > $y ? zero($y) : one($y) ) @define_diffrule Base.min(x, y) = :( $x > $y ? zero($x) : one($x) ), :( $x > $y ? one($y) : zero($y) ) @@ -200,36 +200,36 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) # for forward-mode and reverse-mode derivatives for complex inputs @define_diffrule SpecialFunctions.besselj(ν, x) = - :(oftype($x, NaN)), :( (SpecialFunctions.besselj($ν - 1, $x) - SpecialFunctions.besselj($ν + 1, $x)) / 2 ) + :NaN, :( (SpecialFunctions.besselj($ν - 1, $x) - SpecialFunctions.besselj($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.besseljx(ν, x) = - :(oftype($x, NaN)), :( (SpecialFunctions.besseljx($ν - 1, $x) - SpecialFunctions.besseljx($ν + 1, $x)) / 2 ) + :NaN, :( (SpecialFunctions.besseljx($ν - 1, $x) - SpecialFunctions.besseljx($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.besseli(ν, x) = - :(oftype($x, NaN)), :( (SpecialFunctions.besseli($ν - 1, $x) + SpecialFunctions.besseli($ν + 1, $x)) / 2 ) + :NaN, :( (SpecialFunctions.besseli($ν - 1, $x) + SpecialFunctions.besseli($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.besselix(ν, x) = - :(oftype($x, NaN)), :( (SpecialFunctions.besselix($ν - 1, $x) + SpecialFunctions.besselix($ν + 1, $x)) / 2 - sign($x) * SpecialFunctions.besselix($ν, $x) ) + :NaN, :( (SpecialFunctions.besselix($ν - 1, $x) + SpecialFunctions.besselix($ν + 1, $x)) / 2 - sign($x) * SpecialFunctions.besselix($ν, $x) ) @define_diffrule SpecialFunctions.bessely(ν, x) = - :(oftype($x, NaN)), :( (SpecialFunctions.bessely($ν - 1, $x) - SpecialFunctions.bessely($ν + 1, $x)) / 2 ) + :NaN, :( (SpecialFunctions.bessely($ν - 1, $x) - SpecialFunctions.bessely($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.besselyx(ν, x) = - :(oftype($x, NaN)), :( (SpecialFunctions.besselyx($ν - 1, $x) - SpecialFunctions.besselyx($ν + 1, $x)) / 2 ) + :NaN, :( (SpecialFunctions.besselyx($ν - 1, $x) - SpecialFunctions.besselyx($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.besselk(ν, x) = - :(oftype($x, NaN)), :( -(SpecialFunctions.besselk($ν - 1, $x) + SpecialFunctions.besselk($ν + 1, $x)) / 2 ) + :NaN, :( -(SpecialFunctions.besselk($ν - 1, $x) + SpecialFunctions.besselk($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.besselkx(ν, x) = - :(oftype($x, NaN)), :( -(SpecialFunctions.besselkx($ν - 1, $x) + SpecialFunctions.besselkx($ν + 1, $x)) / 2 + SpecialFunctions.besselkx($ν, $x) ) + :NaN, :( -(SpecialFunctions.besselkx($ν - 1, $x) + SpecialFunctions.besselkx($ν + 1, $x)) / 2 + SpecialFunctions.besselkx($ν, $x) ) @define_diffrule SpecialFunctions.besselh(ν, x) = - :(oftype($x, NaN)), :( (SpecialFunctions.besselh($ν - 1, $x) - SpecialFunctions.besselh($ν + 1, $x)) / 2 ) + :NaN, :( (SpecialFunctions.besselh($ν - 1, $x) - SpecialFunctions.besselh($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.besselhx(ν, x) = - :(oftype($x, NaN)), :( (SpecialFunctions.besselhx($ν - 1, $x) - SpecialFunctions.besselhx($ν + 1, $x)) / 2 - im * SpecialFunctions.besselhx($ν, $x) ) + :NaN, :( (SpecialFunctions.besselhx($ν - 1, $x) - SpecialFunctions.besselhx($ν + 1, $x)) / 2 - im * SpecialFunctions.besselhx($ν, $x) ) @define_diffrule SpecialFunctions.hankelh1(ν, x) = - :(oftype($x, NaN)), :( (SpecialFunctions.hankelh1($ν - 1, $x) - SpecialFunctions.hankelh1($ν + 1, $x)) / 2 ) + :NaN, :( (SpecialFunctions.hankelh1($ν - 1, $x) - SpecialFunctions.hankelh1($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.hankelh1x(ν, x) = - :(oftype($x, NaN)), :( (SpecialFunctions.hankelh1x($ν - 1, $x) - SpecialFunctions.hankelh1x($ν + 1, $x)) / 2 - im * SpecialFunctions.hankelh1x($ν, $x) ) + :NaN, :( (SpecialFunctions.hankelh1x($ν - 1, $x) - SpecialFunctions.hankelh1x($ν + 1, $x)) / 2 - im * SpecialFunctions.hankelh1x($ν, $x) ) @define_diffrule SpecialFunctions.hankelh2(ν, x) = - :(oftype($x, NaN)), :( (SpecialFunctions.hankelh2($ν - 1, $x) - SpecialFunctions.hankelh2($ν + 1, $x)) / 2 ) + :NaN, :( (SpecialFunctions.hankelh2($ν - 1, $x) - SpecialFunctions.hankelh2($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.hankelh2x(ν, x) = - :(oftype($x, NaN)), :( (SpecialFunctions.hankelh2x($ν - 1, $x) - SpecialFunctions.hankelh2x($ν + 1, $x)) / 2 + im * SpecialFunctions.hankelh2x($ν, $x) ) + :NaN, :( (SpecialFunctions.hankelh2x($ν - 1, $x) - SpecialFunctions.hankelh2x($ν + 1, $x)) / 2 + im * SpecialFunctions.hankelh2x($ν, $x) ) @define_diffrule SpecialFunctions.polygamma(m, x) = - :(oftype($x, NaN)), :( SpecialFunctions.polygamma($m + 1, $x) ) + :NaN, :( SpecialFunctions.polygamma($m + 1, $x) ) @define_diffrule SpecialFunctions.beta(a, b) = :( SpecialFunctions.beta($a, $b)*(SpecialFunctions.digamma($a) - SpecialFunctions.digamma($a + $b)) ), :( SpecialFunctions.beta($a, $b)*(SpecialFunctions.digamma($b) - SpecialFunctions.digamma($a + $b)) ) @@ -238,7 +238,7 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) # derivative wrt to `s` is not implemented @define_diffrule SpecialFunctions.zeta(s, z) = - :(oftype($z, NaN)), :( - $s * SpecialFunctions.zeta($s + 1, $z) ) + :NaN, :( - $s * SpecialFunctions.zeta($s + 1, $z) ) # ternary # #---------# From 6828599ce945d19fb4fb5d44f3bd0aabc2568a1f Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Fri, 4 Nov 2022 10:14:34 -0400 Subject: [PATCH 3/4] remove float guard --- src/rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index 3dc1866..464b0e0 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -85,7 +85,7 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule Base.atan(x, y) = :( $y / ($x^2 + $y^2) ), :( -$x / ($x^2 + $y^2) ) @define_diffrule Base.hypot(x, y) = :( $x / hypot($x, $y) ), :( $y / hypot($x, $y) ) @define_diffrule Base.log(b, x) = :( log($x) * inv(-log($b)^2 * $b) ), :( inv($x) / log($b) ) -@define_diffrule Base.ldexp(x, y) = :( oftype(float($x), exp2($y)) ), :(oftype(float($x), NaN)) +@define_diffrule Base.ldexp(x, y) = :( oftype($x, exp2($y)) ), :(oftype(float($x), NaN)) @define_diffrule Base.mod(x, y) = :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), one(float(z))) ), :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), -floor(float(z))) ) @define_diffrule Base.rem(x, y) = :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), one(float(z))) ), :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), -trunc(float(z))) ) From 5767bca5c8635c4c5c01c72dc57fc766987d7d75 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sat, 5 Nov 2022 15:38:22 -0400 Subject: [PATCH 4/4] Removed spurious NaN --- src/rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index 464b0e0..07e5524 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -85,7 +85,7 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule Base.atan(x, y) = :( $y / ($x^2 + $y^2) ), :( -$x / ($x^2 + $y^2) ) @define_diffrule Base.hypot(x, y) = :( $x / hypot($x, $y) ), :( $y / hypot($x, $y) ) @define_diffrule Base.log(b, x) = :( log($x) * inv(-log($b)^2 * $b) ), :( inv($x) / log($b) ) -@define_diffrule Base.ldexp(x, y) = :( oftype($x, exp2($y)) ), :(oftype(float($x), NaN)) +@define_diffrule Base.ldexp(x, y) = :( oftype($x, exp2($y)) ), :NaN @define_diffrule Base.mod(x, y) = :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), one(float(z))) ), :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), -floor(float(z))) ) @define_diffrule Base.rem(x, y) = :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), one(float(z))) ), :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), -trunc(float(z))) )