Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial finite differencing testing #14

Merged
merged 10 commits into from
Apr 17, 2019
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@ uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"

[deps]
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
FDM = "e25cca7e-83ef-51fa-be6c-dfe2a3123128"
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[compat]
Cassette = "^0.2"
FDM = "^0.4.0"
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
julia = "^1.0"

[extras]
FDM = "e25cca7e-83ef-51fa-be6c-dfe2a3123128"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Test", "FDM"]
6 changes: 5 additions & 1 deletion src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,11 @@ Note that this function internally calls `Base.Broadcast.materialize!(Δ, ...)`.

See also: [`accumulate`](@ref), [`store!`](@ref), [`AbstractRule`](@ref)
"""
accumulate!(Δ, rule::AbstractRule, args...) = materialize!(Δ, broadcastable(add(cast(Δ), rule(args...))))
function accumulate!(Δ, rule::AbstractRule, args...)
return materialize!(Δ, broadcastable(add(cast(Δ), rule(args...))))
end

accumulate!(Δ::Number, rule::AbstractRule, args...) = accumulate(Δ, rule, args...)

"""
store!(Δ, rule::AbstractRule, args...)
Expand Down
4 changes: 2 additions & 2 deletions src/rules/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ end

function rrule(::typeof(inv), x::AbstractArray)
Ω = inv(x)
m = @thunk(-Ω)
return Ω, Rule(ΔΩ -> m' * ΔΩ * Ω')
m = @thunk(-Ω')
return Ω, Rule(ΔΩ -> m * ΔΩ * Ω')
end

#####
Expand Down
46 changes: 46 additions & 0 deletions test/differentials.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
@testset "Differentials" begin
@testset "Wirtinger" begin
w = Wirtinger(1+1im, 2+2im)
@test wirtinger_primal(w) == 1+1im
@test wirtinger_conjugate(w) == 2+2im
@test add_wirtinger(w, w) == Wirtinger(2+2im, 4+4im)
# TODO: other add_wirtinger methods stack overflow
@test_throws ErrorException mul_wirtinger(w, w)
@test_throws ErrorException extern(w)
for x in w
@test x === w
end
@test broadcastable(w) == w
@test_throws ErrorException conj(w)
end
@testset "Zero" begin
z = Zero()
@test extern(z) === false
@test add_zero(z, z) == z
@test add_zero(z, 1) == 1
@test add_zero(1, z) == 1
@test mul_zero(z, z) == z
@test mul_zero(z, 1) == z
@test mul_zero(1, z) == z
for x in z
@test x === z
end
@test broadcastable(z) isa Ref{Zero}
@test conj(z) == z
end
@testset "One" begin
o = One()
@test extern(o) === true
@test add_one(o, o) == 2
@test add_one(o, 1) == 2
@test add_one(1, o) == 2
@test mul_one(o, o) == o
@test mul_one(o, 1) == 1
@test mul_one(1, o) == 1
for x in o
@test x === o
end
@test broadcastable(o) isa Ref{One}
@test conj(o) == o
end
end
24 changes: 24 additions & 0 deletions test/rules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
cool(x) = x + 1

@testset "rules" begin
@testset "frule and rrule" begin
@test frule(cool, 1) === nothing
@test rrule(cool, 1) === nothing
ChainRules.@scalar_rule(Main.cool(x), one(x))
frx, fr = frule(cool, 1)
@test frx == 2
@test fr(1) == 1
rrx, rr = rrule(cool, 1)
@test rrx == 2
@test rr(1) == 1
end
@testset "iterating rules" begin
_, rule = frule(+, 1)
i = 0
for r in rule
@test r === rule
i += 1
end
@test i == 1 # rules only iterate once, yielding themselves
end
end
140 changes: 87 additions & 53 deletions test/rules/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,60 +8,94 @@ function test_scalar(f, f′, xs...)
end
end

@testset "Trig" begin
@testset "Basics" for x = (Float64(π), Complex(π, π/2))
test_scalar(sin, cos, x)
test_scalar(cos, x -> -sin(x), x)
test_scalar(tan, x -> 1 + tan(x)^2, x)
test_scalar(sec, x -> sec(x) * tan(x), x)
test_scalar(csc, x -> -csc(x) * cot(x), x)
test_scalar(cot, x -> -1 - cot(x)^2, x)
test_scalar(sinpi, x -> π * cospi(x), x)
test_scalar(cospi, x -> -π * sinpi(x), x)
@testset "base" begin
@testset "Trig" begin
@testset "Basics" for x = (Float64(π), Complex(π, π/2))
test_scalar(sin, cos, x)
test_scalar(cos, x -> -sin(x), x)
test_scalar(tan, x -> 1 + tan(x)^2, x)
test_scalar(sec, x -> sec(x) * tan(x), x)
test_scalar(csc, x -> -csc(x) * cot(x), x)
test_scalar(cot, x -> -1 - cot(x)^2, x)
test_scalar(sinpi, x -> π * cospi(x), x)
test_scalar(cospi, x -> -π * sinpi(x), x)
end
@testset "Hyperbolic" for x = (Float64(π), Complex(π, π/2))
test_scalar(sinh, cosh, x)
test_scalar(cosh, sinh, x)
test_scalar(tanh, x -> sech(x)^2, x)
test_scalar(sech, x -> -tanh(x) * sech(x), x)
test_scalar(csch, x -> -coth(x) * csch(x), x)
test_scalar(coth, x -> -csch(x)^2, x)
end
@testset "Degrees" begin
x = 45.0
test_scalar(sind, x -> (π / 180) * cosd(x), x)
test_scalar(cosd, x -> (-π / 180) * sind(x), x)
test_scalar(tand, x -> (π / 180) * (1 + tand(x)^2), x)
test_scalar(secd, x -> (π / 180) * secd(x) * tand(x), x)
test_scalar(cscd, x -> (-π / 180) * cscd(x) * cotd(x), x)
test_scalar(cotd, x -> (-π / 180) * (1 + cotd(x)^2), x)
end
@testset "Inverses" for x = (1.0, Complex(1.0, 0.25))
test_scalar(asin, x -> 1 / sqrt(1 - x^2), x)
test_scalar(acos, x -> -1 / sqrt(1 - x^2), x)
test_scalar(atan, x -> 1 / (1 + x^2), x)
test_scalar(asec, x -> 1 / (abs(x) * sqrt(x^2 - 1)), x)
test_scalar(acsc, x -> -1 / (abs(x) * sqrt(x^2 - 1)), x)
test_scalar(acot, x -> -1 / (1 + x^2), x)
end
@testset "Inverse hyperbolic" for x = (0.0, Complex(0.0, 0.25))
test_scalar(asinh, x -> 1 / sqrt(x^2 + 1), x)
test_scalar(acosh, x -> 1 / sqrt(x^2 - 1), x + 1) # +1 accounts for domain
test_scalar(atanh, x -> 1 / (1 - x^2), x)
test_scalar(asech, x -> -1 / x / sqrt(1 - x^2), x)
test_scalar(acsch, x -> -1 / abs(x) / sqrt(1 + x^2), x)
test_scalar(acoth, x -> 1 / (1 - x^2), x + 1)
end
@testset "Inverse degrees" begin
x = 1.0
test_scalar(asind, x -> 180 / π / sqrt(1 - x^2), x)
test_scalar(acosd, x -> -180 / π / sqrt(1 - x^2), x)
test_scalar(atand, x -> 180 / π / (1 + x^2), x)
test_scalar(asecd, x -> 180 / π / abs(x) / sqrt(x^2 - 1), x)
test_scalar(acscd, x -> -180 / π / abs(x) / sqrt(x^2 - 1), x)
test_scalar(acotd, x -> -180 / π / (1 + x^2), x)
end
# TODO: atan2 sincos
end
@testset "Hyperbolic" for x = (Float64(π), Complex(π, π/2))
test_scalar(sinh, cosh, x)
test_scalar(cosh, sinh, x)
test_scalar(tanh, x -> sech(x)^2, x)
test_scalar(sech, x -> -tanh(x) * sech(x), x)
test_scalar(csch, x -> -coth(x) * csch(x), x)
test_scalar(coth, x -> -csch(x)^2, x)
end
@testset "Degrees" begin
x = 45.0
test_scalar(sind, x -> (π / 180) * cosd(x), x)
test_scalar(cosd, x -> (-π / 180) * sind(x), x)
test_scalar(tand, x -> (π / 180) * (1 + tand(x)^2), x)
test_scalar(secd, x -> (π / 180) * secd(x) * tand(x), x)
test_scalar(cscd, x -> (-π / 180) * cscd(x) * cotd(x), x)
test_scalar(cotd, x -> (-π / 180) * (1 + cotd(x)^2), x)
end
@testset "Inverses" for x = (1.0, Complex(1.0, 0.25))
test_scalar(asin, x -> 1 / sqrt(1 - x^2), x)
test_scalar(acos, x -> -1 / sqrt(1 - x^2), x)
test_scalar(atan, x -> 1 / (1 + x^2), x)
test_scalar(asec, x -> 1 / (abs(x) * sqrt(x^2 - 1)), x)
test_scalar(acsc, x -> -1 / (abs(x) * sqrt(x^2 - 1)), x)
test_scalar(acot, x -> -1 / (1 + x^2), x)
end
@testset "Inverse hyperbolic" for x = (0.0, Complex(0.0, 0.25))
test_scalar(asinh, x -> 1 / sqrt(x^2 + 1), x)
test_scalar(acosh, x -> 1 / sqrt(x^2 - 1), x + 1) # +1 accounts for domain
test_scalar(atanh, x -> 1 / (1 - x^2), x)
test_scalar(asech, x -> -1 / x / sqrt(1 - x^2), x)
test_scalar(acsch, x -> -1 / abs(x) / sqrt(1 + x^2), x)
test_scalar(acoth, x -> 1 / (1 - x^2), x + 1)
end
@testset "Inverse degrees" begin
x = 1.0
test_scalar(asind, x -> 180 / π / sqrt(1 - x^2), x)
test_scalar(acosd, x -> -180 / π / sqrt(1 - x^2), x)
test_scalar(atand, x -> 180 / π / (1 + x^2), x)
test_scalar(asecd, x -> 180 / π / abs(x) / sqrt(x^2 - 1), x)
test_scalar(acscd, x -> -180 / π / abs(x) / sqrt(x^2 - 1), x)
test_scalar(acotd, x -> -180 / π / (1 + x^2), x)
@testset "Misc. Tests" begin
@testset "*(x, y)" begin
x, y = rand(3, 2), rand(2, 5)
z, (dx, dy) = rrule(*, x, y)

@test z == x * y

z̄ = rand(3, 5)

@test dx(z̄) == extern(accumulate(zeros(3, 2), dx, z̄))
@test dy(z̄) == extern(accumulate(zeros(2, 5), dy, z̄))

test_adjoint!(rand(3, 2), dx, z̄, z̄ * y')
test_adjoint!(rand(2, 5), dy, z̄, x' * z̄)
end
@testset "hypot(x, y)" begin
x, y = rand(2)
h, dxy = frule(hypot, x, y)

@test extern(dxy(One(), Zero())) === y / h
@test extern(dxy(Zero(), One())) === x / h

cx, cy = cast((One(), Zero())), cast((Zero(), One()))
dx, dy = extern(dxy(cx, cy))
@test dx === y / h
@test dy === x / h

cx, cy = cast((rand(), Zero())), cast((Zero(), rand()))
dx, dy = extern(dxy(cx, cy))
@test dx === y / h * cx.value[1]
@test dy === x / h * cy.value[2]
end
end
# TODO: atan2 sincos
end

# TODO: Non-trig stuff
20 changes: 20 additions & 0 deletions test/rules/broadcast.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
@testset "broadcast" begin
@testset "Misc. Tests" begin
@testset "sin.(x)" begin
x = rand(3, 3)
y, (dsin, dx) = rrule(broadcast, sin, x)

@test y == sin.(x)
@test extern(dx(One())) == cos.(x)

x̄, ȳ = rand(), rand()
@test extern(accumulate(x̄, dx, ȳ)) == x̄ .+ ȳ .* cos.(x)

x̄, ȳ = Zero(), rand(3, 3)
@test extern(accumulate(x̄, dx, ȳ)) == ȳ .* cos.(x)

x̄, ȳ = Zero(), cast(rand(3, 3))
@test extern(accumulate(x̄, dx, ȳ)) == extern(ȳ) .* cos.(x)
end
end
end
73 changes: 73 additions & 0 deletions test/rules/linalg.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
function generate_well_conditioned_matrix(rng, N)
A = randn(rng, N, N)
return A * A' + I
end

@testset "linalg" begin
@testset "sum" begin
@testset "Vector" begin
rng, M = MersenneTwister(123456), 3
frule_test(sum, (randn(rng, M), randn(rng, M)))
rrule_test(sum, randn(rng), (randn(rng, M), randn(rng, M)))
end
@testset "Matrix" begin
rng, M, N = MersenneTwister(123456), 3, 4
frule_test(sum, (randn(rng, M, N), randn(rng, M, N)))
rrule_test(sum, randn(rng), (randn(rng, M, N), randn(rng, M, N)))
end
@testset "Array{T, 3}" begin
rng, M, N, P = MersenneTwister(123456), 3, 7, 11
frule_test(sum, (randn(rng, M, N, P), randn(rng, M, N, P)))
rrule_test(sum, randn(rng), (randn(rng, M, N, P), randn(rng, M, N, P)))
end
end
@testset "dot" begin
@testset "Vector" begin
rng, M = MersenneTwister(123456), 3
x, y = randn(rng, M), randn(rng, M)
ẋ, ẏ = randn(rng, M), randn(rng, M)
x̄, ȳ = randn(rng, M), randn(rng, M)
frule_test(dot, (x, ẋ), (y, ẏ))
rrule_test(dot, randn(rng), (x, x̄), (y, ȳ))
end
@testset "Matrix" begin
rng, M, N = MersenneTwister(123456), 3, 4
x, y = randn(rng, M, N), randn(rng, M, N)
ẋ, ẏ = randn(rng, M, N), randn(rng, M, N)
x̄, ȳ = randn(rng, M, N), randn(rng, M, N)
frule_test(dot, (x, ẋ), (y, ẏ))
rrule_test(dot, randn(rng), (x, x̄), (y, ȳ))
end
@testset "Array{T, 3}" begin
rng, M, N, P = MersenneTwister(123456), 3, 4, 5
x, y = randn(rng, M, N, P), randn(rng, M, N, P)
ẋ, ẏ = randn(rng, M, N, P), randn(rng, M, N, P)
x̄, ȳ = randn(rng, M, N, P), randn(rng, M, N, P)
frule_test(dot, (x, ẋ), (y, ẏ))
rrule_test(dot, randn(rng), (x, x̄), (y, ȳ))
end
end
@testset "inv" begin
rng, N = MersenneTwister(123456), 3
B = generate_well_conditioned_matrix(rng, N)
frule_test(inv, (B, randn(rng, N, N)))
rrule_test(inv, randn(rng, N, N), (B, randn(rng, N, N)))
end
@testset "det" begin
rng, N = MersenneTwister(123456), 3
B = generate_well_conditioned_matrix(rng, N)
frule_test(det, (B, randn(rng, N, N)))
rrule_test(det, randn(rng), (B, randn(rng, N, N)))
end
@testset "logdet" begin
rng, N = MersenneTwister(123456), 3
B = generate_well_conditioned_matrix(rng, N)
frule_test(logdet, (B, randn(rng, N, N)))
rrule_test(logdet, randn(rng), (B, randn(rng, N, N)))
end
@testset "tr" begin
rng, N = MersenneTwister(123456), 4
frule_test(tr, (randn(rng, N, N), randn(rng, N, N)))
rrule_test(tr, randn(rng), (randn(rng, N, N), randn(rng, N, N)))
end
end
Loading