diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 94289d502..7c43c8710 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -54,21 +54,4 @@ include("rulesets/LinearAlgebra/factorization.jl") include("rulesets/Random/random.jl") -# Note: The following is only required because package authors sometimes do not -# declare their own rules using `ChainRulesCore.jl`. For arguably good reasons. -# So we define them here for them. -function __init__() - @require NaNMath="77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" begin - include("rulesets/packages/NaNMath.jl") - end - - # Note: drop SpecialFunctions dependency in next breaking release - # https://github.com/JuliaDiff/ChainRules.jl/issues/319 - @require SpecialFunctions="276daf66-3868-5448-9aa4-cd146d93841b" begin - if !isdefined(SpecialFunctions, :ChainRulesCore) - include("rulesets/packages/SpecialFunctions.jl") - end - end -end - end # module diff --git a/src/rulesets/packages/NaNMath.jl b/src/rulesets/packages/NaNMath.jl deleted file mode 100644 index 56aea6055..000000000 --- a/src/rulesets/packages/NaNMath.jl +++ /dev/null @@ -1,38 +0,0 @@ -@scalar_rule(NaNMath.sin(x), NaNMath.cos(x)) -@scalar_rule(NaNMath.cos(x), -NaNMath.sin(x)) -@scalar_rule(NaNMath.asin(x), inv(NaNMath.sqrt(1 - NaNMath.pow(x, 2)))) -@scalar_rule(NaNMath.acos(x), -inv(NaNMath.sqrt(1 - NaNMath.pow(x, 2)))) -@scalar_rule(NaNMath.acosh(x), inv(NaNMath.sqrt(NaNMath.pow(x, 2) - 1))) -@scalar_rule(NaNMath.tan(x), 1 + Ω^2) -@scalar_rule(NaNMath.atanh(x), inv(1 - NaNMath.pow(x, 2))) -@scalar_rule(NaNMath.log(x), inv(x)) -@scalar_rule(NaNMath.log2(x), inv(x) / NaNMath.log(oftype(x, 2))) -@scalar_rule(NaNMath.log10(x), inv(x) / NaNMath.log(oftype(x, 10))) -@scalar_rule(NaNMath.log1p(x), inv(x + 1)) -@scalar_rule(NaNMath.lgamma(x), SpecialFunctions.digamma(x)) -@scalar_rule(NaNMath.sqrt(x), inv(2 * Ω)) -@scalar_rule(NaNMath.pow(x, y), (y * NaNMath.pow(x, y - 1), Ω * NaNMath.log(x))) -@scalar_rule( - NaNMath.max(x, y), - (ifelse( - (y > x) | (signbit(y) < signbit(x)), - ifelse(isnan(y), true, ZeroTangent()), - ifelse(isnan(x), ZeroTangent(), true)), - ifelse( - (y > x) | (signbit(y) < signbit(x)), - ifelse(isnan(y), ZeroTangent(), true), - ifelse(isnan(x), true, ZeroTangent())), - ) -) -@scalar_rule( - NaNMath.min(x, y), - (ifelse( - (y < x) | (signbit(y) > signbit(x)), - ifelse(isnan(y), true, ZeroTangent()), - ifelse(isnan(x), ZeroTangent(), true)), - ifelse( - (y < x) | (signbit(y) > signbit(x)), - ifelse(isnan(y), ZeroTangent(), true), - ifelse(isnan(x), true, ZeroTangent())), - ) -) diff --git a/src/rulesets/packages/README.md b/src/rulesets/packages/README.md deleted file mode 100644 index 6e989d902..000000000 --- a/src/rulesets/packages/README.md +++ /dev/null @@ -1,7 +0,0 @@ -## Package Glue Code - -In the ideal world, everyone would write ChainRules for their functions -in the packages where they are defined. -By depending only on [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) -We do not live in an ideal world, so some of those definitions live here. -In the long-term the plan is to move them out of this repo. diff --git a/src/rulesets/packages/SpecialFunctions.jl b/src/rulesets/packages/SpecialFunctions.jl deleted file mode 100644 index 55f15cf82..000000000 --- a/src/rulesets/packages/SpecialFunctions.jl +++ /dev/null @@ -1,123 +0,0 @@ -const BESSEL_ORDER_INFO = """ -derivatives of Bessel functions with respect to the order are not implemented currently: -https://github.com/JuliaMath/SpecialFunctions.jl/issues/160 -""" - -@scalar_rule(SpecialFunctions.airyai(x), SpecialFunctions.airyaiprime(x)) -@scalar_rule(SpecialFunctions.airyaiprime(x), x * SpecialFunctions.airyai(x)) -@scalar_rule(SpecialFunctions.airybi(x), SpecialFunctions.airybiprime(x)) -@scalar_rule(SpecialFunctions.airybiprime(x), x * SpecialFunctions.airybi(x)) -@scalar_rule(SpecialFunctions.besselj0(x), -SpecialFunctions.besselj1(x)) -@scalar_rule( - SpecialFunctions.besselj1(x), - (SpecialFunctions.besselj0(x) - SpecialFunctions.besselj(2, x)) / 2, -) -@scalar_rule(SpecialFunctions.bessely0(x), -SpecialFunctions.bessely1(x)) -@scalar_rule( - SpecialFunctions.bessely1(x), - (SpecialFunctions.bessely0(x) - SpecialFunctions.bessely(2, x)) / 2, -) -@scalar_rule(SpecialFunctions.dawson(x), 1 - (2 * x * Ω)) -@scalar_rule(SpecialFunctions.digamma(x), SpecialFunctions.trigamma(x)) -@scalar_rule(SpecialFunctions.erf(x), (2 / sqrt(π)) * exp(-x * x)) -@scalar_rule(SpecialFunctions.erfc(x), -(2 / sqrt(π)) * exp(-x * x)) -@scalar_rule(SpecialFunctions.erfcinv(x), -(sqrt(π) / 2) * exp(Ω^2)) -@scalar_rule(SpecialFunctions.erfcx(x), (2 * x * Ω) - (2 / sqrt(π))) -@scalar_rule(SpecialFunctions.erfi(x), (2 / sqrt(π)) * exp(x * x)) -@scalar_rule(SpecialFunctions.erfinv(x), (sqrt(π) / 2) * exp(Ω^2)) -@scalar_rule(SpecialFunctions.gamma(x), Ω * SpecialFunctions.digamma(x)) -@scalar_rule( - SpecialFunctions.invdigamma(x), - inv(SpecialFunctions.trigamma(SpecialFunctions.invdigamma(x))), -) -@scalar_rule(SpecialFunctions.trigamma(x), SpecialFunctions.polygamma(2, x)) - -# binary -@scalar_rule( - SpecialFunctions.besselj(ν, x), - ( - @not_implemented(BESSEL_ORDER_INFO), - (SpecialFunctions.besselj(ν - 1, x) - SpecialFunctions.besselj(ν + 1, x)) / 2 - ), -) -@scalar_rule( - SpecialFunctions.besseli(ν, x), - ( - @not_implemented(BESSEL_ORDER_INFO), - (SpecialFunctions.besseli(ν - 1, x) + SpecialFunctions.besseli(ν + 1, x)) / 2, - ), -) -@scalar_rule( - SpecialFunctions.bessely(ν, x), - ( - @not_implemented(BESSEL_ORDER_INFO), - (SpecialFunctions.bessely(ν - 1, x) - SpecialFunctions.bessely(ν + 1, x)) / 2, - ), -) -@scalar_rule( - SpecialFunctions.besselk(ν, x), - ( - @not_implemented(BESSEL_ORDER_INFO), - -(SpecialFunctions.besselk(ν - 1, x) + SpecialFunctions.besselk(ν + 1, x)) / 2, - ), -) -@scalar_rule( - SpecialFunctions.hankelh1(ν, x), - ( - @not_implemented(BESSEL_ORDER_INFO), - (SpecialFunctions.hankelh1(ν - 1, x) - SpecialFunctions.hankelh1(ν + 1, x)) / 2, - ), -) -@scalar_rule( - SpecialFunctions.hankelh2(ν, x), - ( - @not_implemented(BESSEL_ORDER_INFO), - (SpecialFunctions.hankelh2(ν - 1, x) - SpecialFunctions.hankelh2(ν + 1, x)) / 2, - ), -) -@scalar_rule( - SpecialFunctions.polygamma(m, x), - ( - NoTangent(), - SpecialFunctions.polygamma(m + 1, x), - ), -) -# todo: setup for common expr -@scalar_rule( - SpecialFunctions.beta(a, b), - (Ω*(SpecialFunctions.digamma(a) - SpecialFunctions.digamma(a + b)), - Ω*(SpecialFunctions.digamma(b) - SpecialFunctions.digamma(a + b)),) -) - -# Changes between SpecialFunctions 0.7 and 0.8 -if isdefined(SpecialFunctions, :lgamma) - # actually is the absolute value of the logorithm of gamma - @scalar_rule(SpecialFunctions.lgamma(x), SpecialFunctions.digamma(x)) -end - -if isdefined(SpecialFunctions, :logabsgamma) - # actually is the absolute value of the logorithm of gamma paired with sign gamma - @scalar_rule(SpecialFunctions.logabsgamma(x), SpecialFunctions.digamma(x), ZeroTangent()) -end - -if isdefined(SpecialFunctions, :loggamma) - @scalar_rule(SpecialFunctions.loggamma(x), SpecialFunctions.digamma(x)) -end - -if isdefined(SpecialFunctions, :lbeta) - # todo: setup for common expr - @scalar_rule( - SpecialFunctions.lbeta(a, b), - (SpecialFunctions.digamma(a) - SpecialFunctions.digamma(a + b), - SpecialFunctions.digamma(b) - SpecialFunctions.digamma(a + b),) - ) -end - -if isdefined(SpecialFunctions, :logbeta) - # todo: setup for common expr - @scalar_rule( - SpecialFunctions.logbeta(a, b), - (SpecialFunctions.digamma(a) - SpecialFunctions.digamma(a + b), - SpecialFunctions.digamma(b) - SpecialFunctions.digamma(a + b),) - ) -end diff --git a/test/rulesets/packages/NaNMath.jl b/test/rulesets/packages/NaNMath.jl deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/rulesets/packages/SpecialFunctions.jl b/test/rulesets/packages/SpecialFunctions.jl deleted file mode 100644 index 6cd9c8c0b..000000000 --- a/test/rulesets/packages/SpecialFunctions.jl +++ /dev/null @@ -1,109 +0,0 @@ -@testset "general: single input" begin - for x in (1.0, -1.0, 0.0, 0.5, 10.0, -17.1, 1.5 + 0.7im) - test_scalar(SpecialFunctions.erf, x) - test_scalar(SpecialFunctions.erfc, x) - test_scalar(SpecialFunctions.erfi, x) - - test_scalar(SpecialFunctions.airyai, x) - test_scalar(SpecialFunctions.airyaiprime, x) - test_scalar(SpecialFunctions.airybi, x) - test_scalar(SpecialFunctions.airybiprime, x) - - test_scalar(SpecialFunctions.erfcx, x) - test_scalar(SpecialFunctions.dawson, x) - - if x isa Real - test_scalar(SpecialFunctions.invdigamma, x) - end - - if x isa Real && 0 < x < 1 - test_scalar(SpecialFunctions.erfinv, x) - test_scalar(SpecialFunctions.erfcinv, x) - end - - if x isa Real && x > 0 || x isa Complex - test_scalar(SpecialFunctions.gamma, x) - test_scalar(SpecialFunctions.digamma, x) - test_scalar(SpecialFunctions.trigamma, x) - end - end -end - -@testset "Bessel functions" begin - for x in (1.5, 2.5, 10.5, -0.6, -2.6, -3.3, 1.6 + 1.6im, 1.6 - 1.6im, -4.6 + 1.6im) - test_scalar(SpecialFunctions.besselj0, x) - test_scalar(SpecialFunctions.besselj1, x) - - isreal(x) && x < 0 && continue - - test_scalar(SpecialFunctions.bessely0, x) - test_scalar(SpecialFunctions.bessely1, x) - - for nu in (-1.5, 2.2, 4.0) - test_frule(SpecialFunctions.besseli, nu, x) - test_rrule(SpecialFunctions.besseli, nu, x) - - test_frule(SpecialFunctions.besselj, nu, x) - test_rrule(SpecialFunctions.besselj, nu, x) - - test_frule(SpecialFunctions.besselk, nu, x) - test_rrule(SpecialFunctions.besselk, nu, x) - - test_frule(SpecialFunctions.bessely, nu, x) - test_rrule(SpecialFunctions.bessely, nu, x) - - # use complex numbers in `rrule` for FiniteDifferences - test_frule(SpecialFunctions.hankelh1, nu, x) - test_rrule(SpecialFunctions.hankelh1, nu, complex(x)) - - # use complex numbers in `rrule` for FiniteDifferences - test_frule(SpecialFunctions.hankelh2, nu, x) - test_rrule(SpecialFunctions.hankelh2, nu, complex(x)) - end - end -end - -@testset "beta and logbeta" begin - test_points = (1.5, 2.5, 10.5, 1.6 + 1.6im, 1.6 - 1.6im, 4.6 + 1.6im) - for _x in test_points, _y in test_points - # ensure all complex if any complex for FiniteDifferences - x, y = promote(_x, _y) - test_frule(SpecialFunctions.beta, x, y) - test_rrule(SpecialFunctions.beta, x, y) - - if isdefined(SpecialFunctions, :lbeta) - test_frule(SpecialFunctions.lbeta, x, y) - test_rrule(SpecialFunctions.lbeta, x, y) - end - - if isdefined(SpecialFunctions, :logbeta) - test_frule(SpecialFunctions.logbeta, x, y) - test_rrule(SpecialFunctions.logbeta, x, y) - end - end -end - -@testset "log gamma and co" begin - # It is important that we have negative numbers with both odd and even integer parts - for x in (1.5, 2.5, 10.5, -0.6, -2.6, -3.3, 1.6 + 1.6im, 1.6 - 1.6im, -4.6 + 1.6im) - for m in (0, 1, 2, 3) - test_frule(SpecialFunctions.polygamma, m, x) - test_rrule(SpecialFunctions.polygamma, m, x) - end - - if isdefined(SpecialFunctions, :lgamma) - test_scalar(SpecialFunctions.lgamma, x) - end - - if isdefined(SpecialFunctions, :loggamma) - isreal(x) && x < 0 && continue - test_scalar(SpecialFunctions.loggamma, x) - end - - if isdefined(SpecialFunctions, :logabsgamma) - isreal(x) || continue - test_frule(SpecialFunctions.logabsgamma, x) - test_rrule(SpecialFunctions.logabsgamma, x; output_tangent=(randn(), randn())) - end - end -end diff --git a/test/runtests.jl b/test/runtests.jl index 420892ed1..2adb2f02f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -57,15 +57,5 @@ println("Testing ChainRules.jl") include_test("rulesets/Random/random.jl") end println() - - @testset "packages" begin - include_test("rulesets/packages/NaNMath.jl") - # Note: drop SpecialFunctions dependency in next breaking release - # https://github.com/JuliaDiff/ChainRules.jl/issues/319 - if !isdefined(SpecialFunctions, :ChainRulesCore) - include_test("rulesets/packages/SpecialFunctions.jl") - end - end - println() end end