Skip to content

Commit

Permalink
Merge 3189653 into 9e4cb76
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Oct 18, 2019
2 parents 9e4cb76 + 3189653 commit 3674c38
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 32 deletions.
2 changes: 2 additions & 0 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
@scalar_rule(one(x), Zero())
@scalar_rule(zero(x), Zero())
@scalar_rule(sign(x), Zero())

@scalar_rule(abs2(x), Wirtinger(x', x))
@scalar_rule(log(x), inv(x))
@scalar_rule(log10(x), inv(x) / log(oftype(x, 10)))
Expand Down
16 changes: 15 additions & 1 deletion src/rulesets/packages/SpecialFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ using ChainRulesCore
using ..SpecialFunctions


@scalar_rule(SpecialFunctions.lgamma(x), SpecialFunctions.digamma(x))
@scalar_rule(SpecialFunctions.erf(x), (2 / sqrt(π)) * exp(-x * x))
@scalar_rule(SpecialFunctions.erfc(x), -(2 / sqrt(π)) * exp(-x * x))
@scalar_rule(SpecialFunctions.erfi(x), (2 / sqrt(π)) * exp(x * x))
Expand All @@ -24,4 +23,19 @@ using ..SpecialFunctions
@scalar_rule(SpecialFunctions.erfcx(x), (2 * x * Ω) - (2 / sqrt(π)))
@scalar_rule(SpecialFunctions.dawson(x), 1 - (2 * x * Ω))

# Changes between SpecialFunctions 0.7 and 0.8
if isdefined(SpecialFunctions, :lgamma)
# actually is the absolute value of the logorithm of gamma
@scalar_rule(SpecialFunctions.lgamma(x), SpecialFunctions.digamma(x))
end

if isdefined(SpecialFunctions, :logabsgamma)
# actually is the absolute value of the logorithm of gamma, paired with sign gamma
@scalar_rule(SpecialFunctions.logabsgamma(x), SpecialFunctions.digamma(x), Zero())
end

if isdefined(SpecialFunctions, :loggamma)
@scalar_rule(SpecialFunctions.loggamma(x), SpecialFunctions.digamma(x))
end

end #module
64 changes: 37 additions & 27 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@
test_scalar(acotd, 1/x)
end
@testset "Multivariate" begin
x, y = rand(2)
@testset "atan2" begin
# https://en.wikipedia.org/wiki/Atan2
x, y = rand(2)
ratan = atan(x, y)
u = x^2 + y^2
datan = y/u - 2x/u
Expand All @@ -71,19 +71,11 @@
end

@testset "sincos" begin
rsincos = sincos(x)
dsincos = cos(x) - 2sin(x)

r, pushforward = frule(sincos, x)
@test r === rsincos
df1, df2 = pushforward(NamedTuple(), 1)
@test df1 + 2df2 === dsincos

r, pullback = rrule(sincos, x)
@test r === rsincos
ds, df = pullback(1, 2)
@test df === dsincos
@test ds === NO_FIELDS
x, Δx, x̄ = randn(3)
Δz = (randn(), randn())

frule_test(sincos, (x, Δx))
rrule_test(sincos, Δz, (x, x̄))
end
end
end # Trig
Expand Down Expand Up @@ -114,17 +106,16 @@
end

@testset "Unary complex functions" begin
for x in (-6, rand.((Float32, Float64, Complex{Float32}, Complex{Float64}))...)
rtol = x isa Complex{Float32} ? 1e-6 : 1e-9
test_scalar(real, x; rtol=rtol)
test_scalar(imag, x; rtol=rtol)
for x in (-4.1, 6.4, 1.0+0.5im, -10.0+1.5im)
test_scalar(real, x)
test_scalar(imag, x)

test_scalar(abs, x; rtol=rtol)
test_scalar(hypot, x; rtol=rtol)
test_scalar(abs, x)
test_scalar(hypot, x)

test_scalar(angle, x; rtol=rtol)
test_scalar(abs2, x; rtol=rtol)
test_scalar(conj, x; rtol=rtol)
test_scalar(angle, x)
test_scalar(abs2, x)
test_scalar(conj, x)
end
end

Expand All @@ -146,14 +137,14 @@
test_accumulation(rand(2, 5), dy)
end

@testset "hypot(x, y)" begin
@testset "binary trig ($f)" for f in (hypot, atan)
rng = MersenneTwister(123456)
x, Δx, x̄ = randn(rng, 3)
x, Δx, x̄ = 10randn(rng, 3)
y, Δy, ȳ = randn(rng, 3)
Δz = randn(rng)

frule_test(hypot, (x, Δx), (y, Δy))
rrule_test(hypot, Δz, (x, x̄), (y, ȳ))
frule_test(f, (x, Δx), (y, Δy))
rrule_test(f, Δz, (x, x̄), (y, ȳ))
end

@testset "identity" begin
Expand All @@ -166,4 +157,23 @@
test_scalar(one, x)
test_scalar(zero, x)
end

@testset "sign" begin
@testset "at points" for x in (-1.1, -1.1, 0.5, 100)
test_scalar(sign, x)
end

@testset "Zero over the point discontinuity" begin
# Can't do finite differencing because we are lying
# following the subgradient convention.

_, pb = rrule(sign, 0.0)
_, x̄ = pb(10.5)
@test extern(x̄) == 0

_, pf = frule(sign, 0.0)
= pf(NamedTuple(), 10.5)
@test extern(ẏ) == 0
end
end
end
25 changes: 24 additions & 1 deletion test/rulesets/packages/SpecialFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,29 @@ using SpecialFunctions
test_scalar(SpecialFunctions.gamma, x)
test_scalar(SpecialFunctions.digamma, x)
test_scalar(SpecialFunctions.trigamma, x)
test_scalar(SpecialFunctions.lgamma, x)
end
end

# SpecialFunctions 0.7->0.8 changes:
@testset "log gamma and co" begin
#It is important that we have negative numbers with both odd and even integer parts
for x in (1.5, 2.5, 10.5, -0.6, -2.6, -3.3, 1.6+1.6im, 1.6-1.6im, -4.6+1.6im)
if isdefined(SpecialFunctions, :lgamma)
test_scalar(SpecialFunctions.lgamma, x)
end
if isdefined(SpecialFunctions, :loggamma)
isreal(x) && x < 0 && continue
test_scalar(SpecialFunctions.loggamma, x)
end

if isdefined(SpecialFunctions, :logabsgamma)
isreal(x) || continue

Δx, x̄ = randn(2)
Δz = (randn(), randn())

frule_test(SpecialFunctions.logabsgamma, (x, Δx))
rrule_test(SpecialFunctions.logabsgamma, Δz, (x, x̄))
end
end
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@ using ChainRulesCore: extern, accumulate, accumulate!, store!, @scalar_rule,
Wirtinger, wirtinger_primal, wirtinger_conjugate,
Zero, One, DNE, Thunk, AbstractDifferential

Random.seed!(1) # Set seed that all testsets should reset to.

include("test_util.jl")

println("Testing ChainRules.jl")
@testset "ChainRules" begin
include("helper_functions.jl")
@testset "rulesets" begin

@testset "Base" begin
include(joinpath("rulesets", "Base", "base.jl"))
include(joinpath("rulesets", "Base", "array.jl"))
Expand Down
20 changes: 17 additions & 3 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,13 @@ function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm

# Correctness testing via finite differencing.
dΩ_fd = jvp(fdm, xs->f(xs...), (xs, ẋs))
@test isapprox(dΩ_ad, dΩ_fd; rtol=rtol, atol=atol, kwargs...)
@test isapprox(
collect(dΩ_ad), # Use collect so can use vector equality
collect(dΩ_fd);
rtol=rtol,
atol=atol,
kwargs...
)
end


Expand All @@ -108,6 +114,7 @@ end
# Arguments
- `f`: Function to which rule should be applied.
- `ȳ`: adjoint w.r.t. output of `f` (should generally be set randomly).
Should be same structure as `f(x)` (so if multiple returns should be a tuple)
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
- `x̄`: currently accumulated adjoint (should generally be set randomly).
Expand All @@ -118,8 +125,15 @@ function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm

# Check correctness of evaluation.
fx, pullback = ChainRules.rrule(f, x)
@test fx f(x)
(∂self, x̄_ad) = pullback(ȳ)
@test collect(fx) collect(f(x)) # use collect so can do vector equality
(∂self, x̄_ad) = if fx isa Tuple
# If the function returned multiple values,
# then it must have multiple seeds for propagating backwards
pullback(ȳ...)
else
pullback(ȳ)
end

@test ∂self === NO_FIELDS # No internal fields
# Correctness testing via finite differencing.
x̄_fd = j′vp(fdm, f, ȳ, x)
Expand Down

0 comments on commit 3674c38

Please sign in to comment.