Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: adding more rules and fixing old ones #80

Merged
merged 15 commits into from
Dec 24, 2019
11 changes: 6 additions & 5 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
@scalar_rule(acotd(x), -oftype(x, 180) / π / (1 + x^2))
@scalar_rule(sinh(x), cosh(x))
@scalar_rule(cosh(x), sinh(x))
@scalar_rule(tanh(x), sech(x)^2)
@scalar_rule(tanh(x), 1-Ω^2)
YingboMa marked this conversation as resolved.
Show resolved Hide resolved
@scalar_rule(coth(x), -(csch(x)^2))
@scalar_rule(asinh(x), inv(sqrt(x^2 + 1)))
@scalar_rule(acosh(x), inv(sqrt(x^2 - 1)))
Expand Down Expand Up @@ -67,16 +67,17 @@
@scalar_rule(atan(x, y), @setup(u = x^2 + y^2), (y / u, -x / u))
@scalar_rule(max(x, y), @setup(gt = x > y), (gt, !gt))
@scalar_rule(min(x, y), @setup(gt = x > y), (!gt, gt))
@scalar_rule(mod(x, y), @setup((u, nan) = promote(x / y, NaN16)),
@scalar_rule(mod(x, y), @setup((u, nan) = promote(x / y, NaN16), isint = isinteger(x / y)),
(ifelse(isint, nan, one(u)), ifelse(isint, nan, -floor(u))))
@scalar_rule(rem(x, y), @setup((u, nan) = promote(x / y, NaN16)),
@scalar_rule(rem(x, y), @setup((u, nan) = promote(x / y, NaN16), isint = isinteger(x / y)),
(ifelse(isint, nan, one(u)), ifelse(isint, nan, -trunc(u))))
@scalar_rule(fma(x, y, z), (y, x, One()))
oxinabox marked this conversation as resolved.
Show resolved Hide resolved

# product rule requires special care for arguments where `mul` is non-commutative

frule(::typeof(*), x, y) = x * y, Rule((Δx, Δy) -> Δx * y + x * Δy)
frule(::typeof(*), x::Number, y::Number) = x * y, Rule((Δx, Δy) -> Δx * y + x * Δy)

rrule(::typeof(*), x, y) = x * y, (Rule(ΔΩ -> ΔΩ * y'), Rule(ΔΩ -> x' * ΔΩ))
rrule(::typeof(*), x::Number, y::Number) = x * y, (Rule(ΔΩ -> ΔΩ * y'), Rule(ΔΩ -> x' * ΔΩ))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's important to think about the complex case here too. All other rrules don't return the complex conjugate of the derivative, so either the conjugate transposes should be changed to unconjugated transposes, or this convention needs to be changed everywhere else


frule(::typeof(identity), x) = x, Rule(identity)

Expand Down
9 changes: 9 additions & 0 deletions src/rulesets/packages/NaNMath.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
module NaNMathGlue
using ChainRulesCore
using NaNMath
using SpecialFunctions
oxinabox marked this conversation as resolved.
Show resolved Hide resolved

@scalar_rule(NaNMath.sin(x), NaNMath.cos(x))
@scalar_rule(NaNMath.cos(x), -NaNMath.sin(x))
@scalar_rule(NaNMath.tan(x), 1 + NaNMath.pow(NaNMath.tan(x), 2))
YingboMa marked this conversation as resolved.
Show resolved Hide resolved
@scalar_rule(NaNMath.asin(x), inv(NaNMath.sqrt(1 - NaNMath.pow(x, 2))))
@scalar_rule(NaNMath.acos(x), -inv(NaNMath.sqrt(1 - NaNMath.pow(x, 2))))
@scalar_rule(NaNMath.acosh(x), inv(NaNMath.sqrt(NaNMath.pow(x, 2) - 1)))
Expand All @@ -15,5 +17,12 @@ using NaNMath
@scalar_rule(NaNMath.lgamma(x), SpecialFunctions.digamma(x))
@scalar_rule(NaNMath.sqrt(x), inv(2 * Ω))
@scalar_rule(NaNMath.pow(x, y), (y * NaNMath.pow(x, y - 1), Ω * NaNMath.log(x)))
@scalar_rule(NaNMath.max(x, y),
mattBrzezinski marked this conversation as resolved.
Show resolved Hide resolved
(ifelse((y > x) | (signbit(y) < signbit(x)), ifelse(isnan(y), One(), Zero()), ifelse(isnan(x), Zero(), One())),
ifelse((y > x) | (signbit(y) < signbit(x)), ifelse(isnan(y), Zero(), One()), ifelse(isnan(x), One(), Zero()))))

@scalar_rule(NaNMath.min(x, y),
(ifelse((y < x) | (signbit(y) > signbit(x)), ifelse(isnan(y), One(), Zero()), ifelse(isnan(x), Zero(), One())),
ifelse((y < x) | (signbit(y) > signbit(x)), ifelse(isnan(y), Zero(), One()), ifelse(isnan(x), One(), Zero()))))

end #module
40 changes: 40 additions & 0 deletions src/rulesets/packages/SpecialFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,44 @@ using SpecialFunctions
@scalar_rule(SpecialFunctions.erfcx(x), (2 * x * Ω) - (2 / sqrt(π)))
@scalar_rule(SpecialFunctions.dawson(x), 1 - (2 * x * Ω))

# binary
@scalar_rule(SpecialFunctions.besselj(ν, x),
(NaN,
(SpecialFunctions.besselj(ν - 1, x) -
SpecialFunctions.besselj(ν + 1, x)) / 2))

@scalar_rule(SpecialFunctions.besseli(ν, x),
(NaN,
(SpecialFunctions.besseli(ν - 1, x) +
SpecialFunctions.besseli(ν + 1, x)) / 2))
@scalar_rule(SpecialFunctions.bessely(ν, x),
(NaN,
(SpecialFunctions.bessely(ν - 1, x) -
SpecialFunctions.bessely(ν + 1, x)) / 2))

@scalar_rule(SpecialFunctions.besselk(ν, x),
(NaN,
-(SpecialFunctions.besselk(ν - 1, x) +
SpecialFunctions.besselk(ν + 1, x)) / 2))

@scalar_rule(SpecialFunctions.hankelh1(ν, x),
(NaN,
(SpecialFunctions.hankelh1(ν - 1, x) -
SpecialFunctions.hankelh1(ν + 1, x)) / 2))
@scalar_rule(SpecialFunctions.hankelh2(ν, x),
(NaN,
(SpecialFunctions.hankelh2(ν - 1, x) -
SpecialFunctions.hankelh2(ν + 1, x)) / 2))

@scalar_rule(SpecialFunctions.polygamma(m, x),
(NaN, SpecialFunctions.polygamma(m + 1, x)))

# todo: setup for common expr
@scalar_rule(SpecialFunctions.beta(a, b),
(Ω*(SpecialFunctions.digamma(a) - SpecialFunctions.digamma(a + b)),
Ω*(SpecialFunctions.digamma(b) - SpecialFunctions.digamma(a + b))))

@scalar_rule(SpecialFunctions.lbeta(a, b),
(SpecialFunctions.digamma(a) - SpecialFunctions.digamma(a + b),
SpecialFunctions.digamma(b) - SpecialFunctions.digamma(a + b)))
end #module