From d88cce36628f6a84f7983546281783848365097e Mon Sep 17 00:00:00 2001 From: ShashiGowda Date: Fri, 9 Aug 2019 14:24:42 -0400 Subject: [PATCH 01/12] fma --- src/rulesets/Base/base.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index ec5e5d0bd..4e77a994f 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -71,6 +71,7 @@ (ifelse(isint, nan, one(u)), ifelse(isint, nan, -floor(u)))) @scalar_rule(rem(x, y), @setup((u, nan) = promote(x / y, NaN16)), (ifelse(isint, nan, one(u)), ifelse(isint, nan, -trunc(u)))) +@scalar_rule(fma(x, y, z), (y, x, One())) # product rule requires special care for arguments where `mul` is non-commutative From c04a5d88c32931454c597ade4ed2f09a0260409b Mon Sep 17 00:00:00 2001 From: ShashiGowda Date: Fri, 9 Aug 2019 16:59:04 -0400 Subject: [PATCH 02/12] add a few NaNMath rules --- src/rulesets/packages/NaNMath.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/rulesets/packages/NaNMath.jl b/src/rulesets/packages/NaNMath.jl index 301edee7c..48c183acc 100644 --- a/src/rulesets/packages/NaNMath.jl +++ b/src/rulesets/packages/NaNMath.jl @@ -15,5 +15,12 @@ using NaNMath @scalar_rule(NaNMath.lgamma(x), SpecialFunctions.digamma(x)) @scalar_rule(NaNMath.sqrt(x), inv(2 * Ω)) @scalar_rule(NaNMath.pow(x, y), (y * NaNMath.pow(x, y - 1), Ω * NaNMath.log(x))) +@scalar_rule(NaNMath.max(x, y), + (ifelse((y > x) | (signbit(y) < signbit(x)), ifelse(isnan(y), One(), Zero()), ifelse(isnan(x), Zero(), One())), + ifelse((y > x) | (signbit(y) < signbit(x)), ifelse(isnan(y), Zero(), One()), ifelse(isnan(x), One(), Zero())))) + +@scalar_rule(NaNMath.min(x, y), + (ifelse((y < x) | (signbit(y) > signbit(x)), ifelse(isnan(y), One(), Zero()), ifelse(isnan(x), Zero(), One())), + ifelse((y < x) | (signbit(y) > signbit(x)), ifelse(isnan(y), Zero(), One()), ifelse(isnan(x), One(), Zero())))) end #module From d59f15e30aede16a5bce36fcf881f641973ac8db Mon Sep 17 00:00:00 2001 From: ShashiGowda Date: Fri, 9 Aug 2019 16:59:15 -0400 Subject: [PATCH 03/12] add some SpecialFunctions rules --- src/rulesets/packages/SpecialFunctions.jl | 40 +++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/src/rulesets/packages/SpecialFunctions.jl b/src/rulesets/packages/SpecialFunctions.jl index cec2395f8..f62d8fa4d 100644 --- a/src/rulesets/packages/SpecialFunctions.jl +++ b/src/rulesets/packages/SpecialFunctions.jl @@ -24,4 +24,44 @@ using SpecialFunctions @scalar_rule(SpecialFunctions.erfcx(x), (2 * x * Ω) - (2 / sqrt(π))) @scalar_rule(SpecialFunctions.dawson(x), 1 - (2 * x * Ω)) +# binary +@scalar_rule(SpecialFunctions.besselj(ν, x), + (NaN, + (SpecialFunctions.besselj(ν - 1, x) - + SpecialFunctions.besselj(ν + 1, x)) / 2)) + +@scalar_rule(SpecialFunctions.besseli(ν, x), + (NaN, + (SpecialFunctions.besseli(ν - 1, x) + + SpecialFunctions.besseli(ν + 1, x)) / 2)) +@scalar_rule(SpecialFunctions.bessely(ν, x), + (NaN, + (SpecialFunctions.bessely(ν - 1, x) - + SpecialFunctions.bessely(ν + 1, x)) / 2)) + +@scalar_rule(SpecialFunctions.besselk(ν, x), + (NaN, + (SpecialFunctions.besselk(ν - 1, x) + + SpecialFunctions.besselk(ν + 1, x)) / 2)) + +@scalar_rule(SpecialFunctions.hankelh1(ν, x), + (NaN, + (SpecialFunctions.hankelh1(ν - 1, x) - + SpecialFunctions.hankelh1(ν + 1, x)) / 2)) +@scalar_rule(SpecialFunctions.hankelh2(ν, x), + (NaN, + (SpecialFunctions.hankelh2(ν - 1, x) - + SpecialFunctions.hankelh2(ν + 1, x)) / 2)) + +@scalar_rule(SpecialFunctions.polygamma(m, x), + (NaN, SpecialFunctions.polygamma(m + 1, x))) + +# todo: setup for common expr +@scalar_rule(SpecialFunctions.beta(a, b), + (Ω*(SpecialFunctions.digamma(a) - SpecialFunctions.digamma(a + b)), + Ω*(SpecialFunctions.digamma(b) - SpecialFunctions.digamma(a + b)))) + +@scalar_rule(SpecialFunctions.lbeta(a, b), + (SpecialFunctions.digamma(a) - SpecialFunctions.digamma(a + b), + SpecialFunctions.digamma(b) - SpecialFunctions.digamma(a + b))) end #module From 02664b0b04dd63b6655a39d0e19641f1c86517cd Mon Sep 17 00:00:00 2001 From: ShashiGowda Date: Fri, 9 Aug 2019 18:03:14 -0400 Subject: [PATCH 04/12] isint wasn't defined --- src/rulesets/Base/base.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 4e77a994f..4047e57a2 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -67,9 +67,9 @@ @scalar_rule(atan(y, x), @setup(u = hypot(x, y)), (x / u, y / u)) @scalar_rule(max(x, y), @setup(gt = x > y), (gt, !gt)) @scalar_rule(min(x, y), @setup(gt = x > y), (!gt, gt)) -@scalar_rule(mod(x, y), @setup((u, nan) = promote(x / y, NaN16)), +@scalar_rule(mod(x, y), @setup((u, nan) = promote(x / y, NaN16), isint = isinteger(x / y)), (ifelse(isint, nan, one(u)), ifelse(isint, nan, -floor(u)))) -@scalar_rule(rem(x, y), @setup((u, nan) = promote(x / y, NaN16)), +@scalar_rule(rem(x, y), @setup((u, nan) = promote(x / y, NaN16), isint = isinteger(x / y)), (ifelse(isint, nan, one(u)), ifelse(isint, nan, -trunc(u)))) @scalar_rule(fma(x, y, z), (y, x, One())) From 2f9cdbef592ae30be5fd35c7028cf19a31396867 Mon Sep 17 00:00:00 2001 From: ShashiGowda Date: Fri, 9 Aug 2019 18:03:36 -0400 Subject: [PATCH 05/12] fixes --- src/rulesets/packages/NaNMath.jl | 2 ++ src/rulesets/packages/SpecialFunctions.jl | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/rulesets/packages/NaNMath.jl b/src/rulesets/packages/NaNMath.jl index 48c183acc..02979315a 100644 --- a/src/rulesets/packages/NaNMath.jl +++ b/src/rulesets/packages/NaNMath.jl @@ -1,9 +1,11 @@ module NaNMathGlue using ChainRulesCore using NaNMath +using SpecialFunctions @scalar_rule(NaNMath.sin(x), NaNMath.cos(x)) @scalar_rule(NaNMath.cos(x), -NaNMath.sin(x)) +@scalar_rule(NaNMath.tan(x), 1 + NaNMath.pow(NaNMath.tan(x), 2)) @scalar_rule(NaNMath.asin(x), inv(NaNMath.sqrt(1 - NaNMath.pow(x, 2)))) @scalar_rule(NaNMath.acos(x), -inv(NaNMath.sqrt(1 - NaNMath.pow(x, 2)))) @scalar_rule(NaNMath.acosh(x), inv(NaNMath.sqrt(NaNMath.pow(x, 2) - 1))) diff --git a/src/rulesets/packages/SpecialFunctions.jl b/src/rulesets/packages/SpecialFunctions.jl index f62d8fa4d..739917721 100644 --- a/src/rulesets/packages/SpecialFunctions.jl +++ b/src/rulesets/packages/SpecialFunctions.jl @@ -41,8 +41,8 @@ using SpecialFunctions @scalar_rule(SpecialFunctions.besselk(ν, x), (NaN, - (SpecialFunctions.besselk(ν - 1, x) + - SpecialFunctions.besselk(ν + 1, x)) / 2)) + -(SpecialFunctions.besselk(ν - 1, x) + + SpecialFunctions.besselk(ν + 1, x)) / 2)) @scalar_rule(SpecialFunctions.hankelh1(ν, x), (NaN, From 867362535d110e976a81e71a164af39681c471d3 Mon Sep 17 00:00:00 2001 From: ShashiGowda Date: Fri, 9 Aug 2019 18:13:47 -0400 Subject: [PATCH 06/12] optimize rule for tanh like in JuliaDiff/DiffRules.jl#4861e3 --- src/rulesets/Base/base.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 21f8d4990..004e49817 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -24,7 +24,7 @@ @scalar_rule(acotd(x), -oftype(x, 180) / π / (1 + x^2)) @scalar_rule(sinh(x), cosh(x)) @scalar_rule(cosh(x), sinh(x)) -@scalar_rule(tanh(x), sech(x)^2) +@scalar_rule(tanh(x), 1-Ω^2) @scalar_rule(coth(x), -(csch(x)^2)) @scalar_rule(asinh(x), inv(sqrt(x^2 + 1))) @scalar_rule(acosh(x), inv(sqrt(x^2 - 1))) From 8f077431ef8ccf701854bcfb039b3795e51a921b Mon Sep 17 00:00:00 2001 From: ShashiGowda Date: Mon, 12 Aug 2019 11:55:30 -0400 Subject: [PATCH 07/12] Fix pow. Co-authored-by: YingboMa --- src/rulesets/Base/base.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 004e49817..66d2d0f55 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -45,7 +45,7 @@ @scalar_rule(-(x, y), (One(), -1)) @scalar_rule(/(x, y), (inv(y), -(x / y / y))) @scalar_rule(\(x, y), (-(y / x / x), inv(x))) -@scalar_rule(^(x, y), (y * x^(y - 1), Ω * log(x))) +@scalar_rule(^(x, y), (ifelse(iszero(y), Zero(), y * x^(y - 1)), Ω * log(x))) @scalar_rule(inv(x), -abs2(Ω)) @scalar_rule(sqrt(x), inv(2 * Ω)) @scalar_rule(cbrt(x), inv(3 * Ω^2)) From 68e3e4e2e8493cfbb333c94fd95f90cfa6d64c93 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Sat, 16 Nov 2019 14:01:15 -0500 Subject: [PATCH 08/12] Remove tan from NaNMath.jl --- src/rulesets/packages/NaNMath.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/rulesets/packages/NaNMath.jl b/src/rulesets/packages/NaNMath.jl index 02979315a..a69c8271b 100644 --- a/src/rulesets/packages/NaNMath.jl +++ b/src/rulesets/packages/NaNMath.jl @@ -5,7 +5,6 @@ using SpecialFunctions @scalar_rule(NaNMath.sin(x), NaNMath.cos(x)) @scalar_rule(NaNMath.cos(x), -NaNMath.sin(x)) -@scalar_rule(NaNMath.tan(x), 1 + NaNMath.pow(NaNMath.tan(x), 2)) @scalar_rule(NaNMath.asin(x), inv(NaNMath.sqrt(1 - NaNMath.pow(x, 2)))) @scalar_rule(NaNMath.acos(x), -inv(NaNMath.sqrt(1 - NaNMath.pow(x, 2)))) @scalar_rule(NaNMath.acosh(x), inv(NaNMath.sqrt(NaNMath.pow(x, 2) - 1))) @@ -20,7 +19,6 @@ using SpecialFunctions @scalar_rule(NaNMath.max(x, y), (ifelse((y > x) | (signbit(y) < signbit(x)), ifelse(isnan(y), One(), Zero()), ifelse(isnan(x), Zero(), One())), ifelse((y > x) | (signbit(y) < signbit(x)), ifelse(isnan(y), Zero(), One()), ifelse(isnan(x), One(), Zero())))) - @scalar_rule(NaNMath.min(x, y), (ifelse((y < x) | (signbit(y) > signbit(x)), ifelse(isnan(y), One(), Zero()), ifelse(isnan(x), Zero(), One())), ifelse((y < x) | (signbit(y) > signbit(x)), ifelse(isnan(y), Zero(), One()), ifelse(isnan(x), One(), Zero())))) From 6369630cb197498fc7a14395eb21605b1753e547 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 18 Nov 2019 00:34:44 -0500 Subject: [PATCH 09/12] Fix `inv` and add `muladd` --- Project.toml | 4 ++-- src/rulesets/Base/base.jl | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 3ab7b7dd2..cdfaba6d8 100644 --- a/Project.toml +++ b/Project.toml @@ -10,10 +10,10 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -Reexport = "0.2" -Requires = "0.5.2" ChainRulesCore = "0.4" FiniteDifferences = "^0.7" +Reexport = "0.2" +Requires = "0.5.2" julia = "^1.0" [extras] diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index ed5540d6e..2d3995490 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -68,7 +68,7 @@ @scalar_rule(\(x, y), (-(y / x / x), inv(x))) @scalar_rule(^(x, y), (ifelse(iszero(y), Zero(), y * x^(y - 1)), Ω * log(x))) -@scalar_rule(inv(x), -abs2(Ω)) +@scalar_rule(inv(x), -Ω^2) @scalar_rule(sqrt(x), inv(2 * Ω)) @scalar_rule(cbrt(x), inv(3 * Ω^2)) @scalar_rule(exp(x), Ω) @@ -97,6 +97,7 @@ @scalar_rule(rem(x, y), @setup((u, nan) = promote(x / y, NaN16), isint = isinteger(x / y)), (ifelse(isint, nan, one(u)), ifelse(isint, nan, -trunc(u)))) @scalar_rule(fma(x, y, z), (y, x, One())) +@scalar_rule(muladd(x, y, z), (y, x, One())) @scalar_rule(angle(x::Complex), @setup(u = abs2(x)), Wirtinger(-im//2 * x' / u, im//2 * x / u)) @scalar_rule(angle(x::Real), Zero()) @scalar_rule(real(x::Complex), Wirtinger(1//2, 1//2)) From c4ba4982924cfca3713f2fc98ac42afdcfdb57a1 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 23 Dec 2019 12:44:40 -0500 Subject: [PATCH 10/12] Address code review comments --- Project.toml | 2 +- src/rulesets/Base/base.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 3857967c9..8e8b346a7 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" ChainRulesCore = "0.4" FiniteDifferences = "^0.7" Reexport = "0.2" -Requires = "0.5.2" +Requires = "0.5.2, 1" julia = "^1.0" [extras] diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 2d3995490..39e2ab1e7 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -66,7 +66,7 @@ @scalar_rule(-(x, y), (One(), -1)) @scalar_rule(/(x, y), (inv(y), -(x / y / y))) @scalar_rule(\(x, y), (-(y / x / x), inv(x))) -@scalar_rule(^(x, y), (ifelse(iszero(y), Zero(), y * x^(y - 1)), Ω * log(x))) +@scalar_rule(^(x, y), (ifelse(iszero(y), zero(Ω), y * x^(y - 1)), Ω * log(x))) @scalar_rule(inv(x), -Ω^2) @scalar_rule(sqrt(x), inv(2 * Ω)) From 9bdeb331e0d268f3968ebd26d19a4c602468b0f3 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 23 Dec 2019 14:02:33 -0500 Subject: [PATCH 11/12] Test new base rules --- test/rulesets/Base/base.jl | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 492617512..5c7ba917b 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -137,11 +137,11 @@ test_accumulation(rand(2, 5), dy) end - @testset "binary trig ($f)" for f in (hypot, atan) + @testset "binary function ($f)" for f in (hypot, atan, mod, rem, ^) rng = MersenneTwister(123456) - x, Δx, x̄ = 10randn(rng, 3) - y, Δy, ȳ = randn(rng, 3) - Δz = randn(rng) + x, Δx, x̄ = 10rand(rng, 3) + y, Δy, ȳ = rand(rng, 3) + Δz = rand(rng) frule_test(f, (x, Δx), (y, Δy)) rrule_test(f, Δz, (x, x̄), (y, ȳ)) @@ -176,4 +176,15 @@ @test extern(ẏ) == 0 end end + + @testset "trinary ($f)" for f in (muladd, fma) + rng = MersenneTwister(123456) + x, Δx, x̄ = 10randn(rng, 3) + y, Δy, ȳ = randn(rng, 3) + z, Δz, z̄ = randn(rng, 3) + Δk = randn(rng) + + frule_test(f, (x, Δx), (y, Δy), (z, Δz)) + rrule_test(f, Δk, (x, x̄), (y, ȳ), (z, z̄)) + end end From 5a59b025f2bab795c004790eeb6521cf488bb14d Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Tue, 24 Dec 2019 11:56:35 -0500 Subject: [PATCH 12/12] Patch version bump --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8e8b346a7..4db509f8f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.2.3" +version = "0.2.4" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"