Skip to content

Commit

Permalink
WIP: adding more rules and fixing old ones (#80)
Browse files Browse the repository at this point in the history
* fma

* add a few NaNMath rules

* add some SpecialFunctions rules

* isint wasn't defined

* fixes

* optimize rule for tanh like in JuliaDiff/DiffRules.jl#4861e3

* Fix pow.

Co-authored-by: YingboMa <mayingbo5@gmail.com>

* Remove tan from NaNMath.jl

* Fix `inv` and add `muladd`

* Address code review comments

* Test new base rules

* Patch version bump

Co-authored-by: Yingbo Ma <mayingbo5@gmail.com>
  • Loading branch information
shashi and YingboMa committed Dec 24, 2019
1 parent 90b08a4 commit 3a16dab
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 10 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.2.3"
version = "0.2.4"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -13,7 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
ChainRulesCore = "0.4"
FiniteDifferences = "^0.7"
Reexport = "0.2"
Requires = "0.5.2"
Requires = "0.5.2, 1"
julia = "^1.0"

[extras]
Expand Down
10 changes: 6 additions & 4 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

@scalar_rule(sinh(x), cosh(x))
@scalar_rule(cosh(x), sinh(x))
@scalar_rule(tanh(x), sech(x)^2)
@scalar_rule(tanh(x), 1-Ω^2)
@scalar_rule(coth(x), -(csch(x)^2))

@scalar_rule(asinh(x), inv(sqrt(x^2 + 1)))
Expand Down Expand Up @@ -66,7 +66,7 @@
@scalar_rule(-(x, y), (One(), -1))
@scalar_rule(/(x, y), (inv(y), -(x / y / y)))
@scalar_rule(\(x, y), (-(y / x / x), inv(x)))
@scalar_rule(^(x, y), (y * x^(y - 1), Ω * log(x)))
@scalar_rule(^(x, y), (ifelse(iszero(y), zero(Ω), y * x^(y - 1)), Ω * log(x)))

@scalar_rule(inv(x), -Ω^2)
@scalar_rule(sqrt(x), inv(2 * Ω))
Expand All @@ -92,10 +92,12 @@

@scalar_rule(max(x, y), @setup(gt = x > y), (gt, !gt))
@scalar_rule(min(x, y), @setup(gt = x > y), (!gt, gt))
@scalar_rule(mod(x, y), @setup((u, nan) = promote(x / y, NaN16)),
@scalar_rule(mod(x, y), @setup((u, nan) = promote(x / y, NaN16), isint = isinteger(x / y)),
(ifelse(isint, nan, one(u)), ifelse(isint, nan, -floor(u))))
@scalar_rule(rem(x, y), @setup((u, nan) = promote(x / y, NaN16)),
@scalar_rule(rem(x, y), @setup((u, nan) = promote(x / y, NaN16), isint = isinteger(x / y)),
(ifelse(isint, nan, one(u)), ifelse(isint, nan, -trunc(u))))
@scalar_rule(fma(x, y, z), (y, x, One()))
@scalar_rule(muladd(x, y, z), (y, x, One()))
@scalar_rule(angle(x::Complex), @setup(u = abs2(x)), Wirtinger(-im//2 * x' / u, im//2 * x / u))
@scalar_rule(angle(x::Real), Zero())
@scalar_rule(real(x::Complex), Wirtinger(1//2, 1//2))
Expand Down
7 changes: 7 additions & 0 deletions src/rulesets/packages/NaNMath.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module NaNMathGlue
using ChainRulesCore
using ..NaNMath
using ..SpecialFunctions

@scalar_rule(NaNMath.sin(x), NaNMath.cos(x))
@scalar_rule(NaNMath.cos(x), -NaNMath.sin(x))
Expand All @@ -15,5 +16,11 @@ using ..NaNMath
@scalar_rule(NaNMath.lgamma(x), SpecialFunctions.digamma(x))
@scalar_rule(NaNMath.sqrt(x), inv(2 * Ω))
@scalar_rule(NaNMath.pow(x, y), (y * NaNMath.pow(x, y - 1), Ω * NaNMath.log(x)))
@scalar_rule(NaNMath.max(x, y),
(ifelse((y > x) | (signbit(y) < signbit(x)), ifelse(isnan(y), One(), Zero()), ifelse(isnan(x), Zero(), One())),
ifelse((y > x) | (signbit(y) < signbit(x)), ifelse(isnan(y), Zero(), One()), ifelse(isnan(x), One(), Zero()))))
@scalar_rule(NaNMath.min(x, y),
(ifelse((y < x) | (signbit(y) > signbit(x)), ifelse(isnan(y), One(), Zero()), ifelse(isnan(x), Zero(), One())),
ifelse((y < x) | (signbit(y) > signbit(x)), ifelse(isnan(y), Zero(), One()), ifelse(isnan(x), One(), Zero()))))

end #module
41 changes: 41 additions & 0 deletions src/rulesets/packages/SpecialFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,47 @@ using ..SpecialFunctions
@scalar_rule(SpecialFunctions.erfcx(x), (2 * x * Ω) - (2 / sqrt(π)))
@scalar_rule(SpecialFunctions.dawson(x), 1 - (2 * x * Ω))

# binary
@scalar_rule(SpecialFunctions.besselj(ν, x),
(NaN,
(SpecialFunctions.besselj- 1, x) -
SpecialFunctions.besselj+ 1, x)) / 2))

@scalar_rule(SpecialFunctions.besseli(ν, x),
(NaN,
(SpecialFunctions.besseli- 1, x) +
SpecialFunctions.besseli+ 1, x)) / 2))
@scalar_rule(SpecialFunctions.bessely(ν, x),
(NaN,
(SpecialFunctions.bessely- 1, x) -
SpecialFunctions.bessely+ 1, x)) / 2))

@scalar_rule(SpecialFunctions.besselk(ν, x),
(NaN,
-(SpecialFunctions.besselk- 1, x) +
SpecialFunctions.besselk+ 1, x)) / 2))

