diff --git a/Project.toml b/Project.toml index 3857967c9..4db509f8f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.2.3" +version = "0.2.4" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -13,7 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" ChainRulesCore = "0.4" FiniteDifferences = "^0.7" Reexport = "0.2" -Requires = "0.5.2" +Requires = "0.5.2, 1" julia = "^1.0" [extras] diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index e3c446a59..39e2ab1e7 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -36,7 +36,7 @@ @scalar_rule(sinh(x), cosh(x)) @scalar_rule(cosh(x), sinh(x)) -@scalar_rule(tanh(x), sech(x)^2) +@scalar_rule(tanh(x), 1-Ω^2) @scalar_rule(coth(x), -(csch(x)^2)) @scalar_rule(asinh(x), inv(sqrt(x^2 + 1))) @@ -66,7 +66,7 @@ @scalar_rule(-(x, y), (One(), -1)) @scalar_rule(/(x, y), (inv(y), -(x / y / y))) @scalar_rule(\(x, y), (-(y / x / x), inv(x))) -@scalar_rule(^(x, y), (y * x^(y - 1), Ω * log(x))) +@scalar_rule(^(x, y), (ifelse(iszero(y), zero(Ω), y * x^(y - 1)), Ω * log(x))) @scalar_rule(inv(x), -Ω^2) @scalar_rule(sqrt(x), inv(2 * Ω)) @@ -92,10 +92,12 @@ @scalar_rule(max(x, y), @setup(gt = x > y), (gt, !gt)) @scalar_rule(min(x, y), @setup(gt = x > y), (!gt, gt)) -@scalar_rule(mod(x, y), @setup((u, nan) = promote(x / y, NaN16)), +@scalar_rule(mod(x, y), @setup((u, nan) = promote(x / y, NaN16), isint = isinteger(x / y)), (ifelse(isint, nan, one(u)), ifelse(isint, nan, -floor(u)))) -@scalar_rule(rem(x, y), @setup((u, nan) = promote(x / y, NaN16)), +@scalar_rule(rem(x, y), @setup((u, nan) = promote(x / y, NaN16), isint = isinteger(x / y)), (ifelse(isint, nan, one(u)), ifelse(isint, nan, -trunc(u)))) +@scalar_rule(fma(x, y, z), (y, x, One())) +@scalar_rule(muladd(x, y, z), (y, x, One())) @scalar_rule(angle(x::Complex), @setup(u = abs2(x)), Wirtinger(-im//2 * x' / u, im//2 * x / u)) @scalar_rule(angle(x::Real), Zero()) @scalar_rule(real(x::Complex), Wirtinger(1//2, 1//2)) diff --git a/src/rulesets/packages/NaNMath.jl b/src/rulesets/packages/NaNMath.jl index 028d2655a..e427e2630 100644 --- a/src/rulesets/packages/NaNMath.jl +++ b/src/rulesets/packages/NaNMath.jl @@ -1,6 +1,7 @@ module NaNMathGlue using ChainRulesCore using ..NaNMath +using ..SpecialFunctions @scalar_rule(NaNMath.sin(x), NaNMath.cos(x)) @scalar_rule(NaNMath.cos(x), -NaNMath.sin(x)) @@ -15,5 +16,11 @@ using ..NaNMath @scalar_rule(NaNMath.lgamma(x), SpecialFunctions.digamma(x)) @scalar_rule(NaNMath.sqrt(x), inv(2 * Ω)) @scalar_rule(NaNMath.pow(x, y), (y * NaNMath.pow(x, y - 1), Ω * NaNMath.log(x))) +@scalar_rule(NaNMath.max(x, y), + (ifelse((y > x) | (signbit(y) < signbit(x)), ifelse(isnan(y), One(), Zero()), ifelse(isnan(x), Zero(), One())), + ifelse((y > x) | (signbit(y) < signbit(x)), ifelse(isnan(y), Zero(), One()), ifelse(isnan(x), One(), Zero())))) +@scalar_rule(NaNMath.min(x, y), + (ifelse((y < x) | (signbit(y) > signbit(x)), ifelse(isnan(y), One(), Zero()), ifelse(isnan(x), Zero(), One())), + ifelse((y < x) | (signbit(y) > signbit(x)), ifelse(isnan(y), Zero(), One()), ifelse(isnan(x), One(), Zero())))) end #module diff --git a/src/rulesets/packages/SpecialFunctions.jl b/src/rulesets/packages/SpecialFunctions.jl index dc0a108c2..092da1eba 100644 --- a/src/rulesets/packages/SpecialFunctions.jl +++ b/src/rulesets/packages/SpecialFunctions.jl @@ -23,6 +23,47 @@ using ..SpecialFunctions @scalar_rule(SpecialFunctions.erfcx(x), (2 * x * Ω) - (2 / sqrt(π))) @scalar_rule(SpecialFunctions.dawson(x), 1 - (2 * x * Ω)) +# binary +@scalar_rule(SpecialFunctions.besselj(ν, x), + (NaN, + (SpecialFunctions.besselj(ν - 1, x) - + SpecialFunctions.besselj(ν + 1, x)) / 2)) + +@scalar_rule(SpecialFunctions.besseli(ν, x), + (NaN, + (SpecialFunctions.besseli(ν - 1, x) + + SpecialFunctions.besseli(ν + 1, x)) / 2)) +@scalar_rule(SpecialFunctions.bessely(ν, x), + (NaN, + (SpecialFunctions.bessely(ν - 1, x) - + SpecialFunctions.bessely(ν + 1, x)) / 2)) + +@scalar_rule(SpecialFunctions.besselk(ν, x), + (NaN, + -(SpecialFunctions.besselk(ν - 1, x) + + SpecialFunctions.besselk(ν + 1, x)) / 2)) + +@scalar_rule(SpecialFunctions.hankelh1(ν, x), + (NaN, + (SpecialFunctions.hankelh1(ν - 1, x) - + SpecialFunctions.hankelh1(ν + 1, x)) / 2)) +@scalar_rule(SpecialFunctions.hankelh2(ν, x), + (NaN, + (SpecialFunctions.hankelh2(ν - 1, x) - + SpecialFunctions.hankelh2(ν + 1, x)) / 2)) + +@scalar_rule(SpecialFunctions.polygamma(m, x), + (NaN, SpecialFunctions.polygamma(m + 1, x))) + +# todo: setup for common expr +@scalar_rule(SpecialFunctions.beta(a, b), + (Ω*(SpecialFunctions.digamma(a) - SpecialFunctions.digamma(a + b)), + Ω*(SpecialFunctions.digamma(b) - SpecialFunctions.digamma(a + b)))) + +@scalar_rule(SpecialFunctions.lbeta(a, b), + (SpecialFunctions.digamma(a) - SpecialFunctions.digamma(a + b), + SpecialFunctions.digamma(b) - SpecialFunctions.digamma(a + b))) + # Changes between SpecialFunctions 0.7 and 0.8 if isdefined(SpecialFunctions, :lgamma) # actually is the absolute value of the logorithm of gamma diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 492617512..5c7ba917b 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -137,11 +137,11 @@ test_accumulation(rand(2, 5), dy) end - @testset "binary trig ($f)" for f in (hypot, atan) + @testset "binary function ($f)" for f in (hypot, atan, mod, rem, ^) rng = MersenneTwister(123456) - x, Δx, x̄ = 10randn(rng, 3) - y, Δy, ȳ = randn(rng, 3) - Δz = randn(rng) + x, Δx, x̄ = 10rand(rng, 3) + y, Δy, ȳ = rand(rng, 3) + Δz = rand(rng) frule_test(f, (x, Δx), (y, Δy)) rrule_test(f, Δz, (x, x̄), (y, ȳ)) @@ -176,4 +176,15 @@ @test extern(ẏ) == 0 end end + + @testset "trinary ($f)" for f in (muladd, fma) + rng = MersenneTwister(123456) + x, Δx, x̄ = 10randn(rng, 3) + y, Δy, ȳ = randn(rng, 3) + z, Δz, z̄ = randn(rng, 3) + Δk = randn(rng) + + frule_test(f, (x, Δx), (y, Δy), (z, Δz)) + rrule_test(f, Δk, (x, x̄), (y, ȳ), (z, z̄)) + end end