@scalar_rule(SpecialFunctions.hankelh1(ν, x),
(NaN,
(SpecialFunctions.hankelh1- 1, x) -
SpecialFunctions.hankelh1+ 1, x)) / 2))
@scalar_rule(SpecialFunctions.hankelh2(ν, x),
(NaN,
(SpecialFunctions.hankelh2- 1, x) -
SpecialFunctions.hankelh2+ 1, x)) / 2))

@scalar_rule(SpecialFunctions.polygamma(m, x),
(NaN, SpecialFunctions.polygamma(m + 1, x)))

# todo: setup for common expr
@scalar_rule(SpecialFunctions.beta(a, b),
*(SpecialFunctions.digamma(a) - SpecialFunctions.digamma(a + b)),
Ω*(SpecialFunctions.digamma(b) - SpecialFunctions.digamma(a + b))))

@scalar_rule(SpecialFunctions.lbeta(a, b),
(SpecialFunctions.digamma(a) - SpecialFunctions.digamma(a + b),
SpecialFunctions.digamma(b) - SpecialFunctions.digamma(a + b)))

# Changes between SpecialFunctions 0.7 and 0.8
if isdefined(SpecialFunctions, :lgamma)
# actually is the absolute value of the logorithm of gamma
Expand Down
19 changes: 15 additions & 4 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,11 @@
test_accumulation(rand(2, 5), dy)
end

@testset "binary trig ($f)" for f in (hypot, atan)
@testset "binary function ($f)" for f in (hypot, atan, mod, rem, ^)
rng = MersenneTwister(123456)
x, Δx, x̄ = 10randn(rng, 3)
y, Δy, ȳ = randn(rng, 3)
Δz = randn(rng)
x, Δx, x̄ = 10rand(rng, 3)
y, Δy, ȳ = rand(rng, 3)
Δz = rand(rng)

frule_test(f, (x, Δx), (y, Δy))
rrule_test(f, Δz, (x, x̄), (y, ȳ))
Expand Down Expand Up @@ -176,4 +176,15 @@
@test extern(ẏ) == 0
end
end

@testset "trinary ($f)" for f in (muladd, fma)
rng = MersenneTwister(123456)
x, Δx, x̄ = 10randn(rng, 3)
y, Δy, ȳ = randn(rng, 3)
z, Δz, z̄ = randn(rng, 3)
Δk = randn(rng)

frule_test(f, (x, Δx), (y, Δy), (z, Δz))
rrule_test(f, Δk, (x, x̄), (y, ȳ), (z, z̄))
end
end

2 comments on commit 3a16dab

@YingboMa
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/7130

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if Julia TagBot is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.4 -m "<description of version>" 3a16dab0099bf33ec0b3e06198f7a4dde0abab63
git push origin v0.2.4

Please sign in to comment.